Files
life-echo/api/app/agents/conversation_agent.py

356 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储,支持用户基础资料收集和时代背景融入
"""
import json
from app.core.logging import get_logger
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from app.core.redis import redis_service
from app.core.dependencies import get_llm_provider
from app.agents.prompts import (
ConversationStage,
get_conversation_prompt,
get_guided_conversation_prompt,
get_opening_prompt,
)
from app.agents.prompts.profile_prompts import (
get_profile_greeting_prompt,
get_profile_extraction_prompt,
get_profile_followup_prompt,
format_user_profile_context,
get_missing_profile_fields,
)
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
from app.agents.state_schema import MemoirStateSchema
logger = get_logger(__name__)
def _get_langchain_llm():
"""从 port 获取 LangChain LLM 实例(供 Agent 使用)"""
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class ConversationAgent:
"""对话 Agent支持异步和 Redis 存储)"""
def __init__(self):
self.llm = _get_langchain_llm()
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
async def _save_message(
self,
conversation_id: str,
role: str,
content: str,
message_type: str = "text",
voice_session_id: str | None = None,
timestamp: datetime | str | int | None = None,
):
await redis_service.add_message(
conversation_id,
role,
content,
message_type=message_type,
voice_session_id=voice_session_id,
timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
)
def _format_history_string(self, messages: List[Any]) -> str:
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def generate_response(
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None,
) -> str:
if current_stage is None:
current_stage = ConversationStage.CHILDHOOD
if covered_topics is None:
covered_topics = []
if not self.llm:
return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
try:
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
return response_text
except Exception as e:
logger.error("生成回应失败: %s", e)
return f"抱歉,生成回应时出现错误: {str(e)}"
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成资料收集开场白失败: %s", e)
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
if not self.llm:
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
if not empty_slots_readable:
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
prompt = get_opening_prompt(
current_stage=memoir_state.current_stage,
empty_slots_readable=empty_slots_readable,
user_profile_context=user_profile_context,
)
full_prompt = f"{prompt}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成开场白失败: %s", e, exc_info=True)
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
if not self.llm or not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
history_messages = await self._get_history_messages(conversation_id)
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
parts = []
for msg in recent:
if isinstance(msg, HumanMessage):
parts.append(f"用户: {msg.content}")
elif isinstance(msg, AIMessage):
parts.append(f"助手: {msg.content}")
recent_dialogue = "\n".join(parts) if parts else ""
try:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = {}
if "birth_year" in parsed and parsed["birth_year"] is not None:
raw = parsed["birth_year"]
if isinstance(raw, int) and 1900 <= raw <= 2100:
result["birth_year"] = raw
elif isinstance(raw, str) and raw.isdigit():
y = int(raw)
if y < 100:
y = 1900 + y if y >= 50 else 2000 + y
if 1900 <= y <= 2100:
result["birth_year"] = y
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error("提取资料信息失败: %s", e)
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
human_msg_type = "audio" if is_from_voice else "text"
await self._save_message(
conversation_id,
"human",
user_message,
message_type=human_msg_type,
voice_session_id=voice_session_id,
timestamp=user_message_timestamp,
)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成资料跟进回复失败: %s", e)
return ["谢谢分享!能再告诉我一些吗?"]
def _detect_user_stage(self, user_message: str) -> str:
message = user_message.lower()
stage_keywords = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
}
for stage, keywords in stage_keywords.items():
if any(word in message for word in keywords):
return stage
return ""
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
detected_user_stage = self._detect_user_stage(user_message)
history_messages = await self._get_history_messages(conversation_id)
conversation_turn = len(history_messages) // 2
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
all_stages_coverage = memoir_state.all_stages_coverage()
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
conversation_turn=conversation_turn,
same_topic_turns=same_topic_turns,
all_stages_coverage=all_stages_coverage,
detected_user_stage=detected_user_stage,
user_profile_context=user_profile_context,
)
history_string = self._format_history_string(history_messages)
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
human_msg_type = "audio" if is_from_voice else "text"
await self._save_message(
conversation_id,
"human",
user_message,
message_type=human_msg_type,
voice_session_id=voice_session_id,
timestamp=user_message_timestamp,
)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成回应失败: %s", e)
return [f"抱歉,生成回应时出现错误: {str(e)}"]
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
if len(history_messages) < 4:
return len(history_messages) // 2
recent_messages = history_messages[-6:]
keywords_per_turn = []
for i in range(0, len(recent_messages), 2):
if i + 1 < len(recent_messages):
human_msg = (
recent_messages[i].content
if hasattr(recent_messages[i], "content")
else str(recent_messages[i])
)
ai_msg = (
recent_messages[i + 1].content
if hasattr(recent_messages[i + 1], "content")
else str(recent_messages[i + 1])
)
keywords_per_turn.append((human_msg + ai_msg)[:100])
if len(keywords_per_turn) >= 3:
return 3
return len(keywords_per_turn)
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
message_lower = user_message.lower()
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
return ConversationStage.CHILDHOOD
if any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
return ConversationStage.EDUCATION
if any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
return ConversationStage.CAREER
if any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
return ConversationStage.FAMILY
if any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
return ConversationStage.BELIEFS
if any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
return ConversationStage.SUMMARY
return ConversationStage.CHILDHOOD
async def clear_memory(self, conversation_id: str):
await redis_service.clear_conversation_history(conversation_id)