Files
life-echo/api/app/agents/chat/orchestrator.py

247 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ChatOrchestratorAI 回复用户模块的编排层
负责路由Profile vs Interview、调用 Specialist Agent、统一 Redis 持久化与错误处理
"""
from datetime import datetime
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.helpers import save_message
from app.agents.chat.interview_agent import InterviewAgent
from app.agents.chat.profile_agent import ProfileAgent
from app.agents.state_schema import MemoirStateSchema
from app.core.logging import get_logger
from app.features.memoir.state_service import get_or_create_state
if TYPE_CHECKING:
from app.features.user.models import User
logger = get_logger(__name__)
class ChatOrchestrator:
"""
聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent
统一管理 Redis 写入。
"""
def __init__(self):
self.profile_agent = ProfileAgent()
self.interview_agent = InterviewAgent()
async def process_user_message(
self,
conversation_id: str,
user_message: str,
user: Optional["User"],
conversation, # 用于更新 conversation_stage
is_from_voice: bool,
voice_session_id: Optional[str],
db: AsyncSession,
apply_extracted_profile_fn,
get_missing_profile_fields_fn,
get_filled_profile_fields_fn,
user_message_timestamp: Optional[datetime] = None,
) -> List[str]:
"""
处理用户消息,返回 AI 回复列表。
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent
统一写入 Redis。
"""
# --- 资料收集模式 ---
if user:
missing = get_missing_profile_fields_fn(user)
if missing:
try:
extracted = await self.profile_agent.extract_profile_from_message(
user_message, missing, conversation_id=conversation_id
)
if extracted:
await apply_extracted_profile_fn(user, extracted, db)
remaining = get_missing_profile_fields_fn(user)
filled = get_filled_profile_fields_fn(user)
responses = await self.profile_agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text="\n\n".join(responses),
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
except Exception as e:
logger.error(f"资料收集处理失败: {e}", exc_info=True)
# --- 正式访谈模式 ---
user_id = user.id if user else None
if not user_id:
return ["抱歉,无法识别用户。"]
state = await get_or_create_state(user_id, db)
if conversation and conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
from app.agents.chat.prompts_profile import format_user_profile_context
user_profile_context = ""
if user:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
)
responses = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context=user_profile_context,
)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text="\n\n".join(responses),
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
async def _save_messages(
self,
conversation_id: str,
user_message: str,
response_text: str,
is_from_voice: bool = False,
voice_session_id: Optional[str] = None,
user_message_timestamp: Optional[datetime] = None,
) -> None:
"""统一写入 Human + AI 消息到 Redis"""
human_msg_type = "audio" if is_from_voice else "text"
await save_message(
conversation_id,
"human",
user_message,
message_type=human_msg_type,
voice_session_id=voice_session_id,
timestamp=user_message_timestamp,
)
await save_message(conversation_id, "ai", response_text)
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
):
"""委托 ProfileAgent 提取资料"""
return await self.profile_agent.extract_profile_from_message(
user_message, missing_fields, conversation_id=conversation_id
)
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: dict,
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
"""委托 ProfileAgent 生成资料追问,并写入 Redis"""
responses = await self.profile_agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=missing_fields,
filled_fields=filled_fields,
nickname=nickname,
)
response_text = "\n\n".join(responses)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""委托 ProfileAgent 生成资料收集开场白,并写入 Redis"""
responses = await self.profile_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_fields,
nickname=nickname,
)
response_text = "\n\n".join(responses)
await save_message(conversation_id, "ai", response_text)
return responses
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
) -> List[str]:
"""委托 InterviewAgent 生成访谈回复,并写入 Redis"""
responses = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
response_text = "\n\n".join(responses)
await self._save_messages(
conversation_id=conversation_id,
user_message=user_message,
response_text=response_text,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
)
return responses
def detect_user_stage(self, user_message: str) -> str:
"""委托 InterviewAgent 检测用户阶段"""
return self.interview_agent._detect_user_stage(user_message)
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""委托 InterviewAgent 生成开场白,并写入 Redis"""
responses = await self.interview_agent.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
response_text = "\n\n".join(responses)
await save_message(conversation_id, "ai", response_text)
return responses