Files
life-echo/api/app/agents/chat/conversation_agent.py
2026-03-20 17:25:42 +08:00

148 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
对话 AgentFacade内部委托 ChatOrchestrator + ProfileAgent + InterviewAgent
保留原有对外 API供 router 等调用方兼容使用
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.agents.chat.agent_turn import AgentChatTurn
from app.agents.chat.orchestrator import ChatOrchestrator
from app.agents.chat.prompts_conversation import ConversationStage
from app.agents.state_schema import MemoirStateSchema
from app.core.redis import redis_service
class ConversationAgent:
"""对话 Agent Facade委托 ChatOrchestrator 实现多 Agent 协同"""
def __init__(self):
self._orchestrator = ChatOrchestrator()
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
) -> Dict[str, Any]:
"""委托 ChatOrchestrator/ProfileAgent 提取资料"""
return await self._orchestrator.extract_profile_from_message(
user_message, missing_fields, conversation_id=conversation_id
)
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: Dict[str, str],
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
audio_duration_seconds: int | None = None,
) -> List[str]:
"""委托 ChatOrchestrator/ProfileAgent 生成资料追问"""
return await self._orchestrator.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=missing_fields,
filled_fields=filled_fields,
nickname=nickname,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_duration_seconds,
)
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
) -> List[str]:
"""委托 ChatOrchestrator/ProfileAgent 生成资料收集开场白"""
return await self._orchestrator.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_fields,
nickname=nickname,
)
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
audio_duration_seconds: int | None = None,
) -> AgentChatTurn:
"""委托 ChatOrchestrator/InterviewAgent 生成访谈回复"""
return await self._orchestrator.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
is_from_voice=is_from_voice,
voice_session_id=voice_session_id,
user_message_timestamp=user_message_timestamp,
audio_duration_seconds=audio_duration_seconds,
)
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
) -> List[str]:
"""委托 ChatOrchestrator/InterviewAgent 生成开场白"""
return await self._orchestrator.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
)
async def generate_response(
self,
conversation_id: str,
user_message: str,
current_stage: Optional[ConversationStage] = None,
covered_topics: Optional[List[str]] = None,
) -> str:
"""兼容旧 API生成简单回复无状态感知委托 InterviewAgent 的等价逻辑"""
from app.agents.state_schema import default_state
state = default_state()
state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value
state.covered_stages = covered_topics or []
turn = await self._orchestrator.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state,
user_profile_context="",
)
return turn.messages[0] if turn.messages else ""
def detect_stage(
self, conversation_id: str, user_message: str
) -> ConversationStage:
"""根据关键词检测用户阶段(兼容 API"""
detected = self._orchestrator.detect_user_stage(user_message)
if detected == "childhood":
return ConversationStage.CHILDHOOD
if detected == "education":
return ConversationStage.EDUCATION
if detected == "career":
return ConversationStage.CAREER
if detected == "family":
return ConversationStage.FAMILY
if detected == "belief":
return ConversationStage.BELIEFS
return ConversationStage.CHILDHOOD
async def clear_memory(self, conversation_id: str) -> None:
"""清除 Redis 中的对话历史"""
await redis_service.clear_conversation_history(conversation_id)