Files
life-echo/api/agents/conversation_agent.py
penghanyuan 7fe0b70d5c feat: 增强对话代理以检测用户阶段并更新章节排序
- 在 api/agents/conversation_agent.py 中添加 _detect_user_stage 方法,以通过关键词检测用户谈论的人生阶段。
- 在 api/agents/memory_agent.py 中更新章节排序逻辑,使用 STAGE_TO_ORDER 替代 CHAPTER_ORDER。
- 在 api/agents/state_schema.py 中添加方法以获取各阶段的填充情况。
- 在 api/agents/prompts/conversation_prompts.py 中更新对话提示,包含用户阶段检测和整体进度信息。
- 在 api/migrations/fix_chapter_order_index.sql 中添加 SQL 脚本以修复章节 order_index 的问题。
- 更新相关文档和提示以反映新功能。
2026-02-13 21:45:56 +01:00

260 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储
"""
import logging
from typing import List, Optional, Dict, Any
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from services.llm_service import llm_service
from services.redis_service import redis_service
from .prompts import ConversationStage, get_conversation_prompt, get_guided_conversation_prompt
from .state_schema import MemoirStateSchema
logger = logging.getLogger(__name__)
class ConversationAgent:
"""对话 Agent支持异步和 Redis 存储)"""
def __init__(self):
# 使用 LLM 服务获取 LLM 实例
self.llm = llm_service.get_llm()
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
"""从 Redis 获取对话历史并转换为 LangChain 消息格式"""
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
async def _save_message(self, conversation_id: str, role: str, content: str):
"""保存消息到 Redis"""
await redis_service.add_message(conversation_id, role, content)
def _format_history_string(self, messages: List[Any]) -> str:
"""将消息列表格式化为字符串(用于 prompt"""
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def generate_response(
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None
) -> str:
"""
异步生成 Agent 回应
Args:
conversation_id: 对话 ID
user_message: 用户消息
current_stage: 当前对话阶段
covered_topics: 已聊过的话题列表
Returns:
Agent 回应文本
"""
if current_stage is None:
current_stage = ConversationStage.CHILDHOOD
if covered_topics is None:
covered_topics = []
# 如果没有配置 LLM返回默认回应
if not self.llm:
return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
try:
# 获取系统提示词
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
# 从 Redis 获取对话历史
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
return response_text
except Exception as e:
logger.error(f"生成回应失败: {e}")
return f"抱歉,生成回应时出现错误: {str(e)}"
def _detect_user_stage(self, user_message: str) -> str:
"""
通过关键词检测用户当前正在谈论的人生阶段。
返回阶段名称字符串,未检测到返回空字符串。
"""
message = user_message.lower()
stage_keywords = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
}
for stage, keywords in stage_keywords.items():
if any(word in message for word in keywords):
return stage
return ""
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema
) -> List[str]:
"""
基于共享状态异步生成引导式回复
Args:
conversation_id: 对话 ID
user_message: 用户消息
memoir_state: 共享状态
Returns:
Agent 回应文本列表(支持多条消息)
"""
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
# 检测用户当前正在谈论的阶段
detected_user_stage = self._detect_user_stage(user_message)
# 从 Redis 获取对话历史,用于计算对话轮数
history_messages = await self._get_history_messages(conversation_id)
conversation_turn = len(history_messages) // 2 # 每轮包括一个用户消息和一个AI回复
# 计算同一话题的轮数(简单估算:基于已填充槽位的变化)
# 如果槽位数量没有增加,说明还在同一话题深入
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
# 获取所有阶段的覆盖情况
all_stages_coverage = memoir_state.all_stages_coverage()
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
conversation_turn=conversation_turn,
same_topic_turns=same_topic_turns,
all_stages_coverage=all_stages_coverage,
detected_user_stage=detected_user_stage,
)
history_string = self._format_history_string(history_messages)
# 构建完整 prompt
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
# 异步调用 LLM
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, 'content') else str(response)
# 保存对话到 Redis
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
# 支持多条消息,用 [SPLIT] 分隔
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
# 最多返回 3 条
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error(f"生成回应失败: {e}")
return [f"抱歉,生成回应时出现错误: {str(e)}"]
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
"""
估算同一话题的轮数
通过分析最近几轮对话来判断是否一直在同一个话题上
"""
if len(history_messages) < 4:
return len(history_messages) // 2
# 简单策略:检查最近的对话是否有重复关键词
recent_messages = history_messages[-6:] # 最近3轮
# 提取关键词(简单实现)
keywords_per_turn = []
for i in range(0, len(recent_messages), 2):
if i + 1 < len(recent_messages):
human_msg = recent_messages[i].content if hasattr(recent_messages[i], 'content') else str(recent_messages[i])
ai_msg = recent_messages[i+1].content if hasattr(recent_messages[i+1], 'content') else str(recent_messages[i+1])
combined = human_msg + ai_msg
keywords_per_turn.append(combined[:100]) # 取前100字作为特征
# 如果连续3轮都在讨论相似内容认为同一话题
if len(keywords_per_turn) >= 3:
return 3
return len(keywords_per_turn)
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
"""
检测对话阶段
Args:
conversation_id: 对话 ID
user_message: 用户消息
Returns:
检测到的对话阶段
"""
# 简单的关键词检测(实际应该使用更智能的方法)
message_lower = user_message.lower()
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
return ConversationStage.CHILDHOOD
elif any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
return ConversationStage.EDUCATION
elif any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
return ConversationStage.CAREER
elif any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
return ConversationStage.FAMILY
elif any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
return ConversationStage.BELIEFS
elif any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
return ConversationStage.SUMMARY
else:
# 默认返回当前阶段或童年阶段
return ConversationStage.CHILDHOOD
async def clear_memory(self, conversation_id: str):
"""清除对话记忆(从 Redis"""
await redis_service.clear_conversation_history(conversation_id)