修复一些已知问题
This commit is contained in:
14
api/app/agents/chat/agent_turn.py
Normal file
14
api/app/agents/chat/agent_turn.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""一轮 AI 对话输出:分段文案 + 是否整轮跳过 TTS(如失败兜底)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentChatTurn:
|
||||
"""与 WebSocket pipeline 对齐:messages 为气泡分段;skip_tts 为 True 时不合成语音。"""
|
||||
|
||||
messages: List[str]
|
||||
skip_tts: bool = False
|
||||
@@ -6,6 +6,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agents.chat.agent_turn import AgentChatTurn
|
||||
from app.agents.chat.orchestrator import ChatOrchestrator
|
||||
from app.agents.chat.prompts_conversation import ConversationStage
|
||||
from app.agents.state_schema import MemoirStateSchema
|
||||
@@ -77,7 +78,7 @@ class ConversationAgent:
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
audio_duration_seconds: int | None = None,
|
||||
) -> List[str]:
|
||||
) -> AgentChatTurn:
|
||||
"""委托 ChatOrchestrator/InterviewAgent 生成访谈回复"""
|
||||
return await self._orchestrator.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
@@ -116,13 +117,13 @@ class ConversationAgent:
|
||||
state = default_state()
|
||||
state.current_stage = (current_stage or ConversationStage.CHILDHOOD).value
|
||||
state.covered_stages = covered_topics or []
|
||||
responses = await self._orchestrator.generate_response_with_state(
|
||||
turn = await self._orchestrator.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state,
|
||||
user_profile_context="",
|
||||
)
|
||||
return responses[0] if responses else ""
|
||||
return turn.messages[0] if turn.messages else ""
|
||||
|
||||
def detect_stage(
|
||||
self, conversation_id: str, user_message: str
|
||||
|
||||
@@ -5,6 +5,7 @@ InterviewAgent:正式访谈 Specialist
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from app.agents.chat.agent_turn import AgentChatTurn
|
||||
from app.core.dependencies import get_llm_provider
|
||||
from app.core.logging import get_logger
|
||||
|
||||
@@ -18,6 +19,9 @@ from app.agents.state_schema import MemoirStateSchema
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# LLM 不可用或调用失败时对用户展示(不暴露异常细节、不触发 TTS)
|
||||
_FALLBACK_REPLY = "刚才网络不太稳,没接上。你可以再说一遍,或稍后再试。"
|
||||
|
||||
|
||||
def _get_langchain_llm():
|
||||
try:
|
||||
@@ -149,12 +153,11 @@ class InterviewAgent:
|
||||
user_message: str,
|
||||
memoir_state: MemoirStateSchema,
|
||||
user_profile_context: str = "",
|
||||
) -> List[str]:
|
||||
) -> AgentChatTurn:
|
||||
"""生成状态感知的访谈回复,不持久化(由 Orchestrator 负责)"""
|
||||
if not self.llm:
|
||||
return [
|
||||
"抱歉,LLM 服务未配置。请设置 DEEPSEEK_API_KEY 或 LLM_API_KEY 环境变量。"
|
||||
]
|
||||
logger.warning("InterviewAgent: LLM 未配置,返回兜底文案")
|
||||
return AgentChatTurn(messages=[_FALLBACK_REPLY], skip_tts=True)
|
||||
try:
|
||||
empty_slots = memoir_state.empty_slots_for_current_stage()
|
||||
filled_slots = {
|
||||
@@ -191,10 +194,11 @@ class InterviewAgent:
|
||||
messages = [
|
||||
msg.strip() for msg in response_text.split("[SPLIT]") if msg.strip()
|
||||
]
|
||||
return messages[:3] if messages else [response_text]
|
||||
out = messages[:3] if messages else [response_text]
|
||||
return AgentChatTurn(messages=out, skip_tts=False)
|
||||
except Exception as e:
|
||||
logger.error("生成回应失败: %s", e)
|
||||
return [f"抱歉,生成回应时出现错误: {str(e)}"]
|
||||
logger.error("生成回应失败: %s", e, exc_info=True)
|
||||
return AgentChatTurn(messages=[_FALLBACK_REPLY], skip_tts=True)
|
||||
|
||||
async def generate_opening_message(
|
||||
self,
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.agents.chat.agent_turn import AgentChatTurn
|
||||
from app.agents.chat.helpers import save_message
|
||||
from app.agents.chat.interview_agent import InterviewAgent
|
||||
from app.agents.chat.profile_agent import ProfileAgent
|
||||
@@ -20,6 +21,10 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_UNAUTH_TURN = AgentChatTurn(
|
||||
messages=["暂时没法继续对话,请先登录后再试。"], skip_tts=True
|
||||
)
|
||||
|
||||
|
||||
class ChatOrchestrator:
|
||||
"""
|
||||
@@ -45,9 +50,9 @@ class ChatOrchestrator:
|
||||
get_filled_profile_fields_fn,
|
||||
user_message_timestamp: Optional[datetime] = None,
|
||||
audio_duration_seconds: Optional[int] = None,
|
||||
) -> List[str]:
|
||||
) -> AgentChatTurn:
|
||||
"""
|
||||
处理用户消息,返回 AI 回复列表。
|
||||
处理用户消息,返回 AI 回复(分段 + 是否跳过 TTS)。
|
||||
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent,
|
||||
统一写入 Redis。
|
||||
"""
|
||||
@@ -81,14 +86,14 @@ class ChatOrchestrator:
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
audio_duration_seconds=audio_duration_seconds,
|
||||
)
|
||||
return responses
|
||||
return AgentChatTurn(messages=responses, skip_tts=False)
|
||||
except Exception as e:
|
||||
logger.error(f"资料收集处理失败: {e}", exc_info=True)
|
||||
|
||||
# --- 正式访谈模式 ---
|
||||
user_id = user.id if user else None
|
||||
if not user_id:
|
||||
return ["抱歉,无法识别用户。"]
|
||||
return _UNAUTH_TURN
|
||||
|
||||
state = await get_or_create_state(user_id, db)
|
||||
if conversation and conversation.conversation_stage != state.current_stage:
|
||||
@@ -106,7 +111,7 @@ class ChatOrchestrator:
|
||||
occupation=user.occupation,
|
||||
)
|
||||
|
||||
responses = await self.interview_agent.generate_response_with_state(
|
||||
turn = await self.interview_agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state,
|
||||
@@ -115,13 +120,13 @@ class ChatOrchestrator:
|
||||
await self._save_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
response_text="\n\n".join(responses),
|
||||
response_text="\n\n".join(turn.messages),
|
||||
is_from_voice=is_from_voice,
|
||||
voice_session_id=voice_session_id,
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
audio_duration_seconds=audio_duration_seconds,
|
||||
)
|
||||
return responses
|
||||
return turn
|
||||
|
||||
async def _save_messages(
|
||||
self,
|
||||
@@ -222,15 +227,15 @@ class ChatOrchestrator:
|
||||
voice_session_id: str | None = None,
|
||||
user_message_timestamp: datetime | None = None,
|
||||
audio_duration_seconds: int | None = None,
|
||||
) -> List[str]:
|
||||
) -> AgentChatTurn:
|
||||
"""委托 InterviewAgent 生成访谈回复,并写入 Redis"""
|
||||
responses = await self.interview_agent.generate_response_with_state(
|
||||
turn = await self.interview_agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=memoir_state,
|
||||
user_profile_context=user_profile_context,
|
||||
)
|
||||
response_text = "\n\n".join(responses)
|
||||
response_text = "\n\n".join(turn.messages)
|
||||
await self._save_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
@@ -240,7 +245,7 @@ class ChatOrchestrator:
|
||||
user_message_timestamp=user_message_timestamp,
|
||||
audio_duration_seconds=audio_duration_seconds,
|
||||
)
|
||||
return responses
|
||||
return turn
|
||||
|
||||
def detect_user_stage(self, user_message: str) -> str:
|
||||
"""委托 InterviewAgent 检测用户阶段"""
|
||||
|
||||
Reference in New Issue
Block a user