"""核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合""" import asyncio import base64 import io import time import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from app.core.logging import get_logger if TYPE_CHECKING: from app.features.quota.service import QuotaService from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession from app.agents.chat import ChatOrchestrator from app.core.agent_logging import agent_summary_enabled from app.core.config import settings from app.core.cos_url_keys import TTS_PRESIGNED_EXPIRES_SEC from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider, get_object_storage, get_tts_provider from app.features.conversation.history_store import ( AI_RESPONSE_SEGMENT_JOIN, ConversationHistoryStore, ) from app.features.conversation.models import Conversation, Segment from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.profile_collector import ( apply_extracted_profile, get_filled_profile_fields, get_missing_profile_fields, ) from app.features.memoir.background_runner import BackgroundTaskRunner from app.features.user.models import User logger = get_logger(__name__) # 客户端发送 tts_cancel 时递增;process_user_message 内 TTS 循环与合成前后对照,用于短路剩余片段 _tts_cancel_epoch: dict[str, int] = {} def bump_tts_cancel_epoch(conversation_id: str) -> None: _tts_cancel_epoch[conversation_id] = _tts_cancel_epoch.get(conversation_id, 0) + 1 def _tts_epoch_value(conversation_id: str) -> int: return _tts_cancel_epoch.get(conversation_id, 0) def _tts_object_ext(codec: str) -> str: c = (codec or "mp3").lower().lstrip(".") if c in ("wave",): return "wav" return c if c else "mp3" def _tts_codec_to_content_type(codec: str) -> str: c = (codec or "mp3").lower().lstrip(".") if c == "mp3": return "audio/mpeg" if c in ("wav", "wave"): return "audio/wav" return "application/octet-stream" async def _send_tts_audio( conversation_id: str, text: str, *, chunk_index: int, chunk_total: int, assistant_message_id: str | None, tts_epoch_start: int, ) -> str | None: """Synthesize TTS, upload to COS, append Redis, send TTS_AUDIO. Returns public URL or None.""" if not settings.enable_tts: return None if _tts_epoch_value(conversation_id) != tts_epoch_start: return None try: tts = get_tts_provider() audio_bytes = await tts.synthesize(text) if not audio_bytes: logger.warning( "TTS skipped: synthesize returned empty. Check TTS config in .env" ) return None if _tts_epoch_value(conversation_id) != tts_epoch_start: return None ext = _tts_object_ext(settings.tts_codec) content_type = _tts_codec_to_content_type(settings.tts_codec) storage = get_object_storage() key = f"conversations/{conversation_id}/tts/{uuid.uuid4().hex}.{ext}" public_url = storage.upload(key, audio_bytes, content_type) # 与 `tts_delivery.apply_presigned_tts_urls_to_messages` / 回忆录图片 presign 一致:下发可播 URL playback_url = storage.get_url(key, expires=TTS_PRESIGNED_EXPIRES_SEC) payload_data: Dict[str, Any] = { "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), "format": settings.tts_codec, "audio_url": playback_url, "index": chunk_index, "total": chunk_total, } if assistant_message_id: payload_data["assistant_message_id"] = assistant_message_id await manager.send_message( conversation_id, { "type": MessageType.TTS_AUDIO, "conversation_id": conversation_id, "data": payload_data, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return public_url except Exception as e: err_str = str(e) if "PkgExhausted" in err_str: logger.warning( "TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: {}", err_str[:100], ) else: logger.error("TTS synthesize failed: {}", e) return None # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── chat_orchestrator = ChatOrchestrator() background_runner = BackgroundTaskRunner() # ── 分段流状态 ────────────────────────────────────────────────── @dataclass class SegmentStreamState: """会话内分段处理状态(用于并行 ASR + 有序聚合)""" lock: asyncio.Lock = field(default_factory=asyncio.Lock) pending_indices: Set[int] = field(default_factory=set) processed_indices: Set[int] = field(default_factory=set) buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) consumed_index: int = -1 active_tasks: Set[asyncio.Task] = field(default_factory=set) listening_feedback_sent: bool = False listening_feedback_task: Optional[asyncio.Task] = None _segment_states: Dict[Tuple[str, str], SegmentStreamState] = {} def get_or_create_segment_state( conversation_id: str, voice_session_id: str, ) -> SegmentStreamState: state_key = (conversation_id, voice_session_id) if state_key not in _segment_states: _segment_states[state_key] = SegmentStreamState() return _segment_states[state_key] def register_segment_task( conversation_id: str, voice_session_id: str, task: asyncio.Task, ) -> None: state_key = (conversation_id, voice_session_id) state = get_or_create_segment_state(conversation_id, voice_session_id) state.active_tasks.add(task) def _cleanup(done_task: asyncio.Task) -> None: state.active_tasks.discard(done_task) if not state.active_tasks and conversation_id not in manager.active_connections: _segment_states.pop(state_key, None) if done_task.cancelled(): return exc = done_task.exception() if exc: logger.error( "分段处理任务异常 " f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}", exc_info=True, ) task.add_done_callback(_cleanup) def cleanup_segment_states(conversation_id: str) -> None: """断开连接后清理无活跃任务的分段状态""" stale_keys = [ key for key, state in _segment_states.items() if key[0] == conversation_id and not state.active_tasks ] for key in stale_keys: _segment_states.pop(key, None) # ── 工具函数 ──────────────────────────────────────────────────── def _utc_now() -> datetime: return datetime.now(timezone.utc) def _mark_conversation_active( conversation: Conversation, at: Optional[datetime] = None ) -> datetime: activity_time = at or _utc_now() conversation.last_message_at = activity_time return activity_time def _voice_session_id_from_client_segment_id( client_segment_id: Optional[str], ) -> Optional[str]: if not client_segment_id: return None session_id, separator, _ = client_segment_id.rpartition("-") if separator and session_id: return session_id return None def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str: """构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。""" return f"audio-segment:{voice_session_id}:{segment_index}" def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]: """从 audio_url 解析 voice_session_id 与 segment_index(audio-segment:{session_id}:{index})。""" prefix = "audio-segment:" if not audio_url or not audio_url.startswith(prefix): return None payload = audio_url[len(prefix) :] voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":") if not separator: return None try: sid = str(voice_session_id_raw).strip() if not sid: return None return (sid, int(segment_index_raw)) except ValueError: return None def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]: scope = _extract_segment_scope(audio_url) if scope: return scope[0] return None def _is_transcribe_failure(transcript_text: Optional[str]) -> bool: if not transcript_text: return True return transcript_text.startswith("转写失败") async def _find_existing_segment_by_index( db: AsyncSession, conversation_id: str, voice_session_id: str, segment_index: int, ) -> Optional[Segment]: segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index) stmt = ( select(Segment) .where( Segment.conversation_id == conversation_id, Segment.audio_url == segment_audio_url, ) .order_by(Segment.created_at.desc()) ) result = await db.execute(stmt) candidates = result.scalars().all() for item in candidates: if ( item.conversation_id == conversation_id and item.audio_url == segment_audio_url ): return item return None async def _get_persisted_contiguous_segment_index( db: AsyncSession, conversation_id: str, voice_session_id: str, ) -> int: """读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。""" stmt = select(Segment).where(Segment.conversation_id == conversation_id) result = await db.execute(stmt) candidates = result.scalars().all() persisted_indices: Set[int] = set() for item in candidates: if item.conversation_id != conversation_id: continue segment_scope = _extract_segment_scope(item.audio_url) if not segment_scope: continue item_voice_session_id, item_index = segment_scope if item_voice_session_id != voice_session_id: continue persisted_indices.add(item_index) contiguous_index = -1 while contiguous_index + 1 in persisted_indices: contiguous_index += 1 return contiguous_index # ── 过渡反馈 ──────────────────────────────────────────────────── LISTENING_FEEDBACK_DELAY_SEC = 5.0 LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。" async def _send_segment_transition_feedback( conversation_id: str, segment_index: int, ) -> None: """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。""" await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": LISTENING_FEEDBACK_TEXT, "transition": True, "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) async def _delayed_listening_feedback( conversation_id: str, voice_session_id: str, ) -> None: """录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。""" await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC) state = get_or_create_segment_state(conversation_id, voice_session_id) async with state.lock: if state.listening_feedback_sent: return state.listening_feedback_sent = True state.listening_feedback_task = None await _send_segment_transition_feedback(conversation_id, 0) # ── 长音频切片转写 ──────────────────────────────────────────── MAX_ASR_CHUNK_MS = 55_000 def _split_audio_bytes(audio_bytes: bytes, fmt: str) -> list[bytes]: """用 pydub 将长音频按 ≤55 s 切片,每片导出为 16 kHz mono WAV(腾讯 ASR 3 MB 限制内)。""" from pydub import AudioSegment as PydubSegment audio = PydubSegment.from_file(io.BytesIO(audio_bytes), format=fmt) duration_ms = len(audio) if duration_ms <= MAX_ASR_CHUNK_MS: return [audio_bytes] mono_16k = audio.set_frame_rate(16000).set_channels(1).set_sample_width(2) chunks: list[bytes] = [] for start in range(0, duration_ms, MAX_ASR_CHUNK_MS): chunk = mono_16k[start : start + MAX_ASR_CHUNK_MS] buf = io.BytesIO() chunk.export(buf, format="wav") chunks.append(buf.getvalue()) return chunks async def _transcribe_long_audio(audio_bytes: bytes, fmt: str = "m4a") -> str: """超过 55 s 的音频自动切片后并行 ASR;短音频直接转写。""" asr = get_asr_provider() try: chunks = await asyncio.to_thread(_split_audio_bytes, audio_bytes, fmt) except Exception as exc: logger.warning("pydub 切片失败 ({}), 回退到直接转写", exc) return await asr.transcribe(audio_bytes, format=fmt) if len(chunks) <= 1: return await asr.transcribe(audio_bytes, format=fmt) logger.info("长音频切片: {} 段", len(chunks)) results = await asyncio.gather( *[asr.transcribe(c, format="wav") for c in chunks], return_exceptions=True, ) texts: list[str] = [] for i, r in enumerate(results): if isinstance(r, BaseException): logger.warning("切片 {} 转写异常: {}", i, r) continue if r and not _is_transcribe_failure(r): texts.append(r) return "".join(texts) # ── 分段语音异步处理 ──────────────────────────────────────────── async def process_audio_segment( conversation_id: str, user_id: str, voice_session_id: str, segment_index: int, audio_base64: str, audio_duration: int, is_last: bool, ) -> None: """分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。""" state = get_or_create_segment_state(conversation_id, voice_session_id) logger.info( "process_audio_segment 开始: conversation_id={} voice_session_id={} " "segment_index={} is_last={} duration_s={} audio_b64_len={}", conversation_id, voice_session_id, segment_index, is_last, audio_duration, len(audio_base64 or ""), ) try: async with AsyncSessionLocal() as db: conversation = await db.get(Conversation, conversation_id) user = await db.get(User, user_id) if not conversation or conversation.deleted_at is not None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话不存在,分段处理已取消"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return if not user: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "用户不存在,分段处理已取消"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return async with state.lock: should_prime_state = ( state.consumed_index < 0 and not state.processed_indices and not state.buffered_transcripts ) if should_prime_state: persisted_contiguous_index = ( await _get_persisted_contiguous_segment_index( db=db, conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) if persisted_contiguous_index >= 0: async with state.lock: state.consumed_index = max( state.consumed_index, persisted_contiguous_index ) try: audio_bytes = base64.b64decode(audio_base64) except Exception: audio_bytes = b"" if not audio_bytes: logger.warning( "process_audio_segment: 解码后音频为空 conversation_id={} segment_index={}", conversation_id, segment_index, ) transcript_text = await _transcribe_long_audio(audio_bytes, fmt="m4a") await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": transcript_text or "", "audio_duration": audio_duration, "voice_session_id": voice_session_id, "segment_index": segment_index, "is_last": is_last, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if _is_transcribe_failure(transcript_text): detail = (transcript_text or "").strip() if detail.startswith("转写失败"): user_msg = f"分段 {segment_index} {detail}" elif not detail: user_msg = f"分段 {segment_index} 转写失败:未识别到内容(请检查后端 ASR 配置)" else: user_msg = f"分段 {segment_index} 转写失败:{detail[:400]}" await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": user_msg, "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return existing_segment = await _find_existing_segment_by_index( db=db, conversation_id=conversation_id, voice_session_id=voice_session_id, segment_index=segment_index, ) if existing_segment: async with state.lock: state.processed_indices.add(segment_index) logger.debug( "分段已存在,按幂等跳过: conversation_id={} voice_session_id={} " "segment_index={} segment_id={} transcript={}", conversation_id, voice_session_id, segment_index, existing_segment.id, existing_segment.user_input_text or "", ) return else: segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, user_input_text=transcript_text or "", audio_url=_build_segment_audio_url(voice_session_id, segment_index), audio_duration_seconds=audio_duration if audio_duration > 0 else None, processed=False, ) db.add(segment) user_message_timestamp = _mark_conversation_active(conversation) await db.commit() await db.refresh(segment) await background_runner.queue_message( conversation.user_id, segment.id, text_char_count=len((transcript_text or "").strip()), ) ready_segments: List[Tuple[int, str, Segment]] = [] async with state.lock: state.processed_indices.add(segment_index) state.buffered_transcripts[segment_index] = ( transcript_text or "", segment, ) next_index = state.consumed_index + 1 while next_index in state.buffered_transcripts: text, seg = state.buffered_transcripts.pop(next_index) ready_segments.append((next_index, text, seg)) state.consumed_index = next_index next_index += 1 for _, ordered_text, ordered_segment in ready_segments: await process_user_message( conversation_id=conversation_id, user_message=ordered_text, conversation=conversation, segment=ordered_segment, db=db, user=user, user_message_timestamp=ordered_segment.created_at or user_message_timestamp, ) except Exception as e: logger.error( f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}", exc_info=True, ) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": f"分段处理失败: {str(e)}", "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) finally: async with state.lock: state.pending_indices.discard(segment_index) # ── 用户消息处理 ──────────────────────────────────────────────── async def process_user_message( conversation_id: str, user_message: str, conversation: Conversation, segment: Segment, db: AsyncSession, user: User = None, user_message_timestamp: Optional[datetime] = None, ) -> None: """处理用户消息,生成 Agent 回应。由 ChatOrchestrator 路由到 ProfileAgent 或 InterviewAgent。""" store = ConversationHistoryStore(db) tts_urls: list[str] = [] try: logger.info( "process_user_message 开始: conversation_id={} segment_id={} user_chars={}", conversation_id, segment.id, len(user_message or ""), ) is_from_voice = bool(segment.audio_url) voice_session_id = _voice_session_id_from_audio_url(segment.audio_url) audio_dur = getattr(segment, "audio_duration_seconds", None) t_pipeline = time.perf_counter() turn = await chat_orchestrator.process_user_message( conversation_id=conversation_id, user_message=user_message, user=user, conversation=conversation, is_from_voice=is_from_voice, voice_session_id=voice_session_id, db=db, apply_extracted_profile_fn=apply_extracted_profile, get_missing_profile_fields_fn=get_missing_profile_fields, get_filled_profile_fields_fn=get_filled_profile_fields, user_message_timestamp=user_message_timestamp, audio_duration_seconds=audio_dur, ) if agent_summary_enabled(): logger.info( "pipeline.process_user_message duration_ms={:.2f} " "conversation_id={} segment_id={} user_msg_len={} " "response_segments={} skip_tts={}", (time.perf_counter() - t_pipeline) * 1000, conversation_id, segment.id, len(user_message or ""), len(turn.messages), turn.skip_tts, ) responses = turn.messages skip_tts = turn.skip_tts segment.agent_response = AI_RESPONSE_SEGMENT_JOIN.join(responses) _mark_conversation_active(conversation) ai_msg_id = await store.record_human_ai_turn( conversation_id=conversation_id, user_message=user_message, responses=responses, user_message_timestamp=user_message_timestamp, is_from_voice=is_from_voice, voice_session_id=voice_session_id, audio_duration_seconds=audio_dur, tts_audio_urls=None, segment_id=segment.id, ) if not ai_msg_id: logger.warning( "process_user_message: 无有效助手段落(responses 为空),conversation_id={} segment_id={}", conversation_id, segment.id, ) if conversation_id in manager.active_connections: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "未生成回复,请重试或稍后再试", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) return tts_epoch_start = _tts_epoch_value(conversation_id) n = len(responses) for i, response_text in enumerate(responses): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": response_text, "index": i, "total": n, "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) url = None if not skip_tts: if _tts_epoch_value(conversation_id) != tts_epoch_start: break url = await _send_tts_audio( conversation_id, response_text, chunk_index=i, chunk_total=n, assistant_message_id=ai_msg_id, tts_epoch_start=tts_epoch_start, ) if url: tts_urls.append(url) if _tts_epoch_value(conversation_id) != tts_epoch_start: break if i < n - 1: await asyncio.sleep(0.5) if tts_urls: await store.attach_ai_tts_audio_urls( conversation_id, tts_audio_urls=tts_urls, segment_id=segment.id, ) await db.execute( update(Segment) .where(Segment.id == segment.id) .values(tts_audio_urls=tts_urls) ) await db.commit() except Exception as e: if tts_urls: try: await store.attach_ai_tts_audio_urls( conversation_id, tts_audio_urls=tts_urls, segment_id=segment.id, ) await db.execute( update(Segment) .where(Segment.id == segment.id) .values(tts_audio_urls=tts_urls) ) await db.commit() except Exception as persist_error: logger.warning("补写 TTS 元数据失败: {}", persist_error) logger.error(f"处理用户消息失败: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": f"生成回应失败: {str(e)}"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") # ── 对话结束处理 ──────────────────────────────────────────────── async def process_conversation_segments( conversation_id: str, db: AsyncSession, quota_service: "QuotaService" ): """ 对话结束时:把本对话仍待 Phase1 的段落交给回忆录管线。 经 `BackgroundTaskRunner.flush_pending` 将内存防抖 batch 与当前查询到的 `topic_category IS NULL` 段 ID 合并、去重后**单次**提交 `process_memoir_phase1`, 并在 flush 末尾触发待叙事 Phase2 派发;避免会话结束路径与 debounce flush 双发 Phase1。 配额检查通过注入的 `quota_service` 完成,不直接 import quota 内部函数。 """ conversation = await db.get(Conversation, conversation_id) if not conversation or conversation.deleted_at is not None: return stmt = select(Segment).where( Segment.conversation_id == conversation_id, Segment.processed == False, Segment.topic_category.is_(None), ) result = await db.execute(stmt) segments = result.scalars().all() if not segments: await background_runner.flush_pending(conversation.user_id) return user = await db.get(User, conversation.user_id) if user: can_submit, _ = await quota_service.check_can_submit_organize( user.id, user.subscription_type ) if not can_submit: logger.info( f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}" ) await background_runner.flush_pending(conversation.user_id) return segment_ids = [seg.id for seg in segments] try: await background_runner.flush_pending( conversation.user_id, extra_segment_ids=segment_ids ) logger.info( "对话结束,合并批内 segment 与 DB 待分类段,单次提交 Phase1: " "conversation_id={} segments={}", conversation_id, len(segment_ids), ) except Exception as e: logger.error(f"提交 Celery 任务失败: {e}")