"""核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合""" import asyncio import base64 import time import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from app.core.logging import get_logger if TYPE_CHECKING: from app.features.quota.service import QuotaService from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat import ChatOrchestrator from app.agents.chat.reply_limits import segments_from_llm_response from app.core.agent_logging import agent_summary_enabled, log_asr_transcript_result from app.core.business_telemetry import business_span from app.core.config import settings from app.core.cos_url_keys import ( TTS_PRESIGNED_EXPIRES_SEC, extract_cos_object_key_if_owned, ) from app.core.db import AsyncSessionLocal from app.features.conversation.ws.persist import ( persist_message_tts_url_segment, persist_voice_segment_row, ) from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider from app.features.conversation.chat_turn import ( ChatTurnContext, ChatTurnInput, ChatTurnService, ) from app.features.conversation.history_store import ( AI_RESPONSE_SEGMENT_JOIN, ConversationHistoryStore, ) from app.features.conversation.models import Conversation, ConversationMessage, Segment from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.profile_collector import ( apply_extracted_profile, get_filled_profile_fields, get_missing_profile_fields, ) from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws from app.features.memoir.background_runner import BackgroundTaskRunner from app.features.memoir.ingest_scheduler import MemoirIngestScheduler, MemoirTrigger from app.features.memoir.state_service import get_or_create_state from app.features.user.models import User from app.ports.asr import ASRTranscriptionError from app.core.runtime_constants import tts_defaults logger = get_logger(__name__) # 客户端发送 tts_cancel 时递增;process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段 _tts_cancel_epoch: dict[str, int] = {} def bump_tts_cancel_epoch(conversation_id: str) -> None: _tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1 def _tts_epoch_value(conversation_id: str) -> int: return _tts_cancel_epoch.get(conversation_id, 0) def _resolve_user_language(user) -> str: """Return 'en' iff user.language_preference is set to 'en'; default 'zh'.""" raw = getattr(user, "language_preference", "zh") if user is not None else "zh" return "en" if str(raw or "zh").strip().lower() == "en" else "zh" def _tts_object_ext(codec: str) -> str: c = (codec or "mp3").lower().lstrip(".") if c in ("wave",): return "wav" return c if c else "mp3" def _tts_codec_to_content_type(codec: str) -> str: c = (codec or "mp3").lower().lstrip(".") if c == "mp3": return "audio/mpeg" if c in ("wav", "wave"): return "audio/wav" return "application/octet-stream" async def _send_tts_audio( conversation_id: str, text: str, *, chunk_index: int, chunk_total: int, assistant_message_id: str | None, tts_epoch_start: int, manual: bool = False, language: str = "zh", ) -> str | None: """Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None.""" current_epoch = _tts_epoch_value(conversation_id) # enable_tts:仅禁用「助手回复自动生成 TTS」(want_tts 路径);用户点喇叭(manual=True)仍可合成。 if not manual and not settings.enable_tts: return None if current_epoch != tts_epoch_start: return None try: tts = get_tts_provider() audio_bytes = await tts.synthesize(text, language=language) if not audio_bytes: logger.warning( "TTS skipped: synthesize returned empty conversation_id={} chunk_index={} " "language={} text_preview={!r} voice_provider={}", conversation_id, chunk_index, language, (text or "")[:30], tts_defaults.provider, ) return None if _tts_epoch_value(conversation_id) != tts_epoch_start: return None ext = _tts_object_ext(tts_defaults.codec) content_type = _tts_codec_to_content_type(tts_defaults.codec) storage = get_object_storage() key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}" public_url = storage.upload(key, audio_bytes, content_type) # 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC) payload_data: Dict[str, Any] = { "format": tts_defaults.codec, "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), "audio_url": playback_url, "index": chunk_index, "total": chunk_total, } if assistant_message_id: payload_data["assistant_message_id"] = assistant_message_id if manual: payload_data["manual"] = True await manager.send_message( conversation_id, { "type": MessageType.TTS_AUDIO, "conversation_id": conversation_id, "data": payload_data, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return public_url except Exception as e: err_str = str(e) if "PkgExhausted" in err_str: logger.warning( "TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: {}", err_str[:100], ) else: logger.error("TTS synthesize failed: {}", e) return None async def handle_tts_request_on_demand( *, conversation_id: str, user_id: str, assistant_message_id: str, segment_index: int, segment_text: str | None, db: AsyncSession, ) -> tuple[bool, str]: """用户点喇叭:该段已有 TTS 则预签名下发;否则合成后落库并下发。不重复合成同一段。""" logger.info( "pipeline.handle_tts_request_on_demand entry conversation_id={} user_id={} " "assistant_message_id={} segment_index={} segment_text_len={} enable_tts={} provider={}", conversation_id, user_id, assistant_message_id, segment_index, len(segment_text or ""), settings.enable_tts, tts_defaults.provider, ) conv = await db.get(Conversation, conversation_id) if not conv or conv.user_id != user_id or conv.deleted_at is not None: logger.debug( "pipeline.handle_tts_request_on_demand result ok=False reason=对话不存在或无权访问 " "conversation_id={} user_id={}", conversation_id, user_id, ) return False, "对话不存在或无权访问" msg = await db.get(ConversationMessage, assistant_message_id) if not msg or msg.conversation_id != conversation_id or msg.role != "ai": logger.debug( "pipeline.handle_tts_request_on_demand result ok=False reason=消息不存在 " "conversation_id={} assistant_message_id={}", conversation_id, assistant_message_id, ) return False, "消息不存在" # 与客户端 splitMessageParts / segments_from_llm_response 对齐(含无 [SPLIT] 时的段落拆段) parts = segments_from_llm_response(msg.content or "", max_segments=3) if segment_index < 0 or segment_index >= len(parts): return False, "分段序号无效" canon = (parts[segment_index] or "").strip() if not canon: return False, "该段无朗读文本" if segment_text and segment_text.strip() and segment_text.strip() != canon: logger.debug( "按需 TTS: 客户端传入 segment_text 与规范化后 canon 不完全一致,已按 segment_index 朗读 canon " "(client_len={} canon_len={})", len(segment_text.strip()), len(canon), ) urls: List[str] = [] for x in msg.tts_audio_urls or []: if isinstance(x, str) and x.strip(): urls.append(x) else: urls.append("") while len(urls) < len(parts): urls.append("") existing = urls[segment_index].strip() if segment_index < len(urls) else "" chunk_total = len(parts) if existing: logger.info( "pipeline.handle_tts_request_on_demand reuse existing url conversation_id={} " "assistant_message_id={} segment_index={} url_len={}", conversation_id, assistant_message_id, segment_index, len(existing), ) storage = get_object_storage() key = extract_cos_object_key_if_owned(existing) try: playback_url = ( storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC) if key else existing ) except Exception as exc: logger.warning("按需 TTS 预签名失败: {}", exc) playback_url = existing await manager.send_message( conversation_id, { "type": MessageType.TTS_AUDIO, "conversation_id": conversation_id, "data": { "audio_url": playback_url, "format": tts_defaults.codec, "index": segment_index, "total": chunk_total, "assistant_message_id": assistant_message_id, "manual": True, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) logger.info( "pipeline.handle_tts_request_on_demand result ok=True reason=existing_reused " "conversation_id={} assistant_message_id={} segment_index={}", conversation_id, assistant_message_id, segment_index, ) return True, "" logger.info( "pipeline.handle_tts_request_on_demand no existing url, will synthesize " "conversation_id={} assistant_message_id={} segment_index={} canon_len={}", conversation_id, assistant_message_id, segment_index, len(canon), ) user_obj = await db.get(User, user_id) user_language = _resolve_user_language(user_obj) tts_epoch_start = _tts_epoch_value(conversation_id) url_stored = await _send_tts_audio( conversation_id, canon, chunk_index=segment_index, chunk_total=chunk_total, assistant_message_id=assistant_message_id, tts_epoch_start=tts_epoch_start, manual=True, language=user_language, ) logger.info( "pipeline.handle_tts_request_on_demand _send_tts_audio returned url_stored_set={} " "conversation_id={} assistant_message_id={} segment_index={}", bool(url_stored), conversation_id, assistant_message_id, segment_index, ) if not url_stored: logger.info( "pipeline.handle_tts_request_on_demand result ok=False reason=语音合成失败 " "conversation_id={} assistant_message_id={} segment_index={}", conversation_id, assistant_message_id, segment_index, ) return False, "语音合成失败" await persist_message_tts_url_segment(db, msg, segment_index, url_stored) store = ConversationHistoryStore(db) await store._sync_redis_best_effort(conversation_id) logger.info( "pipeline.handle_tts_request_on_demand result ok=True reason=synthesized " "conversation_id={} assistant_message_id={} segment_index={}", conversation_id, assistant_message_id, segment_index, ) return True, "" # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── chat_orchestrator = ChatOrchestrator() chat_turn_service = ChatTurnService(chat_orchestrator) _background_runner = BackgroundTaskRunner() memoir_ingest_scheduler = MemoirIngestScheduler(_background_runner) async def _schedule_memoir_ingest_for_segment( user_id: str, segment: Segment, *, trigger: MemoirTrigger = "turn", ) -> None: """Queue memoir phase1 after segment text (and ideally lineage) is durable.""" text = (segment.user_input_text or "").strip() if not text: return await memoir_ingest_scheduler.queue_segment( user_id, str(segment.id), text_char_count=len(text), trigger=trigger, ) # ── 分段流状态 ────────────────────────────────────────────────── @dataclass class SegmentStreamState: """会话内分段处理状态(用于并行 ASR + 有序聚合)""" lock: asyncio.Lock = field(default_factory=asyncio.Lock) #: 本条语音会话最近一次分段上行携带的本轮朗读开关(客户端每段一致即可) tts_this_turn: bool = False pending_indices: Set[int] = field(default_factory=set) processed_indices: Set[int] = field(default_factory=set) buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) consumed_index: int = -1 active_tasks: Set[asyncio.Task] = field(default_factory=set) listening_feedback_sent: bool = False listening_feedback_task: Optional[asyncio.Task] = None _segment_states: Dict[Tuple[str, str], SegmentStreamState] = {} _user_response_tasks: Dict[str, Set[asyncio.Task]] = {} _user_response_locks: Dict[str, asyncio.Lock] = {} def _get_user_response_lock(conversation_id: str) -> asyncio.Lock: lock = _user_response_locks.get(conversation_id) if lock is None: lock = asyncio.Lock() _user_response_locks[conversation_id] = lock return lock def register_user_response_task(conversation_id: str, task: asyncio.Task) -> None: tasks = _user_response_tasks.setdefault(conversation_id, set()) tasks.add(task) def _cleanup(done_task: asyncio.Task) -> None: tasks.discard(done_task) if not tasks: _user_response_tasks.pop(conversation_id, None) _user_response_locks.pop(conversation_id, None) if done_task.cancelled(): logger.warning( "用户回复后台任务被取消 conversation_id={}", conversation_id, ) return exc = done_task.exception() if exc: logger.error( "用户回复后台任务异常 conversation_id={}: {}", conversation_id, exc, exc_info=True, ) task.add_done_callback(_cleanup) def get_or_create_segment_state( conversation_id: str, voice_session_id: str, ) -> SegmentStreamState: state_key = (conversation_id, voice_session_id) if state_key not in _segment_states: _segment_states[state_key] = SegmentStreamState() return _segment_states[state_key] def register_segment_task( conversation_id: str, voice_session_id: str, task: asyncio.Task, ) -> None: state_key = (conversation_id, voice_session_id) state = get_or_create_segment_state(conversation_id, voice_session_id) state.active_tasks.add(task) def _cleanup(done_task: asyncio.Task) -> None: state.active_tasks.discard(done_task) if not state.active_tasks and conversation_id not in manager.active_connections: _segment_states.pop(state_key, None) if done_task.cancelled(): return exc = done_task.exception() if exc: logger.error( "分段处理任务异常 " f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}", exc_info=True, ) task.add_done_callback(_cleanup) def cleanup_segment_states(conversation_id: str) -> None: """断开连接后清理无活跃任务的分段状态""" stale_keys = [ key for key, state in _segment_states.items() if key[0] == conversation_id and not state.active_tasks ] for key in stale_keys: _segment_states.pop(key, None) # ── 工具函数 ──────────────────────────────────────────────────── def _utc_now() -> datetime: return datetime.now(timezone.utc) def _voice_session_id_from_client_segment_id( client_segment_id: Optional[str], ) -> Optional[str]: if not client_segment_id: return None session_id, separator, _ = client_segment_id.rpartition("-") if separator and session_id: return session_id return None def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str: """构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。""" return f"audio-segment:{voice_session_id}:{segment_index}" def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]: """从 audio_url 解析 voice_session_id 与 segment_index(audio-segment:{session_id}:{index})。""" prefix = "audio-segment:" if not audio_url or not audio_url.startswith(prefix): return None payload = audio_url[len(prefix) :] voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":") if not separator: return None try: sid = str(voice_session_id_raw).strip() if not sid: return None return (sid, int(segment_index_raw)) except ValueError: return None def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]: scope = _extract_segment_scope(audio_url) if scope: return scope[0] return None def _is_transcribe_failure(transcript_text: Optional[str]) -> bool: if not transcript_text: return True return transcript_text.startswith("转写失败") async def _find_existing_segment_by_index( db: AsyncSession, conversation_id: str, voice_session_id: str, segment_index: int, ) -> Optional[Segment]: segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index) stmt = ( select(Segment) .where( Segment.conversation_id == conversation_id, Segment.audio_url == segment_audio_url, ) .order_by(Segment.created_at.desc()) ) result = await db.execute(stmt) candidates = result.scalars().all() for item in candidates: if ( item.conversation_id == conversation_id and item.audio_url == segment_audio_url ): return item return None async def _get_persisted_contiguous_segment_index( db: AsyncSession, conversation_id: str, voice_session_id: str, ) -> int: """读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。""" stmt = select(Segment).where(Segment.conversation_id == conversation_id) result = await db.execute(stmt) candidates = result.scalars().all() persisted_indices: Set[int] = set() for item in candidates: if item.conversation_id != conversation_id: continue segment_scope = _extract_segment_scope(item.audio_url) if not segment_scope: continue item_voice_session_id, item_index = segment_scope if item_voice_session_id != voice_session_id: continue persisted_indices.add(item_index) contiguous_index = -1 while contiguous_index + 1 in persisted_indices: contiguous_index += 1 return contiguous_index # ── 过渡反馈 ──────────────────────────────────────────────────── LISTENING_FEEDBACK_DELAY_SEC = 5.0 LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。" async def _send_segment_transition_feedback( conversation_id: str, segment_index: int, ) -> None: """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。""" await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": LISTENING_FEEDBACK_TEXT, "transition": True, "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) async def _delayed_listening_feedback( conversation_id: str, voice_session_id: str, ) -> None: """录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。""" await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC) state = get_or_create_segment_state(conversation_id, voice_session_id) async with state.lock: if state.listening_feedback_sent: return state.listening_feedback_sent = True state.listening_feedback_task = None await _send_segment_transition_feedback(conversation_id, 0) # ── 分段语音异步处理 ──────────────────────────────────────────── async def process_audio_segment( conversation_id: str, user_id: str, voice_session_id: str, segment_index: int, audio_base64: str, audio_duration: int, is_last: bool, *, tts_this_turn: bool = False, ) -> None: """分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。""" state = get_or_create_segment_state(conversation_id, voice_session_id) async with state.lock: state.tts_this_turn = bool(tts_this_turn) logger.info( "process_audio_segment 开始: conversation_id={} voice_session_id={} " "segment_index={} is_last={} duration_s={} audio_b64_len={}", conversation_id, voice_session_id, segment_index, is_last, audio_duration, len(audio_base64 or ""), ) try: async with AsyncSessionLocal() as db: conversation = await db.get(Conversation, conversation_id) user = await db.get(User, user_id) if not conversation or conversation.deleted_at is not None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话不存在,分段处理已取消"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return if not user: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "用户不存在,分段处理已取消"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return async with state.lock: should_prime_state = ( state.consumed_index < 0 and not state.processed_indices and not state.buffered_transcripts ) if should_prime_state: persisted_contiguous_index = ( await _get_persisted_contiguous_segment_index( db=db, conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) if persisted_contiguous_index >= 0: async with state.lock: state.consumed_index = max( state.consumed_index, persisted_contiguous_index ) try: audio_bytes = base64.b64decode(audio_base64) except Exception: audio_bytes = b"" if not audio_bytes: logger.warning( "process_audio_segment: 解码后音频为空 conversation_id={} segment_index={}", conversation_id, segment_index, ) try: asr = get_asr_provider() transcript_text = await asr.transcribe(audio_bytes, format="m4a") if transcript_text: log_asr_transcript_result( logger, text=transcript_text, conversation_id=conversation_id, voice_session_id=voice_session_id, segment_index=segment_index, duration_s=audio_duration, audio_len=len(audio_bytes), source="audio_segment", ) except ASRTranscriptionError as e: logger.warning( "ASR 转写失败 segment_index={} conversation_id={}: {}", segment_index, conversation_id, e, ) transcript_text = "" await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": transcript_text or "", "audio_duration": audio_duration, "voice_session_id": voice_session_id, "segment_index": segment_index, "is_last": is_last, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if _is_transcribe_failure(transcript_text): detail = (transcript_text or "").strip() if not detail: user_msg = f"分段 {segment_index} 未识别到语音内容,请重试或检查麦克风与网络" else: user_msg = f"分段 {segment_index} 语音识别失败,请稍后再试" await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": user_msg, "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return existing_segment = await _find_existing_segment_by_index( db=db, conversation_id=conversation_id, voice_session_id=voice_session_id, segment_index=segment_index, ) if existing_segment: async with state.lock: state.processed_indices.add(segment_index) logger.debug( "分段已存在,按幂等跳过: conversation_id={} voice_session_id={} " "segment_index={} segment_id={} transcript={}", conversation_id, voice_session_id, segment_index, existing_segment.id, existing_segment.user_input_text or "", ) return else: segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, user_input_text=transcript_text or "", audio_url=_build_segment_audio_url(voice_session_id, segment_index), audio_duration_seconds=audio_duration if audio_duration > 0 else None, processed=False, ) await persist_voice_segment_row(db, segment, conversation) user_message_timestamp = conversation.last_message_at await db.refresh(segment) ready_segments: List[Tuple[int, str, Segment]] = [] tts_flag_this_voice_session = False async with state.lock: state.processed_indices.add(segment_index) state.buffered_transcripts[segment_index] = ( transcript_text or "", segment, ) next_index = state.consumed_index + 1 while next_index in state.buffered_transcripts: text, seg = state.buffered_transcripts.pop(next_index) ready_segments.append((next_index, text, seg)) state.consumed_index = next_index next_index += 1 tts_flag_this_voice_session = bool(state.tts_this_turn) for _, ordered_text, ordered_segment in ready_segments: await process_user_message( conversation_id=conversation_id, user_message=ordered_text, conversation=conversation, segment=ordered_segment, db=db, user=user, user_message_timestamp=ordered_segment.created_at or user_message_timestamp, tts_this_turn=tts_flag_this_voice_session, ) except Exception as e: logger.error( f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}", exc_info=True, ) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音分段处理遇到问题,请重试", "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) finally: async with state.lock: state.pending_indices.discard(segment_index) # ── 用户消息处理 ──────────────────────────────────────────────── async def process_persisted_user_segment_response( *, conversation_id: str, user_id: str, segment_id: str, tts_this_turn: bool = False, ) -> None: """后台继续生成已落库用户段落的助手回复;即使 WS 页面退出也要完成落库。""" lock = _get_user_response_lock(conversation_id) async with lock: async with AsyncSessionLocal() as db: conversation = await db.get(Conversation, conversation_id) user = await db.get(User, user_id) segment = await db.get(Segment, segment_id) if ( not conversation or conversation.deleted_at is not None or conversation.user_id != user_id or not user or not segment or segment.conversation_id != conversation_id ): logger.warning( "跳过用户回复后台任务: conversation_id={} segment_id={} user_id={}", conversation_id, segment_id, user_id, ) return await process_user_message( conversation_id=conversation_id, user_message=segment.user_input_text or "", conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or conversation.last_message_at, tts_this_turn=tts_this_turn, ) async def process_user_message( conversation_id: str, user_message: str, conversation: Conversation, segment: Segment, db: AsyncSession, user: User = None, user_message_timestamp: Optional[datetime] = None, *, force_skip_tts: bool = False, tts_this_turn: Optional[bool] = None, memoir_trigger: MemoirTrigger = "turn", schedule_memoir: bool = True, ) -> None: """处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。""" with business_span("conversation.ws.process_turn"): await _process_user_message_inner( conversation_id, user_message, conversation, segment, db, user, user_message_timestamp, force_skip_tts=force_skip_tts, tts_this_turn=tts_this_turn, memoir_trigger=memoir_trigger, schedule_memoir=schedule_memoir, ) async def _process_user_message_inner( conversation_id: str, user_message: str, conversation: Conversation, segment: Segment, db: AsyncSession, user: User = None, user_message_timestamp: Optional[datetime] = None, *, force_skip_tts: bool = False, tts_this_turn: Optional[bool] = None, memoir_trigger: MemoirTrigger = "turn", schedule_memoir: bool = True, ) -> None: store = ConversationHistoryStore(db) tts_urls: list[str] = [] user_language = _resolve_user_language(user) try: logger.info( "process_user_message 开始: conversation_id={} segment_id={} user_chars={}", conversation_id, segment.id, len(user_message or ""), ) is_from_voice = bool(segment.audio_url) voice_session_id = _voice_session_id_from_audio_url(segment.audio_url) audio_dur = getattr(segment, "audio_duration_seconds", None) t_pipeline = time.perf_counter() turn = await chat_turn_service.process_turn( ChatTurnInput( conversation_id=conversation_id, user_message=user_message, is_from_voice=is_from_voice, voice_session_id=voice_session_id, user_message_timestamp=user_message_timestamp, audio_duration_seconds=audio_dur, force_skip_tts=force_skip_tts, ), ChatTurnContext( db=db, user=user, conversation=conversation, apply_extracted_profile_fn=apply_extracted_profile, get_missing_profile_fields_fn=get_missing_profile_fields, get_filled_profile_fields_fn=get_filled_profile_fields, ), ) responses = turn.messages skip_tts = bool(turn.skip_tts) want_voice = bool(tts_this_turn) if tts_this_turn is not None else False want_tts = want_voice and settings.enable_tts and not skip_tts if agent_summary_enabled(): logger.info( "pipeline.process_user_message duration_ms={:.2f} " "conversation_id={} segment_id={} user_msg_len={} " "response_segments={} skip_tts={} want_tts={}", (time.perf_counter() - t_pipeline) * 1000, conversation_id, segment.id, len(user_message or ""), len(turn.messages), turn.skip_tts, want_tts, ) agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) turn_ids = await store.record_human_ai_turn_with_segment( conversation_id=conversation_id, user_message=user_message, responses=responses, segment=segment, user_message_timestamp=user_message_timestamp, is_from_voice=is_from_voice, voice_session_id=voice_session_id, audio_duration_seconds=audio_dur, agent_response=agent_response, memory_retrieval_trace=turn.memory_retrieval_trace, ) if not turn_ids: logger.warning( "process_user_message: 无有效助手段落(responses 为空),conversation_id={} segment_id={}", conversation_id, segment.id, ) if conversation_id in manager.active_connections: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "未生成回复,请重试或稍后再试", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) owner_id = (user.id if user is not None else None) or conversation.user_id if schedule_memoir: await _schedule_memoir_ingest_for_segment( owner_id, segment, trigger=memoir_trigger, ) return owner_id = (user.id if user is not None else None) or conversation.user_id if schedule_memoir: await _schedule_memoir_ingest_for_segment( owner_id, segment, trigger=memoir_trigger, ) ai_msg_id = turn_ids.assistant_message_id tts_epoch_start = _tts_epoch_value(conversation_id) n = len(responses) # tts_cancelled 仅用于跳过后续 TTS 合成;AGENT_RESPONSE 必须为每段完整下发, # 否则 FE 会停留在 "正在回复…" 或丢失尾段文本。 tts_cancelled = False for i, response_text in enumerate(responses): url_for_segment: Optional[str] = None if want_tts and not tts_cancelled: if _tts_epoch_value(conversation_id) != tts_epoch_start: tts_cancelled = True logger.info( "pipeline.process_user_message segment={}/{} tts_branch=skip_cancelled " "tts_cancelled={} conversation_id={}", i, n, tts_cancelled, conversation_id, ) else: logger.info( "pipeline.process_user_message segment={}/{} tts_branch=synthesize " "tts_cancelled={} conversation_id={}", i, n, tts_cancelled, conversation_id, ) url_for_segment = await _send_tts_audio( conversation_id, response_text, chunk_index=i, chunk_total=n, assistant_message_id=ai_msg_id, tts_epoch_start=tts_epoch_start, language=user_language, ) if url_for_segment: tts_urls.append(url_for_segment) if _tts_epoch_value(conversation_id) != tts_epoch_start: tts_cancelled = True else: logger.info( "pipeline.process_user_message segment={}/{} tts_branch={} " "tts_cancelled={} want_tts={} conversation_id={}", i, n, "skip_cancelled" if tts_cancelled else "skip_no_tts", tts_cancelled, want_tts, conversation_id, ) await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": response_text, "index": i, "total": n, "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < n - 1: await asyncio.sleep(0.5) if user is not None: try: fresh_memoir = await get_or_create_state(user.id, db) await maybe_send_topic_chips_ws( conversation_id, user=user, memoir_state=fresh_memoir, reason="after_assistant_turn", language=user_language, ) except Exception as chip_err: logger.warning("after-turn topic chips skipped: {}", chip_err) if tts_urls: await store.attach_ai_tts_for_turn( conversation_id, tts_audio_urls=tts_urls, segment=segment, ) except Exception as e: if tts_urls: try: await store.attach_ai_tts_for_turn( conversation_id, tts_audio_urls=tts_urls, segment=segment, ) except Exception as persist_error: logger.warning("补写 TTS 元数据失败: {}", persist_error) logger.exception("处理用户消息失败: {}", e) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "生成回应时遇到问题,请稍后再试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning("发送错误消息失败: {}", send_error) # ── 对话结束处理 ──────────────────────────────────────────────── async def process_conversation_segments( conversation_id: str, db: AsyncSession, quota_service: "QuotaService" ): """ 对话结束时:把本对话仍待 Phase1 的段落交给回忆录管线。 经 `MemoirIngestScheduler.flush_pending` 将内存防抖 batch 与当前查询到的 `topic_category IS NULL` 段 ID 合并、去重后**单次**提交 `process_memoir_phase1`, 并在 flush 末尾触发待叙事 Phase2 派发;避免会话结束路径与 debounce flush 双发 Phase1。 配额检查通过注入的 `quota_service` 完成,不直接 import quota 内部函数。 """ conversation = await db.get(Conversation, conversation_id) if not conversation or conversation.deleted_at is not None: return stmt = select(Segment).where( Segment.conversation_id == conversation_id, Segment.processed == False, Segment.topic_category.is_(None), ) result = await db.execute(stmt) segments = result.scalars().all() if not segments: await memoir_ingest_scheduler.flush_pending( conversation.user_id, trigger="conversation_end", ) return user = await db.get(User, conversation.user_id) if user: can_submit, _ = await quota_service.check_can_submit_organize( user.id, user.subscription_type ) if not can_submit: logger.info( f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}" ) await memoir_ingest_scheduler.flush_pending( conversation.user_id, trigger="conversation_end", ) return segment_ids = [seg.id for seg in segments] try: await memoir_ingest_scheduler.flush_pending( conversation.user_id, extra_segment_ids=segment_ids, trigger="conversation_end", ) logger.info( "对话结束,合并批内 segment 与 DB 待分类段,单次提交 Phase1: " "conversation_id={} segments={}", conversation_id, len(segment_ids), ) except Exception as e: logger.error("提交 Celery 任务失败: {}", e)