"""核心消息处理管道:Agent 调用、ASR 转写、分段有序聚合""" import asyncio import base64 from app.core.logging import get_logger import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple if TYPE_CHECKING: from app.features.quota.service import QuotaService from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.agents import ConversationAgent, MemoryAgent from app.agents.memoir_processor import BackgroundTaskRunner from app.agents.prompts.profile_prompts import format_user_profile_context from app.core.db import AsyncSessionLocal from app.features.conversation.models import Conversation, Segment from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import LEGACY_VOICE_SESSION_ID, MessageType from app.features.conversation.ws.profile_collector import ( apply_extracted_profile, get_filled_profile_fields, get_missing_profile_fields, ) from app.features.user.models import User from app.core.config import settings from app.core.dependencies import get_asr_provider, get_tts_provider from app.features.memoir.state_service import get_or_create_state logger = get_logger(__name__) async def _send_tts_audio(conversation_id: str, text: str) -> None: """Synthesize text to speech and send TTS_AUDIO if successful.""" try: tts = get_tts_provider() audio_bytes = await tts.synthesize(text) if not audio_bytes: logger.warning( "TTS skipped: synthesize returned empty. Check TTS config in .env" ) return await manager.send_message(conversation_id, { "type": MessageType.TTS_AUDIO, "conversation_id": conversation_id, "data": { "audio_base64": base64.b64encode(audio_bytes).decode("utf-8"), "format": settings.tts_codec, }, "timestamp": datetime.now(timezone.utc).isoformat(), }) except Exception as e: err_str = str(e) if "PkgExhausted" in err_str: logger.warning( "TTS skipped: 腾讯云语音合成资源包已用尽,请在控制台购买或开通后付费: %s", err_str[:100], ) else: logger.error("TTS synthesize failed: %s", e) # ── Agent 实例(从 ConnectionManager 移出) ───────────────────── conversation_agent = ConversationAgent() memory_agent = MemoryAgent() background_runner = BackgroundTaskRunner() # ── 分段流状态 ────────────────────────────────────────────────── @dataclass class SegmentStreamState: """会话内分段处理状态(用于并行 ASR + 有序聚合)""" lock: asyncio.Lock = field(default_factory=asyncio.Lock) pending_indices: Set[int] = field(default_factory=set) processed_indices: Set[int] = field(default_factory=set) buffered_transcripts: Dict[int, Tuple[str, Segment]] = field(default_factory=dict) consumed_index: int = -1 active_tasks: Set[asyncio.Task] = field(default_factory=set) listening_feedback_sent: bool = False listening_feedback_task: Optional[asyncio.Task] = None _segment_states: Dict[Tuple[str, str], SegmentStreamState] = {} def get_or_create_segment_state( conversation_id: str, voice_session_id: str, ) -> SegmentStreamState: state_key = (conversation_id, voice_session_id) if state_key not in _segment_states: _segment_states[state_key] = SegmentStreamState() return _segment_states[state_key] def register_segment_task( conversation_id: str, voice_session_id: str, task: asyncio.Task, ) -> None: state_key = (conversation_id, voice_session_id) state = get_or_create_segment_state(conversation_id, voice_session_id) state.active_tasks.add(task) def _cleanup(done_task: asyncio.Task) -> None: state.active_tasks.discard(done_task) if not state.active_tasks and conversation_id not in manager.active_connections: _segment_states.pop(state_key, None) if done_task.cancelled(): return exc = done_task.exception() if exc: logger.error( "分段处理任务异常 " f"(conversation_id={conversation_id}, voice_session_id={voice_session_id}): {exc}", exc_info=True, ) task.add_done_callback(_cleanup) def cleanup_segment_states(conversation_id: str) -> None: """断开连接后清理无活跃任务的分段状态""" stale_keys = [ key for key, state in _segment_states.items() if key[0] == conversation_id and not state.active_tasks ] for key in stale_keys: _segment_states.pop(key, None) # ── 工具函数 ──────────────────────────────────────────────────── def _utc_now() -> datetime: return datetime.now(timezone.utc) def _mark_conversation_active(conversation: Conversation, at: Optional[datetime] = None) -> datetime: activity_time = at or _utc_now() conversation.last_message_at = activity_time return activity_time def _normalize_voice_session_id(voice_session_id: Optional[str]) -> str: if voice_session_id: return str(voice_session_id) return LEGACY_VOICE_SESSION_ID def _voice_session_id_from_client_segment_id(client_segment_id: Optional[str]) -> Optional[str]: if not client_segment_id: return None session_id, separator, _ = client_segment_id.rpartition("-") if separator and session_id: return session_id return None def _build_segment_audio_url(voice_session_id: str, segment_index: int) -> str: """构建分段语音的幂等标识(conversation_id + voice_session_id + segment_index)。""" return f"audio-segment:{voice_session_id}:{segment_index}" def _extract_segment_scope(audio_url: Optional[str]) -> Optional[Tuple[str, int]]: """从 audio_url 中解析 voice_session_id 与 segment_index。兼容旧格式 audio-segment:{index}。""" prefix = "audio-segment:" if not audio_url or not audio_url.startswith(prefix): return None payload = audio_url[len(prefix):] voice_session_id_raw, separator, segment_index_raw = payload.rpartition(":") try: if separator: return (_normalize_voice_session_id(voice_session_id_raw), int(segment_index_raw)) return (LEGACY_VOICE_SESSION_ID, int(payload)) except ValueError: return None def _voice_session_id_from_audio_url(audio_url: Optional[str]) -> Optional[str]: scope = _extract_segment_scope(audio_url) if scope: return scope[0] return None def _is_transcribe_failure(transcript_text: Optional[str]) -> bool: if not transcript_text: return True return transcript_text.startswith("转写失败") async def _find_existing_segment_by_index( db: AsyncSession, conversation_id: str, voice_session_id: str, segment_index: int, ) -> Optional[Segment]: segment_audio_url = _build_segment_audio_url(voice_session_id, segment_index) stmt = select(Segment).where( Segment.conversation_id == conversation_id, Segment.audio_url == segment_audio_url, ).order_by(Segment.created_at.desc()) result = await db.execute(stmt) candidates = result.scalars().all() for item in candidates: if item.conversation_id == conversation_id and item.audio_url == segment_audio_url: return item return None async def _get_persisted_contiguous_segment_index( db: AsyncSession, conversation_id: str, voice_session_id: str, ) -> int: """读取数据库中当前 voice session 已连续落库的最大 segment_index,用于重连恢复。""" stmt = select(Segment).where(Segment.conversation_id == conversation_id) result = await db.execute(stmt) candidates = result.scalars().all() persisted_indices: Set[int] = set() for item in candidates: if item.conversation_id != conversation_id: continue segment_scope = _extract_segment_scope(item.audio_url) if not segment_scope: continue item_voice_session_id, item_index = segment_scope if item_voice_session_id != voice_session_id: continue persisted_indices.add(item_index) contiguous_index = -1 while contiguous_index + 1 in persisted_indices: contiguous_index += 1 return contiguous_index # ── 过渡反馈 ──────────────────────────────────────────────────── LISTENING_FEEDBACK_DELAY_SEC = 5.0 LISTENING_FEEDBACK_TEXT = "我在认真听,你继续说,我会边听边整理重点。" async def _send_segment_transition_feedback( conversation_id: str, segment_index: int, ) -> None: """发送一次「我在认真听」陪伴式过渡反馈(由延迟任务调用)。""" await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": LISTENING_FEEDBACK_TEXT, "transition": True, "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }) async def _delayed_listening_feedback( conversation_id: str, voice_session_id: str, ) -> None: """录音开始后延迟 5 秒发送一次「我在认真听」,本会话内只发一次;若用户已结束录音则不再发送。""" await asyncio.sleep(LISTENING_FEEDBACK_DELAY_SEC) state = get_or_create_segment_state(conversation_id, voice_session_id) async with state.lock: if state.listening_feedback_sent: return state.listening_feedback_sent = True state.listening_feedback_task = None await _send_segment_transition_feedback(conversation_id, 0) # ── 分段语音异步处理 ──────────────────────────────────────────── async def process_audio_segment( conversation_id: str, user_id: str, voice_session_id: str, segment_index: int, audio_base64: str, audio_duration: int, is_last: bool, ) -> None: """分段语音的异步处理:并行 ASR + 幂等落库 + 有序聚合触发 Agent。""" state = get_or_create_segment_state(conversation_id, voice_session_id) try: async with AsyncSessionLocal() as db: conversation = await db.get(Conversation, conversation_id) user = await db.get(User, user_id) if not conversation: await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话不存在,分段处理已取消"}, "timestamp": datetime.now(timezone.utc).isoformat(), }) return if not user: await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": {"message": "用户不存在,分段处理已取消"}, "timestamp": datetime.now(timezone.utc).isoformat(), }) return async with state.lock: should_prime_state = ( state.consumed_index < 0 and not state.processed_indices and not state.buffered_transcripts ) if should_prime_state: persisted_contiguous_index = await _get_persisted_contiguous_segment_index( db=db, conversation_id=conversation_id, voice_session_id=voice_session_id, ) if persisted_contiguous_index >= 0: async with state.lock: state.consumed_index = max(state.consumed_index, persisted_contiguous_index) try: audio_bytes = base64.b64decode(audio_base64) except Exception: audio_bytes = b"" transcript_text = await get_asr_provider().transcribe( audio_bytes, format="m4a" ) await manager.send_message(conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": transcript_text or "", "audio_duration": audio_duration, "voice_session_id": voice_session_id, "segment_index": segment_index, "is_last": is_last, }, "timestamp": datetime.now(timezone.utc).isoformat(), }) if _is_transcribe_failure(transcript_text): await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": { "message": f"分段 {segment_index} 转写失败,请重试该片段", "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }) return existing_segment = await _find_existing_segment_by_index( db=db, conversation_id=conversation_id, voice_session_id=voice_session_id, segment_index=segment_index, ) if existing_segment: async with state.lock: state.processed_indices.add(segment_index) logger.info( "分段已存在,按幂等处理跳过: " f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}" ) return else: segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, transcript_text=transcript_text or "", audio_url=_build_segment_audio_url(voice_session_id, segment_index), processed=False, ) db.add(segment) user_message_timestamp = _mark_conversation_active(conversation) await db.commit() await db.refresh(segment) await background_runner.queue_message(conversation.user_id, segment.id) ready_segments: List[Tuple[int, str, Segment]] = [] async with state.lock: state.processed_indices.add(segment_index) state.buffered_transcripts[segment_index] = (transcript_text or "", segment) next_index = state.consumed_index + 1 while next_index in state.buffered_transcripts: text, seg = state.buffered_transcripts.pop(next_index) ready_segments.append((next_index, text, seg)) state.consumed_index = next_index next_index += 1 for _, ordered_text, ordered_segment in ready_segments: await process_user_message( conversation_id=conversation_id, user_message=ordered_text, conversation=conversation, segment=ordered_segment, db=db, user=user, user_message_timestamp=ordered_segment.created_at or user_message_timestamp, ) except Exception as e: logger.error( f"处理语音分段失败: conversation_id={conversation_id}, segment_index={segment_index}, error={e}", exc_info=True, ) await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": { "message": f"分段处理失败: {str(e)}", "segment_index": segment_index, }, "timestamp": datetime.now(timezone.utc).isoformat(), }) finally: async with state.lock: state.pending_indices.discard(segment_index) # ── 用户消息处理 ──────────────────────────────────────────────── async def process_user_message( conversation_id: str, user_message: str, conversation: Conversation, segment: Segment, db: AsyncSession, user: User = None, user_message_timestamp: Optional[datetime] = None, ) -> None: """处理用户消息,生成 Agent 回应。支持资料收集模式和正式访谈模式。""" agent = conversation_agent if user: missing = get_missing_profile_fields(user) if missing: try: extracted = await agent.extract_profile_from_message( user_message, missing, conversation_id=conversation_id ) if extracted: await apply_extracted_profile(user, extracted, db) remaining = get_missing_profile_fields(user) filled = get_filled_profile_fields(user) is_from_voice = bool(segment.audio_url) responses = await agent.generate_profile_followup( conversation_id=conversation_id, user_message=user_message, missing_fields=remaining, filled_fields=filled, nickname=user.nickname or "", is_from_voice=is_from_voice, voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), user_message_timestamp=user_message_timestamp, ) segment.agent_response = "\n\n".join(responses) _mark_conversation_active(conversation) await db.commit() for i, response_text in enumerate(responses): await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": {"text": response_text, "index": i, "total": len(responses)}, "timestamp": datetime.now(timezone.utc).isoformat(), }) await _send_tts_audio(conversation_id, response_text) if i < len(responses) - 1: await asyncio.sleep(0.5) return except Exception as e: logger.error(f"资料收集处理失败: {e}", exc_info=True) state = await get_or_create_state(conversation.user_id, db) if conversation.conversation_stage != state.current_stage: conversation.conversation_stage = state.current_stage await db.commit() stmt_segments = select(Segment).where( Segment.conversation_id == conversation_id ).order_by(Segment.created_at) result_segments = await db.execute(stmt_segments) previous_segments = result_segments.scalars().all() covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] user_profile_context = "" if user: user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, ) try: is_from_voice = bool(segment.audio_url) responses = await agent.generate_response_with_state( conversation_id=conversation_id, user_message=user_message, memoir_state=state, user_profile_context=user_profile_context, is_from_voice=is_from_voice, voice_session_id=_voice_session_id_from_audio_url(segment.audio_url), user_message_timestamp=user_message_timestamp, ) segment.agent_response = "\n\n".join(responses) _mark_conversation_active(conversation) await db.commit() for i, response_text in enumerate(responses): await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": {"text": response_text, "index": i, "total": len(responses)}, "timestamp": datetime.now(timezone.utc).isoformat(), }) await _send_tts_audio(conversation_id, response_text) if i < len(responses) - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"处理用户消息失败: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": {"message": f"生成回应失败: {str(e)}"}, "timestamp": datetime.now(timezone.utc).isoformat(), }) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") # ── 对话结束处理 ──────────────────────────────────────────────── async def process_conversation_segments( conversation_id: str, db: AsyncSession, quota_service: "QuotaService" ): """ 处理对话段落,生成章节(对话结束时调用) 注意:大部分处理已通过 Celery 任务增量完成 这里立即提交所有待处理的段落到 Celery 配额检查通过注入的 quota_service 完成,不直接 import quota 内部函数。 """ conversation = await db.get(Conversation, conversation_id) if not conversation: return stmt = select(Segment).where( Segment.conversation_id == conversation_id, Segment.processed == False, ) result = await db.execute(stmt) segments = result.scalars().all() if not segments: await background_runner.flush_pending(conversation.user_id) return user = await db.get(User, conversation.user_id) if user: can_submit, _ = await quota_service.check_can_submit_organize( user.id, user.subscription_type ) if not can_submit: logger.info( f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}" ) await background_runner.flush_pending(conversation.user_id) return segment_ids = [seg.id for seg in segments] try: from app.tasks.memoir_tasks import process_memoir_segments process_memoir_segments.delay(conversation.user_id, segment_ids) logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}") except Exception as e: logger.error(f"提交 Celery 任务失败: {e}") await background_runner.flush_pending(conversation.user_id)