feat & refactor: 重构agents目录结构;AI回复模块agent结构封装

This commit is contained in:
yangshilin
2026-03-19 10:36:55 +08:00
parent b56fc859cc
commit b16bb2b96c
14 changed files with 754 additions and 781 deletions

View File

@@ -1,10 +1,19 @@
"""
Agent 模块(app 内,符合架构计划
Agent 模块(按功能拆分chat / memoir
"""
from app.agents.conversation_agent import ConversationAgent
from app.agents.memory_agent import MemoryAgent
from app.agents.chat import (
ChatOrchestrator,
ConversationAgent,
InterviewAgent,
ProfileAgent,
)
from app.agents.memoir import BackgroundTaskRunner, MemoryAgent
__all__ = [
"ConversationAgent",
"MemoryAgent",
"ChatOrchestrator",
"ProfileAgent",
"InterviewAgent",
"BackgroundTaskRunner",
]

View File

@@ -0,0 +1,12 @@
"""聊天模块AI 回复用户ProfileAgent + InterviewAgent + ChatOrchestrator"""
from app.agents.chat.conversation_agent import ConversationAgent
from app.agents.chat.orchestrator import ChatOrchestrator
from app.agents.chat.profile_agent import ProfileAgent
from app.agents.chat.interview_agent import InterviewAgent
__all__ = [
"ConversationAgent",
"ChatOrchestrator",
"ProfileAgent",
"InterviewAgent",
]

View File

@@ -0,0 +1,139 @@
"""
对话 AgentFacade内部委托 ChatOrchestrator + ProfileAgent + InterviewAgent
保留原有对外 API供 router 等调用方兼容使用
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.agents.chat.orchestrator import ChatOrchestrator
from app.agents.prompts import ConversationStage
from app.agents.state_schema import MemoirStateSchema
from app.core.redis import redis_service
class ConversationAgent:
"""对话 Agent Facade委托 ChatOrchestrator 实现多 Agent 协同"""
def __init__(self):
self._orchestrator = ChatOrchestrator()
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""委托 ChatOrchestrator/ProfileAgent 提取资料"""
return await self._orchestrator.extract_profile_from_message(
user_message, missing_fields, conversation_id=conversation_id
)
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
"""委托 ChatOrchestrator/ProfileAgent 生成资料追问"""
return await self._orchestrator.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=missing_fields,
filled_fields=filled_fields,
nickname=nickname,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""委托 ChatOrchestrator/ProfileAgent 生成资料收集开场白"""
return await self._orchestrator.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_fields,
nickname=nickname,
)
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
"""委托 ChatOrchestrator/InterviewAgent 生成访谈回复"""
return await self._orchestrator.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""委托 ChatOrchestrator/InterviewAgent 生成开场白"""
return await self._orchestrator.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
async def generate_response(
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None,
) -> str:
"""兼容旧 API生成简单回复无状态感知委托 InterviewAgent 的等价逻辑"""
from app.agents.state_schema import default_state
state = default_state()
state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value
state.covered_stages = covered_topics or []
responses = await self._orchestrator.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context="",
)
return responses[0] if responses else ""
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
"""根据关键词检测用户阶段(兼容 API"""
detected = self._orchestrator.detect_user_stage(user_message)
if detected == "childhood":
return ConversationStage.CHILDHOOD
if detected == "education":
return ConversationStage.EDUCATION
if detected == "career":
return ConversationStage.CAREER
if detected == "family":
return ConversationStage.FAMILY
if detected == "belief":
return ConversationStage.BELIEFS
return ConversationStage.CHILDHOOD
async def clear_memory(self, conversation_id: str) -> None:
"""清除 Redis 中的对话历史"""
await redis_service.clear_conversation_history(conversation_id)

View File

