Files
life-echo/api/app/agents/chat/orchestrator.py
Sully 53e0065e3e refactor(api): TOML 配置 SSOT、统一错误契约、Auth/事务加固与可观测性 (#33)
配置 SSOT(TOML + .env)
统一错误契约
Auth 与事务边界
Redis / Celery 可靠性:业务 Redis(DB/0)与 Celery broker/backend(DB/1)显式拆分;连接池、sync client
可观测性(OpenTelemetry + LGTM)
2026-05-22 13:44:50 +08:00

541 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ChatOrchestratorAI 回复用户模块的编排层
负责路由Profile vs Interview、调用 Specialist Agent持久化由 feature 层 ConversationHistoryStore 完成。
"""
import time
from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.agents.chat.agent_turn import AgentChatTurn
from app.agents.chat.helpers import get_history_with_window
from app.agents.chat.interview_agent import InterviewAgent
from app.agents.chat.interview_state_hints import (
build_runtime_interview_state,
extract_scene_cues,
)
from app.agents.chat.profile_agent import ProfileAgent
from app.agents.chat.stage_detection import (
detect_primary_life_stage,
life_stage_display_name,
)
from app.agents.state_schema import MemoirStateSchema
from app.core.agent_logging import agent_summary_enabled, log_agent_detail
from app.core.config import settings
from app.core.dependencies import get_embedding_provider
from app.core.llm_gateway import LlmGateway
from app.core.logging import get_logger
from app.features.conversation.input_normalize import normalize_chat_input_for_agent
from app.features.memoir.state_service import (
get_or_create_state,
save_interview_state_meta,
switch_stage,
)
from app.features.memory.prompt_adapter import MemoryPromptAdapter
from app.features.conversation.constants import chat
from app.features.memory.constants import memory
from app.features.story.constants import story
def _llm_for_chat_input_normalize():
try:
return LlmGateway().langchain_llm_for()
except Exception:
return None
if TYPE_CHECKING:
from app.features.user.models import User
from app.ports.embedding import EmbeddingProvider
from app.ports.llm import LLMProvider
logger = get_logger(__name__)
_UNAUTH_TURN_ZH = AgentChatTurn(
messages=["暂时没法继续对话,请先登录后再试。"], skip_tts=True
)
_UNAUTH_TURN_EN = AgentChatTurn(
messages=["You'll need to sign in again before we can continue."],
skip_tts=True,
)
def _user_language(user: Optional["User"]) -> str:
if not user:
return "zh"
lang = getattr(user, "language_preference", None) or "zh"
return "en" if str(lang).lower() == "en" else "zh"
async def _fetch_interview_memory_bundle(
db: AsyncSession,
user_id: str,
user_message: str,
*,
get_embedding_provider_fn: Callable[[], "EmbeddingProvider"],
) -> tuple[dict | None, object | None]:
"""检索记忆 bundle原始结构是否进主 prompt 由 adapter 再筛。"""
from app.features.memory.retrieval_trace import (
chat_memory_retrieval_trace_from_bundle,
)
from app.features.memory.service import MemoryService
if not chat.memory_retrieval_enabled:
logger.debug(
"event=chat_memory_retrieval_skip reason=disabled user_id={}", user_id
)
return None, None
msg = (user_message or "").strip()
if not msg:
logger.debug(
"event=chat_memory_retrieval_skip reason=empty user_id={}", user_id
)
return None, None
try:
emb = get_embedding_provider_fn()
ms = MemoryService(db, embedding_provider=emb)
top_k = chat.memory_top_k
bundle = await ms.retrieve(user_id, msg, top_k=top_k)
bd = bundle.model_dump()
trace = chat_memory_retrieval_trace_from_bundle(
bd, top_k=top_k, query_len=len(msg)
)
logger.info(
"event=memory_retrieval_bundle user_id={} top_k={}",
user_id,
top_k,
)
return bd, trace
except Exception as e:
try:
await db.rollback()
except Exception as rollback_error:
logger.warning("访谈记忆检索失败后回滚也失败: {}", rollback_error)
logger.warning("访谈记忆检索失败: {}", e)
return None, None
class ChatOrchestrator:
"""
聊天编排器:根据用户资料完成度路由到 ProfileAgent 或 InterviewAgent。
不直接写入 Redis/DB由 WS pipeline / ConversationHistoryStore 落库并同步缓存。
``get_embedding_provider_fn`` / ``llm_provider`` 供测试或脚本注入;默认使用全局依赖。
"""
def __init__(
self,
*,
get_embedding_provider_fn: Callable[[], "EmbeddingProvider"] | None = None,
llm_provider: "LLMProvider | None" = None,
):
self._get_embedding_provider_fn = (
get_embedding_provider_fn or get_embedding_provider
)
self.profile_agent = ProfileAgent(llm_provider=llm_provider)
self.interview_agent = InterviewAgent()
self.memory_prompt_adapter = MemoryPromptAdapter()
async def process_user_message(
self,
conversation_id: str,
user_message: str,
user: Optional["User"],
conversation, # 用于更新 conversation_stage
is_from_voice: bool,
voice_session_id: Optional[str],
db: AsyncSession,
apply_extracted_profile_fn,
get_missing_profile_fields_fn,
get_filled_profile_fields_fn,
user_message_timestamp: Optional[datetime] = None,
audio_duration_seconds: Optional[int] = None,
) -> AgentChatTurn:
"""
处理用户消息,返回 AI 回复(分段 + 是否跳过 TTS
根据 missing_fields 路由到 ProfileAgent 或 InterviewAgent。
"""
t0 = time.perf_counter()
language = _user_language(user)
# --- 资料收集模式 ---
if user:
missing = get_missing_profile_fields_fn(user)
if missing:
hw_profile = await get_history_with_window(
conversation_id,
max_pairs=chat.history_max_pairs,
max_chars=chat.history_max_chars,
)
profile_turn_total = hw_profile.turn_total
if profile_turn_total >= chat.profile_max_turns:
logger.info(
"event=chat_profile_cap_skip conversation_id={} "
"turn_total={} cap={} missing_fields={}",
conversation_id,
profile_turn_total,
chat.profile_max_turns,
missing,
)
else:
try:
log_agent_detail(
logger,
"ChatOrchestrator route=profile conversation_id={} "
"missing_fields={} user_msg_len={} profile_turn_total={}",
conversation_id,
missing,
len(user_message or ""),
profile_turn_total,
)
# Profile 阶段每轮都抽取:短确认语也可能带可推断资料,跳过抽取会导致槽位长期不更新
extracted = (
await self.profile_agent.extract_profile_from_message(
user_message,
missing,
conversation_id=conversation_id,
language=language,
)
)
logger.info(
"event=chat_profile_extract conversation_id={} "
"extracted_keys={} missing_before={}",
conversation_id,
list(extracted.keys()) if extracted else [],
missing,
)
if extracted:
await apply_extracted_profile_fn(user, extracted, db)
remaining = get_missing_profile_fields_fn(user)
filled = get_filled_profile_fields_fn(user)
interview_stage_hint = ""
if not remaining:
st = await get_or_create_state(user.id, db)
interview_stage_hint = life_stage_display_name(
st.current_stage, language=language
)
responses = await self.profile_agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=remaining,
filled_fields=filled,
nickname=user.nickname or "",
interview_stage_hint=interview_stage_hint,
language=language,
)
if agent_summary_enabled():
logger.info(
"ChatOrchestrator.process_user_message route=profile "
"duration_ms={:.2f} conversation_id={} response_segments={}",
(time.perf_counter() - t0) * 1000,
conversation_id,
len(responses),
)
return AgentChatTurn(
messages=responses,
skip_tts=False,
memory_retrieval_trace=None,
)
except Exception as e:
logger.exception("资料收集处理失败: {}", e)
fb_msg = (
"Sorry, I missed that. Could you say it again?"
if language == "en"
else "不好意思刚才没接住,你再说一遍好吗?"
)
return AgentChatTurn(
messages=[fb_msg],
skip_tts=False,
memory_retrieval_trace=None,
)
# --- 正式访谈模式 ---
user_id = user.id if user else None
if not user_id:
if agent_summary_enabled():
logger.info(
"ChatOrchestrator.process_user_message route=unauth "
"duration_ms={:.2f} conversation_id={}",
(time.perf_counter() - t0) * 1000,
conversation_id,
)
return _UNAUTH_TURN_EN if language == "en" else _UNAUTH_TURN_ZH
log_agent_detail(
logger,
"ChatOrchestrator route=interview conversation_id={} user_msg_len={}",
conversation_id,
len(user_message or ""),
)
llm_n = None
if chat.input_normalize_enabled and (
(chat.input_normalize_mode or "").strip().lower() == "llm"
):
llm_n = _llm_for_chat_input_normalize()
normalized_user_message = normalize_chat_input_for_agent(
user_message or "",
llm=llm_n,
is_from_voice=is_from_voice,
)
state = await get_or_create_state(user_id, db)
stage_before = state.current_stage
detected = await detect_primary_life_stage(
normalized_user_message,
state.current_stage,
self.interview_agent.llm,
)
stage_switched_this_turn = detected != stage_before
if stage_switched_this_turn:
state = await switch_stage(user_id, detected, db)
if conversation and conversation.conversation_stage != state.current_stage:
from app.core.db import transactional
async with transactional(db):
conversation.conversation_stage = state.current_stage
from app.agents.chat.background_voice import infer_background_voice
from app.agents.chat.prompts_profile import format_user_profile_context
user_profile_context = ""
background_voice = "default"
occupation = ""
if user:
user_profile_context = format_user_profile_context(
birth_year=user.birth_year,
birth_place=user.birth_place,
grew_up_place=user.grew_up_place,
occupation=user.occupation,
language=language,
)
background_voice = infer_background_voice(user.occupation)
occupation = user.occupation or ""
memory_bundle, mem_trace = await _fetch_interview_memory_bundle(
db,
user_id,
normalized_user_message,
get_embedding_provider_fn=self._get_embedding_provider_fn,
)
mem_slices = self.memory_prompt_adapter.slice_for_interview(
memory_bundle,
normalized_user_message,
)
# 场景关键词仅作为 focus planner 的辅助输入,不直接拼进记忆块,避免抢过用户明确的关系/身份线索
scene_cues_for_planner = extract_scene_cues(normalized_user_message)
profile_birth_year = user.birth_year if user else None
profile_era_place = ""
if user:
profile_era_place = (user.birth_place or user.grew_up_place or "").strip()
prompt_state = build_runtime_interview_state(
state,
user_message=normalized_user_message,
active_stage=detected or state.current_stage,
birth_year=profile_birth_year,
birth_place=(user.birth_place or "").strip() if user else "",
grew_up_place=(user.grew_up_place or "").strip() if user else "",
occupation=occupation,
)
turn = await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=prompt_state,
user_profile_context=user_profile_context,
detected_user_stage=detected,
memory_evidence_text=mem_slices.prompt_excerpt,
memory_anchor_source=mem_slices.anchor_source,
memory_planner_text=mem_slices.planner_preview,
background_voice=background_voice,
normalized_user_message=normalized_user_message,
occupation=occupation,
profile_birth_year=profile_birth_year,
profile_era_place=profile_era_place,
stage_switched_this_turn=stage_switched_this_turn,
scene_cues_for_planner=scene_cues_for_planner,
language=language,
)
recent_questions = prompt_state.recent_questions
if turn.interview_state_meta and isinstance(turn.interview_state_meta, dict):
raw_recent = turn.interview_state_meta.get("recent_questions")
if isinstance(raw_recent, list):
recent_questions = [
str(x).strip() for x in raw_recent if str(x).strip()
]
await save_interview_state_meta(
user_id,
known_facts=prompt_state.known_facts,
persona_threads=prompt_state.persona_threads,
recent_questions=recent_questions,
db=db,
)
if agent_summary_enabled():
logger.info(
"ChatOrchestrator.process_user_message route=interview "
"duration_ms={:.2f} conversation_id={} stage={} response_segments={} skip_tts={}",
(time.perf_counter() - t0) * 1000,
conversation_id,
state.current_stage,
len(turn.messages),
turn.skip_tts,
)
if mem_trace is not None:
return AgentChatTurn(
messages=turn.messages,
skip_tts=turn.skip_tts,
memory_retrieval_trace=mem_trace,
interview_state_meta=turn.interview_state_meta,
)
return turn
async def extract_profile_from_message(
self,
user_message: str,
missing_fields: List[str],
conversation_id: Optional[str] = None,
):
"""委托 ProfileAgent 提取资料"""
return await self.profile_agent.extract_profile_from_message(
user_message, missing_fields, conversation_id=conversation_id
)
async def generate_profile_followup(
self,
conversation_id: str,
user_message: str,
missing_fields: List[str],
filled_fields: dict,
nickname: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
audio_duration_seconds: int | None = None,
language: str = "zh",
) -> List[str]:
"""委托 ProfileAgent 生成资料追问(持久化由调用方负责)。"""
return await self.profile_agent.generate_profile_followup(
conversation_id=conversation_id,
user_message=user_message,
missing_fields=missing_fields,
filled_fields=filled_fields,
nickname=nickname,
language=language,
)
async def generate_profile_greeting(
self,
conversation_id: str,
missing_fields: List[str],
nickname: str = "",
language: str = "zh",
) -> List[str]:
"""委托 ProfileAgent 生成资料收集开场白(持久化由调用方负责)。"""
return await self.profile_agent.generate_profile_greeting(
conversation_id=conversation_id,
missing_fields=missing_fields,
nickname=nickname,
language=language,
)
async def generate_response_with_state(
self,
conversation_id: str,
user_message: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
is_from_voice: bool = False,
voice_session_id: str | None = None,
user_message_timestamp: datetime | None = None,
audio_duration_seconds: int | None = None,
detected_user_stage: str | None = None,
memory_evidence_text: str = "",
memory_anchor_source: str = "",
memory_planner_text: str = "",
background_voice: str = "default",
normalized_user_message: str | None = None,
occupation: str = "",
profile_birth_year: int | None = None,
profile_era_place: str = "",
stage_switched_this_turn: bool = False,
scene_cues_for_planner: Optional[list[str]] = None,
language: str = "zh",
) -> AgentChatTurn:
"""委托 InterviewAgent 生成访谈回复(持久化由调用方负责)。"""
return await self.interview_agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
detected_user_stage=detected_user_stage,
memory_evidence_text=memory_evidence_text,
memory_anchor_source=memory_anchor_source,
memory_planner_text=memory_planner_text,
background_voice=background_voice,
normalized_user_message=normalized_user_message,
occupation=occupation,
profile_birth_year=profile_birth_year,
profile_era_place=profile_era_place,
stage_switched_this_turn=stage_switched_this_turn,
scene_cues_for_planner=scene_cues_for_planner,
language=language,
)
def detect_user_stage(self, user_message: str) -> str:
"""委托 InterviewAgent 检测用户阶段"""
return self.interview_agent._detect_user_stage(user_message)
async def generate_opening_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
user_profile_context: str = "",
background_voice: str = "default",
occupation: str = "",
profile_birth_year: Optional[int] = None,
profile_era_place: str = "",
language: str = "zh",
) -> List[str]:
"""
委托 InterviewAgent 生成访谈开场白(持久化由调用方 ConversationHistoryStore 负责)。
"""
return await self.interview_agent.generate_opening_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
user_profile_context=user_profile_context,
background_voice=background_voice,
occupation=occupation,
profile_birth_year=profile_birth_year,
profile_era_place=profile_era_place,
language=language,
)
async def generate_re_greeting_message(
self,
conversation_id: str,
memoir_state: MemoirStateSchema,
idle_hours: float,
user_profile_context: str = "",
background_voice: str = "default",
occupation: str = "",
profile_birth_year: Optional[int] = None,
profile_era_place: str = "",
language: str = "zh",
) -> List[str]:
"""委托 InterviewAgent 生成老对话回访问候(持久化由调用方负责)。"""
return await self.interview_agent.generate_re_greeting_message(
conversation_id=conversation_id,
memoir_state=memoir_state,
idle_hours=idle_hours,
user_profile_context=user_profile_context,
background_voice=background_voice,
occupation=occupation,
profile_birth_year=profile_birth_year,
profile_era_place=profile_era_place,
language=language,
)