""" 对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应 """ import os from typing import List, Optional from langchain.chains import ConversationChain from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from langchain_openai import ChatOpenAI from .prompts import ConversationStage, get_conversation_prompt class ConversationAgent: """对话 Agent""" def __init__(self): # 初始化 LLM(使用环境变量配置) # 优先使用 LLM_API_KEY 和 LLM_BASE_URL,如果没有则使用 OPENAI_API_KEY api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "") base_url = os.getenv("LLM_BASE_URL", "") model_name = os.getenv("OPENAI_MODEL", "gpt-4o") if not api_key: self.llm = None self.memories: dict[str, ConversationBufferMemory] = {} return # 如果提供了 base_url,需要处理路径(langchain 会自动添加 /v1/chat/completions) llm_kwargs = { "temperature": 0.7, "model": model_name, "openai_api_key": api_key, } if base_url: # 移除可能的 /v1/chat/completions 路径,langchain 会自动添加 if base_url.endswith("/v1/chat/completions"): base_url = base_url[:-20] # 移除 "/v1/chat/completions" elif base_url.endswith("/v1"): base_url = base_url[:-3] # 移除 "/v1" # 确保 base_url 以 / 结尾(如果没有) if base_url and not base_url.endswith("/"): base_url += "/" llm_kwargs["openai_api_base"] = base_url try: self.llm = ChatOpenAI(**llm_kwargs) except Exception: self.llm = None # 对话记忆 self.memories: dict[str, ConversationBufferMemory] = {} # 对话记忆 self.memories: dict[str, ConversationBufferMemory] = {} def _get_memory(self, conversation_id: str) -> ConversationBufferMemory: """获取或创建对话记忆""" if conversation_id not in self.memories: self.memories[conversation_id] = ConversationBufferMemory( return_messages=True, memory_key="history" ) return self.memories[conversation_id] def generate_response( self, conversation_id: str, user_message: str, current_stage: Optional[ConversationStage] = None, covered_topics: Optional[List[str]] = None ) -> str: """ 生成 Agent 回应 Args: conversation_id: 对话 ID user_message: 用户消息 current_stage: 当前对话阶段 covered_topics: 已聊过的话题列表 Returns: Agent 回应文本 """ if current_stage is None: current_stage = ConversationStage.CHILDHOOD if covered_topics is None: covered_topics = [] # 如果没有配置 LLM,返回默认回应 if not self.llm: return "抱歉,LLM 服务未配置。请设置 LLM_API_KEY 或 OPENAI_API_KEY 环境变量。" # 获取系统提示词 system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message) # 获取对话记忆 memory = self._get_memory(conversation_id) # 创建对话链 prompt_template = PromptTemplate( input_variables=["history", "input"], template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:" ) chain = ConversationChain( llm=self.llm, prompt=prompt_template, memory=memory, verbose=False ) # 生成回应 response = chain.predict(input=user_message) return response def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage: """ 检测对话阶段 Args: conversation_id: 对话 ID user_message: 用户消息 Returns: 检测到的对话阶段 """ # 简单的关键词检测(实际应该使用更智能的方法) message_lower = user_message.lower() if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]): return ConversationStage.CHILDHOOD elif any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]): return ConversationStage.EDUCATION elif any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]): return ConversationStage.CAREER elif any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]): return ConversationStage.FAMILY elif any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]): return ConversationStage.BELIEFS elif any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]): return ConversationStage.SUMMARY else: # 默认返回当前阶段或童年阶段 return ConversationStage.CHILDHOOD def clear_memory(self, conversation_id: str): """清除对话记忆""" if conversation_id in self.memories: del self.memories[conversation_id]