@@ -0,0 +1,49 @@
"""聊天 Agent 共享工具:历史获取、格式化、存储"""
from datetime import datetime
from typing import Any, List
from langchain_core.messages import AIMessage, HumanMessage
from app.core.redis import redis_service
async def get_history_messages(conversation_id: str) -> List[Any]:
"""从 Redis 获取对话历史"""
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
def format_history_string(messages: List[Any]) -> str:
"""将消息列表格式化为 Human/Assistant 字符串"""
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def save_message(
conversation_id: str,
role: str,
content: str,
message_type: str = "text",
voice_session_id: str | None = None,
timestamp: datetime | str | int | None = None,
) -> None:
"""保存消息到 Redis"""
await redis_service.add_message(
conversation_id,
role,
content,
message_type=message_type,
voice_session_id=voice_session_id,
timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
)

View File

@@ -0,0 +1,143 @@
"""
InterviewAgent正式访谈 Specialist
负责状态感知回复、开场白,不负责 Redis 持久化(由 Orchestrator 统一处理)
"""
from typing import Any, List
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.agents.chat.helpers import format_history_string, get_history_messages
from app.agents.prompts import get_guided_conversation_prompt, get_opening_prompt
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
from app.agents.state_schema import MemoirStateSchema
logger = get_logger(__name__)
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class InterviewAgent:
"""正式访谈 Specialist Agent"""
def __init__(self):
self.llm = _get_langchain_llm()
def _detect_user_stage(self, user_message: str) -> str:
"""根据关键词检测用户正在谈论的人生阶段"""
message = user_message.lower()
stage_keywords = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
}
for stage, keywords in stage_keywords.items():
if any(word in message for word in keywords):
return stage
return ""
def _estimate_same_topic_turns(
self, history_messages: List[Any], current_filled_slots: dict
) -> int:
"""估算同一话题的连续轮数"""
if len(history_messages) < 4:
return len(history_messages) // 2
recent_messages = history_messages[-6:]
keywords_per_turn = []
for i in range(0, len(recent_messages), 2):
if i + 1 < len(recent_messages):
human_msg = (
recent_messages[i].content
if hasattr(recent_messages[i], "content")
else str(recent_messages[i])
)
ai_msg = (
recent_messages[i + 1].content
if hasattr(recent_messages[i + 1], "content")
else str(recent_messages[i + 1])
)
keywords_per_turn.append((human_msg + ai_msg)[:100])
if len(keywords_per_turn) >= 3:
return 3
return len(keywords_per_turn)
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
detected_user_stage = self._detect_user_stage(user_message)
history_messages = await get_history_messages(conversation_id)
conversation_turn = len(history_messages) // 2
same_topic_turns = self._estimate_same_topic_turns(
history_messages, filled_slots
)
all_stages_coverage = memoir_state.all_stages_coverage()
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
conversation_turn=conversation_turn,
same_topic_turns=same_topic_turns,
all_stages_coverage=all_stages_coverage,
detected_user_stage=detected_user_stage,
user_profile_context=user_profile_context,
)
history_string = format_history_string(history_messages)
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成回应失败: %s", e)
return [f"抱歉,生成回应时出现错误: {str(e)}"]
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""生成空对话开场白,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
if not empty_slots_readable:
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
prompt = get_opening_prompt(
current_stage=memoir_state.current_stage,
empty_slots_readable=empty_slots_readable,
user_profile_context=user_profile_context,
)
full_prompt = f"{prompt}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成开场白失败: %s", e, exc_info=True)
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]

View File

