- Drop interview_reply_length and utterance_substance; always run stage LLM and memory retrieval when enabled; trim Settings fields and .env.example. - Replace guided/opening prompts with compact fact blocks plus unified behavior guidance; slim background_voice and persona to tone hints. - InterviewAgent uses fixed chat_interview max_tokens/chars/segments. Also includes stacked work: profile followup/extract path, evaluation rubric and judge schema updates, transcript SPLIT handling in execution service, user export markdown split tests, and golden case fixture.
388 lines
15 KiB
Python
388 lines
15 KiB
Python
"""
|
||
ChatOrchestrator:AI 回复用户模块的编排层
|
||
负责路由(Profile vs Interview)、调用 Specialist Agent;持久化由 feature 层 ConversationHistoryStore 完成。
|
||
"""
|
||
|
||
import time
|
||
from datetime import datetime
|
||
from typing import TYPE_CHECKING, List, Optional
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.agents.chat.agent_turn import AgentChatTurn
|
||
from app.agents.chat.helpers import get_history_with_window
|
||
from app.agents.chat.interview_agent import InterviewAgent
|
||
from app.agents.chat.profile_agent import ProfileAgent
|
||
from app.agents.state_schema import MemoirStateSchema
|
||
from app.core.agent_logging import agent_summary_enabled, log_agent_detail
|
||
from app.core.logging import get_logger
|
||
from app.agents.chat.stage_detection import (
|
||
detect_primary_life_stage,
|
||
life_stage_display_name,
|
||
)
|
||
from app.core.config import settings
|
||
from app.core.dependencies import get_llm_provider
|
||
from app.features.conversation.input_normalize import normalize_chat_input_for_agent
|
||
from app.features.memoir.state_service import get_or_create_state, switch_stage
|
||
|
||
|
||
def _llm_for_chat_input_normalize():
|
||
try:
|
||
p = get_llm_provider()
|
||
return getattr(p, "langchain_llm", None)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
if TYPE_CHECKING:
|
||
from app.features.user.models import User
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
_UNAUTH_TURN = AgentChatTurn(
|
||
messages=["暂时没法继续对话,请先登录后再试。"], skip_tts=True
|
||
)
|
||
|
||
|
||
async def _fetch_interview_memory_evidence(
|
||
db: AsyncSession,
|
||
user_id: str,
|
||
user_message: str,
|
||
) -> str:
|
||
"""按本轮用户话检索记忆,格式化为短文本;失败或未启用时返回空串。"""
|
||
from app.core.dependencies import get_embedding_provider
|
||
from app.features.memory.evidence_format import format_evidence_chunks_for_prompt
|
||
from app.features.memory.service import MemoryService
|
||
|
||
if not settings.chat_memory_retrieval_enabled:
|
||
logger.debug(
|
||
"event=chat_memory_retrieval_skip reason=disabled user_id={}", user_id
|
||
)
|
||
return ""
|
||
msg = (user_message or "").strip()
|
||
if not msg:
|
||
logger.debug(
|
||
"event=chat_memory_retrieval_skip reason=empty user_id={}", user_id
|
||
)
|
||
return ""
|
||
try:
|
||
emb = get_embedding_provider()
|
||
ms = MemoryService(db, embedding_provider=emb)
|
||
bundle = await ms.retrieve(user_id, msg, top_k=settings.chat_memory_top_k)
|
||
bd = bundle.model_dump()
|
||
text = format_evidence_chunks_for_prompt(bd)
|
||
t = (text or "").strip()
|
||
if not t:
|
||
logger.debug(
|
||
"event=memory_evidence_for_prompt user_id={} formatted_chars=0",
|
||
user_id,
|
||
)
|
||
return ""
|
||
max_c = settings.chat_memory_evidence_max_chars
|
||
if len(t) > max_c:
|
||
t = t[: max_c - 3] + "..."
|
||
logger.info(
|
||
"event=memory_evidence_for_prompt user_id={} formatted_chars={}",
|
||
user_id,
|
||
len(t),
|
||
)
|
||
return t
|
||
except Exception as e:
|
||
try:
|
||
await db.rollback()
|
||
except Exception as rollback_error:
|
||
logger.warning("访谈记忆检索失败后回滚也失败: {}", rollback_error)
|
||
logger.warning("访谈记忆检索失败: {}", e)
|
||
return ""
|
||
|
||
|
||
class ChatOrchestrator:
|
||
"""
|
||
聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent。
|
||
不直接写入 Redis/DB;由 WS pipeline / ConversationHistoryStore 落库并同步缓存。
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.profile_agent = ProfileAgent()
|
||
self.interview_agent = InterviewAgent()
|
||
|
||
async def process_user_message(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
user: Optional["User"],
|
||
conversation, # 用于更新 conversation_stage
|
||
is_from_voice: bool,
|
||
voice_session_id: Optional[str],
|
||
db: AsyncSession,
|
||
apply_extracted_profile_fn,
|
||
get_missing_profile_fields_fn,
|
||
get_filled_profile_fields_fn,
|
||
user_message_timestamp: Optional[datetime] = None,
|
||
audio_duration_seconds: Optional[int] = None,
|
||
) -> AgentChatTurn:
|
||
"""
|
||
处理用户消息,返回 AI 回复(分段 + 是否跳过 TTS)。
|
||
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent。
|
||
"""
|
||
t0 = time.perf_counter()
|
||
|
||
# --- 资料收集模式 ---
|
||
if user:
|
||
missing = get_missing_profile_fields_fn(user)
|
||
if missing:
|
||
hw_profile = await get_history_with_window(
|
||
conversation_id,
|
||
max_pairs=settings.chat_history_max_pairs,
|
||
max_chars=settings.chat_history_max_chars,
|
||
)
|
||
profile_turn_total = hw_profile.turn_total
|
||
if profile_turn_total >= settings.chat_profile_max_turns:
|
||
logger.info(
|
||
"event=chat_profile_cap_skip conversation_id={} "
|
||
"turn_total={} cap={} missing_fields={}",
|
||
conversation_id,
|
||
profile_turn_total,
|
||
settings.chat_profile_max_turns,
|
||
missing,
|
||
)
|
||
else:
|
||
try:
|
||
log_agent_detail(
|
||
logger,
|
||
"ChatOrchestrator route=profile conversation_id={} "
|
||
"missing_fields={} user_msg_len={} profile_turn_total={}",
|
||
conversation_id,
|
||
missing,
|
||
len(user_message or ""),
|
||
profile_turn_total,
|
||
)
|
||
# Profile 阶段每轮都抽取:短确认语也可能带可推断资料,跳过抽取会导致槽位长期不更新
|
||
extracted = (
|
||
await self.profile_agent.extract_profile_from_message(
|
||
user_message, missing, conversation_id=conversation_id
|
||
)
|
||
)
|
||
logger.info(
|
||
"event=chat_profile_extract conversation_id={} "
|
||
"extracted_keys={} missing_before={}",
|
||
conversation_id,
|
||
list(extracted.keys()) if extracted else [],
|
||
missing,
|
||
)
|
||
if extracted:
|
||
await apply_extracted_profile_fn(user, extracted, db)
|
||
|
||
remaining = get_missing_profile_fields_fn(user)
|
||
filled = get_filled_profile_fields_fn(user)
|
||
interview_stage_hint = ""
|
||
if not remaining:
|
||
st = await get_or_create_state(user.id, db)
|
||
interview_stage_hint = life_stage_display_name(
|
||
st.current_stage
|
||
)
|
||
responses = await self.profile_agent.generate_profile_followup(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
missing_fields=remaining,
|
||
filled_fields=filled,
|
||
nickname=user.nickname or "",
|
||
interview_stage_hint=interview_stage_hint,
|
||
)
|
||
if agent_summary_enabled():
|
||
logger.info(
|
||
"ChatOrchestrator.process_user_message route=profile "
|
||
"duration_ms={:.2f} conversation_id={} response_segments={}",
|
||
(time.perf_counter() - t0) * 1000,
|
||
conversation_id,
|
||
len(responses),
|
||
)
|
||
return AgentChatTurn(messages=responses, skip_tts=False)
|
||
except Exception as e:
|
||
logger.error(f"资料收集处理失败: {e}", exc_info=True)
|
||
return AgentChatTurn(
|
||
messages=["不好意思刚才没接住,你再说一遍好吗?"],
|
||
skip_tts=False,
|
||
)
|
||
|
||
# --- 正式访谈模式 ---
|
||
user_id = user.id if user else None
|
||
if not user_id:
|
||
if agent_summary_enabled():
|
||
logger.info(
|
||
"ChatOrchestrator.process_user_message route=unauth "
|
||
"duration_ms={:.2f} conversation_id={}",
|
||
(time.perf_counter() - t0) * 1000,
|
||
conversation_id,
|
||
)
|
||
return _UNAUTH_TURN
|
||
|
||
log_agent_detail(
|
||
logger,
|
||
"ChatOrchestrator route=interview conversation_id={} user_msg_len={}",
|
||
conversation_id,
|
||
len(user_message or ""),
|
||
)
|
||
llm_n = None
|
||
if settings.chat_input_normalize_enabled and (
|
||
(settings.chat_input_normalize_mode or "").strip().lower() == "llm"
|
||
):
|
||
llm_n = _llm_for_chat_input_normalize()
|
||
normalized_user_message = normalize_chat_input_for_agent(
|
||
user_message or "",
|
||
llm=llm_n,
|
||
is_from_voice=is_from_voice,
|
||
)
|
||
state = await get_or_create_state(user_id, db)
|
||
detected = await detect_primary_life_stage(
|
||
normalized_user_message,
|
||
state.current_stage,
|
||
self.interview_agent.llm,
|
||
)
|
||
if detected != state.current_stage:
|
||
state = await switch_stage(user_id, detected, db)
|
||
|
||
if conversation and conversation.conversation_stage != state.current_stage:
|
||
conversation.conversation_stage = state.current_stage
|
||
await db.commit()
|
||
|
||
from app.agents.chat.background_voice import infer_background_voice
|
||
from app.agents.chat.prompts_profile import format_user_profile_context
|
||
|
||
user_profile_context = ""
|
||
background_voice = "default"
|
||
occupation = ""
|
||
if user:
|
||
user_profile_context = format_user_profile_context(
|
||
birth_year=user.birth_year,
|
||
birth_place=user.birth_place,
|
||
grew_up_place=user.grew_up_place,
|
||
occupation=user.occupation,
|
||
)
|
||
background_voice = infer_background_voice(user.occupation)
|
||
occupation = user.occupation or ""
|
||
|
||
memory_evidence_text = await _fetch_interview_memory_evidence(
|
||
db, user_id, normalized_user_message
|
||
)
|
||
|
||
turn = await self.interview_agent.generate_response_with_state(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
memoir_state=state,
|
||
user_profile_context=user_profile_context,
|
||
detected_user_stage=detected,
|
||
memory_evidence_text=memory_evidence_text,
|
||
background_voice=background_voice,
|
||
normalized_user_message=normalized_user_message,
|
||
occupation=occupation,
|
||
)
|
||
if agent_summary_enabled():
|
||
logger.info(
|
||
"ChatOrchestrator.process_user_message route=interview "
|
||
"duration_ms={:.2f} conversation_id={} stage={} response_segments={} skip_tts={}",
|
||
(time.perf_counter() - t0) * 1000,
|
||
conversation_id,
|
||
state.current_stage,
|
||
len(turn.messages),
|
||
turn.skip_tts,
|
||
)
|
||
return turn
|
||
|
||
async def extract_profile_from_message(
|
||
self,
|
||
user_message: str,
|
||
missing_fields: List[str],
|
||
conversation_id: Optional[str] = None,
|
||
):
|
||
"""委托 ProfileAgent 提取资料"""
|
||
return await self.profile_agent.extract_profile_from_message(
|
||
user_message, missing_fields, conversation_id=conversation_id
|
||
)
|
||
|
||
async def generate_profile_followup(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
missing_fields: List[str],
|
||
filled_fields: dict,
|
||
nickname: str = "",
|
||
is_from_voice: bool = False,
|
||
voice_session_id: str | None = None,
|
||
user_message_timestamp: datetime | None = None,
|
||
audio_duration_seconds: int | None = None,
|
||
) -> List[str]:
|
||
"""委托 ProfileAgent 生成资料追问(持久化由调用方负责)。"""
|
||
return await self.profile_agent.generate_profile_followup(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
missing_fields=missing_fields,
|
||
filled_fields=filled_fields,
|
||
nickname=nickname,
|
||
)
|
||
|
||
async def generate_profile_greeting(
|
||
self,
|
||
conversation_id: str,
|
||
missing_fields: List[str],
|
||
nickname: str = "",
|
||
) -> List[str]:
|
||
"""委托 ProfileAgent 生成资料收集开场白(持久化由调用方负责)。"""
|
||
return await self.profile_agent.generate_profile_greeting(
|
||
conversation_id=conversation_id,
|
||
missing_fields=missing_fields,
|
||
nickname=nickname,
|
||
)
|
||
|
||
async def generate_response_with_state(
|
||
self,
|
||
conversation_id: str,
|
||
user_message: str,
|
||
memoir_state: MemoirStateSchema,
|
||
user_profile_context: str = "",
|
||
is_from_voice: bool = False,
|
||
voice_session_id: str | None = None,
|
||
user_message_timestamp: datetime | None = None,
|
||
audio_duration_seconds: int | None = None,
|
||
detected_user_stage: str | None = None,
|
||
memory_evidence_text: str = "",
|
||
background_voice: str = "default",
|
||
normalized_user_message: str | None = None,
|
||
occupation: str = "",
|
||
) -> AgentChatTurn:
|
||
"""委托 InterviewAgent 生成访谈回复(持久化由调用方负责)。"""
|
||
return await self.interview_agent.generate_response_with_state(
|
||
conversation_id=conversation_id,
|
||
user_message=user_message,
|
||
memoir_state=memoir_state,
|
||
user_profile_context=user_profile_context,
|
||
detected_user_stage=detected_user_stage,
|
||
memory_evidence_text=memory_evidence_text,
|
||
background_voice=background_voice,
|
||
normalized_user_message=normalized_user_message,
|
||
occupation=occupation,
|
||
)
|
||
|
||
def detect_user_stage(self, user_message: str) -> str:
|
||
"""委托 InterviewAgent 检测用户阶段"""
|
||
return self.interview_agent._detect_user_stage(user_message)
|
||
|
||
async def generate_opening_message(
|
||
self,
|
||
conversation_id: str,
|
||
memoir_state: MemoirStateSchema,
|
||
user_profile_context: str = "",
|
||
background_voice: str = "default",
|
||
occupation: str = "",
|
||
) -> List[str]:
|
||
"""
|
||
委托 InterviewAgent 生成访谈开场白(持久化由调用方 ConversationHistoryStore 负责)。
|
||
"""
|
||
return await self.interview_agent.generate_opening_message(
|
||
conversation_id=conversation_id,
|
||
memoir_state=memoir_state,
|
||
user_profile_context=user_profile_context,
|
||
background_voice=background_voice,
|
||
occupation=occupation,
|
||
)
|