refactor: 更新Agent模块
- 优化conversation_agent代码结构 - 优化memory_agent代码结构 - 改进错误处理和代码可读性
This commit is contained in:
@@ -1,14 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
|
from services.llm_service import llm_service
|
||||||
from .prompts import ConversationStage, get_conversation_prompt
|
from .prompts import ConversationStage, get_conversation_prompt
|
||||||
|
|
||||||
|
|
||||||
@@ -16,42 +15,8 @@ class ConversationAgent:
|
|||||||
"""对话 Agent"""
|
"""对话 Agent"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 初始化 LLM(使用环境变量配置)
|
# 使用 LLM 服务获取 LLM 实例
|
||||||
# 优先使用 LLM_API_KEY 和 LLM_BASE_URL,如果没有则使用 OPENAI_API_KEY
|
self.llm = llm_service.get_llm()
|
||||||
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "")
|
|
||||||
base_url = os.getenv("LLM_BASE_URL", "")
|
|
||||||
model_name = os.getenv("OPENAI_MODEL", "gpt-4o")
|
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
self.llm = None
|
|
||||||
self.memories: dict[str, ConversationBufferMemory] = {}
|
|
||||||
return
|
|
||||||
|
|
||||||
# 如果提供了 base_url,需要处理路径(langchain 会自动添加 /v1/chat/completions)
|
|
||||||
llm_kwargs = {
|
|
||||||
"temperature": 0.7,
|
|
||||||
"model": model_name,
|
|
||||||
"openai_api_key": api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
if base_url:
|
|
||||||
# 移除可能的 /v1/chat/completions 路径,langchain 会自动添加
|
|
||||||
if base_url.endswith("/v1/chat/completions"):
|
|
||||||
base_url = base_url[:-20] # 移除 "/v1/chat/completions"
|
|
||||||
elif base_url.endswith("/v1"):
|
|
||||||
base_url = base_url[:-3] # 移除 "/v1"
|
|
||||||
# 确保 base_url 以 / 结尾(如果没有)
|
|
||||||
if base_url and not base_url.endswith("/"):
|
|
||||||
base_url += "/"
|
|
||||||
llm_kwargs["openai_api_base"] = base_url
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.llm = ChatOpenAI(**llm_kwargs)
|
|
||||||
except Exception:
|
|
||||||
self.llm = None
|
|
||||||
|
|
||||||
# 对话记忆
|
|
||||||
self.memories: dict[str, ConversationBufferMemory] = {}
|
|
||||||
|
|
||||||
# 对话记忆
|
# 对话记忆
|
||||||
self.memories: dict[str, ConversationBufferMemory] = {}
|
self.memories: dict[str, ConversationBufferMemory] = {}
|
||||||
@@ -92,7 +57,7 @@ class ConversationAgent:
|
|||||||
|
|
||||||
# 如果没有配置 LLM,返回默认回应
|
# 如果没有配置 LLM,返回默认回应
|
||||||
if not self.llm:
|
if not self.llm:
|
||||||
return "抱歉,LLM 服务未配置。请设置 LLM_API_KEY 或 OPENAI_API_KEY 环境变量。"
|
return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||||||
|
|
||||||
# 获取系统提示词
|
# 获取系统提示词
|
||||||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节
|
回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain.prompts import PromptTemplate
|
from services.llm_service import llm_service
|
||||||
|
|
||||||
from .prompts import (
|
from .prompts import (
|
||||||
get_memory_prompt,
|
get_memory_prompt,
|
||||||
@@ -20,38 +19,8 @@ class MemoryAgent:
|
|||||||
"""回忆录整理 Agent"""
|
"""回忆录整理 Agent"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 初始化 LLM
|
# 使用 LLM 服务获取 LLM 实例
|
||||||
# 优先使用 LLM_API_KEY 和 LLM_BASE_URL,如果没有则使用 OPENAI_API_KEY
|
self.llm = llm_service.get_llm()
|
||||||
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "")
|
|
||||||
base_url = os.getenv("LLM_BASE_URL", "")
|
|
||||||
model_name = os.getenv("OPENAI_MODEL", "gpt-4o")
|
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
self.llm = None
|
|
||||||
return
|
|
||||||
|
|
||||||
# 如果提供了 base_url,需要处理路径(langchain 会自动添加 /v1/chat/completions)
|
|
||||||
llm_kwargs = {
|
|
||||||
"temperature": 0.3, # 较低温度,更稳定
|
|
||||||
"model": model_name,
|
|
||||||
"openai_api_key": api_key,
|
|
||||||
}
|
|
||||||
|
|
||||||
if base_url:
|
|
||||||
# 移除可能的 /v1/chat/completions 路径,langchain 会自动添加
|
|
||||||
if base_url.endswith("/v1/chat/completions"):
|
|
||||||
base_url = base_url[:-20] # 移除 "/v1/chat/completions"
|
|
||||||
elif base_url.endswith("/v1"):
|
|
||||||
base_url = base_url[:-3] # 移除 "/v1"
|
|
||||||
# 确保 base_url 以 / 结尾(如果没有)
|
|
||||||
if base_url and not base_url.endswith("/"):
|
|
||||||
base_url += "/"
|
|
||||||
llm_kwargs["openai_api_base"] = base_url
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.llm = ChatOpenAI(**llm_kwargs)
|
|
||||||
except Exception:
|
|
||||||
self.llm = None
|
|
||||||
|
|
||||||
def classify_chapter(self, segments_text: str) -> str:
|
def classify_chapter(self, segments_text: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user