diff --git a/api/app/agents/__init__.py b/api/app/agents/__init__.py index 988a061..b1b613c 100644 --- a/api/app/agents/__init__.py +++ b/api/app/agents/__init__.py @@ -1,10 +1,19 @@ """ -Agent 模块(app 内,符合架构计划) +Agent 模块(按功能拆分:chat / memoir) """ -from app.agents.conversation_agent import ConversationAgent -from app.agents.memory_agent import MemoryAgent +from app.agents.chat import ( + ChatOrchestrator, + ConversationAgent, + InterviewAgent, + ProfileAgent, +) +from app.agents.memoir import BackgroundTaskRunner, MemoryAgent __all__ = [ "ConversationAgent", "MemoryAgent", + "ChatOrchestrator", + "ProfileAgent", + "InterviewAgent", + "BackgroundTaskRunner", ] diff --git a/api/app/agents/chat/__init__.py b/api/app/agents/chat/__init__.py new file mode 100644 index 0000000..32305dd --- /dev/null +++ b/api/app/agents/chat/__init__.py @@ -0,0 +1,12 @@ +"""聊天模块:AI 回复用户(ProfileAgent + InterviewAgent + ChatOrchestrator)""" +from app.agents.chat.conversation_agent import ConversationAgent +from app.agents.chat.orchestrator import ChatOrchestrator +from app.agents.chat.profile_agent import ProfileAgent +from app.agents.chat.interview_agent import InterviewAgent + +__all__ = [ + "ConversationAgent", + "ChatOrchestrator", + "ProfileAgent", + "InterviewAgent", +] diff --git a/api/app/agents/chat/conversation_agent.py b/api/app/agents/chat/conversation_agent.py new file mode 100644 index 0000000..27d8680 --- /dev/null +++ b/api/app/agents/chat/conversation_agent.py @@ -0,0 +1,139 @@ +""" +对话 Agent:Facade,内部委托 ChatOrchestrator + ProfileAgent + InterviewAgent +保留原有对外 API,供 router 等调用方兼容使用 +""" +from datetime import datetime +from typing import Any, Dict, List, Optional + +from app.agents.chat.orchestrator import ChatOrchestrator +from app.agents.prompts import ConversationStage +from app.agents.state_schema import MemoirStateSchema +from app.core.redis import redis_service + + +class ConversationAgent: + """对话 Agent Facade,委托 ChatOrchestrator 实现多 Agent 协同""" + + def __init__(self): + self._orchestrator = ChatOrchestrator() + + async def extract_profile_from_message( + self, + user_message: str, + missing_fields: List[str], + conversation_id: Optional[str] = None, + ) -> Dict[str, Any]: + """委托 ChatOrchestrator/ProfileAgent 提取资料""" + return await self._orchestrator.extract_profile_from_message( + user_message, missing_fields, conversation_id=conversation_id + ) + + async def generate_profile_followup( + self, + conversation_id: str, + user_message: str, + missing_fields: List[str], + filled_fields: Dict[str, str], + nickname: str = "", + is_from_voice: bool = False, + voice_session_id: str | None = None, + user_message_timestamp: datetime | None = None, + ) -> List[str]: + """委托 ChatOrchestrator/ProfileAgent 生成资料追问""" + return await self._orchestrator.generate_profile_followup( + conversation_id=conversation_id, + user_message=user_message, + missing_fields=missing_fields, + filled_fields=filled_fields, + nickname=nickname, + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + ) + + async def generate_profile_greeting( + self, + conversation_id: str, + missing_fields: List[str], + nickname: str = "", + ) -> List[str]: + """委托 ChatOrchestrator/ProfileAgent 生成资料收集开场白""" + return await self._orchestrator.generate_profile_greeting( + conversation_id=conversation_id, + missing_fields=missing_fields, + nickname=nickname, + ) + + async def generate_response_with_state( + self, + conversation_id: str, + user_message: str, + memoir_state: MemoirStateSchema, + user_profile_context: str = "", + is_from_voice: bool = False, + voice_session_id: str | None = None, + user_message_timestamp: datetime | None = None, + ) -> List[str]: + """委托 ChatOrchestrator/InterviewAgent 生成访谈回复""" + return await self._orchestrator.generate_response_with_state( + conversation_id=conversation_id, + user_message=user_message, + memoir_state=memoir_state, + user_profile_context=user_profile_context, + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + ) + + async def generate_opening_message( + self, + conversation_id: str, + memoir_state: MemoirStateSchema, + user_profile_context: str = "", + ) -> List[str]: + """委托 ChatOrchestrator/InterviewAgent 生成开场白""" + return await self._orchestrator.generate_opening_message( + conversation_id=conversation_id, + memoir_state=memoir_state, + user_profile_context=user_profile_context, + ) + + async def generate_response( + self, + conversation_id: str, + user_message: str, + current_stage: Optional[ConversationStage] = None, + covered_topics: Optional[List[str]] = None, + ) -> str: + """兼容旧 API:生成简单回复(无状态感知),委托 InterviewAgent 的等价逻辑""" + from app.agents.state_schema import default_state + + state = default_state() + state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value + state.covered_stages = covered_topics or [] + responses = await self._orchestrator.generate_response_with_state( + conversation_id=conversation_id, + user_message=user_message, + memoir_state=state, + user_profile_context="", + ) + return responses[0] if responses else "" + + def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage: + """根据关键词检测用户阶段(兼容 API)""" + detected = self._orchestrator.detect_user_stage(user_message) + if detected == "childhood": + return ConversationStage.CHILDHOOD + if detected == "education": + return ConversationStage.EDUCATION + if detected == "career": + return ConversationStage.CAREER + if detected == "family": + return ConversationStage.FAMILY + if detected == "belief": + return ConversationStage.BELIEFS + return ConversationStage.CHILDHOOD + + async def clear_memory(self, conversation_id: str) -> None: + """清除 Redis 中的对话历史""" + await redis_service.clear_conversation_history(conversation_id) diff --git a/api/app/agents/chat/helpers.py b/api/app/agents/chat/helpers.py new file mode 100644 index 0000000..db2c666 --- /dev/null +++ b/api/app/agents/chat/helpers.py @@ -0,0 +1,49 @@ +"""聊天 Agent 共享工具:历史获取、格式化、存储""" +from datetime import datetime +from typing import Any, List + +from langchain_core.messages import AIMessage, HumanMessage + +from app.core.redis import redis_service + + +async def get_history_messages(conversation_id: str) -> List[Any]: + """从 Redis 获取对话历史""" + history = await redis_service.get_conversation_history(conversation_id) + messages = [] + for msg in history: + if msg["role"] == "human": + messages.append(HumanMessage(content=msg["content"])) + elif msg["role"] == "ai": + messages.append(AIMessage(content=msg["content"])) + return messages + + +def format_history_string(messages: List[Any]) -> str: + """将消息列表格式化为 Human/Assistant 字符串""" + history_parts = [] + for msg in messages: + if isinstance(msg, HumanMessage): + history_parts.append(f"Human: {msg.content}") + elif isinstance(msg, AIMessage): + history_parts.append(f"Assistant: {msg.content}") + return "\n\n".join(history_parts) + + +async def save_message( + conversation_id: str, + role: str, + content: str, + message_type: str = "text", + voice_session_id: str | None = None, + timestamp: datetime | str | int | None = None, +) -> None: + """保存消息到 Redis""" + await redis_service.add_message( + conversation_id, + role, + content, + message_type=message_type, + voice_session_id=voice_session_id, + timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp, + ) diff --git a/api/app/agents/chat/interview_agent.py b/api/app/agents/chat/interview_agent.py new file mode 100644 index 0000000..4bd045d --- /dev/null +++ b/api/app/agents/chat/interview_agent.py @@ -0,0 +1,143 @@ +""" +InterviewAgent:正式访谈 Specialist +负责状态感知回复、开场白,不负责 Redis 持久化(由 Orchestrator 统一处理) +""" +from typing import Any, List + +from app.core.dependencies import get_llm_provider +from app.core.logging import get_logger + +from app.agents.chat.helpers import format_history_string, get_history_messages +from app.agents.prompts import get_guided_conversation_prompt, get_opening_prompt +from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP +from app.agents.state_schema import MemoirStateSchema + +logger = get_logger(__name__) + + +def _get_langchain_llm(): + try: + provider = get_llm_provider() + return getattr(provider, "langchain_llm", None) + except Exception: + return None + + +class InterviewAgent: + """正式访谈 Specialist Agent""" + + def __init__(self): + self.llm = _get_langchain_llm() + + def _detect_user_stage(self, user_message: str) -> str: + """根据关键词检测用户正在谈论的人生阶段""" + message = user_message.lower() + stage_keywords = { + "childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"], + "education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"], + "career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"], + "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"], + "belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"], + } + for stage, keywords in stage_keywords.items(): + if any(word in message for word in keywords): + return stage + return "" + + def _estimate_same_topic_turns( + self, history_messages: List[Any], current_filled_slots: dict + ) -> int: + """估算同一话题的连续轮数""" + if len(history_messages) < 4: + return len(history_messages) // 2 + recent_messages = history_messages[-6:] + keywords_per_turn = [] + for i in range(0, len(recent_messages), 2): + if i + 1 < len(recent_messages): + human_msg = ( + recent_messages[i].content + if hasattr(recent_messages[i], "content") + else str(recent_messages[i]) + ) + ai_msg = ( + recent_messages[i + 1].content + if hasattr(recent_messages[i + 1], "content") + else str(recent_messages[i + 1]) + ) + keywords_per_turn.append((human_msg + ai_msg)[:100]) + if len(keywords_per_turn) >= 3: + return 3 + return len(keywords_per_turn) + + async def generate_response_with_state( + self, + conversation_id: str, + user_message: str, + memoir_state: MemoirStateSchema, + user_profile_context: str = "", + ) -> List[str]: + """生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)""" + if not self.llm: + return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"] + try: + empty_slots = memoir_state.empty_slots_for_current_stage() + filled_slots = { + key: value.snippet + for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items() + if value.snippet + } + detected_user_stage = self._detect_user_stage(user_message) + history_messages = await get_history_messages(conversation_id) + conversation_turn = len(history_messages) // 2 + same_topic_turns = self._estimate_same_topic_turns( + history_messages, filled_slots + ) + all_stages_coverage = memoir_state.all_stages_coverage() + system_prompt = get_guided_conversation_prompt( + current_stage=memoir_state.current_stage, + empty_slots=empty_slots, + filled_slots=filled_slots, + user_message=user_message, + conversation_turn=conversation_turn, + same_topic_turns=same_topic_turns, + all_stages_coverage=all_stages_coverage, + detected_user_stage=detected_user_stage, + user_profile_context=user_profile_context, + ) + history_string = format_history_string(history_messages) + full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, "content") else str(response) + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + return messages[:3] if messages else [response_text] + except Exception as e: + logger.error("生成回应失败: %s", e) + return [f"抱歉,生成回应时出现错误: {str(e)}"] + + async def generate_opening_message( + self, + conversation_id: str, + memoir_state: MemoirStateSchema, + user_profile_context: str = "", + ) -> List[str]: + """生成空对话开场白,不持久化(由 Orchestrator 负责)""" + if not self.llm: + return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"] + try: + empty_slots = memoir_state.empty_slots_for_current_stage() + empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots] + if not empty_slots_readable: + empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"] + prompt = get_opening_prompt( + current_stage=memoir_state.current_stage, + empty_slots_readable=empty_slots_readable, + user_profile_context=user_profile_context, + ) + full_prompt = f"{prompt}\n\nAssistant:" + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, "content") else str(response) + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + return messages[:2] if messages else [response_text] + except Exception as e: + logger.error("生成开场白失败: %s", e, exc_info=True) + return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"] diff --git a/api/app/agents/chat/orchestrator.py b/api/app/agents/chat/orchestrator.py new file mode 100644 index 0000000..1a5e66d --- /dev/null +++ b/api/app/agents/chat/orchestrator.py @@ -0,0 +1,246 @@ +""" +ChatOrchestrator:AI 回复用户模块的编排层 +负责路由(Profile vs Interview)、调用 Specialist Agent、统一 Redis 持久化与错误处理 +""" +from datetime import datetime +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.agents.chat.helpers import save_message +from app.agents.chat.interview_agent import InterviewAgent +from app.agents.chat.profile_agent import ProfileAgent +from app.agents.state_schema import MemoirStateSchema +from app.core.logging import get_logger +from app.features.memoir.state_service import get_or_create_state + +if TYPE_CHECKING: + from app.features.user.models import User + +logger = get_logger(__name__) + + +class ChatOrchestrator: + """ + 聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent, + 统一管理 Redis 写入。 + """ + + def __init__(self): + self.profile_agent = ProfileAgent() + self.interview_agent = InterviewAgent() + + async def process_user_message( + self, + conversation_id: str, + user_message: str, + user: Optional["User"], + conversation, # 用于更新 conversation_stage + is_from_voice: bool, + voice_session_id: Optional[str], + db: AsyncSession, + apply_extracted_profile_fn, + get_missing_profile_fields_fn, + get_filled_profile_fields_fn, + user_message_timestamp: Optional[datetime] = None, + ) -> List[str]: + """ + 处理用户消息,返回 AI 回复列表。 + 根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent, + 统一写入 Redis。 + """ + + # --- 资料收集模式 --- + if user: + missing = get_missing_profile_fields_fn(user) + if missing: + try: + extracted = await self.profile_agent.extract_profile_from_message( + user_message, missing, conversation_id=conversation_id + ) + if extracted: + await apply_extracted_profile_fn(user, extracted, db) + + remaining = get_missing_profile_fields_fn(user) + filled = get_filled_profile_fields_fn(user) + responses = await self.profile_agent.generate_profile_followup( + conversation_id=conversation_id, + user_message=user_message, + missing_fields=remaining, + filled_fields=filled, + nickname=user.nickname or "", + ) + await self._save_messages( + conversation_id=conversation_id, + user_message=user_message, + response_text="\n\n".join(responses), + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + ) + return responses + except Exception as e: + logger.error(f"资料收集处理失败: {e}", exc_info=True) + + # --- 正式访谈模式 --- + user_id = user.id if user else None + if not user_id: + return ["抱歉,无法识别用户。"] + + state = await get_or_create_state(user_id, db) + if conversation and conversation.conversation_stage != state.current_stage: + conversation.conversation_stage = state.current_stage + await db.commit() + + from app.agents.prompts.profile_prompts import format_user_profile_context + + user_profile_context = "" + if user: + user_profile_context = format_user_profile_context( + birth_year=user.birth_year, + birth_place=user.birth_place, + grew_up_place=user.grew_up_place, + occupation=user.occupation, + ) + + responses = await self.interview_agent.generate_response_with_state( + conversation_id=conversation_id, + user_message=user_message, + memoir_state=state, + user_profile_context=user_profile_context, + ) + await self._save_messages( + conversation_id=conversation_id, + user_message=user_message, + response_text="\n\n".join(responses), + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + ) + return responses + + async def _save_messages( + self, + conversation_id: str, + user_message: str, + response_text: str, + is_from_voice: bool = False, + voice_session_id: Optional[str] = None, + user_message_timestamp: Optional[datetime] = None, + ) -> None: + """统一写入 Human + AI 消息到 Redis""" + human_msg_type = "audio" if is_from_voice else "text" + await save_message( + conversation_id, + "human", + user_message, + message_type=human_msg_type, + voice_session_id=voice_session_id, + timestamp=user_message_timestamp, + ) + await save_message(conversation_id, "ai", response_text) + + async def extract_profile_from_message( + self, + user_message: str, + missing_fields: List[str], + conversation_id: Optional[str] = None, + ): + """委托 ProfileAgent 提取资料""" + return await self.profile_agent.extract_profile_from_message( + user_message, missing_fields, conversation_id=conversation_id + ) + + async def generate_profile_followup( + self, + conversation_id: str, + user_message: str, + missing_fields: List[str], + filled_fields: dict, + nickname: str = "", + is_from_voice: bool = False, + voice_session_id: str | None = None, + user_message_timestamp: datetime | None = None, + ) -> List[str]: + """委托 ProfileAgent 生成资料追问,并写入 Redis""" + responses = await self.profile_agent.generate_profile_followup( + conversation_id=conversation_id, + user_message=user_message, + missing_fields=missing_fields, + filled_fields=filled_fields, + nickname=nickname, + ) + response_text = "\n\n".join(responses) + await self._save_messages( + conversation_id=conversation_id, + user_message=user_message, + response_text=response_text, + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + ) + return responses + + async def generate_profile_greeting( + self, + conversation_id: str, + missing_fields: List[str], + nickname: str = "", + ) -> List[str]: + """委托 ProfileAgent 生成资料收集开场白,并写入 Redis""" + responses = await self.profile_agent.generate_profile_greeting( + conversation_id=conversation_id, + missing_fields=missing_fields, + nickname=nickname, + ) + response_text = "\n\n".join(responses) + await save_message(conversation_id, "ai", response_text) + return responses + + async def generate_response_with_state( + self, + conversation_id: str, + user_message: str, + memoir_state: MemoirStateSchema, + user_profile_context: str = "", + is_from_voice: bool = False, + voice_session_id: str | None = None, + user_message_timestamp: datetime | None = None, + ) -> List[str]: + """委托 InterviewAgent 生成访谈回复,并写入 Redis""" + responses = await self.interview_agent.generate_response_with_state( + conversation_id=conversation_id, + user_message=user_message, + memoir_state=memoir_state, + user_profile_context=user_profile_context, + ) + response_text = "\n\n".join(responses) + await self._save_messages( + conversation_id=conversation_id, + user_message=user_message, + response_text=response_text, + is_from_voice=is_from_voice, + voice_session_id=voice_session_id, + user_message_timestamp=user_message_timestamp, + ) + return responses + + def detect_user_stage(self, user_message: str) -> str: + """委托 InterviewAgent 检测用户阶段""" + return self.interview_agent._detect_user_stage(user_message) + + async def generate_opening_message( + self, + conversation_id: str, + memoir_state: MemoirStateSchema, + user_profile_context: str = "", + ) -> List[str]: + """委托 InterviewAgent 生成开场白,并写入 Redis""" + responses = await self.interview_agent.generate_opening_message( + conversation_id=conversation_id, + memoir_state=memoir_state, + user_profile_context=user_profile_context, + ) + response_text = "\n\n".join(responses) + await save_message(conversation_id, "ai", response_text) + return responses diff --git a/api/app/agents/chat/profile_agent.py b/api/app/agents/chat/profile_agent.py new file mode 100644 index 0000000..ae7688b --- /dev/null +++ b/api/app/agents/chat/profile_agent.py @@ -0,0 +1,132 @@ +""" +ProfileAgent:用户资料收集 Specialist +负责提取资料、资料追问、资料收集开场白,不负责 Redis 持久化(由 Orchestrator 统一处理) +""" +import json +from typing import Any, Dict, List, Optional + +from langchain_core.messages import AIMessage, HumanMessage + +from app.core.dependencies import get_llm_provider +from app.core.logging import get_logger + +from app.agents.chat.helpers import format_history_string, get_history_messages +from app.agents.prompts.profile_prompts import ( + get_profile_extraction_prompt, + get_profile_followup_prompt, + get_profile_greeting_prompt, +) + +logger = get_logger(__name__) + + +def _get_langchain_llm(): + try: + provider = get_llm_provider() + return getattr(provider, "langchain_llm", None) + except Exception: + return None + + +class ProfileAgent: + """用户资料收集 Specialist Agent""" + + def __init__(self): + self.llm = _get_langchain_llm() + + async def extract_profile_from_message( + self, + user_message: str, + missing_fields: List[str], + conversation_id: Optional[str] = None, + ) -> Dict[str, Any]: + """从用户消息中提取资料字段,不持久化""" + if not self.llm or not missing_fields: + return {} + recent_dialogue = "" + if conversation_id: + history_messages = await get_history_messages(conversation_id) + recent = history_messages[-4:] if len(history_messages) > 4 else history_messages + parts = [] + for msg in recent: + if isinstance(msg, HumanMessage): + parts.append(f"用户: {msg.content}") + elif isinstance(msg, AIMessage): + parts.append(f"助手: {msg.content}") + recent_dialogue = "\n".join(parts) if parts else "" + try: + prompt = get_profile_extraction_prompt( + user_message, missing_fields, recent_dialogue=recent_dialogue or None + ) + response = await self.llm.ainvoke(prompt) + content = response.content.strip() + parsed = json.loads(content) + result = {} + if "birth_year" in parsed and parsed["birth_year"] is not None: + raw = parsed["birth_year"] + if isinstance(raw, int) and 1900 <= raw <= 2100: + result["birth_year"] = raw + elif isinstance(raw, str) and raw.isdigit(): + y = int(raw) + if y < 100: + y = 1900 + y if y >= 50 else 2000 + y + if 1900 <= y <= 2100: + result["birth_year"] = y + if "birth_place" in parsed and parsed["birth_place"]: + result["birth_place"] = str(parsed["birth_place"]) + if "grew_up_place" in parsed and parsed["grew_up_place"]: + result["grew_up_place"] = str(parsed["grew_up_place"]) + if "occupation" in parsed and parsed["occupation"]: + result["occupation"] = str(parsed["occupation"]) + return result + except (json.JSONDecodeError, Exception) as e: + logger.error("提取资料信息失败: %s", e) + return {} + + async def generate_profile_followup( + self, + conversation_id: str, + user_message: str, + missing_fields: List[str], + filled_fields: Dict[str, str], + nickname: str = "", + ) -> List[str]: + """生成资料追问回复,不持久化(由 Orchestrator 负责)""" + if not self.llm: + return ["谢谢!还能告诉我更多吗?"] + try: + prompt = get_profile_followup_prompt( + missing_fields, filled_fields, user_message, nickname + ) + history_messages = await get_history_messages(conversation_id) + history_string = format_history_string(history_messages) + full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, "content") else str(response) + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + return messages[:3] if messages else [response_text] + except Exception as e: + logger.error("生成资料跟进回复失败: %s", e) + return ["谢谢分享!能再告诉我一些吗?"] + + async def generate_profile_greeting( + self, + conversation_id: str, + missing_fields: List[str], + nickname: str = "", + ) -> List[str]: + """生成资料收集开场白,不持久化(由 Orchestrator 负责)""" + if not self.llm: + return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"] + try: + prompt = get_profile_greeting_prompt(missing_fields, nickname) + history_messages = await get_history_messages(conversation_id) + history_string = format_history_string(history_messages) + full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt + response = await self.llm.ainvoke(full_prompt) + response_text = response.content if hasattr(response, "content") else str(response) + messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] + return messages[:2] if messages else [response_text] + except Exception as e: + logger.error("生成资料收集开场白失败: %s", e) + return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"] diff --git a/api/app/agents/conversation_agent.py b/api/app/agents/conversation_agent.py deleted file mode 100644 index 3430f51..0000000 --- a/api/app/agents/conversation_agent.py +++ /dev/null @@ -1,355 +0,0 @@ -""" -对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应 -支持异步调用和 Redis 会话存储,支持用户基础资料收集和时代背景融入 -""" -import json -from app.core.logging import get_logger -from datetime import datetime -from typing import Any, Dict, List, Optional - -from langchain_core.messages import AIMessage, HumanMessage - -from app.core.redis import redis_service -from app.core.dependencies import get_llm_provider - -from app.agents.prompts import ( - ConversationStage, - get_conversation_prompt, - get_guided_conversation_prompt, - get_opening_prompt, -) -from app.agents.prompts.profile_prompts import ( - get_profile_greeting_prompt, - get_profile_extraction_prompt, - get_profile_followup_prompt, - format_user_profile_context, - get_missing_profile_fields, -) -from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP -from app.agents.state_schema import MemoirStateSchema - -logger = get_logger(__name__) - - -def _get_langchain_llm(): - """从 port 获取 LangChain LLM 实例(供 Agent 使用)""" - try: - provider = get_llm_provider() - return getattr(provider, "langchain_llm", None) - except Exception: - return None - - -class ConversationAgent: - """对话 Agent(支持异步和 Redis 存储)""" - - def __init__(self): - self.llm = _get_langchain_llm() - - async def _get_history_messages(self, conversation_id: str) -> List[Any]: - history = await redis_service.get_conversation_history(conversation_id) - messages = [] - for msg in history: - if msg["role"] == "human": - messages.append(HumanMessage(content=msg["content"])) - elif msg["role"] == "ai": - messages.append(AIMessage(content=msg["content"])) - return messages - - async def _save_message( - self, - conversation_id: str, - role: str, - content: str, - message_type: str = "text", - voice_session_id: str | None = None, - timestamp: datetime | str | int | None = None, - ): - await redis_service.add_message( - conversation_id, - role, - content, - message_type=message_type, - voice_session_id=voice_session_id, - timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp, - ) - - def _format_history_string(self, messages: List[Any]) -> str: - history_parts = [] - for msg in messages: - if isinstance(msg, HumanMessage): - history_parts.append(f"Human: {msg.content}") - elif isinstance(msg, AIMessage): - history_parts.append(f"Assistant: {msg.content}") - return "\n\n".join(history_parts) - - async def generate_response( - self, - conversation_id: str, - user_message: str, - current_stage: Optional[ConversationStage] = None, - covered_topics: Optional[List[str]] = None, - ) -> str: - if current_stage is None: - current_stage = ConversationStage.CHILDHOOD - if covered_topics is None: - covered_topics = [] - if not self.llm: - return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。" - try: - system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message) - history_messages = await self._get_history_messages(conversation_id) - history_string = self._format_history_string(history_messages) - full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" - response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - await self._save_message(conversation_id, "human", user_message) - await self._save_message(conversation_id, "ai", response_text) - return response_text - except Exception as e: - logger.error("生成回应失败: %s", e) - return f"抱歉,生成回应时出现错误: {str(e)}" - - async def generate_profile_greeting( - self, - conversation_id: str, - missing_fields: List[str], - nickname: str = "", - ) -> List[str]: - if not self.llm: - return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"] - try: - prompt = get_profile_greeting_prompt(missing_fields, nickname) - history_messages = await self._get_history_messages(conversation_id) - history_string = self._format_history_string(history_messages) - full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt - response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - await self._save_message(conversation_id, "ai", response_text) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - return messages[:2] if messages else [response_text] - except Exception as e: - logger.error("生成资料收集开场白失败: %s", e) - return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"] - - async def generate_opening_message( - self, - conversation_id: str, - memoir_state: MemoirStateSchema, - user_profile_context: str = "", - ) -> List[str]: - if not self.llm: - return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"] - try: - empty_slots = memoir_state.empty_slots_for_current_stage() - empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots] - if not empty_slots_readable: - empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"] - prompt = get_opening_prompt( - current_stage=memoir_state.current_stage, - empty_slots_readable=empty_slots_readable, - user_profile_context=user_profile_context, - ) - full_prompt = f"{prompt}\n\nAssistant:" - response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - await self._save_message(conversation_id, "ai", response_text) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - return messages[:2] if messages else [response_text] - except Exception as e: - logger.error("生成开场白失败: %s", e, exc_info=True) - return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"] - - async def extract_profile_from_message( - self, - user_message: str, - missing_fields: List[str], - conversation_id: Optional[str] = None, - ) -> Dict[str, Any]: - if not self.llm or not missing_fields: - return {} - recent_dialogue = "" - if conversation_id: - history_messages = await self._get_history_messages(conversation_id) - recent = history_messages[-4:] if len(history_messages) > 4 else history_messages - parts = [] - for msg in recent: - if isinstance(msg, HumanMessage): - parts.append(f"用户: {msg.content}") - elif isinstance(msg, AIMessage): - parts.append(f"助手: {msg.content}") - recent_dialogue = "\n".join(parts) if parts else "" - try: - prompt = get_profile_extraction_prompt( - user_message, missing_fields, recent_dialogue=recent_dialogue or None - ) - response = await self.llm.ainvoke(prompt) - content = response.content.strip() - parsed = json.loads(content) - result = {} - if "birth_year" in parsed and parsed["birth_year"] is not None: - raw = parsed["birth_year"] - if isinstance(raw, int) and 1900 <= raw <= 2100: - result["birth_year"] = raw - elif isinstance(raw, str) and raw.isdigit(): - y = int(raw) - if y < 100: - y = 1900 + y if y >= 50 else 2000 + y - if 1900 <= y <= 2100: - result["birth_year"] = y - if "birth_place" in parsed and parsed["birth_place"]: - result["birth_place"] = str(parsed["birth_place"]) - if "grew_up_place" in parsed and parsed["grew_up_place"]: - result["grew_up_place"] = str(parsed["grew_up_place"]) - if "occupation" in parsed and parsed["occupation"]: - result["occupation"] = str(parsed["occupation"]) - return result - except (json.JSONDecodeError, Exception) as e: - logger.error("提取资料信息失败: %s", e) - return {} - - async def generate_profile_followup( - self, - conversation_id: str, - user_message: str, - missing_fields: List[str], - filled_fields: Dict[str, str], - nickname: str = "", - is_from_voice: bool = False, - voice_session_id: str | None = None, - user_message_timestamp: datetime | None = None, - ) -> List[str]: - if not self.llm: - return ["谢谢!还能告诉我更多吗?"] - try: - prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname) - history_messages = await self._get_history_messages(conversation_id) - history_string = self._format_history_string(history_messages) - full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" - response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - human_msg_type = "audio" if is_from_voice else "text" - await self._save_message( - conversation_id, - "human", - user_message, - message_type=human_msg_type, - voice_session_id=voice_session_id, - timestamp=user_message_timestamp, - ) - await self._save_message(conversation_id, "ai", response_text) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - return messages[:3] if messages else [response_text] - except Exception as e: - logger.error("生成资料跟进回复失败: %s", e) - return ["谢谢分享!能再告诉我一些吗?"] - - def _detect_user_stage(self, user_message: str) -> str: - message = user_message.lower() - stage_keywords = { - "childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"], - "education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"], - "career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"], - "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"], - "belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"], - } - for stage, keywords in stage_keywords.items(): - if any(word in message for word in keywords): - return stage - return "" - - async def generate_response_with_state( - self, - conversation_id: str, - user_message: str, - memoir_state: MemoirStateSchema, - user_profile_context: str = "", - is_from_voice: bool = False, - voice_session_id: str | None = None, - user_message_timestamp: datetime | None = None, - ) -> List[str]: - if not self.llm: - return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"] - try: - empty_slots = memoir_state.empty_slots_for_current_stage() - filled_slots = { - key: value.snippet - for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items() - if value.snippet - } - detected_user_stage = self._detect_user_stage(user_message) - history_messages = await self._get_history_messages(conversation_id) - conversation_turn = len(history_messages) // 2 - same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots) - all_stages_coverage = memoir_state.all_stages_coverage() - system_prompt = get_guided_conversation_prompt( - current_stage=memoir_state.current_stage, - empty_slots=empty_slots, - filled_slots=filled_slots, - user_message=user_message, - conversation_turn=conversation_turn, - same_topic_turns=same_topic_turns, - all_stages_coverage=all_stages_coverage, - detected_user_stage=detected_user_stage, - user_profile_context=user_profile_context, - ) - history_string = self._format_history_string(history_messages) - full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:" - response = await self.llm.ainvoke(full_prompt) - response_text = response.content if hasattr(response, "content") else str(response) - human_msg_type = "audio" if is_from_voice else "text" - await self._save_message( - conversation_id, - "human", - user_message, - message_type=human_msg_type, - voice_session_id=voice_session_id, - timestamp=user_message_timestamp, - ) - await self._save_message(conversation_id, "ai", response_text) - messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()] - return messages[:3] if messages else [response_text] - except Exception as e: - logger.error("生成回应失败: %s", e) - return [f"抱歉,生成回应时出现错误: {str(e)}"] - - def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int: - if len(history_messages) < 4: - return len(history_messages) // 2 - recent_messages = history_messages[-6:] - keywords_per_turn = [] - for i in range(0, len(recent_messages), 2): - if i + 1 < len(recent_messages): - human_msg = ( - recent_messages[i].content - if hasattr(recent_messages[i], "content") - else str(recent_messages[i]) - ) - ai_msg = ( - recent_messages[i + 1].content - if hasattr(recent_messages[i + 1], "content") - else str(recent_messages[i + 1]) - ) - keywords_per_turn.append((human_msg + ai_msg)[:100]) - if len(keywords_per_turn) >= 3: - return 3 - return len(keywords_per_turn) - - def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage: - message_lower = user_message.lower() - if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]): - return ConversationStage.CHILDHOOD - if any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]): - return ConversationStage.EDUCATION - if any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]): - return ConversationStage.CAREER - if any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]): - return ConversationStage.FAMILY - if any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]): - return ConversationStage.BELIEFS - if any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]): - return ConversationStage.SUMMARY - return ConversationStage.CHILDHOOD - - async def clear_memory(self, conversation_id: str): - await redis_service.clear_conversation_history(conversation_id) diff --git a/api/app/agents/memoir_processor.py b/api/app/agents/memoir_processor.py deleted file mode 100644 index 18eece3..0000000 --- a/api/app/agents/memoir_processor.py +++ /dev/null @@ -1,211 +0,0 @@ -""" -回忆录后台处理器:分析对话、更新状态、生成章节、创意标题 -使用 Celery 进行后台任务处理 -""" -from __future__ import annotations - -import json -from app.core.logging import get_logger -from dataclasses import dataclass -from typing import Dict, List - -from app.core.dependencies import get_llm_provider -from app.core.task_tracker import task_tracker -from app.agents.state_schema import MemoirStateSchema -from app.agents.prompts.memory_prompts import ( - get_creative_title_prompt, - get_narrative_prompt, - get_state_extraction_prompt, -) - -logger = get_logger(__name__) - -STAGE_KEYWORDS = { - "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], - "education": ["上学", "学校", "老师", "同学", "教育", "大学"], - "career": ["工作", "职业", "事业", "公司", "同事", "创业"], - "family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"], - "belief": ["信念", "价值观", "座右铭", "坚持", "原则"], -} - - -def _get_langchain_llm(): - try: - provider = get_llm_provider() - return getattr(provider, "langchain_llm", None) - except Exception: - return None - - -@dataclass -class AnalysisResult: - detected_stage: str - extracted_slots: Dict[str, str] - emotion: str - is_new_chapter: bool - - -class ContentAnalyzer: - def __init__(self) -> None: - self.llm = _get_langchain_llm() - - def _detect_stage(self, user_message: str, fallback_stage: str) -> str: - message = user_message.lower() - for stage, keywords in STAGE_KEYWORDS.items(): - if any(word in message for word in keywords): - return stage - return fallback_stage - - def _fallback_slots( - self, state: MemoirStateSchema, stage: str, user_message: str - ) -> Dict[str, str]: - stage_slots = state.slots.get(stage, {}) - for key, value in stage_slots.items(): - if not value.snippet: - return {key: user_message.strip()[:200]} - return {} - - async def analyze_message( - self, user_message: str, current_state: MemoirStateSchema - ) -> AnalysisResult: - detected_stage = self._detect_stage( - user_message, current_state.current_stage - ) - extracted_slots: Dict[str, str] = {} - emotion = "neutral" - is_new_chapter = False - if self.llm: - try: - prompt = get_state_extraction_prompt( - user_message=user_message, - current_stage=current_state.current_stage, - stage_slots=current_state.slots.get(detected_stage, {}), - ) - response = await self.llm.ainvoke(prompt) - content = response.content.strip() - parsed = json.loads(content) - detected_stage = parsed.get("detected_stage", detected_stage) - extracted_slots = parsed.get("slots", {}) or {} - emotion = parsed.get("emotion", emotion) - is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter)) - except json.JSONDecodeError: - extracted_slots = self._fallback_slots( - current_state, detected_stage, user_message - ) - except Exception as e: - logger.error("分析消息失败: %s", e) - extracted_slots = self._fallback_slots( - current_state, detected_stage, user_message - ) - else: - extracted_slots = self._fallback_slots( - current_state, detected_stage, user_message - ) - return AnalysisResult( - detected_stage=detected_stage, - extracted_slots=extracted_slots, - emotion=emotion, - is_new_chapter=is_new_chapter, - ) - - -class MemoirGenerator: - def __init__(self) -> None: - self.llm = _get_langchain_llm() - - async def generate_chapter_title( - self, stage: str, slots: Dict[str, str], emotion: str - ) -> str: - if not self.llm: - return f"{stage} 回忆" - try: - prompt = get_creative_title_prompt( - stage=stage, emotion=emotion, slots=slots - ) - response = await self.llm.ainvoke(prompt) - return response.content.strip().strip('"') - except Exception as e: - logger.error("生成标题失败: %s", e) - return f"{stage} 回忆" - - async def generate_narrative( - self, - stage: str, - slots: Dict[str, str], - new_content: str, - existing_content: str, - ) -> str: - if not self.llm: - if existing_content: - return f"{existing_content}\n\n{new_content}" - return new_content - try: - prompt = get_narrative_prompt( - stage=stage, - slots=slots, - new_content=new_content, - existing_content=existing_content, - ) - response = await self.llm.ainvoke(prompt) - return response.content.strip() - except Exception as e: - logger.error("生成叙事失败: %s", e) - if existing_content: - return f"{existing_content}\n\n{new_content}" - return new_content - - -class BackgroundTaskRunner: - def __init__(self, debounce_seconds: int = 5) -> None: - self.debounce_seconds = debounce_seconds - self._pending: Dict[str, List[str]] = {} - self._timers: Dict[str, object] = {} - self.analyzer = ContentAnalyzer() - self.generator = MemoirGenerator() - - async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None: - try: - from app.tasks.memoir_tasks import process_memoir_segments - - result = process_memoir_segments.delay(user_id, segment_ids) - task_id = result.id - await task_tracker.add_task(user_id, task_id, "memoir") - logger.info( - "已提交 Celery 任务: user_id=%s, task_id=%s, segments=%s", - user_id, - task_id, - len(segment_ids), - ) - return task_id - except Exception as e: - logger.error("提交 Celery 任务失败: %s", e) - return None - - async def queue_message(self, user_id: str, segment_id: str) -> None: - import asyncio - - self._pending.setdefault(user_id, []).append(segment_id) - if user_id in self._timers: - self._timers[user_id].cancel() - - async def delayed_submit(): - try: - await asyncio.sleep(self.debounce_seconds) - segment_ids = self._pending.pop(user_id, []) - if segment_ids: - await self._submit_task(user_id, segment_ids) - except asyncio.CancelledError: - pass - except Exception as e: - logger.error("延迟提交任务失败: %s", e) - - self._timers[user_id] = asyncio.create_task(delayed_submit()) - - async def flush_pending(self, user_id: str) -> str | None: - if user_id in self._timers: - self._timers[user_id].cancel() - del self._timers[user_id] - segment_ids = self._pending.pop(user_id, []) - if segment_ids: - return await self._submit_task(user_id, segment_ids) - return None diff --git a/api/app/agents/memory_agent.py b/api/app/agents/memory_agent.py deleted file mode 100644 index 2970a92..0000000 --- a/api/app/agents/memory_agent.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节 -支持异步调用 -""" -import json -from app.core.logging import get_logger -from typing import Dict, List, Optional - -from app.core.dependencies import get_llm_provider - -from app.agents.prompts import ( - get_chapter_classification_prompt, - get_text_rewrite_prompt, - inject_image_placeholder_template, - CHAPTER_CATEGORIES, - STAGE_TO_ORDER, -) - -logger = get_logger(__name__) - - -def _get_langchain_llm(): - try: - provider = get_llm_provider() - return getattr(provider, "langchain_llm", None) - except Exception: - return None - - -class MemoryAgent: - """回忆录整理 Agent(支持异步)""" - - def __init__(self): - self.llm = _get_langchain_llm() - - async def classify_chapter(self, segments_text: str) -> str: - if not self.llm: - return "childhood" - try: - prompt = get_chapter_classification_prompt(segments_text) - response = await self.llm.ainvoke(prompt) - content = response.content if hasattr(response, "content") else str(response) - category = content.strip().lower() - if category in CHAPTER_CATEGORIES: - return category - except Exception as e: - logger.error("分类章节失败: %s", e) - return "childhood" - - async def rewrite_to_literary( - self, - segments_text: str, - chapter_category: str, - existing_content: Optional[str] = None, - ) -> Dict: - if not self.llm: - return { - "title": CHAPTER_CATEGORIES.get(chapter_category, "章节"), - "content": segments_text, - "summary": "", - "image_suggestions": [], - } - try: - prompt = get_text_rewrite_prompt( - segments_text, chapter_category, existing_content or "" - ) - response = await self.llm.ainvoke(prompt) - content = response.content if hasattr(response, "content") else str(response) - content = content.strip() - if content.startswith("```json"): - content = content[7:] - if content.startswith("```"): - content = content[3:] - if content.endswith("```"): - content = content[:-3] - content = content.strip() - result = json.loads(content) - result["content"] = inject_image_placeholder_template( - result.get("content") or "" - ) - return result - except json.JSONDecodeError: - raw = response.content if hasattr(response, "content") else str(response) - return { - "title": CHAPTER_CATEGORIES.get(chapter_category, "章节"), - "content": inject_image_placeholder_template(raw), - "summary": "", - "image_suggestions": [], - } - except Exception as e: - logger.error("改写文本失败: %s", e) - return { - "title": CHAPTER_CATEGORIES.get(chapter_category, "章节"), - "content": segments_text, - "summary": "", - "image_suggestions": [], - } - - async def process_segments( - self, - segments: List[Dict], - existing_chapters: Optional[Dict[str, Dict]] = None, - ) -> Dict[str, Dict]: - if existing_chapters is None: - existing_chapters = {} - segments_by_category: Dict[str, List[str]] = {} - for segment in segments: - text = segment.get("transcript_text", "") - if not text: - continue - category = await self.classify_chapter(text) - if category not in segments_by_category: - segments_by_category[category] = [] - segments_by_category[category].append(text) - updated_chapters = existing_chapters.copy() - for category, texts in segments_by_category.items(): - combined_text = "\n\n".join(texts) - existing_content = existing_chapters.get(category, {}).get("content", "") - result = await self.rewrite_to_literary( - combined_text, category, existing_content - ) - updated_chapters[category] = { - "title": result.get("title", CHAPTER_CATEGORIES.get(category, "章节")), - "content": result.get("content", ""), - "summary": result.get("summary", ""), - "image_suggestions": result.get("image_suggestions", []), - "category": category, - "order_index": STAGE_TO_ORDER.get(category, 999), - } - return updated_chapters diff --git a/api/app/features/conversation/ws/pipeline.py b/api/app/features/conversation/ws/pipeline.py index b3e0c04..b04555f 100644 --- a/api/app/features/conversation/ws/pipeline.py +++ b/api/app/features/conversation/ws/pipeline.py @@ -14,8 +14,8 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agents import ConversationAgent, MemoryAgent -from app.agents.memoir_processor import BackgroundTaskRunner -from app.agents.prompts.profile_prompts import format_user_profile_context +from app.agents.chat import ChatOrchestrator +from app.agents.memoir import BackgroundTaskRunner from app.core.db import AsyncSessionLocal from app.features.conversation.models import Conversation, Segment from app.features.conversation.ws.connection_manager import manager @@ -28,7 +28,6 @@ from app.features.conversation.ws.profile_collector import ( from app.features.user.models import User from app.core.config import settings from app.core.dependencies import get_asr_provider, get_tts_provider -from app.features.memoir.state_service import get_or_create_state logger = get_logger(__name__) @@ -53,6 +52,7 @@ async def _send_tts_audio(conversation_id: str, text: str) -> None: # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── conversation_agent = ConversationAgent() +chat_orchestrator = ChatOrchestrator() memory_agent = MemoryAgent() background_runner = BackgroundTaskRunner() @@ -429,82 +429,21 @@ async def process_user_message( user: User = None, user_message_timestamp: Optional[datetime] = None, ) -> None: - """处理用户消息,生成 Agent 回应。支持资料收集模式和正式访谈模式。""" - agent = conversation_agent - - if user: - missing = get_missing_profile_fields(user) - if missing: - try: - extracted = await agent.extract_profile_from_message( - user_message, missing, conversation_id=conversation_id - ) - if extracted: - await apply_extracted_profile(user, extracted, db) - - remaining = get_missing_profile_fields(user) - filled = get_filled_profile_fields(user) - is_from_voice = bool(segment.audio_url) - responses = await agent.generate_profile_followup( - conversation_id=conversation_id, - user_message=user_message, - missing_fields=remaining, - filled_fields=filled, - nickname=user.nickname or "", - is_from_voice=is_from_voice, - voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), - user_message_timestamp=user_message_timestamp, - ) - - segment.agent_response = "\n\n".join(responses) - _mark_conversation_active(conversation) - await db.commit() - - for i, response_text in enumerate(responses): - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": response_text, "index": i, "total": len(responses)}, - "timestamp": datetime.now(timezone.utc).isoformat(), - }) - await _send_tts_audio(conversation_id, response_text) - if i < len(responses) - 1: - await asyncio.sleep(0.5) - return - except Exception as e: - logger.error(f"资料收集处理失败: {e}", exc_info=True) - - state = await get_or_create_state(conversation.user_id, db) - - if conversation.conversation_stage != state.current_stage: - conversation.conversation_stage = state.current_stage - await db.commit() - - stmt_segments = select(Segment).where( - Segment.conversation_id == conversation_id - ).order_by(Segment.created_at) - result_segments = await db.execute(stmt_segments) - previous_segments = result_segments.scalars().all() - covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] - - user_profile_context = "" - if user: - user_profile_context = format_user_profile_context( - birth_year=user.birth_year, - birth_place=user.birth_place, - grew_up_place=user.grew_up_place, - occupation=user.occupation, - ) - + """处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。""" try: is_from_voice = bool(segment.audio_url) - responses = await agent.generate_response_with_state( + voice_session_id = _voice_session_id_from_audio_url(segment.audio_url) + responses = await chat_orchestrator.process_user_message( conversation_id=conversation_id, user_message=user_message, - memoir_state=state, - user_profile_context=user_profile_context, + user=user, + conversation=conversation, is_from_voice=is_from_voice, - voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), + voice_session_id=voice_session_id, + db=db, + apply_extracted_profile_fn=apply_extracted_profile, + get_missing_profile_fields_fn=get_missing_profile_fields, + get_filled_profile_fields_fn=get_filled_profile_fields, user_message_timestamp=user_message_timestamp, ) diff --git a/api/app/features/memoir/processor.py b/api/app/features/memoir/processor.py index ad367a2..0d19bf0 100644 --- a/api/app/features/memoir/processor.py +++ b/api/app/features/memoir/processor.py @@ -1,6 +1,6 @@ """Memoir processor — 从 agents/memoir_processor.py 迁入的占位。 实际逻辑仍由 agents/memoir_processor.py 提供,后续迁入。""" -from app.agents.memoir_processor import BackgroundTaskRunner +from app.agents.memoir import BackgroundTaskRunner __all__ = ["BackgroundTaskRunner"] diff --git a/api/routers/chapters.py b/api/routers/chapters.py index 7f2bb58..a357f06 100644 --- a/api/routers/chapters.py +++ b/api/routers/chapters.py @@ -14,7 +14,7 @@ from database import get_async_db from database.models import Chapter as ChapterModel, ChapterSection from database.models import User as UserModel from middleware.auth import get_current_user -from agents.prompts.memory_prompts import CHAPTER_CATEGORIES, CHAPTER_ORDER, STAGE_TO_ORDER +from app.agents.prompts.memory_prompts import CHAPTER_CATEGORIES, CHAPTER_ORDER, STAGE_TO_ORDER from services.memoir_images.schema import ( completed_image_assets, IMAGE_STATUS_COMPLETED, diff --git a/api/routers/websocket.py b/api/routers/websocket.py index f6c9afa..4559842 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -15,15 +15,15 @@ from starlette.websockets import WebSocketState from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from agents import ConversationAgent, MemoryAgent -from agents.memoir_processor import BackgroundTaskRunner +from app.agents import ConversationAgent, MemoryAgent +from app.agents.memoir import BackgroundTaskRunner from database import get_async_db from database.models import Conversation, Segment from database.models import User as UserModel from services.auth_service import verify_token from services.memoir_state_service import get_or_create_state from services import asr_service, redis_service -from agents.prompts.profile_prompts import format_user_profile_context +from app.agents.prompts.profile_prompts import format_user_profile_context logger = logging.getLogger(__name__) LEGACY_VOICE_SESSION_ID = "legacy" @@ -924,7 +924,7 @@ async def websocket_endpoint( def _get_missing_profile_fields(user: UserModel) -> list: """检查用户缺失的资料字段""" - from agents.prompts.profile_prompts import get_missing_profile_fields + from app.agents.prompts.profile_prompts import get_missing_profile_fields return get_missing_profile_fields( birth_year=user.birth_year, birth_place=user.birth_place, @@ -935,7 +935,7 @@ def _get_missing_profile_fields(user: UserModel) -> list: def _get_filled_profile_fields(user: UserModel) -> dict: """获取用户已有的资料字段(中文展示)""" - from agents.prompts.profile_prompts import PROFILE_FIELD_NAMES + from app.agents.prompts.profile_prompts import PROFILE_FIELD_NAMES filled = {} if user.birth_year: filled["birth_year"] = str(user.birth_year) @@ -1045,7 +1045,7 @@ async def process_user_message( # 构建用户资料上下文 user_profile_context = "" if user: - from agents.prompts.profile_prompts import format_user_profile_context + from app.agents.prompts.profile_prompts import format_user_profile_context user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place,