Files
life-echo/api/agents/conversation_agent.py

405 lines
18 KiB
Python
Raw Normal View History

2026-01-07 11:56:53 +08:00
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储
支持用户基础资料收集和时代背景融入
2026-01-07 11:56:53 +08:00
"""
import json
import logging
from typing import List, Optional, Dict, Any
2026-01-07 11:56:53 +08:00
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
2026-01-07 11:56:53 +08:00
from services.llm_service import llm_service
from services.redis_service import redis_service
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt, get_opening_prompt
from .prompts.profile_prompts import (
get_profile_greeting_prompt,
get_profile_extraction_prompt,
get_profile_followup_prompt,
format_user_profile_context,
get_missing_profile_fields,
)
from .prompts.conversation_prompts import SLOT_NAME_MAP
2026-01-21 22:31:03 +01:00
from .state_schema import MemoirStateSchema
2026-01-07 11:56:53 +08:00
logger = logging.getLogger(__name__)
2026-01-07 11:56:53 +08:00
class ConversationAgent:
"""对话 Agent支持异步和 Redis 存储)"""
2026-01-07 11:56:53 +08:00
def __init__(self):
# 使用 LLM 服务获取 LLM 实例
self.llm = llm_service.get_llm()
2026-01-07 11:56:53 +08:00
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
"""从 Redis 获取对话历史并转换为 LangChain 消息格式"""
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
async def _save_message(self, conversation_id: str, role: str, content: str):
"""保存消息到 Redis"""
await redis_service.add_message(conversation_id, role, content)
2026-01-07 11:56:53 +08:00
def _format_history_string(self, messages: List[Any]) -> str:
"""将消息列表格式化为字符串(用于 prompt"""
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def generate_response(
2026-01-07 11:56:53 +08:00
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None
) -> str:
"""
异步生成 Agent 回应
2026-01-07 11:56:53 +08:00
Args:
conversation_id: 对话 ID
user_message: 用户消息
current_stage: 当前对话阶段
covered_topics: 已聊过的话题列表
Returns:
Agent 回应文本
"""
if current_stage is None:
current_stage = ConversationStage.CHILDHOOD
if covered_topics is None:
covered_topics = []
# 如果没有配置 LLM返回默认回应
if not self.llm:
return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
2026-01-07 11:56:53 +08:00
try:
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 从 Redis 获取对话历史
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
return response_text
except Exception as e:
logger.error(f"生成回应失败: {e}")
return f"抱歉,生成回应时出现错误: {str(e)}"
2026-01-21 22:31:03 +01:00
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""生成资料收集的开场白(首次对话时使用)"""
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成资料收集开场白失败: {e}")
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""
空对话时 AI 先开口用户通过打个招呼进入尚未发任何消息
生成问候 + 一个引导问题写入 Redis 并返回消息列表
"""
if not self.llm:
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
if not empty_slots_readable:
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
prompt = get_opening_prompt(
current_stage=memoir_state.current_stage,
empty_slots_readable=empty_slots_readable,
user_profile_context=user_profile_context,
)
full_prompt = f"{prompt}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成开场白失败: {e}", exc_info=True)
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""从用户消息中提取基础资料信息;若提供 conversation_id会结合最近几轮对话一起提取避免漏提。"""
if not self.llm or not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
history_messages = await self._get_history_messages(conversation_id)
# 取最近 4 条2 轮),不包含本轮;本轮由 user_message 单独传入
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
parts = []
for msg in recent:
if isinstance(msg, HumanMessage):
parts.append(f"用户: {msg.content}")
elif isinstance(msg, AIMessage):
parts.append(f"助手: {msg.content}")
recent_dialogue = "\n".join(parts) if parts else ""
try:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = {}
if "birth_year" in parsed and parsed["birth_year"] is not None:
raw = parsed["birth_year"]
if isinstance(raw, int) and 1900 <= raw <= 2100:
result["birth_year"] = raw
elif isinstance(raw, str) and raw.isdigit():
y = int(raw)
if y < 100: # "65" -> 1965
y = 1900 + y if y >= 50 else 2000 + y
if 1900 <= y <= 2100:
result["birth_year"] = y
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error(f"提取资料信息失败: {e}")
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
) -> List[str]:
"""在资料收集过程中生成跟进回复"""
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成资料跟进回复失败: {e}")
return ["谢谢分享!能再告诉我一些吗?"]
def _detect_user_stage(self, user_message: str) -> str:
"""
通过关键词检测用户当前正在谈论的人生阶段
返回阶段名称字符串未检测到返回空字符串
"""
message = user_message.lower()
stage_keywords = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
}
for stage, keywords in stage_keywords.items():
if any(word in message for word in keywords):
return stage
return ""
async def generate_response_with_state(
2026-01-21 22:31:03 +01:00
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
2026-01-21 22:31:03 +01:00
) -> List[str]:
"""
基于共享状态异步生成引导式回复
2026-01-21 22:31:03 +01:00
Args:
conversation_id: 对话 ID
user_message: 用户消息
memoir_state: 共享状态
user_profile_context: 用户基础资料上下文
2026-01-21 22:31:03 +01:00
Returns:
Agent 回应文本列表支持多条消息
"""
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
2026-01-21 22:31:03 +01:00
detected_user_stage = self._detect_user_stage(user_message)
history_messages = await self._get_history_messages(conversation_id)
conversation_turn = len(history_messages) // 2
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
all_stages_coverage = memoir_state.all_stages_coverage()
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
conversation_turn=conversation_turn,
same_topic_turns=same_topic_turns,
all_stages_coverage=all_stages_coverage,
detected_user_stage=detected_user_stage,
user_profile_context=user_profile_context,
)
2026-01-21 22:31:03 +01:00
history_string = self._format_history_string(history_messages)
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成回应失败: {e}")
return [f"抱歉,生成回应时出现错误: {str(e)}"]
2026-01-07 11:56:53 +08:00
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
"""
估算同一话题的轮数
通过分析最近几轮对话来判断是否一直在同一个话题上
"""
if len(history_messages) < 4:
return len(history_messages) // 2
# 简单策略:检查最近的对话是否有重复关键词
recent_messages = history_messages[-6:] # 最近3轮
# 提取关键词(简单实现)
keywords_per_turn = []
for i in range(0, len(recent_messages), 2):
if i + 1 < len(recent_messages):
human_msg = recent_messages[i].content if hasattr(recent_messages[i], 'content') else str(recent_messages[i])
ai_msg = recent_messages[i+1].content if hasattr(recent_messages[i+1], 'content') else str(recent_messages[i+1])
combined = human_msg + ai_msg
keywords_per_turn.append(combined[:100]) # 取前100字作为特征
# 如果连续3轮都在讨论相似内容认为同一话题
if len(keywords_per_turn) >= 3:
return 3
return len(keywords_per_turn)
2026-01-07 11:56:53 +08:00
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
"""
检测对话阶段
Args:
conversation_id: 对话 ID
user_message: 用户消息
Returns:
检测到的对话阶段
"""
# 简单的关键词检测(实际应该使用更智能的方法)
message_lower = user_message.lower()
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
return ConversationStage.CHILDHOOD
elif any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
return ConversationStage.EDUCATION
elif any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
return ConversationStage.CAREER
elif any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
return ConversationStage.FAMILY
elif any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
return ConversationStage.BELIEFS
elif any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
return ConversationStage.SUMMARY
else:
# 默认返回当前阶段或童年阶段
return ConversationStage.CHILDHOOD
async def clear_memory(self, conversation_id: str):
"""清除对话记忆(从 Redis"""
await redis_service.clear_conversation_history(conversation_id)
2026-01-07 11:56:53 +08:00