140 lines
5.2 KiB
Python
140 lines
5.2 KiB
Python
"""
|
||
对话 Agent:Facade,内部委托 ChatOrchestrator + ProfileAgent + InterviewAgent
|
||
保留原有对外 API,供 router 等调用方兼容使用
|
||
"""
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from app.agents.chat.orchestrator import ChatOrchestrator
|
||
from app.agents.chat.prompts_conversation import ConversationStage
|
||
from app.agents.state_schema import MemoirStateSchema
|
||
from app.core.redis import redis_service
|
||
|
||
|
||
class ConversationAgent:
|
||
"""对话 Agent Facade,委托 ChatOrchestrator 实现多 Agent 协同"""
|
||
|
||
def __init__(self):
|
||
self._orchestrator = ChatOrchestrator()
|
||
|
||
async def extract_profile_from_message(
|
||
self,
|
||
user_message: str,
|
||
missing_fields: List[str],
|
||
conversation_id: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""委托 ChatOrchestrator/ProfileAgent 提取资料"""
|
||
return await self._orchestrator.extract_profile_from_message(
|
||
user_message, missing_fields, conversation_id=conversation_id
|
||
)
|
||
|
||
async def generate_profile_followup(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
missing_fields: List[str],
|
||
filled_fields: Dict[str, str],
|
||
nickname: str = "",
|
||
is_from_voice: bool = False,
|
||
voice_session_id: str | None = None,
|
||
user_message_timestamp: datetime | None = None,
|
||
) -> List[str]:
|
||
"""委托 ChatOrchestrator/ProfileAgent 生成资料追问"""
|
||
return await self._orchestrator.generate_profile_followup(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
missing_fields=missing_fields,
|
||
filled_fields=filled_fields,
|
||
nickname=nickname,
|
||
is_from_voice=is_from_voice,
|
||
voice_session_id=voice_session_id,
|
||
user_message_timestamp=user_message_timestamp,
|
||
)
|
||
|
||
async def generate_profile_greeting(
|
||
self,
|
||
conversation_id: str,
|
||
missing_fields: List[str],
|
||
nickname: str = "",
|
||
) -> List[str]:
|
||
"""委托 ChatOrchestrator/ProfileAgent 生成资料收集开场白"""
|
||
return await self._orchestrator.generate_profile_greeting(
|
||
conversation_id=conversation_id,
|
||
missing_fields=missing_fields,
|
||
nickname=nickname,
|
||
)
|
||
|
||
async def generate_response_with_state(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
memoir_state: MemoirStateSchema,
|
||
user_profile_context: str = "",
|
||
is_from_voice: bool = False,
|
||
voice_session_id: str | None = None,
|
||
user_message_timestamp: datetime | None = None,
|
||
) -> List[str]:
|
||
"""委托 ChatOrchestrator/InterviewAgent 生成访谈回复"""
|
||
return await self._orchestrator.generate_response_with_state(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
memoir_state=memoir_state,
|
||
user_profile_context=user_profile_context,
|
||
is_from_voice=is_from_voice,
|
||
voice_session_id=voice_session_id,
|
||
user_message_timestamp=user_message_timestamp,
|
||
)
|
||
|
||
async def generate_opening_message(
|
||
self,
|
||
conversation_id: str,
|
||
memoir_state: MemoirStateSchema,
|
||
user_profile_context: str = "",
|
||
) -> List[str]:
|
||
"""委托 ChatOrchestrator/InterviewAgent 生成开场白"""
|
||
return await self._orchestrator.generate_opening_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=memoir_state,
|
||
user_profile_context=user_profile_context,
|
||
)
|
||
|
||
async def generate_response(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
current_stage: Optional[ConversationStage] = None,
|
||
covered_topics: Optional[List[str]] = None,
|
||
) -> str:
|
||
"""兼容旧 API:生成简单回复(无状态感知),委托 InterviewAgent 的等价逻辑"""
|
||
from app.agents.state_schema import default_state
|
||
|
||
state = default_state()
|
||
state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value
|
||
state.covered_stages = covered_topics or []
|
||
responses = await self._orchestrator.generate_response_with_state(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
memoir_state=state,
|
||
user_profile_context="",
|
||
)
|
||
return responses[0] if responses else ""
|
||
|
||
def detect_stage(self, conversation_id: str, user_message: str) -> ConversationStage:
|
||
"""根据关键词检测用户阶段(兼容 API)"""
|
||
detected = self._orchestrator.detect_user_stage(user_message)
|
||
if detected == "childhood":
|
||
return ConversationStage.CHILDHOOD
|
||
if detected == "education":
|
||
return ConversationStage.EDUCATION
|
||
if detected == "career":
|
||
return ConversationStage.CAREER
|
||
if detected == "family":
|
||
return ConversationStage.FAMILY
|
||
if detected == "belief":
|
||
return ConversationStage.BELIEFS
|
||
return ConversationStage.CHILDHOOD
|
||
|
||
async def clear_memory(self, conversation_id: str) -> None:
|
||
"""清除 Redis 中的对话历史"""
|
||
await redis_service.clear_conversation_history(conversation_id)
|