Files
life-echo/api/agents/conversation_agent.py
2026-01-07 11:56:53 +08:00

156 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
"""
import os
from typing import List, Optional
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from .prompts import ConversationStage, get_conversation_prompt
class ConversationAgent:
"""对话 Agent"""
def __init__(self):
# 初始化 LLM使用环境变量配置
# 优先使用 LLM_API_KEY 和 LLM_BASE_URL如果没有则使用 OPENAI_API_KEY
api_key = os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY", "")
base_url = os.getenv("LLM_BASE_URL", "")
model_name = os.getenv("OPENAI_MODEL", "gpt-4o")
if not api_key:
self.llm = None
self.memories: dict[str, ConversationBufferMemory] = {}
return
# 如果提供了 base_url需要处理路径langchain 会自动添加 /v1/chat/completions
llm_kwargs = {
"temperature": 0.7,
"model": model_name,
"openai_api_key": api_key,
}
if base_url:
# 移除可能的 /v1/chat/completions 路径langchain 会自动添加
if base_url.endswith("/v1/chat/completions"):
base_url = base_url[:-20] # 移除 "/v1/chat/completions"
elif base_url.endswith("/v1"):
base_url = base_url[:-3] # 移除 "/v1"
# 确保 base_url 以 / 结尾(如果没有)
if base_url and not base_url.endswith("/"):
base_url += "/"
llm_kwargs["openai_api_base"] = base_url
try:
self.llm = ChatOpenAI(**llm_kwargs)
except Exception:
self.llm = None
# 对话记忆
self.memories: dict[str, ConversationBufferMemory] = {}
# 对话记忆
self.memories: dict[str, ConversationBufferMemory] = {}
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
"""获取或创建对话记忆"""
if conversation_id not in self.memories:
self.memories[conversation_id] = ConversationBufferMemory(
return_messages=True,
memory_key="history"
)
return self.memories[conversation_id]
def generate_response(
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None
) -> str:
"""
生成 Agent 回应
Args:
conversation_id: 对话 ID
user_message: 用户消息
current_stage: 当前对话阶段
covered_topics: 已聊过的话题列表
Returns:
Agent 回应文本
"""
if current_stage is None:
current_stage = ConversationStage.CHILDHOOD
if covered_topics is None:
covered_topics = []
# 如果没有配置 LLM返回默认回应
if not self.llm:
return "抱歉LLM 服务未配置。请设置 LLM_API_KEY 或 OPENAI_API_KEY 环境变量。"
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 获取对话记忆
memory = self._get_memory(conversation_id)
# 创建对话链
prompt_template = PromptTemplate(
input_variables=["history", "input"],
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
)
chain = ConversationChain(
llm=self.llm,
prompt=prompt_template,
memory=memory,
verbose=False
)
# 生成回应
response = chain.predict(input=user_message)
return response
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
"""
检测对话阶段
Args:
conversation_id: 对话 ID
user_message: 用户消息
Returns:
检测到的对话阶段
"""
# 简单的关键词检测(实际应该使用更智能的方法)
message_lower = user_message.lower()
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
return ConversationStage.CHILDHOOD
elif any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
return ConversationStage.EDUCATION
elif any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
return ConversationStage.CAREER
elif any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
return ConversationStage.FAMILY
elif any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
return ConversationStage.BELIEFS
elif any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
return ConversationStage.SUMMARY
else:
# 默认返回当前阶段或童年阶段
return ConversationStage.CHILDHOOD
def clear_memory(self, conversation_id: str):
"""清除对话记忆"""
if conversation_id in self.memories:
del self.memories[conversation_id]