@@ -0,0 +1,246 @@
"""
ChatOrchestratorAI 回复用户模块的编排层
负责路由Profile vs Interview、调用 Specialist Agent、统一 Redis 持久化与错误处理
"""
from datetime import datetime
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.helpers import save_message
from app.agents.chat.interview_agent import InterviewAgent
from app.agents.chat.profile_agent import ProfileAgent
from app.agents.state_schema import MemoirStateSchema
from app.core.logging import get_logger
from app.features.memoir.state_service import get_or_create_state
if TYPE_CHECKING:
from app.features.user.models import User
logger = get_logger(__name__)
class ChatOrchestrator:
"""
聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent
统一管理 Redis 写入。
"""
def __init__(self):
self.profile_agent = ProfileAgent()
self.interview_agent = InterviewAgent()
async def process_user_message(
self,
conversation_id: str,
user_message: str,
user: Optional["User"],
conversation, # 用于更新 conversation_stage
is_from_voice: bool,
voice_session_id: Optional[str],
db: AsyncSession,
apply_extracted_profile_fn,
get_missing_profile_fields_fn,
get_filled_profile_fields_fn,
user_message_timestamp: Optional[datetime] = None,
) -> List[str]:
"""
处理用户消息,返回 AI 回复列表。
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent
统一写入 Redis。
"""
# --- 资料收集模式 ---
if user:
missing = get_missing_profile_fields_fn(user)
if missing:
try:
extracted = await self.profile_agent.extract_profile_from_message(
user_message, missing, conversation_id=conversation_id
)
if extracted:
await apply_extracted_profile_fn(user, extracted, db)
remaining = get_missing_profile_fields_fn(user)
filled = get_filled_profile_fields_fn(user)
responses = await self.profile_agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text="\n\n".join(responses),
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
# --- 正式访谈模式 ---
user_id = user.id if user else None
if not user_id:
return ["抱歉,无法识别用户。"]
state = await get_or_create_state(user_id, db)
if conversation and conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
from app.agents.prompts.profile_prompts import format_user_profile_context
user_profile_context = ""
if user:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
responses = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context=user_profile_context,
)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text="\n\n".join(responses),
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
async def _save_messages(
self,
conversation_id: str,
user_message: str,
response_text: str,
is_from_voice: bool = False,
voice_session_id: Optional[str] = None,
user_message_timestamp: Optional[datetime] = None,
) -> None:
"""统一写入 Human + AI 消息到 Redis"""
human_msg_type = "audio" if is_from_voice else "text"
await save_message(
conversation_id,
"human",
user_message,
message_type=human_msg_type,
voice_session_id=voice_session_id,
timestamp=user_message_timestamp,
)
await save_message(conversation_id, "ai", response_text)
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
):
"""委托 ProfileAgent 提取资料"""
return await self.profile_agent.extract_profile_from_message(
user_message, missing_fields, conversation_id=conversation_id
)
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: dict,
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
"""委托 ProfileAgent 生成资料追问,并写入 Redis"""
responses = await self.profile_agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=missing_fields,
filled_fields=filled_fields,
nickname=nickname,
)
response_text = "\n\n".join(responses)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""委托 ProfileAgent 生成资料收集开场白,并写入 Redis"""
responses = await self.profile_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_fields,
nickname=nickname,
)
response_text = "\n\n".join(responses)
await save_message(conversation_id, "ai", response_text)
return responses
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
"""委托 InterviewAgent 生成访谈回复,并写入 Redis"""
responses = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
response_text = "\n\n".join(responses)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
def detect_user_stage(self, user_message: str) -> str:
"""委托 InterviewAgent 检测用户阶段"""
return self.interview_agent._detect_user_stage(user_message)
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""委托 InterviewAgent 生成开场白,并写入 Redis"""
responses = await self.interview_agent.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
response_text = "\n\n".join(responses)
await save_message(conversation_id, "ai", response_text)
return responses

View File

