""" WebSocket 路由:实时对话通信 仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块 """ import asyncio import base64 import uuid from datetime import datetime, timezone from fastapi import WebSocket, WebSocketDisconnect, status from starlette.websockets import WebSocketState from app.agents.chat.prompts_profile import format_user_profile_context from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider from app.core.logging import get_logger from app.core.security import verify_token from app.features.conversation.history_store import ConversationHistoryStore from app.features.conversation.models import Conversation, Segment from app.features.conversation.service import ConversationService from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.pipeline import ( _delayed_listening_feedback, _mark_conversation_active, _voice_session_id_from_client_segment_id, background_runner, bump_tts_cancel_epoch, chat_orchestrator, cleanup_segment_states, get_or_create_segment_state, process_audio_segment, process_conversation_segments, process_user_message, register_segment_task, ) from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota from app.features.memoir.state_service import get_or_create_state from app.features.quota.service import QuotaService from app.features.user.models import User logger = get_logger(__name__) async def websocket_endpoint( websocket: WebSocket, conversation_id: str, ): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID """ token = websocket.query_params.get("token") if not token: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌" ) return payload = verify_token(token) if not payload or payload.get("type") != "access": await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌" ) return user_id = payload.get("sub") if not user_id: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容" ) return async with AsyncSessionLocal() as db: user = await db.get(User, user_id) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" ) return await manager.connect(websocket, conversation_id) quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) try: await manager.send_message( conversation_id, { "type": MessageType.CONNECT, "conversation_id": conversation_id, "data": {"status": "connected"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) conversation = await db.get(Conversation, conversation_id) if not conversation: conversation = Conversation( id=conversation_id, user_id=user_id, started_at=datetime.now(timezone.utc), status="active", ) db.add(conversation) await db.commit() else: if conversation.user_id != user_id: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "无权访问此对话"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话" ) return if conversation.deleted_at is not None: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话已删除"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除" ) return history = await conversation_service.ensure_redis_history_from_db( conversation_id ) if not history: missing_profile = get_missing_profile_fields(user) if missing_profile: try: greetings = await chat_orchestrator.generate_profile_greeting( conversation_id=conversation_id, missing_fields=missing_profile, nickname=user.nickname or "", ) await ConversationHistoryStore(db).record_ai_only_turn( conversation_id, greetings ) for i, text in enumerate(greetings): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": len(greetings), }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < len(greetings) - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) else: try: state = await get_or_create_state(user_id, db) user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, ) opening_messages = ( await chat_orchestrator.generate_opening_message( conversation_id=conversation_id, memoir_state=state, user_profile_context=user_profile_context, ) ) await ConversationHistoryStore(db).record_ai_only_turn( conversation_id, opening_messages ) for i, text in enumerate(opening_messages): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": len(opening_messages), }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < len(opening_messages) - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送空对话开场白失败: {e}", exc_info=True) while True: try: if websocket.application_state != WebSocketState.CONNECTED: logger.debug( "WebSocket 已非连接状态,退出循环: conversation_id={}", conversation_id, ) break message = await websocket.receive_json() msg_type = message.get("type") if msg_type == MessageType.TEXT: text_message = message.get("data", {}).get("text", "") if text_message: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, user_input_text=text_message, processed=False, ) db.add(segment) user_message_timestamp = _mark_conversation_active( conversation ) await db.commit() await db.refresh(segment) await background_runner.queue_message( conversation.user_id, segment.id ) await process_user_message( conversation_id=conversation_id, user_message=text_message, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) raw_vs = data.get("voice_session_id") if not raw_vs or not str(raw_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 voice_session_id"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(raw_vs).strip() segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) async with segment_state.lock: if ( segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done() ): continue if segment_state.listening_feedback_sent: continue delayed_task = asyncio.create_task( _delayed_listening_feedback( conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) segment_state.listening_feedback_task = delayed_task elif msg_type == MessageType.AUDIO_SEGMENT: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") segment_index_raw = data.get("segment_index") resolved_vs = data.get("voice_session_id") or ( _voice_session_id_from_client_segment_id( data.get("client_segment_id") ) ) if not resolved_vs or not str(resolved_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "缺少 voice_session_id 或有效的 client_segment_id" }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(resolved_vs).strip() is_last = bool(data.get("is_last", False)) audio_duration = int(data.get("duration", 0) or 0) if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index_raw is None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 segment_index"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: segment_index = int(segment_index_raw) except (TypeError, ValueError): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 必须为整数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index < 0: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 不能为负数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) should_process = False async with segment_state.lock: already_seen = ( segment_index in segment_state.pending_indices or segment_index in segment_state.processed_indices or segment_index <= segment_state.consumed_index ) if not already_seen: segment_state.pending_indices.add(segment_index) should_process = True if not should_process: logger.debug( "收到重复分段,跳过: conversation_id={} voice_session_id={} " "segment_index={} audio_b64_len={} duration={}", conversation_id, voice_session_id, segment_index, len(audio_base64 or ""), audio_duration, ) continue if is_last: async with segment_state.lock: t = segment_state.listening_feedback_task segment_state.listening_feedback_task = None if t is not None and not t.done(): t.cancel() task = asyncio.create_task( process_audio_segment( conversation_id=conversation_id, user_id=user_id, voice_session_id=voice_session_id, segment_index=segment_index, audio_base64=audio_base64, audio_duration=audio_duration, is_last=is_last, ) ) register_segment_task(conversation_id, voice_session_id, task) elif msg_type == MessageType.AUDIO_MESSAGE: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") audio_duration = data.get("duration", 0) if audio_base64: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue logger.debug( "收到音频消息: conversation_id={} duration_s={}", conversation_id, audio_duration, ) try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") logger.debug( "ASR 转写完成: conversation_id={} chars={}", conversation_id, len(asr_text or ""), ) logger.debug( "ASR 转写全文: conversation_id={} text={}", conversation_id, asr_text, ) await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": asr_text, "audio_duration": audio_duration, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) try: ads = int(audio_duration) except (TypeError, ValueError): ads = 0 segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, user_input_text=asr_text, audio_url=f"audio:{audio_duration}s", audio_duration_seconds=ads if ads > 0 else None, processed=False, ) db.add(segment) user_message_timestamp = _mark_conversation_active( conversation ) await db.commit() await db.refresh(segment) await background_runner.queue_message( conversation.user_id, segment.id ) if asr_text and not asr_text.startswith("转写失败"): await process_user_message( conversation_id=conversation_id, user_message=asr_text, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) else: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音转写失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as e: logger.error(f"处理音频消息失败: {e}", exc_info=True) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": f"处理音频消息失败: {str(e)}" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) elif msg_type == MessageType.TRANSCRIBE_ONLY: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": {"text": asr_text or ""}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as e: logger.error(f"仅转写失败: {e}", exc_info=True) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": f"转写失败: {str(e)}"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) elif msg_type == MessageType.TTS_CANCEL: bump_tts_cancel_epoch(conversation_id) elif msg_type == MessageType.END_CONVERSATION: conversation.status = "ended" conversation.ended_at = datetime.now(timezone.utc) await db.commit() await process_conversation_segments( conversation_id, db, quota_service ) await manager.send_message( conversation_id, { "type": MessageType.END_CONVERSATION, "conversation_id": conversation_id, "data": {"status": "ended"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) break except RuntimeError as e: error_msg = str(e) if ( "disconnect" in error_msg.lower() or 'Cannot call "receive"' in error_msg or "accept" in error_msg.lower() and "not connected" in error_msg.lower() ): logger.debug( "WebSocket 连接已断开或未就绪: conversation_id={} error={}", conversation_id, error_msg, ) break else: logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": str(e)}, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect: logger.debug( "WebSocket 断开连接: conversation_id={}", conversation_id ) break except Exception as e: logger.error(f"处理消息时发生错误: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": str(e)}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect: logger.debug("WebSocket 断开连接: conversation_id={}", conversation_id) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) except Exception as e: logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) finally: await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id)