feat & refactor: 重构agents目录结构;AI回复模块agent结构封装
This commit is contained in:
@@ -1,10 +1,19 @@
|
||||
"""
|
||||
Agent 模块(app 内,符合架构计划)
|
||||
Agent 模块(按功能拆分:chat / memoir)
|
||||
"""
|
||||
from app.agents.conversation_agent import ConversationAgent
|
||||
from app.agents.memory_agent import MemoryAgent
|
||||
from app.agents.chat import (
|
||||
ChatOrchestrator,
|
||||
ConversationAgent,
|
||||
InterviewAgent,
|
||||
ProfileAgent,
|
||||
)
|
||||
from app.agents.memoir import BackgroundTaskRunner, MemoryAgent
|
||||
|
||||
__all__ = [
|
||||
"ConversationAgent",
|
||||
"MemoryAgent",
|
||||
"ChatOrchestrator",
|
||||
"ProfileAgent",
|
||||
"InterviewAgent",
|
||||
"BackgroundTaskRunner",
|
||||
]
|
||||
|
||||
12
api/app/agents/chat/__init__.py
Normal file
12
api/app/agents/chat/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""聊天模块:AI 回复用户(ProfileAgent + InterviewAgent + ChatOrchestrator)"""
|
||||
from app.agents.chat.conversation_agent import ConversationAgent
|
||||
from app.agents.chat.orchestrator import ChatOrchestrator
|
||||
from app.agents.chat.profile_agent import ProfileAgent
|
||||
from app.agents.chat.interview_agent import InterviewAgent
|
||||
|
||||
__all__ = [
|
||||
"ConversationAgent",
|
||||
"ChatOrchestrator",
|
||||
"ProfileAgent",
|
||||
"InterviewAgent",
|
||||
]
|
||||
139
api/app/agents/chat/conversation_agent.py
Normal file
139
api/app/agents/chat/conversation_agent.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
对话 Agent:Facade,内部委托 ChatOrchestrator + ProfileAgent + InterviewAgent
|
||||
保留原有对外 API,供 router 等调用方兼容使用
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agents.chat.orchestrator import ChatOrchestrator
|
||||
from app.agents.prompts import ConversationStage
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
from app.core.redis import redis_service
|
||||
|
||||
|
||||
class ConversationAgent:
|
||||
"""对话 Agent Facade,委托 ChatOrchestrator 实现多 Agent 协同"""
|
||||
|
||||
def __init__(self):
|
||||
self._orchestrator = ChatOrchestrator()
|
||||
|
||||
async def extract_profile_from_message(
|
||||
self,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""委托 ChatOrchestrator/ProfileAgent 提取资料"""
|
||||
return await self._orchestrator.extract_profile_from_message(
|
||||
user_message, missing_fields, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
async def generate_profile_followup(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
filled_fields: Dict[str, str],
|
||||
nickname: str = "",
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
) -> List[str]:
|
||||
"""委托 ChatOrchestrator/ProfileAgent 生成资料追问"""
|
||||
return await self._orchestrator.generate_profile_followup(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
missing_fields=missing_fields,
|
||||
filled_fields=filled_fields,
|
||||
nickname=nickname,
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
)
|
||||
|
||||
async def generate_profile_greeting(
|
||||
self,
|
||||
conversation_id: str,
|
||||
missing_fields: List[str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""委托 ChatOrchestrator/ProfileAgent 生成资料收集开场白"""
|
||||
return await self._orchestrator.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_fields,
|
||||
nickname=nickname,
|
||||
)
|
||||
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
) -> List[str]:
|
||||
"""委托 ChatOrchestrator/InterviewAgent 生成访谈回复"""
|
||||
return await self._orchestrator.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=memoir_state,
|
||||
user_profile_context=user_profile_context,
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
)
|
||||
|
||||
async def generate_opening_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""委托 ChatOrchestrator/InterviewAgent 生成开场白"""
|
||||
return await self._orchestrator.generate_opening_message(
|
||||
conversation_id=conversation_id,
|
||||
memoir_state=memoir_state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
current_stage: Optional[ConversationStage] = None,
|
||||
covered_topics: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""兼容旧 API:生成简单回复(无状态感知),委托 InterviewAgent 的等价逻辑"""
|
||||
from app.agents.state_schema import default_state
|
||||
|
||||
state = default_state()
|
||||
state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value
|
||||
state.covered_stages = covered_topics or []
|
||||
responses = await self._orchestrator.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state,
|
||||
user_profile_context="",
|
||||
)
|
||||
return responses[0] if responses else ""
|
||||
|
||||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||||
"""根据关键词检测用户阶段(兼容 API)"""
|
||||
detected = self._orchestrator.detect_user_stage(user_message)
|
||||
if detected == "childhood":
|
||||
return ConversationStage.CHILDHOOD
|
||||
if detected == "education":
|
||||
return ConversationStage.EDUCATION
|
||||
if detected == "career":
|
||||
return ConversationStage.CAREER
|
||||
if detected == "family":
|
||||
return ConversationStage.FAMILY
|
||||
if detected == "belief":
|
||||
return ConversationStage.BELIEFS
|
||||
return ConversationStage.CHILDHOOD
|
||||
|
||||
async def clear_memory(self, conversation_id: str) -> None:
|
||||
"""清除 Redis 中的对话历史"""
|
||||
await redis_service.clear_conversation_history(conversation_id)
|
||||
49
api/app/agents/chat/helpers.py
Normal file
49
api/app/agents/chat/helpers.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""聊天 Agent 共享工具:历史获取、格式化、存储"""
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.core.redis import redis_service
|
||||
|
||||
|
||||
async def get_history_messages(conversation_id: str) -> List[Any]:
|
||||
"""从 Redis 获取对话历史"""
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
messages = []
|
||||
for msg in history:
|
||||
if msg["role"] == "human":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "ai":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
return messages
|
||||
|
||||
|
||||
def format_history_string(messages: List[Any]) -> str:
|
||||
"""将消息列表格式化为 Human/Assistant 字符串"""
|
||||
history_parts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_parts.append(f"Human: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
return "\n\n".join(history_parts)
|
||||
|
||||
|
||||
async def save_message(
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
message_type: str = "text",
|
||||
voice_session_id: str | None = None,
|
||||
timestamp: datetime | str | int | None = None,
|
||||
) -> None:
|
||||
"""保存消息到 Redis"""
|
||||
await redis_service.add_message(
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
message_type=message_type,
|
||||
voice_session_id=voice_session_id,
|
||||
timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
|
||||
)
|
||||
143
api/app/agents/chat/interview_agent.py
Normal file
143
api/app/agents/chat/interview_agent.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
InterviewAgent:正式访谈 Specialist
|
||||
负责状态感知回复、开场白,不负责 Redis 持久化(由 Orchestrator 统一处理)
|
||||
"""
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.agents.chat.helpers import format_history_string, get_history_messages
|
||||
from app.agents.prompts import get_guided_conversation_prompt, get_opening_prompt
|
||||
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class InterviewAgent:
|
||||
"""正式访谈 Specialist Agent"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
def _detect_user_stage(self, user_message: str) -> str:
|
||||
"""根据关键词检测用户正在谈论的人生阶段"""
|
||||
message = user_message.lower()
|
||||
stage_keywords = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
|
||||
}
|
||||
for stage, keywords in stage_keywords.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return ""
|
||||
|
||||
def _estimate_same_topic_turns(
|
||||
self, history_messages: List[Any], current_filled_slots: dict
|
||||
) -> int:
|
||||
"""估算同一话题的连续轮数"""
|
||||
if len(history_messages) < 4:
|
||||
return len(history_messages) // 2
|
||||
recent_messages = history_messages[-6:]
|
||||
keywords_per_turn = []
|
||||
for i in range(0, len(recent_messages), 2):
|
||||
if i + 1 < len(recent_messages):
|
||||
human_msg = (
|
||||
recent_messages[i].content
|
||||
if hasattr(recent_messages[i], "content")
|
||||
else str(recent_messages[i])
|
||||
)
|
||||
ai_msg = (
|
||||
recent_messages[i + 1].content
|
||||
if hasattr(recent_messages[i + 1], "content")
|
||||
else str(recent_messages[i + 1])
|
||||
)
|
||||
keywords_per_turn.append((human_msg + ai_msg)[:100])
|
||||
if len(keywords_per_turn) >= 3:
|
||||
return 3
|
||||
return len(keywords_per_turn)
|
||||
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
detected_user_stage = self._detect_user_stage(user_message)
|
||||
history_messages = await get_history_messages(conversation_id)
|
||||
conversation_turn = len(history_messages) // 2
|
||||
same_topic_turns = self._estimate_same_topic_turns(
|
||||
history_messages, filled_slots
|
||||
)
|
||||
all_stages_coverage = memoir_state.all_stages_coverage()
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
conversation_turn=conversation_turn,
|
||||
same_topic_turns=same_topic_turns,
|
||||
all_stages_coverage=all_stages_coverage,
|
||||
detected_user_stage=detected_user_stage,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
history_string = format_history_string(history_messages)
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成回应失败: %s", e)
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
|
||||
async def generate_opening_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""生成空对话开场白,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
|
||||
if not empty_slots_readable:
|
||||
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
|
||||
prompt = get_opening_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots_readable=empty_slots_readable,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
full_prompt = f"{prompt}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成开场白失败: %s", e, exc_info=True)
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
|
||||
246
api/app/agents/chat/orchestrator.py
Normal file
246
api/app/agents/chat/orchestrator.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
ChatOrchestrator:AI 回复用户模块的编排层
|
||||
负责路由(Profile vs Interview)、调用 Specialist Agent、统一 Redis 持久化与错误处理
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.helpers import save_message
|
||||
from app.agents.chat.interview_agent import InterviewAgent
|
||||
from app.agents.chat.profile_agent import ProfileAgent
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
from app.core.logging import get_logger
|
||||
from app.features.memoir.state_service import get_or_create_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.features.user.models import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChatOrchestrator:
|
||||
"""
|
||||
聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent,
|
||||
统一管理 Redis 写入。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.profile_agent = ProfileAgent()
|
||||
self.interview_agent = InterviewAgent()
|
||||
|
||||
async def process_user_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
user: Optional["User"],
|
||||
conversation, # 用于更新 conversation_stage
|
||||
is_from_voice: bool,
|
||||
voice_session_id: Optional[str],
|
||||
db: AsyncSession,
|
||||
apply_extracted_profile_fn,
|
||||
get_missing_profile_fields_fn,
|
||||
get_filled_profile_fields_fn,
|
||||
user_message_timestamp: Optional[datetime] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
处理用户消息,返回 AI 回复列表。
|
||||
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent,
|
||||
统一写入 Redis。
|
||||
"""
|
||||
|
||||
# --- 资料收集模式 ---
|
||||
if user:
|
||||
missing = get_missing_profile_fields_fn(user)
|
||||
if missing:
|
||||
try:
|
||||
extracted = await self.profile_agent.extract_profile_from_message(
|
||||
user_message, missing, conversation_id=conversation_id
|
||||
)
|
||||
if extracted:
|
||||
await apply_extracted_profile_fn(user, extracted, db)
|
||||
|
||||
remaining = get_missing_profile_fields_fn(user)
|
||||
filled = get_filled_profile_fields_fn(user)
|
||||
responses = await self.profile_agent.generate_profile_followup(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
missing_fields=remaining,
|
||||
filled_fields=filled,
|
||||
nickname=user.nickname or "",
|
||||
)
|
||||
await self._save_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text="\n\n".join(responses),
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
)
|
||||
return responses
|
||||
except Exception as e:
|
||||
logger.error(f"资料收集处理失败: {e}", exc_info=True)
|
||||
|
||||
# --- 正式访谈模式 ---
|
||||
user_id = user.id if user else None
|
||||
if not user_id:
|
||||
return ["抱歉,无法识别用户。"]
|
||||
|
||||
state = await get_or_create_state(user_id, db)
|
||||
if conversation and conversation.conversation_stage != state.current_stage:
|
||||
conversation.conversation_stage = state.current_stage
|
||||
await db.commit()
|
||||
|
||||
from app.agents.prompts.profile_prompts import format_user_profile_context
|
||||
|
||||
user_profile_context = ""
|
||||
if user:
|
||||
user_profile_context = format_user_profile_context(
|
||||
birth_year=user.birth_year,
|
||||
birth_place=user.birth_place,
|
||||
grew_up_place=user.grew_up_place,
|
||||
occupation=user.occupation,
|
||||
)
|
||||
|
||||
responses = await self.interview_agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
await self._save_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text="\n\n".join(responses),
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
)
|
||||
return responses
|
||||
|
||||
async def _save_messages(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
response_text: str,
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: Optional[str] = None,
|
||||
user_message_timestamp: Optional[datetime] = None,
|
||||
) -> None:
|
||||
"""统一写入 Human + AI 消息到 Redis"""
|
||||
human_msg_type = "audio" if is_from_voice else "text"
|
||||
await save_message(
|
||||
conversation_id,
|
||||
"human",
|
||||
user_message,
|
||||
message_type=human_msg_type,
|
||||
voice_session_id=voice_session_id,
|
||||
timestamp=user_message_timestamp,
|
||||
)
|
||||
await save_message(conversation_id, "ai", response_text)
|
||||
|
||||
async def extract_profile_from_message(
|
||||
self,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
conversation_id: Optional[str] = None,
|
||||
):
|
||||
"""委托 ProfileAgent 提取资料"""
|
||||
return await self.profile_agent.extract_profile_from_message(
|
||||
user_message, missing_fields, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
async def generate_profile_followup(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
filled_fields: dict,
|
||||
nickname: str = "",
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
) -> List[str]:
|
||||
"""委托 ProfileAgent 生成资料追问,并写入 Redis"""
|
||||
responses = await self.profile_agent.generate_profile_followup(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
missing_fields=missing_fields,
|
||||
filled_fields=filled_fields,
|
||||
nickname=nickname,
|
||||
)
|
||||
response_text = "\n\n".join(responses)
|
||||
await self._save_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text=response_text,
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
)
|
||||
return responses
|
||||
|
||||
async def generate_profile_greeting(
|
||||
self,
|
||||
conversation_id: str,
|
||||
missing_fields: List[str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""委托 ProfileAgent 生成资料收集开场白,并写入 Redis"""
|
||||
responses = await self.profile_agent.generate_profile_greeting(
|
||||
conversation_id=conversation_id,
|
||||
missing_fields=missing_fields,
|
||||
nickname=nickname,
|
||||
)
|
||||
response_text = "\n\n".join(responses)
|
||||
await save_message(conversation_id, "ai", response_text)
|
||||
return responses
|
||||
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
) -> List[str]:
|
||||
"""委托 InterviewAgent 生成访谈回复,并写入 Redis"""
|
||||
responses = await self.interview_agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=memoir_state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
response_text = "\n\n".join(responses)
|
||||
await self._save_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text=response_text,
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
)
|
||||
return responses
|
||||
|
||||
def detect_user_stage(self, user_message: str) -> str:
|
||||
"""委托 InterviewAgent 检测用户阶段"""
|
||||
return self.interview_agent._detect_user_stage(user_message)
|
||||
|
||||
async def generate_opening_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
"""委托 InterviewAgent 生成开场白,并写入 Redis"""
|
||||
responses = await self.interview_agent.generate_opening_message(
|
||||
conversation_id=conversation_id,
|
||||
memoir_state=memoir_state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
response_text = "\n\n".join(responses)
|
||||
await save_message(conversation_id, "ai", response_text)
|
||||
return responses
|
||||
132
api/app/agents/chat/profile_agent.py
Normal file
132
api/app/agents/chat/profile_agent.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
ProfileAgent:用户资料收集 Specialist
|
||||
负责提取资料、资料追问、资料收集开场白,不负责 Redis 持久化(由 Orchestrator 统一处理)
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.logging import get_logger
|
||||
|
||||
from app.agents.chat.helpers import format_history_string, get_history_messages
|
||||
from app.agents.prompts.profile_prompts import (
|
||||
get_profile_extraction_prompt,
|
||||
get_profile_followup_prompt,
|
||||
get_profile_greeting_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class ProfileAgent:
|
||||
"""用户资料收集 Specialist Agent"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
async def extract_profile_from_message(
|
||||
self,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""从用户消息中提取资料字段,不持久化"""
|
||||
if not self.llm or not missing_fields:
|
||||
return {}
|
||||
recent_dialogue = ""
|
||||
if conversation_id:
|
||||
history_messages = await get_history_messages(conversation_id)
|
||||
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
|
||||
parts = []
|
||||
for msg in recent:
|
||||
if isinstance(msg, HumanMessage):
|
||||
parts.append(f"用户: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
parts.append(f"助手: {msg.content}")
|
||||
recent_dialogue = "\n".join(parts) if parts else ""
|
||||
try:
|
||||
prompt = get_profile_extraction_prompt(
|
||||
user_message, missing_fields, recent_dialogue=recent_dialogue or None
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
result = {}
|
||||
if "birth_year" in parsed and parsed["birth_year"] is not None:
|
||||
raw = parsed["birth_year"]
|
||||
if isinstance(raw, int) and 1900 <= raw <= 2100:
|
||||
result["birth_year"] = raw
|
||||
elif isinstance(raw, str) and raw.isdigit():
|
||||
y = int(raw)
|
||||
if y < 100:
|
||||
y = 1900 + y if y >= 50 else 2000 + y
|
||||
if 1900 <= y <= 2100:
|
||||
result["birth_year"] = y
|
||||
if "birth_place" in parsed and parsed["birth_place"]:
|
||||
result["birth_place"] = str(parsed["birth_place"])
|
||||
if "grew_up_place" in parsed and parsed["grew_up_place"]:
|
||||
result["grew_up_place"] = str(parsed["grew_up_place"])
|
||||
if "occupation" in parsed and parsed["occupation"]:
|
||||
result["occupation"] = str(parsed["occupation"])
|
||||
return result
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error("提取资料信息失败: %s", e)
|
||||
return {}
|
||||
|
||||
async def generate_profile_followup(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
filled_fields: Dict[str, str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""生成资料追问回复,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return ["谢谢!还能告诉我更多吗?"]
|
||||
try:
|
||||
prompt = get_profile_followup_prompt(
|
||||
missing_fields, filled_fields, user_message, nickname
|
||||
)
|
||||
history_messages = await get_history_messages(conversation_id)
|
||||
history_string = format_history_string(history_messages)
|
||||
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成资料跟进回复失败: %s", e)
|
||||
return ["谢谢分享!能再告诉我一些吗?"]
|
||||
|
||||
async def generate_profile_greeting(
|
||||
self,
|
||||
conversation_id: str,
|
||||
missing_fields: List[str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
"""生成资料收集开场白,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
|
||||
try:
|
||||
prompt = get_profile_greeting_prompt(missing_fields, nickname)
|
||||
history_messages = await get_history_messages(conversation_id)
|
||||
history_string = format_history_string(history_messages)
|
||||
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成资料收集开场白失败: %s", e)
|
||||
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
|
||||
@@ -1,355 +0,0 @@
|
||||
"""
|
||||
对话 Agent:基于访谈问题清单,动态选择问题,实时生成回应
|
||||
支持异步调用和 Redis 会话存储,支持用户基础资料收集和时代背景融入
|
||||
"""
|
||||
import json
|
||||
from app.core.logging import get_logger
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from app.core.redis import redis_service
|
||||
from app.core.dependencies import get_llm_provider
|
||||
|
||||
from app.agents.prompts import (
|
||||
ConversationStage,
|
||||
get_conversation_prompt,
|
||||
get_guided_conversation_prompt,
|
||||
get_opening_prompt,
|
||||
)
|
||||
from app.agents.prompts.profile_prompts import (
|
||||
get_profile_greeting_prompt,
|
||||
get_profile_extraction_prompt,
|
||||
get_profile_followup_prompt,
|
||||
format_user_profile_context,
|
||||
get_missing_profile_fields,
|
||||
)
|
||||
from app.agents.prompts.conversation_prompts import SLOT_NAME_MAP
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
"""从 port 获取 LangChain LLM 实例(供 Agent 使用)"""
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class ConversationAgent:
|
||||
"""对话 Agent(支持异步和 Redis 存储)"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
async def _get_history_messages(self, conversation_id: str) -> List[Any]:
|
||||
history = await redis_service.get_conversation_history(conversation_id)
|
||||
messages = []
|
||||
for msg in history:
|
||||
if msg["role"] == "human":
|
||||
messages.append(HumanMessage(content=msg["content"]))
|
||||
elif msg["role"] == "ai":
|
||||
messages.append(AIMessage(content=msg["content"]))
|
||||
return messages
|
||||
|
||||
async def _save_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
message_type: str = "text",
|
||||
voice_session_id: str | None = None,
|
||||
timestamp: datetime | str | int | None = None,
|
||||
):
|
||||
await redis_service.add_message(
|
||||
conversation_id,
|
||||
role,
|
||||
content,
|
||||
message_type=message_type,
|
||||
voice_session_id=voice_session_id,
|
||||
timestamp=timestamp.isoformat() if isinstance(timestamp, datetime) else timestamp,
|
||||
)
|
||||
|
||||
def _format_history_string(self, messages: List[Any]) -> str:
|
||||
history_parts = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_parts.append(f"Human: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
return "\n\n".join(history_parts)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
current_stage: Optional[ConversationStage] = None,
|
||||
covered_topics: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
if current_stage is None:
|
||||
current_stage = ConversationStage.CHILDHOOD
|
||||
if covered_topics is None:
|
||||
covered_topics = []
|
||||
if not self.llm:
|
||||
return "抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||||
try:
|
||||
system_prompt = get_conversation_prompt(current_stage, covered_topics, user_message)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
await self._save_message(conversation_id, "human", user_message)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
return response_text
|
||||
except Exception as e:
|
||||
logger.error("生成回应失败: %s", e)
|
||||
return f"抱歉,生成回应时出现错误: {str(e)}"
|
||||
|
||||
async def generate_profile_greeting(
|
||||
self,
|
||||
conversation_id: str,
|
||||
missing_fields: List[str],
|
||||
nickname: str = "",
|
||||
) -> List[str]:
|
||||
if not self.llm:
|
||||
return ["你好!在开始之前,能告诉我你是哪一年出生的吗?"]
|
||||
try:
|
||||
prompt = get_profile_greeting_prompt(missing_fields, nickname)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
full_prompt = f"{prompt}\n\n{history_string}" if history_string else prompt
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成资料收集开场白失败: %s", e)
|
||||
return ["你好!在我们开始聊人生故事之前,能先简单介绍一下你自己吗?比如你是哪一年出生的?"]
|
||||
|
||||
async def generate_opening_message(
|
||||
self,
|
||||
conversation_id: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
if not self.llm:
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你小时候是在哪儿长大的?"]
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
empty_slots_readable = [SLOT_NAME_MAP.get(s, s) for s in empty_slots]
|
||||
if not empty_slots_readable:
|
||||
empty_slots_readable = ["成长的地方", "难忘的事", "重要的人"]
|
||||
prompt = get_opening_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots_readable=empty_slots_readable,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
full_prompt = f"{prompt}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:2] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成开场白失败: %s", e, exc_info=True)
|
||||
return ["你好呀~ 有空聊聊你的人生故事吗?你童年里印象最深的一件事是什么?"]
|
||||
|
||||
async def extract_profile_from_message(
|
||||
self,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
if not self.llm or not missing_fields:
|
||||
return {}
|
||||
recent_dialogue = ""
|
||||
if conversation_id:
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
recent = history_messages[-4:] if len(history_messages) > 4 else history_messages
|
||||
parts = []
|
||||
for msg in recent:
|
||||
if isinstance(msg, HumanMessage):
|
||||
parts.append(f"用户: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
parts.append(f"助手: {msg.content}")
|
||||
recent_dialogue = "\n".join(parts) if parts else ""
|
||||
try:
|
||||
prompt = get_profile_extraction_prompt(
|
||||
user_message, missing_fields, recent_dialogue=recent_dialogue or None
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
result = {}
|
||||
if "birth_year" in parsed and parsed["birth_year"] is not None:
|
||||
raw = parsed["birth_year"]
|
||||
if isinstance(raw, int) and 1900 <= raw <= 2100:
|
||||
result["birth_year"] = raw
|
||||
elif isinstance(raw, str) and raw.isdigit():
|
||||
y = int(raw)
|
||||
if y < 100:
|
||||
y = 1900 + y if y >= 50 else 2000 + y
|
||||
if 1900 <= y <= 2100:
|
||||
result["birth_year"] = y
|
||||
if "birth_place" in parsed and parsed["birth_place"]:
|
||||
result["birth_place"] = str(parsed["birth_place"])
|
||||
if "grew_up_place" in parsed and parsed["grew_up_place"]:
|
||||
result["grew_up_place"] = str(parsed["grew_up_place"])
|
||||
if "occupation" in parsed and parsed["occupation"]:
|
||||
result["occupation"] = str(parsed["occupation"])
|
||||
return result
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error("提取资料信息失败: %s", e)
|
||||
return {}
|
||||
|
||||
async def generate_profile_followup(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
missing_fields: List[str],
|
||||
filled_fields: Dict[str, str],
|
||||
nickname: str = "",
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
) -> List[str]:
|
||||
if not self.llm:
|
||||
return ["谢谢!还能告诉我更多吗?"]
|
||||
try:
|
||||
prompt = get_profile_followup_prompt(missing_fields, filled_fields, user_message, nickname)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
full_prompt = f"{prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
human_msg_type = "audio" if is_from_voice else "text"
|
||||
await self._save_message(
|
||||
conversation_id,
|
||||
"human",
|
||||
user_message,
|
||||
message_type=human_msg_type,
|
||||
voice_session_id=voice_session_id,
|
||||
timestamp=user_message_timestamp,
|
||||
)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成资料跟进回复失败: %s", e)
|
||||
return ["谢谢分享!能再告诉我一些吗?"]
|
||||
|
||||
def _detect_user_stage(self, user_message: str) -> str:
|
||||
message = user_message.lower()
|
||||
stage_keywords = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇", "爸妈", "父亲", "母亲", "爷爷", "奶奶", "外公", "外婆", "幼儿园"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学", "高中", "初中", "小学", "考试", "毕业", "读书", "高考", "课堂"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业", "升职", "跳槽", "老板", "行业", "项目", "加班", "薪水", "面试"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "爱人", "老婆", "老公", "丈夫", "妻子", "儿子", "女儿", "婚礼", "恋爱"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则", "信仰", "意义", "感悟", "遗憾", "骄傲"],
|
||||
}
|
||||
for stage, keywords in stage_keywords.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return ""
|
||||
|
||||
async def generate_response_with_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
is_from_voice: bool = False,
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
) -> List[str]:
|
||||
if not self.llm:
|
||||
return ["抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"]
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
key: value.snippet
|
||||
for key, value in memoir_state.slots.get(memoir_state.current_stage, {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
detected_user_stage = self._detect_user_stage(user_message)
|
||||
history_messages = await self._get_history_messages(conversation_id)
|
||||
conversation_turn = len(history_messages) // 2
|
||||
same_topic_turns = self._estimate_same_topic_turns(history_messages, filled_slots)
|
||||
all_stages_coverage = memoir_state.all_stages_coverage()
|
||||
system_prompt = get_guided_conversation_prompt(
|
||||
current_stage=memoir_state.current_stage,
|
||||
empty_slots=empty_slots,
|
||||
filled_slots=filled_slots,
|
||||
user_message=user_message,
|
||||
conversation_turn=conversation_turn,
|
||||
same_topic_turns=same_topic_turns,
|
||||
all_stages_coverage=all_stages_coverage,
|
||||
detected_user_stage=detected_user_stage,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
history_string = self._format_history_string(history_messages)
|
||||
full_prompt = f"{system_prompt}\n\n{history_string}\n\nHuman: {user_message}\n\nAssistant:"
|
||||
response = await self.llm.ainvoke(full_prompt)
|
||||
response_text = response.content if hasattr(response, "content") else str(response)
|
||||
human_msg_type = "audio" if is_from_voice else "text"
|
||||
await self._save_message(
|
||||
conversation_id,
|
||||
"human",
|
||||
user_message,
|
||||
message_type=human_msg_type,
|
||||
voice_session_id=voice_session_id,
|
||||
timestamp=user_message_timestamp,
|
||||
)
|
||||
await self._save_message(conversation_id, "ai", response_text)
|
||||
messages = [msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()]
|
||||
return messages[:3] if messages else [response_text]
|
||||
except Exception as e:
|
||||
logger.error("生成回应失败: %s", e)
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
|
||||
def _estimate_same_topic_turns(self, history_messages: List[Any], current_filled_slots: dict) -> int:
|
||||
if len(history_messages) < 4:
|
||||
return len(history_messages) // 2
|
||||
recent_messages = history_messages[-6:]
|
||||
keywords_per_turn = []
|
||||
for i in range(0, len(recent_messages), 2):
|
||||
if i + 1 < len(recent_messages):
|
||||
human_msg = (
|
||||
recent_messages[i].content
|
||||
if hasattr(recent_messages[i], "content")
|
||||
else str(recent_messages[i])
|
||||
)
|
||||
ai_msg = (
|
||||
recent_messages[i + 1].content
|
||||
if hasattr(recent_messages[i + 1], "content")
|
||||
else str(recent_messages[i + 1])
|
||||
)
|
||||
keywords_per_turn.append((human_msg + ai_msg)[:100])
|
||||
if len(keywords_per_turn) >= 3:
|
||||
return 3
|
||||
return len(keywords_per_turn)
|
||||
|
||||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||||
message_lower = user_message.lower()
|
||||
if any(word in message_lower for word in ["童年", "小时候", "出生", "家庭背景"]):
|
||||
return ConversationStage.CHILDHOOD
|
||||
if any(word in message_lower for word in ["上学", "学校", "老师", "同学", "教育"]):
|
||||
return ConversationStage.EDUCATION
|
||||
if any(word in message_lower for word in ["工作", "职业", "事业", "公司", "同事"]):
|
||||
return ConversationStage.CAREER
|
||||
if any(word in message_lower for word in ["伴侣", "孩子", "家庭", "家人", "结婚"]):
|
||||
return ConversationStage.FAMILY
|
||||
if any(word in message_lower for word in ["信念", "价值观", "座右铭", "坚持", "原则"]):
|
||||
return ConversationStage.BELIEFS
|
||||
if any(word in message_lower for word in ["总结", "回顾", "感激", "希望", "未来"]):
|
||||
return ConversationStage.SUMMARY
|
||||
return ConversationStage.CHILDHOOD
|
||||
|
||||
async def clear_memory(self, conversation_id: str):
|
||||
await redis_service.clear_conversation_history(conversation_id)
|
||||
@@ -1,211 +0,0 @@
|
||||
"""
|
||||
回忆录后台处理器:分析对话、更新状态、生成章节、创意标题
|
||||
使用 Celery 进行后台任务处理
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from app.core.logging import get_logger
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.task_tracker import task_tracker
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
from app.agents.prompts.memory_prompts import (
|
||||
get_creative_title_prompt,
|
||||
get_narrative_prompt,
|
||||
get_state_extraction_prompt,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||||
}
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
detected_stage: str
|
||||
extracted_slots: Dict[str, str]
|
||||
emotion: str
|
||||
is_new_chapter: bool
|
||||
|
||||
|
||||
class ContentAnalyzer:
|
||||
def __init__(self) -> None:
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
|
||||
message = user_message.lower()
|
||||
for stage, keywords in STAGE_KEYWORDS.items():
|
||||
if any(word in message for word in keywords):
|
||||
return stage
|
||||
return fallback_stage
|
||||
|
||||
def _fallback_slots(
|
||||
self, state: MemoirStateSchema, stage: str, user_message: str
|
||||
) -> Dict[str, str]:
|
||||
stage_slots = state.slots.get(stage, {})
|
||||
for key, value in stage_slots.items():
|
||||
if not value.snippet:
|
||||
return {key: user_message.strip()[:200]}
|
||||
return {}
|
||||
|
||||
async def analyze_message(
|
||||
self, user_message: str, current_state: MemoirStateSchema
|
||||
) -> AnalysisResult:
|
||||
detected_stage = self._detect_stage(
|
||||
user_message, current_state.current_stage
|
||||
)
|
||||
extracted_slots: Dict[str, str] = {}
|
||||
emotion = "neutral"
|
||||
is_new_chapter = False
|
||||
if self.llm:
|
||||
try:
|
||||
prompt = get_state_extraction_prompt(
|
||||
user_message=user_message,
|
||||
current_stage=current_state.current_stage,
|
||||
stage_slots=current_state.slots.get(detected_stage, {}),
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content.strip()
|
||||
parsed = json.loads(content)
|
||||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||||
extracted_slots = parsed.get("slots", {}) or {}
|
||||
emotion = parsed.get("emotion", emotion)
|
||||
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
|
||||
except json.JSONDecodeError:
|
||||
extracted_slots = self._fallback_slots(
|
||||
current_state, detected_stage, user_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("分析消息失败: %s", e)
|
||||
extracted_slots = self._fallback_slots(
|
||||
current_state, detected_stage, user_message
|
||||
)
|
||||
else:
|
||||
extracted_slots = self._fallback_slots(
|
||||
current_state, detected_stage, user_message
|
||||
)
|
||||
return AnalysisResult(
|
||||
detected_stage=detected_stage,
|
||||
extracted_slots=extracted_slots,
|
||||
emotion=emotion,
|
||||
is_new_chapter=is_new_chapter,
|
||||
)
|
||||
|
||||
|
||||
class MemoirGenerator:
|
||||
def __init__(self) -> None:
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
async def generate_chapter_title(
|
||||
self, stage: str, slots: Dict[str, str], emotion: str
|
||||
) -> str:
|
||||
if not self.llm:
|
||||
return f"{stage} 回忆"
|
||||
try:
|
||||
prompt = get_creative_title_prompt(
|
||||
stage=stage, emotion=emotion, slots=slots
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
return response.content.strip().strip('"')
|
||||
except Exception as e:
|
||||
logger.error("生成标题失败: %s", e)
|
||||
return f"{stage} 回忆"
|
||||
|
||||
async def generate_narrative(
|
||||
self,
|
||||
stage: str,
|
||||
slots: Dict[str, str],
|
||||
new_content: str,
|
||||
existing_content: str,
|
||||
) -> str:
|
||||
if not self.llm:
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
try:
|
||||
prompt = get_narrative_prompt(
|
||||
stage=stage,
|
||||
slots=slots,
|
||||
new_content=new_content,
|
||||
existing_content=existing_content,
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
return response.content.strip()
|
||||
except Exception as e:
|
||||
logger.error("生成叙事失败: %s", e)
|
||||
if existing_content:
|
||||
return f"{existing_content}\n\n{new_content}"
|
||||
return new_content
|
||||
|
||||
|
||||
class BackgroundTaskRunner:
|
||||
def __init__(self, debounce_seconds: int = 5) -> None:
|
||||
self.debounce_seconds = debounce_seconds
|
||||
self._pending: Dict[str, List[str]] = {}
|
||||
self._timers: Dict[str, object] = {}
|
||||
self.analyzer = ContentAnalyzer()
|
||||
self.generator = MemoirGenerator()
|
||||
|
||||
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
|
||||
try:
|
||||
from app.tasks.memoir_tasks import process_memoir_segments
|
||||
|
||||
result = process_memoir_segments.delay(user_id, segment_ids)
|
||||
task_id = result.id
|
||||
await task_tracker.add_task(user_id, task_id, "memoir")
|
||||
logger.info(
|
||||
"已提交 Celery 任务: user_id=%s, task_id=%s, segments=%s",
|
||||
user_id,
|
||||
task_id,
|
||||
len(segment_ids),
|
||||
)
|
||||
return task_id
|
||||
except Exception as e:
|
||||
logger.error("提交 Celery 任务失败: %s", e)
|
||||
return None
|
||||
|
||||
async def queue_message(self, user_id: str, segment_id: str) -> None:
|
||||
import asyncio
|
||||
|
||||
self._pending.setdefault(user_id, []).append(segment_id)
|
||||
if user_id in self._timers:
|
||||
self._timers[user_id].cancel()
|
||||
|
||||
async def delayed_submit():
|
||||
try:
|
||||
await asyncio.sleep(self.debounce_seconds)
|
||||
segment_ids = self._pending.pop(user_id, [])
|
||||
if segment_ids:
|
||||
await self._submit_task(user_id, segment_ids)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("延迟提交任务失败: %s", e)
|
||||
|
||||
self._timers[user_id] = asyncio.create_task(delayed_submit())
|
||||
|
||||
async def flush_pending(self, user_id: str) -> str | None:
|
||||
if user_id in self._timers:
|
||||
self._timers[user_id].cancel()
|
||||
del self._timers[user_id]
|
||||
segment_ids = self._pending.pop(user_id, [])
|
||||
if segment_ids:
|
||||
return await self._submit_task(user_id, segment_ids)
|
||||
return None
|
||||
@@ -1,130 +0,0 @@
|
||||
"""
|
||||
回忆录整理 Agent:基于传记结构,将口语改写为书面语,归类到章节
|
||||
支持异步调用
|
||||
"""
|
||||
import json
|
||||
from app.core.logging import get_logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.dependencies import get_llm_provider
|
||||
|
||||
from app.agents.prompts import (
|
||||
get_chapter_classification_prompt,
|
||||
get_text_rewrite_prompt,
|
||||
inject_image_placeholder_template,
|
||||
CHAPTER_CATEGORIES,
|
||||
STAGE_TO_ORDER,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
return getattr(provider, "langchain_llm", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
class MemoryAgent:
|
||||
"""回忆录整理 Agent(支持异步)"""
|
||||
|
||||
def __init__(self):
|
||||
self.llm = _get_langchain_llm()
|
||||
|
||||
async def classify_chapter(self, segments_text: str) -> str:
|
||||
if not self.llm:
|
||||
return "childhood"
|
||||
try:
|
||||
prompt = get_chapter_classification_prompt(segments_text)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
category = content.strip().lower()
|
||||
if category in CHAPTER_CATEGORIES:
|
||||
return category
|
||||
except Exception as e:
|
||||
logger.error("分类章节失败: %s", e)
|
||||
return "childhood"
|
||||
|
||||
async def rewrite_to_literary(
|
||||
self,
|
||||
segments_text: str,
|
||||
chapter_category: str,
|
||||
existing_content: Optional[str] = None,
|
||||
) -> Dict:
|
||||
if not self.llm:
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": segments_text,
|
||||
"summary": "",
|
||||
"image_suggestions": [],
|
||||
}
|
||||
try:
|
||||
prompt = get_text_rewrite_prompt(
|
||||
segments_text, chapter_category, existing_content or ""
|
||||
)
|
||||
response = await self.llm.ainvoke(prompt)
|
||||
content = response.content if hasattr(response, "content") else str(response)
|
||||
content = content.strip()
|
||||
if content.startswith("```json"):
|
||||
content = content[7:]
|
||||
if content.startswith("```"):
|
||||
content = content[3:]
|
||||
if content.endswith("```"):
|
||||
content = content[:-3]
|
||||
content = content.strip()
|
||||
result = json.loads(content)
|
||||
result["content"] = inject_image_placeholder_template(
|
||||
result.get("content") or ""
|
||||
)
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
raw = response.content if hasattr(response, "content") else str(response)
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": inject_image_placeholder_template(raw),
|
||||
"summary": "",
|
||||
"image_suggestions": [],
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("改写文本失败: %s", e)
|
||||
return {
|
||||
"title": CHAPTER_CATEGORIES.get(chapter_category, "章节"),
|
||||
"content": segments_text,
|
||||
"summary": "",
|
||||
"image_suggestions": [],
|
||||
}
|
||||
|
||||
async def process_segments(
|
||||
self,
|
||||
segments: List[Dict],
|
||||
existing_chapters: Optional[Dict[str, Dict]] = None,
|
||||
) -> Dict[str, Dict]:
|
||||
if existing_chapters is None:
|
||||
existing_chapters = {}
|
||||
segments_by_category: Dict[str, List[str]] = {}
|
||||
for segment in segments:
|
||||
text = segment.get("transcript_text", "")
|
||||
if not text:
|
||||
continue
|
||||
category = await self.classify_chapter(text)
|
||||
if category not in segments_by_category:
|
||||
segments_by_category[category] = []
|
||||
segments_by_category[category].append(text)
|
||||
updated_chapters = existing_chapters.copy()
|
||||
for category, texts in segments_by_category.items():
|
||||
combined_text = "\n\n".join(texts)
|
||||
existing_content = existing_chapters.get(category, {}).get("content", "")
|
||||
result = await self.rewrite_to_literary(
|
||||
combined_text, category, existing_content
|
||||
)
|
||||
updated_chapters[category] = {
|
||||
"title": result.get("title", CHAPTER_CATEGORIES.get(category, "章节")),
|
||||
"content": result.get("content", ""),
|
||||
"summary": result.get("summary", ""),
|
||||
"image_suggestions": result.get("image_suggestions", []),
|
||||
"category": category,
|
||||
"order_index": STAGE_TO_ORDER.get(category, 999),
|
||||
}
|
||||
return updated_chapters
|
||||
Reference in New Issue
Block a user