176 lines
6.0 KiB
Python
176 lines
6.0 KiB
Python
"""
|
||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||
"""
|
||
from typing import List, Optional
|
||
|
||
from langchain.chains import ConversationChain
|
||
from langchain.memory import ConversationBufferMemory
|
||
from langchain.prompts import PromptTemplate
|
||
|
||
from services.llm_service import llm_service
|
||
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
|
||
from .state_schema import MemoirStateSchema
|
||
|
||
|
||
class ConversationAgent:
|
||
"""对话 Agent"""
|
||
|
||
def __init__(self):
|
||
# 使用 LLM 服务获取 LLM 实例
|
||
self.llm = llm_service.get_llm()
|
||
|
||
# 对话记忆
|
||
self.memories: dict[str, ConversationBufferMemory] = {}
|
||
|
||
def _get_memory(self, conversation_id: str) -> ConversationBufferMemory:
|
||
"""获取或创建对话记忆"""
|
||
if conversation_id not in self.memories:
|
||
self.memories[conversation_id] = ConversationBufferMemory(
|
||
return_messages=True,
|
||
memory_key="history"
|
||
)
|
||
return self.memories[conversation_id]
|
||
|
||
def generate_response(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
current_stage: Optional[ConversationStage] = None,
|
||
covered_topics: Optional[List[str]] = None
|
||
) -> str:
|
||
"""
|
||
生成 Agent 回应
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
user_message: 用户消息
|
||
current_stage: 当前对话阶段
|
||
covered_topics: 已聊过的话题列表
|
||
|
||
Returns:
|
||
Agent 回应文本
|
||
"""
|
||
if current_stage is None:
|
||
current_stage = ConversationStage.CHILDHOOD
|
||
|
||
if covered_topics is None:
|
||
covered_topics = []
|
||
|
||
# 如果没有配置 LLM,返回默认回应
|
||
if not self.llm:
|
||
return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||
|
||
# 获取系统提示词
|
||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||
|
||
# 获取对话记忆
|
||
memory = self._get_memory(conversation_id)
|
||
|
||
# 创建对话链
|
||
prompt_template = PromptTemplate(
|
||
input_variables=["history", "input"],
|
||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||
)
|
||
|
||
chain = ConversationChain(
|
||
llm=self.llm,
|
||
prompt=prompt_template,
|
||
memory=memory,
|
||
verbose=False
|
||
)
|
||
|
||
# 生成回应
|
||
response = chain.predict(input=user_message)
|
||
|
||
return response
|
||
|
||
def generate_response_with_state(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
memoir_state: MemoirStateSchema
|
||
) -> List[str]:
|
||
"""
|
||
基于共享状态生成引导式回复
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
user_message: 用户消息
|
||
memoir_state: 共享状态
|
||
|
||
Returns:
|
||
Agent 回应文本列表(支持多条消息)
|
||
"""
|
||
if not self.llm:
|
||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||
|
||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||
filled_slots = {
|
||
key: value.snippet
|
||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||
if value.snippet
|
||
}
|
||
|
||
system_prompt = get_guided_conversation_prompt(
|
||
current_stage=memoir_state.current_stage,
|
||
empty_slots=empty_slots,
|
||
filled_slots=filled_slots,
|
||
user_message=user_message,
|
||
)
|
||
|
||
memory = self._get_memory(conversation_id)
|
||
prompt_template = PromptTemplate(
|
||
input_variables=["history", "input"],
|
||
template=f"{system_prompt}\n\n{{history}}\n\nHuman: {{input}}\n\nAssistant:"
|
||
)
|
||
|
||
chain = ConversationChain(
|
||
llm=self.llm,
|
||
prompt=prompt_template,
|
||
memory=memory,
|
||
verbose=False
|
||
)
|
||
|
||
response = chain.predict(input=user_message)
|
||
|
||
# 支持多条消息,用 [SPLIT] 分隔
|
||
messages = [msg.strip() for msg in response.split("[SPLIT]") if msg.strip()]
|
||
# 最多返回 3 条
|
||
return messages[:3] if messages else [response]
|
||
|
||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||
"""
|
||
检测对话阶段
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
user_message: 用户消息
|
||
|
||
Returns:
|
||
检测到的对话阶段
|
||
"""
|
||
# 简单的关键词检测(实际应该使用更智能的方法)
|
||
message_lower = user_message.lower()
|
||
|
||
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
|
||
return ConversationStage.CHILDHOOD
|
||
elif any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
|
||
return ConversationStage.EDUCATION
|
||
elif any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
|
||
return ConversationStage.CAREER
|
||
elif any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
|
||
return ConversationStage.FAMILY
|
||
elif any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
|
||
return ConversationStage.BELIEFS
|
||
elif any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
|
||
return ConversationStage.SUMMARY
|
||
else:
|
||
# 默认返回当前阶段或童年阶段
|
||
return ConversationStage.CHILDHOOD
|
||
|
||
def clear_memory(self, conversation_id: str):
|
||
"""清除对话记忆"""
|
||
if conversation_id in self.memories:
|
||
del self.memories[conversation_id]
|
||
|