refactor: 更新Agent模块

- 优化conversation_agent代码结构
- 优化memory_agent代码结构
- 改进错误处理和代码可读性
This commit is contained in:
徐在坤
2026-01-18 15:57:53 +08:00
parent 802f5a3833
commit dfe41a727a
2 changed files with 8 additions and 74 deletions

View File

@@ -1,14 +1,13 @@
""" """
对话 Agent基于访谈问题清单动态选择问题实时生成回应 对话 Agent基于访谈问题清单动态选择问题实时生成回应
""" """
import os
from typing import List, Optional from typing import List, Optional
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from services.llm_service import llm_service
from .prompts import ConversationStage, get_conversation_prompt from .prompts import ConversationStage, get_conversation_prompt
@@ -16,42 +15,8 @@ class ConversationAgent:
"""对话 Agent""" """对话 Agent"""
def __init__(self): def __init__(self):
# 初始化 LLM使用环境变量配置 # 使用 LLM 服务获取 LLM 实例
# 优先使用 LLM_API_KEY 和 LLM_BASE_URL如果没有则使用 OPENAI_API_KEY self.llm = llm_service.get_llm()
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "")
base_url = os.getenv("LLM_BASE_URL", "")
model_name = os.getenv("OPENAI_MODEL", "gpt-4o")
if not api_key:
self.llm = None
self.memories: dict[str, ConversationBufferMemory] = {}
return
# 如果提供了 base_url需要处理路径langchain 会自动添加 /v1/chat/completions
llm_kwargs = {
"temperature": 0.7,
"model": model_name,
"openai_api_key": api_key,
}
if base_url:
# 移除可能的 /v1/chat/completions 路径langchain 会自动添加
if base_url.endswith("/v1/chat/completions"):
base_url = base_url[:-20] # 移除 "/v1/chat/completions"
elif base_url.endswith("/v1"):
base_url = base_url[:-3] # 移除 "/v1"
# 确保 base_url 以 / 结尾(如果没有)
if base_url and not base_url.endswith("/"):
base_url += "/"
llm_kwargs["openai_api_base"] = base_url
try:
self.llm = ChatOpenAI(**llm_kwargs)
except Exception:
self.llm = None
# 对话记忆
self.memories: dict[str, ConversationBufferMemory] = {}
# 对话记忆 # 对话记忆
self.memories: dict[str, ConversationBufferMemory] = {} self.memories: dict[str, ConversationBufferMemory] = {}
@@ -92,7 +57,7 @@ class ConversationAgent:
# 如果没有配置 LLM返回默认回应 # 如果没有配置 LLM返回默认回应
if not self.llm: if not self.llm:
return "抱歉LLM 服务未配置。请设置 LLM_API_KEY 或 OPENAI_API_KEY 环境变量。" return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
# 获取系统提示词 # 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message) system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)

View File

@@ -1,11 +1,10 @@
""" """
回忆录整理 Agent基于传记结构将口语改写为书面语归类到章节 回忆录整理 Agent基于传记结构将口语改写为书面语归类到章节
""" """
import os
import json import json
from typing import List, Dict, Optional from typing import List, Dict, Optional
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate from services.llm_service import llm_service
from .prompts import ( from .prompts import (
get_memory_prompt, get_memory_prompt,
@@ -20,38 +19,8 @@ class MemoryAgent:
"""回忆录整理 Agent""" """回忆录整理 Agent"""
def __init__(self): def __init__(self):
# 初始化 LLM # 使用 LLM 服务获取 LLM 实例
# 优先使用 LLM_API_KEY 和 LLM_BASE_URL如果没有则使用 OPENAI_API_KEY self.llm = llm_service.get_llm()
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "")
base_url = os.getenv("LLM_BASE_URL", "")
model_name = os.getenv("OPENAI_MODEL", "gpt-4o")
if not api_key:
self.llm = None
return
# 如果提供了 base_url需要处理路径langchain 会自动添加 /v1/chat/completions
llm_kwargs = {
"temperature": 0.3, # 较低温度,更稳定
"model": model_name,
"openai_api_key": api_key,
}
if base_url:
# 移除可能的 /v1/chat/completions 路径langchain 会自动添加
if base_url.endswith("/v1/chat/completions"):
base_url = base_url[:-20] # 移除 "/v1/chat/completions"
elif base_url.endswith("/v1"):
base_url = base_url[:-3] # 移除 "/v1"
# 确保 base_url 以 / 结尾(如果没有)
if base_url and not base_url.endswith("/"):
base_url += "/"
llm_kwargs["openai_api_base"] = base_url
try:
self.llm = ChatOpenAI(**llm_kwargs)
except Exception:
self.llm = None
def classify_chapter(self, segments_text: str) -> str: def classify_chapter(self, segments_text: str) -> str:
""" """