""" WebSocket 路由:实时对话通信 仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块 """ import asyncio import base64 from datetime import datetime, timezone from fastapi import WebSocket, WebSocketDisconnect, status from starlette.websockets import WebSocketState from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context from app.agents.stage_constants import STAGE_TO_ORDER from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider from app.core.logging import get_logger from app.core.security import verify_token from app.features.conversation.history_store import ConversationHistoryStore from app.features.conversation.service import ConversationService from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.pipeline import ( _delayed_listening_feedback, _voice_session_id_from_client_segment_id, background_runner, bump_tts_cancel_epoch, chat_orchestrator, cleanup_segment_states, get_or_create_segment_state, process_audio_segment, process_conversation_segments, process_user_message, register_segment_task, ) from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota from app.features.memoir.state_service import get_or_create_state from app.features.quota.service import QuotaService from app.features.user.models import User logger = get_logger(__name__) async def websocket_endpoint( websocket: WebSocket, conversation_id: str, ): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID """ token = websocket.query_params.get("token") if not token: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌" ) return payload = verify_token(token) if not payload or payload.get("type") != "access": await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌" ) return user_id = payload.get("sub") if not user_id: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容" ) return async with AsyncSessionLocal() as db: user = await db.get(User, user_id) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" ) return await manager.connect(websocket, conversation_id) logger.info( "WebSocket 已连接 conversation_id={} user_id={}", conversation_id, user_id, ) quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) try: await manager.send_message( conversation_id, { "type": MessageType.CONNECT, "conversation_id": conversation_id, "data": {"status": "connected"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) conversation, ws_conn_err = await conversation_service.ensure_ws_connection( conversation_id, user_id ) if ws_conn_err == "forbidden": try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "无权访问此对话"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话" ) return if ws_conn_err == "deleted": try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话已删除"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除" ) return # 冷启动对齐 conversation_stage 与 MemoirState.current_stage; # 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。 memoir_state = await get_or_create_state(user_id, db) ms = (memoir_state.current_stage or "").strip() cs = (conversation.conversation_stage or "").strip() if ms: if not cs: conversation.conversation_stage = ms elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1): conversation.conversation_stage = ms await db.commit() await db.refresh(conversation) history = await conversation_service.ensure_redis_history_from_db( conversation_id ) if not history: missing_profile = get_missing_profile_fields(user) if missing_profile: try: greetings = await chat_orchestrator.generate_profile_greeting( conversation_id=conversation_id, missing_fields=missing_profile, nickname=user.nickname or "", ) ai_msg_id = await ConversationHistoryStore( db ).record_ai_only_turn(conversation_id, greetings) if ai_msg_id: ng = len(greetings) for i, text in enumerate(greetings): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": ng, "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) if i < ng - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) else: try: state = memoir_state user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, ) opening_messages = ( await chat_orchestrator.generate_opening_message( conversation_id=conversation_id, memoir_state=state, user_profile_context=user_profile_context, background_voice=infer_background_voice( user.occupation ), occupation=user.occupation or "", ) ) ai_msg_id = await ConversationHistoryStore( db ).record_ai_only_turn(conversation_id, opening_messages) if ai_msg_id: no = len(opening_messages) for i, text in enumerate(opening_messages): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": no, "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) if i < no - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送空对话开场白失败: {e}", exc_info=True) while True: try: if websocket.application_state != WebSocketState.CONNECTED: logger.debug( "WebSocket 已非连接状态,退出循环: conversation_id={}", conversation_id, ) break message = await websocket.receive_json() msg_type = message.get("type") if msg_type == MessageType.AUDIO_SEGMENT: _d = message.get("data") or {} logger.info( "WebSocket 收到消息 type={} conversation_id={} " "segment_index={} is_last={} duration_s={} audio_b64_len={}", msg_type, conversation_id, _d.get("segment_index"), bool(_d.get("is_last")), int(_d.get("duration") or 0), len(_d.get("audio_base64") or ""), ) elif msg_type is not None: logger.info( "WebSocket 收到消息 type={} conversation_id={}", msg_type, conversation_id, ) else: logger.warning( "WebSocket 收到缺少 type 的 JSON conversation_id={}", conversation_id, ) if msg_type == MessageType.TEXT: text_message = message.get("data", {}).get("text", "") if text_message: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue segment = await conversation_service.create_user_segment( conversation, user_id, text_message, ) user_message_timestamp = conversation.last_message_at await background_runner.queue_message( conversation.user_id, segment.id, text_char_count=len(text_message.strip()), ) await process_user_message( conversation_id=conversation_id, user_message=text_message, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) raw_vs = data.get("voice_session_id") if not raw_vs or not str(raw_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 voice_session_id"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(raw_vs).strip() segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) async with segment_state.lock: if ( segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done() ): continue if segment_state.listening_feedback_sent: continue delayed_task = asyncio.create_task( _delayed_listening_feedback( conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) segment_state.listening_feedback_task = delayed_task elif msg_type == MessageType.AUDIO_SEGMENT: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") segment_index_raw = data.get("segment_index") resolved_vs = data.get("voice_session_id") or ( _voice_session_id_from_client_segment_id( data.get("client_segment_id") ) ) if not resolved_vs or not str(resolved_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "缺少 voice_session_id 或有效的 client_segment_id" }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(resolved_vs).strip() is_last = bool(data.get("is_last", False)) audio_duration = int(data.get("duration", 0) or 0) if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index_raw is None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 segment_index"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: segment_index = int(segment_index_raw) except (TypeError, ValueError): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 必须为整数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index < 0: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 不能为负数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) should_process = False async with segment_state.lock: already_seen = ( segment_index in segment_state.pending_indices or segment_index in segment_state.processed_indices or segment_index <= segment_state.consumed_index ) if not already_seen: segment_state.pending_indices.add(segment_index) should_process = True if not should_process: logger.debug( "收到重复分段,跳过: conversation_id={} voice_session_id={} " "segment_index={} audio_b64_len={} duration={}", conversation_id, voice_session_id, segment_index, len(audio_base64 or ""), audio_duration, ) continue if is_last: async with segment_state.lock: t = segment_state.listening_feedback_task segment_state.listening_feedback_task = None if t is not None and not t.done(): t.cancel() task = asyncio.create_task( process_audio_segment( conversation_id=conversation_id, user_id=user_id, voice_session_id=voice_session_id, segment_index=segment_index, audio_base64=audio_base64, audio_duration=audio_duration, is_last=is_last, ) ) register_segment_task(conversation_id, voice_session_id, task) elif msg_type == MessageType.AUDIO_MESSAGE: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") audio_duration = data.get("duration", 0) if audio_base64: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue logger.debug( "收到音频消息: conversation_id={} duration_s={}", conversation_id, audio_duration, ) try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") logger.debug( "ASR 转写完成: conversation_id={} chars={}", conversation_id, len(asr_text or ""), ) logger.debug( "ASR 转写全文: conversation_id={} text={}", conversation_id, asr_text, ) await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": asr_text, "audio_duration": audio_duration, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) try: ads = int(audio_duration) except (TypeError, ValueError): ads = 0 segment = ( await conversation_service.create_user_segment( conversation, user_id, asr_text, audio_url=f"audio:{audio_duration}s", audio_duration_seconds=ads if ads > 0 else None, ) ) user_message_timestamp = conversation.last_message_at await background_runner.queue_message( conversation.user_id, segment.id, text_char_count=len((asr_text or "").strip()), ) if asr_text and not asr_text.startswith("转写失败"): await process_user_message( conversation_id=conversation_id, user_message=asr_text, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) else: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音转写失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as e: logger.error(f"处理音频消息失败: {e}", exc_info=True) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音处理失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) elif msg_type == MessageType.TRANSCRIBE_ONLY: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": {"text": asr_text or ""}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as e: logger.error(f"仅转写失败: {e}", exc_info=True) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "语音转写失败,请重试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) elif msg_type == MessageType.TTS_CANCEL: bump_tts_cancel_epoch(conversation_id) elif msg_type == MessageType.END_CONVERSATION: await conversation_service.end(conversation_id, user_id) await process_conversation_segments( conversation_id, db, quota_service ) await manager.send_message( conversation_id, { "type": MessageType.END_CONVERSATION, "conversation_id": conversation_id, "data": {"status": "ended"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) break elif msg_type == MessageType.PING: await manager.send_message( conversation_id, { "type": MessageType.PONG, "conversation_id": conversation_id, "data": {}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) else: if msg_type is not None: logger.warning( "WebSocket 未识别的消息 type={} conversation_id={}", msg_type, conversation_id, ) except RuntimeError as e: error_msg = str(e) if ( "disconnect" in error_msg.lower() or 'Cannot call "receive"' in error_msg or "accept" in error_msg.lower() and "not connected" in error_msg.lower() ): logger.debug( "WebSocket 连接已断开或未就绪: conversation_id={} error={}", conversation_id, error_msg, ) break else: logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "处理失败,请重试"}, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect as disc: logger.info( "WebSocket 断开连接(收消息循环): conversation_id={} code={}", conversation_id, getattr(disc, "code", None), ) break except Exception as e: logger.error(f"处理消息时发生错误: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "处理失败,请重试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect as disc: logger.info( "WebSocket 断开连接: conversation_id={} code={}", conversation_id, getattr(disc, "code", None), ) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) except Exception as e: logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) finally: await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id)