356 lines
16 KiB
Python
356 lines
16 KiB
Python
"""
|
||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||
支持异步调用和 Redis 会话存储,支持用户基础资料收集和时代背景融入
|
||
"""
|
||
import json
|
||
from app.core.logging import get_logger
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from langchain_core.messages import AIMessage, HumanMessage
|
||
|
||
from app.core.redis import redis_service
|
||
from app.core.dependencies import get_llm_provider
|
||
|
||
from app.agents.prompts import (
|
||
ConversationStage,
|
||
get_conversation_prompt,
|
||
get_guided_conversation_prompt,
|
||
get_opening_prompt,
|
||
)
|
||
from app.agents.prompts.profile_prompts import (
|
||
get_profile_greeting_prompt,
|
||
get_profile_extraction_prompt,
|
||
get_profile_followup_prompt,
|
||
format_user_profile_context,
|
||
get_missing_profile_fields,
|
||
)
|
||
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
|
||
from app.agents.state_schema import MemoirStateSchema
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
def _get_langchain_llm():
|
||
"""从 port 获取 LangChain LLM 实例(供 Agent 使用)"""
|
||
try:
|
||
provider = get_llm_provider()
|
||
return getattr(provider, "langchain_llm", None)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
class ConversationAgent:
|
||
"""对话 Agent(支持异步和 Redis 存储)"""
|
||
|
||
def __init__(self):
|
||
self.llm = _get_langchain_llm()
|
||
|
||
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
|
||
history = await redis_service.get_conversation_history(conversation_id)
|
||
messages = []
|
||
for msg in history:
|
||
if msg["role"] == "human":
|
||
messages.append(HumanMessage(content=msg["content"]))
|
||
elif msg["role"] == "ai":
|
||
messages.append(AIMessage(content=msg["content"]))
|
||
return messages
|
||
|
||
async def _save_message(
|
||
self,
|
||
conversation_id: str,
|
||
role: str,
|
||
content: str,
|
||
message_type: str = "text",
|
||
voice_session_id: str | None = None,
|
||
timestamp: datetime | str | int | None = None,
|
||
):
|
||
await redis_service.add_message(
|
||
conversation_id,
|
||
role,
|
||
content,
|
||
message_type=message_type,
|
||
voice_session_id=voice_session_id,
|
||
timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
|
||
)
|
||
|
||
def _format_history_string(self, messages: List[Any]) -> str:
|
||
history_parts = []
|
||
for msg in messages:
|
||
if isinstance(msg, HumanMessage):
|
||
history_parts.append(f"Human: {msg.content}")
|
||
elif isinstance(msg, AIMessage):
|
||
history_parts.append(f"Assistant: {msg.content}")
|
||
return "\n\n".join(history_parts)
|
||
|
||
async def generate_response(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
current_stage: Optional[ConversationStage] = None,
|
||
covered_topics: Optional[List[str]] = None,
|
||
) -> str:
|
||
if current_stage is None:
|
||
current_stage = ConversationStage.CHILDHOOD
|
||
if covered_topics is None:
|
||
covered_topics = []
|
||
if not self.llm:
|
||
return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||
try:
|
||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||
history_messages = await self._get_history_messages(conversation_id)
|
||
history_string = self._format_history_string(history_messages)
|
||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||
response = await self.llm.ainvoke(full_prompt)
|
||
response_text = response.content if hasattr(response, "content") else str(response)
|
||
await self._save_message(conversation_id, "human", user_message)
|
||
await self._save_message(conversation_id, "ai", response_text)
|
||
return response_text
|
||
except Exception as e:
|
||
logger.error("生成回应失败: %s", e)
|
||
return f"抱歉,生成回应时出现错误: {str(e)}"
|
||
|
||
async def generate_profile_greeting(
|
||
self,
|
||
conversation_id: str,
|
||
missing_fields: List[str],
|
||
nickname: str = "",
|
||
) -> List[str]:
|
||
if not self.llm:
|
||
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
|
||
try:
|
||
prompt = get_profile_greeting_prompt(missing_fields, nickname)
|
||
history_messages = await self._get_history_messages(conversation_id)
|
||
history_string = self._format_history_string(history_messages)
|
||
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
|
||
response = await self.llm.ainvoke(full_prompt)
|
||
response_text = response.content if hasattr(response, "content") else str(response)
|
||
await self._save_message(conversation_id, "ai", response_text)
|
||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||
return messages[:2] if messages else [response_text]
|
||
except Exception as e:
|
||
logger.error("生成资料收集开场白失败: %s", e)
|
||
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
|
||
|
||
async def generate_opening_message(
|
||
self,
|
||
conversation_id: str,
|
||
memoir_state: MemoirStateSchema,
|
||
user_profile_context: str = "",
|
||
) -> List[str]:
|
||
if not self.llm:
|
||
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
|
||
try:
|
||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
|
||
if not empty_slots_readable:
|
||
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
|
||
prompt = get_opening_prompt(
|
||
current_stage=memoir_state.current_stage,
|
||
empty_slots_readable=empty_slots_readable,
|
||
user_profile_context=user_profile_context,
|
||
)
|
||
full_prompt = f"{prompt}\n\nAssistant:"
|
||
response = await self.llm.ainvoke(full_prompt)
|
||
response_text = response.content if hasattr(response, "content") else str(response)
|
||
await self._save_message(conversation_id, "ai", response_text)
|
||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||
return messages[:2] if messages else [response_text]
|
||
except Exception as e:
|
||
logger.error("生成开场白失败: %s", e, exc_info=True)
|
||
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
|
||
|
||
async def extract_profile_from_message(
|
||
self,
|
||
user_message: str,
|
||
missing_fields: List[str],
|
||
conversation_id: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
if not self.llm or not missing_fields:
|
||
return {}
|
||
recent_dialogue = ""
|
||
if conversation_id:
|
||
history_messages = await self._get_history_messages(conversation_id)
|
||
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
|
||
parts = []
|
||
for msg in recent:
|
||
if isinstance(msg, HumanMessage):
|
||
parts.append(f"用户: {msg.content}")
|
||
elif isinstance(msg, AIMessage):
|
||
parts.append(f"助手: {msg.content}")
|
||
recent_dialogue = "\n".join(parts) if parts else ""
|
||
try:
|
||
prompt = get_profile_extraction_prompt(
|
||
user_message, missing_fields, recent_dialogue=recent_dialogue or None
|
||
)
|
||
response = await self.llm.ainvoke(prompt)
|
||
content = response.content.strip()
|
||
parsed = json.loads(content)
|
||
result = {}
|
||
if "birth_year" in parsed and parsed["birth_year"] is not None:
|
||
raw = parsed["birth_year"]
|
||
if isinstance(raw, int) and 1900 <= raw <= 2100:
|
||
result["birth_year"] = raw
|
||
elif isinstance(raw, str) and raw.isdigit():
|
||
y = int(raw)
|
||
if y < 100:
|
||
y = 1900 + y if y >= 50 else 2000 + y
|
||
if 1900 <= y <= 2100:
|
||
result["birth_year"] = y
|
||
if "birth_place" in parsed and parsed["birth_place"]:
|
||
result["birth_place"] = str(parsed["birth_place"])
|
||
if "grew_up_place" in parsed and parsed["grew_up_place"]:
|
||
result["grew_up_place"] = str(parsed["grew_up_place"])
|
||
if "occupation" in parsed and parsed["occupation"]:
|
||
result["occupation"] = str(parsed["occupation"])
|
||
return result
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logger.error("提取资料信息失败: %s", e)
|
||
return {}
|
||
|
||
async def generate_profile_followup(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
missing_fields: List[str],
|
||
filled_fields: Dict[str, str],
|
||
nickname: str = "",
|
||
is_from_voice: bool = False,
|
||
voice_session_id: str | None = None,
|
||
user_message_timestamp: datetime | None = None,
|
||
) -> List[str]:
|
||
if not self.llm:
|
||
return ["谢谢!还能告诉我更多吗?"]
|
||
try:
|
||
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
|
||
history_messages = await self._get_history_messages(conversation_id)
|
||
history_string = self._format_history_string(history_messages)
|
||
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||
response = await self.llm.ainvoke(full_prompt)
|
||
response_text = response.content if hasattr(response, "content") else str(response)
|
||
human_msg_type = "audio" if is_from_voice else "text"
|
||
await self._save_message(
|
||
conversation_id,
|
||
"human",
|
||
user_message,
|
||
message_type=human_msg_type,
|
||
voice_session_id=voice_session_id,
|
||
timestamp=user_message_timestamp,
|
||
)
|
||
await self._save_message(conversation_id, "ai", response_text)
|
||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||
return messages[:3] if messages else [response_text]
|
||
except Exception as e:
|
||
logger.error("生成资料跟进回复失败: %s", e)
|
||
return ["谢谢分享!能再告诉我一些吗?"]
|
||
|
||
def _detect_user_stage(self, user_message: str) -> str:
|
||
message = user_message.lower()
|
||
stage_keywords = {
|
||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
|
||
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
|
||
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
|
||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
|
||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
|
||
}
|
||
for stage, keywords in stage_keywords.items():
|
||
if any(word in message for word in keywords):
|
||
return stage
|
||
return ""
|
||
|
||
async def generate_response_with_state(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
memoir_state: MemoirStateSchema,
|
||
user_profile_context: str = "",
|
||
is_from_voice: bool = False,
|
||
voice_session_id: str | None = None,
|
||
user_message_timestamp: datetime | None = None,
|
||
) -> List[str]:
|
||
if not self.llm:
|
||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||
try:
|
||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||
filled_slots = {
|
||
key: value.snippet
|
||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||
if value.snippet
|
||
}
|
||
detected_user_stage = self._detect_user_stage(user_message)
|
||
history_messages = await self._get_history_messages(conversation_id)
|
||
conversation_turn = len(history_messages) // 2
|
||
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
|
||
all_stages_coverage = memoir_state.all_stages_coverage()
|
||
system_prompt = get_guided_conversation_prompt(
|
||
current_stage=memoir_state.current_stage,
|
||
empty_slots=empty_slots,
|
||
filled_slots=filled_slots,
|
||
user_message=user_message,
|
||
conversation_turn=conversation_turn,
|
||
same_topic_turns=same_topic_turns,
|
||
all_stages_coverage=all_stages_coverage,
|
||
detected_user_stage=detected_user_stage,
|
||
user_profile_context=user_profile_context,
|
||
)
|
||
history_string = self._format_history_string(history_messages)
|
||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||
response = await self.llm.ainvoke(full_prompt)
|
||
response_text = response.content if hasattr(response, "content") else str(response)
|
||
human_msg_type = "audio" if is_from_voice else "text"
|
||
await self._save_message(
|
||
conversation_id,
|
||
"human",
|
||
user_message,
|
||
message_type=human_msg_type,
|
||
voice_session_id=voice_session_id,
|
||
timestamp=user_message_timestamp,
|
||
)
|
||
await self._save_message(conversation_id, "ai", response_text)
|
||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||
return messages[:3] if messages else [response_text]
|
||
except Exception as e:
|
||
logger.error("生成回应失败: %s", e)
|
||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||
|
||
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
|
||
if len(history_messages) < 4:
|
||
return len(history_messages) // 2
|
||
recent_messages = history_messages[-6:]
|
||
keywords_per_turn = []
|
||
for i in range(0, len(recent_messages), 2):
|
||
if i + 1 < len(recent_messages):
|
||
human_msg = (
|
||
recent_messages[i].content
|
||
if hasattr(recent_messages[i], "content")
|
||
else str(recent_messages[i])
|
||
)
|
||
ai_msg = (
|
||
recent_messages[i + 1].content
|
||
if hasattr(recent_messages[i + 1], "content")
|
||
else str(recent_messages[i + 1])
|
||
)
|
||
keywords_per_turn.append((human_msg + ai_msg)[:100])
|
||
if len(keywords_per_turn) >= 3:
|
||
return 3
|
||
return len(keywords_per_turn)
|
||
|
||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||
message_lower = user_message.lower()
|
||
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
|
||
return ConversationStage.CHILDHOOD
|
||
if any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
|
||
return ConversationStage.EDUCATION
|
||
if any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
|
||
return ConversationStage.CAREER
|
||
if any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
|
||
return ConversationStage.FAMILY
|
||
if any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
|
||
return ConversationStage.BELIEFS
|
||
if any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
|
||
return ConversationStage.SUMMARY
|
||
return ConversationStage.CHILDHOOD
|
||
|
||
async def clear_memory(self, conversation_id: str):
|
||
await redis_service.clear_conversation_history(conversation_id)
|