156 lines
5.6 KiB
Python
156 lines
5.6 KiB
Python
|
|
"""
|
|||
|
|
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
|||
|
|
"""
|
|||
|
|
import os
|
|||
|
|
from typing import List, Optional
|
|||
|
|
|
|||
|
|
from langchain.chains import ConversationChain
|
|||
|
|
from langchain.memory import ConversationBufferMemory
|
|||
|
|
from langchain.prompts import PromptTemplate
|
|||
|
|
from langchain_openai import ChatOpenAI
|
|||
|
|
|
|||
|
|
from .prompts import ConversationStage, get_conversation_prompt
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ConversationAgent:
|
|||
|
|
"""对话 Agent"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
# 初始化 LLM(使用环境变量配置)
|
|||
|
|
# 优先使用 LLM_API_KEY 和 LLM_BASE_URL,如果没有则使用 OPENAI_API_KEY
|
|||
|
|
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "")
|
|||
|
|
base_url = os.getenv("LLM_BASE_URL", "")
|
|||
|
|
model_name = os.getenv("OPENAI_MODEL", "gpt-4o")
|
|||
|
|
|
|||
|
|
if not api_key:
|
|||
|
|
self.llm = None
|
|||
|
|
self.memories: dict[str, ConversationBufferMemory] = {}
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 如果提供了 base_url,需要处理路径(langchain 会自动添加 /v1/chat/completions)
|
|||
|
|
llm_kwargs = {
|
|||
|
|
"temperature": 0.7,
|
|||
|
|
"model": model_name,
|
|||
|
|
"openai_api_key": api_key,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if base_url:
|
|||
|
|
# 移除可能的 /v1/chat/completions 路径,langchain 会自动添加
|
|||
|
|
if base_url.endswith("/v1/chat/completions"):
|
|||
|
|
base_url = base_url[:-20] # 移除 "/v1/chat/completions"
|
|||
|
|
elif base_url.endswith("/v1"):
|
|||
|
|
base_url = base_url[:-3] # 移除 "/v1"
|
|||
|
|
# 确保 base_url 以 / 结尾(如果没有)
|
|||
|
|
if base_url and not base_url.endswith("/"):
|
|||
|
|
base_url += "/"
|
|||
|
|
llm_kwargs["openai_api_base"] = base_url
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
self.llm = ChatOpenAI(**llm_kwargs)
|
|||
|
|
except Exception:
|
|||
|
|
self.llm = None
|
|||
|
|
|
|||
|
|
# 对话记忆
|
|||
|
|
self.memories: dict[str, ConversationBufferMemory] = {}
|
|||
|
|
|
|||
|
|
# 对话记忆
|
|||
|
|
self.memories: dict[str, ConversationBufferMemory] = {}
|
|||
|
|
|
|||
|
|
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
|
|||
|
|
"""获取或创建对话记忆"""
|
|||
|
|
if conversation_id not in self.memories:
|
|||
|
|
self.memories[conversation_id] = ConversationBufferMemory(
|
|||
|
|
return_messages=True,
|
|||
|
|
memory_key="history"
|
|||
|
|
)
|
|||
|
|
return self.memories[conversation_id]
|
|||
|
|
|
|||
|
|
def generate_response(
|
|||
|
|
self,
|
|||
|
|
conversation_id: str,
|
|||
|
|
user_message: str,
|
|||
|
|
current_stage: Optional[ConversationStage] = None,
|
|||
|
|
covered_topics: Optional[List[str]] = None
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
生成 Agent 回应
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conversation_id: 对话 ID
|
|||
|
|
user_message: 用户消息
|
|||
|
|
current_stage: 当前对话阶段
|
|||
|
|
covered_topics: 已聊过的话题列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Agent 回应文本
|
|||
|
|
"""
|
|||
|
|
if current_stage is None:
|
|||
|
|
current_stage = ConversationStage.CHILDHOOD
|
|||
|
|
|
|||
|
|
if covered_topics is None:
|
|||
|
|
covered_topics = []
|
|||
|
|
|
|||
|
|
# 如果没有配置 LLM,返回默认回应
|
|||
|
|
if not self.llm:
|
|||
|
|
return "抱歉,LLM 服务未配置。请设置 LLM_API_KEY 或 OPENAI_API_KEY 环境变量。"
|
|||
|
|
|
|||
|
|
# 获取系统提示词
|
|||
|
|
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
|||
|
|
|
|||
|
|
# 获取对话记忆
|
|||
|
|
memory = self._get_memory(conversation_id)
|
|||
|
|
|
|||
|
|
# 创建对话链
|
|||
|
|
prompt_template = PromptTemplate(
|
|||
|
|
input_variables=["history", "input"],
|
|||
|
|
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
chain = ConversationChain(
|
|||
|
|
llm=self.llm,
|
|||
|
|
prompt=prompt_template,
|
|||
|
|
memory=memory,
|
|||
|
|
verbose=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 生成回应
|
|||
|
|
response = chain.predict(input=user_message)
|
|||
|
|
|
|||
|
|
return response
|
|||
|
|
|
|||
|
|
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
|||
|
|
"""
|
|||
|
|
检测对话阶段
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conversation_id: 对话 ID
|
|||
|
|
user_message: 用户消息
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
检测到的对话阶段
|
|||
|
|
"""
|
|||
|
|
# 简单的关键词检测(实际应该使用更智能的方法)
|
|||
|
|
message_lower = user_message.lower()
|
|||
|
|
|
|||
|
|
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
|
|||
|
|
return ConversationStage.CHILDHOOD
|
|||
|
|
elif any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
|
|||
|
|
return ConversationStage.EDUCATION
|
|||
|
|
elif any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
|
|||
|
|
return ConversationStage.CAREER
|
|||
|
|
elif any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
|
|||
|
|
return ConversationStage.FAMILY
|
|||
|
|
elif any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
|
|||
|
|
return ConversationStage.BELIEFS
|
|||
|
|
elif any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
|
|||
|
|
return ConversationStage.SUMMARY
|
|||
|
|
else:
|
|||
|
|
# 默认返回当前阶段或童年阶段
|
|||
|
|
return ConversationStage.CHILDHOOD
|
|||
|
|
|
|||
|
|
def clear_memory(self, conversation_id: str):
|
|||
|
|
"""清除对话记忆"""
|
|||
|
|
if conversation_id in self.memories:
|
|||
|
|
del self.memories[conversation_id]
|
|||
|
|
|