@@ -0,0 +1,132 @@
"""
ProfileAgent用户资料收集 Specialist
负责提取资料、资料追问、资料收集开场白,不负责 Redis 持久化(由 Orchestrator 统一处理)
"""
import json
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from app.core.dependencies import get_llm_provider
from app.core.logging import get_logger
from app.agents.chat.helpers import format_history_string, get_history_messages
from app.agents.prompts.profile_prompts import (
get_profile_extraction_prompt,
get_profile_followup_prompt,
get_profile_greeting_prompt,
)
logger = get_logger(__name__)
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class ProfileAgent:
"""用户资料收集 Specialist Agent"""
def __init__(self):
self.llm = _get_langchain_llm()
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""从用户消息中提取资料字段,不持久化"""
if not self.llm or not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
history_messages = await get_history_messages(conversation_id)
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
parts = []
for msg in recent:
if isinstance(msg, HumanMessage):
parts.append(f"用户: {msg.content}")
elif isinstance(msg, AIMessage):
parts.append(f"助手: {msg.content}")
recent_dialogue = "\n".join(parts) if parts else ""
try:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = {}
if "birth_year" in parsed and parsed["birth_year"] is not None:
raw = parsed["birth_year"]
if isinstance(raw, int) and 1900 <= raw <= 2100:
result["birth_year"] = raw
elif isinstance(raw, str) and raw.isdigit():
y = int(raw)
if y < 100:
y = 1900 + y if y >= 50 else 2000 + y
if 1900 <= y <= 2100:
result["birth_year"] = y
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error("提取资料信息失败: %s", e)
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
) -> List[str]:
"""生成资料追问回复,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(
missing_fields, filled_fields, user_message, nickname
)
history_messages = await get_history_messages(conversation_id)
history_string = format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成资料跟进回复失败: %s", e)
return ["谢谢分享!能再告诉我一些吗?"]
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""生成资料收集开场白,不持久化(由 Orchestrator 负责)"""
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await get_history_messages(conversation_id)
history_string = format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成资料收集开场白失败: %s", e)
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]

View File

@@ -1,355 +0,0 @@
"""
对话 Agent基于访谈问题清单动态选择问题实时生成回应
支持异步调用和 Redis 会话存储,支持用户基础资料收集和时代背景融入
"""
import json
from app.core.logging import get_logger
from datetime import datetime
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from app.core.redis import redis_service
from app.core.dependencies import get_llm_provider
from app.agents.prompts import (
ConversationStage,
get_conversation_prompt,
get_guided_conversation_prompt,
get_opening_prompt,
)
from app.agents.prompts.profile_prompts import (
get_profile_greeting_prompt,
get_profile_extraction_prompt,
get_profile_followup_prompt,
format_user_profile_context,
get_missing_profile_fields,
)
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
from app.agents.state_schema import MemoirStateSchema
logger = get_logger(__name__)
def _get_langchain_llm():
"""从 port 获取 LangChain LLM 实例(供 Agent 使用)"""
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class ConversationAgent:
"""对话 Agent支持异步和 Redis 存储)"""
def __init__(self):
self.llm = _get_langchain_llm()
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
history = await redis_service.get_conversation_history(conversation_id)
messages = []
for msg in history:
if msg["role"] == "human":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "ai":
messages.append(AIMessage(content=msg["content"]))
return messages
async def _save_message(
self,
conversation_id: str,
role: str,
content: str,
message_type: str = "text",
voice_session_id: str | None = None,
timestamp: datetime | str | int | None = None,
):
await redis_service.add_message(
conversation_id,
role,
content,
message_type=message_type,
voice_session_id=voice_session_id,
timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
)
def _format_history_string(self, messages: List[Any]) -> str:
history_parts = []
for msg in messages:
if isinstance(msg, HumanMessage):
history_parts.append(f"Human: {msg.content}")
elif isinstance(msg, AIMessage):
history_parts.append(f"Assistant: {msg.content}")
return "\n\n".join(history_parts)
async def generate_response(
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None,
) -> str:
if current_stage is None:
current_stage = ConversationStage.CHILDHOOD
if covered_topics is None:
covered_topics = []
if not self.llm:
return "抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
try:
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "human", user_message)
await self._save_message(conversation_id, "ai", response_text)
return response_text
except Exception as e:
logger.error("生成回应失败: %s", e)
return f"抱歉,生成回应时出现错误: {str(e)}"
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
if not self.llm:
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
try:
prompt = get_profile_greeting_prompt(missing_fields, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成资料收集开场白失败: %s", e)
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
if not self.llm:
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
if not empty_slots_readable:
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
prompt = get_opening_prompt(
current_stage=memoir_state.current_stage,
empty_slots_readable=empty_slots_readable,
user_profile_context=user_profile_context,
)
full_prompt = f"{prompt}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:2] if messages else [response_text]
except Exception as e:
logger.error("生成开场白失败: %s", e, exc_info=True)
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
if not self.llm or not missing_fields:
return {}
recent_dialogue = ""
if conversation_id:
history_messages = await self._get_history_messages(conversation_id)
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
parts = []
for msg in recent:
if isinstance(msg, HumanMessage):
parts.append(f"用户: {msg.content}")
elif isinstance(msg, AIMessage):
parts.append(f"助手: {msg.content}")
recent_dialogue = "\n".join(parts) if parts else ""
try:
prompt = get_profile_extraction_prompt(
user_message, missing_fields, recent_dialogue=recent_dialogue or None
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
result = {}
if "birth_year" in parsed and parsed["birth_year"] is not None:
raw = parsed["birth_year"]
if isinstance(raw, int) and 1900 <= raw <= 2100:
result["birth_year"] = raw
elif isinstance(raw, str) and raw.isdigit():
y = int(raw)
if y < 100:
y = 1900 + y if y >= 50 else 2000 + y
if 1900 <= y <= 2100:
result["birth_year"] = y
if "birth_place" in parsed and parsed["birth_place"]:
result["birth_place"] = str(parsed["birth_place"])
if "grew_up_place" in parsed and parsed["grew_up_place"]:
result["grew_up_place"] = str(parsed["grew_up_place"])
if "occupation" in parsed and parsed["occupation"]:
result["occupation"] = str(parsed["occupation"])
return result
except (json.JSONDecodeError, Exception) as e:
logger.error("提取资料信息失败: %s", e)
return {}
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
if not self.llm:
return ["谢谢!还能告诉我更多吗?"]
try:
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
history_messages = await self._get_history_messages(conversation_id)
history_string = self._format_history_string(history_messages)
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
human_msg_type = "audio" if is_from_voice else "text"
await self._save_message(
conversation_id,
"human",
user_message,
message_type=human_msg_type,
voice_session_id=voice_session_id,
timestamp=user_message_timestamp,
)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成资料跟进回复失败: %s", e)
return ["谢谢分享!能再告诉我一些吗?"]
def _detect_user_stage(self, user_message: str) -> str:
message = user_message.lower()
stage_keywords = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
}
for stage, keywords in stage_keywords.items():
if any(word in message for word in keywords):
return stage
return ""
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
if not self.llm:
return ["抱歉LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
try:
empty_slots = memoir_state.empty_slots_for_current_stage()
filled_slots = {
key: value.snippet
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
if value.snippet
}
detected_user_stage = self._detect_user_stage(user_message)
history_messages = await self._get_history_messages(conversation_id)
conversation_turn = len(history_messages) // 2
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
all_stages_coverage = memoir_state.all_stages_coverage()
system_prompt = get_guided_conversation_prompt(
current_stage=memoir_state.current_stage,
empty_slots=empty_slots,
filled_slots=filled_slots,
user_message=user_message,
conversation_turn=conversation_turn,
same_topic_turns=same_topic_turns,
all_stages_coverage=all_stages_coverage,
detected_user_stage=detected_user_stage,
user_profile_context=user_profile_context,
)
history_string = self._format_history_string(history_messages)
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
response = await self.llm.ainvoke(full_prompt)
response_text = response.content if hasattr(response, "content") else str(response)
human_msg_type = "audio" if is_from_voice else "text"
await self._save_message(
conversation_id,
"human",
user_message,
message_type=human_msg_type,
voice_session_id=voice_session_id,
timestamp=user_message_timestamp,
)
await self._save_message(conversation_id, "ai", response_text)
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
return messages[:3] if messages else [response_text]
except Exception as e:
logger.error("生成回应失败: %s", e)
return [f"抱歉,生成回应时出现错误: {str(e)}"]
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
if len(history_messages) < 4:
return len(history_messages) // 2
recent_messages = history_messages[-6:]
keywords_per_turn = []
for i in range(0, len(recent_messages), 2):
if i + 1 < len(recent_messages):
human_msg = (
recent_messages[i].content
if hasattr(recent_messages[i], "content")
else str(recent_messages[i])
)
ai_msg = (
recent_messages[i + 1].content
if hasattr(recent_messages[i + 1], "content")
else str(recent_messages[i + 1])
)
keywords_per_turn.append((human_msg + ai_msg)[:100])
if len(keywords_per_turn) >= 3:
return 3
return len(keywords_per_turn)
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
message_lower = user_message.lower()
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
return ConversationStage.CHILDHOOD
if any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
return ConversationStage.EDUCATION
if any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
return ConversationStage.CAREER
if any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
return ConversationStage.FAMILY
if any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
return ConversationStage.BELIEFS
if any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
return ConversationStage.SUMMARY
return ConversationStage.CHILDHOOD
async def clear_memory(self, conversation_id: str):
await redis_service.clear_conversation_history(conversation_id)

View File

@@ -1,211 +0,0 @@
"""
回忆录后台处理器:分析对话、更新状态、生成章节、创意标题
使用 Celery 进行后台任务处理
"""
from __future__ import annotations
import json
from app.core.logging import get_logger
from dataclasses import dataclass
from typing import Dict, List
from app.core.dependencies import get_llm_provider
from app.core.task_tracker import task_tracker
from app.agents.state_schema import MemoirStateSchema
from app.agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = get_logger(__name__)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
@dataclass
class AnalysisResult:
detected_stage: str
extracted_slots: Dict[str, str]
emotion: str
is_new_chapter: bool
class ContentAnalyzer:
def __init__(self) -> None:
self.llm = _get_langchain_llm()
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _fallback_slots(
self, state: MemoirStateSchema, stage: str, user_message: str
) -> Dict[str, str]:
stage_slots = state.slots.get(stage, {})
for key, value in stage_slots.items():
if not value.snippet:
return {key: user_message.strip()[:200]}
return {}
async def analyze_message(
self, user_message: str, current_state: MemoirStateSchema
) -> AnalysisResult:
detected_stage = self._detect_stage(
user_message, current_state.current_stage
)
extracted_slots: Dict[str, str] = {}
emotion = "neutral"
is_new_chapter = False
if self.llm:
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
emotion = parsed.get("emotion", emotion)
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
except json.JSONDecodeError:
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
except Exception as e:
logger.error("分析消息失败: %s", e)
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
else:
extracted_slots = self._fallback_slots(
current_state, detected_stage, user_message
)
return AnalysisResult(
detected_stage=detected_stage,
extracted_slots=extracted_slots,
emotion=emotion,
is_new_chapter=is_new_chapter,
)
class MemoirGenerator:
def __init__(self) -> None:
self.llm = _get_langchain_llm()
async def generate_chapter_title(
self, stage: str, slots: Dict[str, str], emotion: str
) -> str:
if not self.llm:
return f"{stage} 回忆"
try:
prompt = get_creative_title_prompt(
stage=stage, emotion=emotion, slots=slots
)
response = await self.llm.ainvoke(prompt)
return response.content.strip().strip('"')
except Exception as e:
logger.error("生成标题失败: %s", e)
return f"{stage} 回忆"
async def generate_narrative(
self,
stage: str,
slots: Dict[str, str],
new_content: str,
existing_content: str,
) -> str:
if not self.llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
try:
prompt = get_narrative_prompt(
stage=stage,
slots=slots,
new_content=new_content,
existing_content=existing_content,
)
response = await self.llm.ainvoke(prompt)
return response.content.strip()
except Exception as e:
logger.error("生成叙事失败: %s", e)
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
class BackgroundTaskRunner:
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
self._pending: Dict[str, List[str]] = {}
self._timers: Dict[str, object] = {}
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
try:
from app.tasks.memoir_tasks import process_memoir_segments
result = process_memoir_segments.delay(user_id, segment_ids)
task_id = result.id
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(
"已提交 Celery 任务: user_id=%s, task_id=%s, segments=%s",
user_id,
task_id,
len(segment_ids),
)
return task_id
except Exception as e:
logger.error("提交 Celery 任务失败: %s", e)
return None
async def queue_message(self, user_id: str, segment_id: str) -> None:
import asyncio
self._pending.setdefault(user_id, []).append(segment_id)
if user_id in self._timers:
self._timers[user_id].cancel()
async def delayed_submit():
try:
await asyncio.sleep(self.debounce_seconds)
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
await self._submit_task(user_id, segment_ids)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error("延迟提交任务失败: %s", e)
self._timers[user_id] = asyncio.create_task(delayed_submit())
async def flush_pending(self, user_id: str) -> str | None:
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
return await self._submit_task(user_id, segment_ids)
return None

View File

@@ -1,130 +0,0 @@
"""
回忆录整理 Agent基于传记结构将口语改写为书面语归类到章节
支持异步调用
"""
import json
from app.core.logging import get_logger
from typing import Dict, List, Optional
from app.core.dependencies import get_llm_provider
from app.agents.prompts import (
get_chapter_classification_prompt,
get_text_rewrite_prompt,
inject_image_placeholder_template,
CHAPTER_CATEGORIES,
STAGE_TO_ORDER,
)
logger = get_logger(__name__)
def _get_langchain_llm():
try:
provider = get_llm_provider()
return getattr(provider, "langchain_llm", None)
except Exception:
return None
class MemoryAgent:
"""回忆录整理 Agent支持异步"""
def __init__(self):
self.llm = _get_langchain_llm()
async def classify_chapter(self, segments_text: str) -> str:
if not self.llm:
return "childhood"
try:
prompt = get_chapter_classification_prompt(segments_text)
response = await self.llm.ainvoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
category = content.strip().lower()
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.error("分类章节失败: %s", e)
return "childhood"
async def rewrite_to_literary(
self,
segments_text: str,
chapter_category: str,
existing_content: Optional[str] = None,
) -> Dict:
if not self.llm:
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": segments_text,
"summary": "",
"image_suggestions": [],
}
try:
prompt = get_text_rewrite_prompt(
segments_text, chapter_category, existing_content or ""
)
response = await self.llm.ainvoke(prompt)
content = response.content if hasattr(response, "content") else str(response)
content = content.strip()
if content.startswith("```json"):
content = content[7:]
if content.startswith("```"):
content = content[3:]
if content.endswith("```"):
content = content[:-3]
content = content.strip()
result = json.loads(content)
result["content"] = inject_image_placeholder_template(
result.get("content") or ""
)
return result
except json.JSONDecodeError:
raw = response.content if hasattr(response, "content") else str(response)
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": inject_image_placeholder_template(raw),
"summary": "",
"image_suggestions": [],
}
except Exception as e:
logger.error("改写文本失败: %s", e)
return {
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
"content": segments_text,
"summary": "",
"image_suggestions": [],
}
async def process_segments(
self,
segments: List[Dict],
existing_chapters: Optional[Dict[str, Dict]] = None,
) -> Dict[str, Dict]:
if existing_chapters is None:
existing_chapters = {}
segments_by_category: Dict[str, List[str]] = {}
for segment in segments:
text = segment.get("transcript_text", "")
if not text:
continue
category = await self.classify_chapter(text)
if category not in segments_by_category:
segments_by_category[category] = []
segments_by_category[category].append(text)
updated_chapters = existing_chapters.copy()
for category, texts in segments_by_category.items():
combined_text = "\n\n".join(texts)
existing_content = existing_chapters.get(category, {}).get("content", "")
result = await self.rewrite_to_literary(
combined_text, category, existing_content
)
updated_chapters[category] = {
"title": result.get("title", CHAPTER_CATEGORIES.get(category, "章节")),
"content": result.get("content", ""),
"summary": result.get("summary", ""),
"image_suggestions": result.get("image_suggestions", []),
"category": category,
"order_index": STAGE_TO_ORDER.get(category, 999),
}
return updated_chapters