""" WebSocket 路由:实时对话通信 仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块 """ import asyncio import base64 import uuid from datetime import datetime, timezone from fastapi import WebSocket, WebSocketDisconnect, status from starlette.websockets import WebSocketState from app.agents.chat.prompts_profile import format_user_profile_context from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider from app.core.logging import get_logger from app.core.security import verify_token from app.features.conversation.models import Conversation, Segment from app.features.conversation.service import ConversationService from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.pipeline import ( SegmentStreamState, # noqa: F401 — re-export for test backward compat _delayed_listening_feedback, _mark_conversation_active, _normalize_voice_session_id, _voice_session_id_from_client_segment_id, background_runner, cleanup_segment_states, conversation_agent, get_or_create_segment_state, process_audio_segment, process_conversation_segments, process_user_message, register_segment_task, ) from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota from app.features.memoir.state_service import get_or_create_state from app.features.quota.service import QuotaService from app.features.user.models import User logger = get_logger(__name__) async def websocket_endpoint( websocket: WebSocket, conversation_id: str, ): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID """ token = websocket.query_params.get("token") if not token: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌" ) return payload = verify_token(token) if not payload or payload.get("type") != "access": await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌" ) return user_id = payload.get("sub") if not user_id: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容" ) return async with AsyncSessionLocal() as db: user = await db.get(User, user_id) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" ) return await manager.connect(websocket, conversation_id) quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) try: await manager.send_message( conversation_id, { "type": MessageType.CONNECT, "conversation_id": conversation_id, "data": {"status": "connected"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) conversation = await db.get(Conversation, conversation_id) if not conversation: conversation = Conversation( id=conversation_id, user_id=user_id, started_at=datetime.now(timezone.utc), status="active", ) db.add(conversation) await db.commit() else: if conversation.user_id != user_id: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "无权访问此对话"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话" ) return if conversation.deleted_at is not None: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话已删除"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除" ) return history = await conversation_service.ensure_redis_history_from_segments( conversation_id ) if not history: missing_profile = get_missing_profile_fields(user) if missing_profile: try: greetings = await conversation_agent.generate_profile_greeting( conversation_id=conversation_id, missing_fields=missing_profile, nickname=user.nickname or "", ) for i, text in enumerate(greetings): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": len(greetings), }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < len(greetings) - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送资料收集开场白失败: {e}", exc_info=True) else: try: state = await get_or_create_state(user_id, db) user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, ) opening_messages = ( await conversation_agent.generate_opening_message( conversation_id=conversation_id, memoir_state=state, user_profile_context=user_profile_context, ) ) for i, text in enumerate(opening_messages): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": len(opening_messages), }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < len(opening_messages) - 1: await asyncio.sleep(0.5) except Exception as e: logger.error(f"发送空对话开场白失败: {e}", exc_info=True) while True: try: if websocket.application_state != WebSocketState.CONNECTED: logger.info( f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}" ) break message = await websocket.receive_json() msg_type = message.get("type") if msg_type == MessageType.TEXT: text_message = message.get("data", {}).get("text", "") if text_message: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, transcript_text=text_message, processed=False, ) db.add(segment) user_message_timestamp = _mark_conversation_active( conversation ) await db.commit() await db.refresh(segment) await background_runner.queue_message( conversation.user_id, segment.id ) await process_user_message( conversation_id=conversation_id, user_message=text_message, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) voice_session_id = _normalize_voice_session_id( data.get("voice_session_id") ) segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) async with segment_state.lock: if ( segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done() ): continue if segment_state.listening_feedback_sent: continue delayed_task = asyncio.create_task( _delayed_listening_feedback( conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) segment_state.listening_feedback_task = delayed_task elif msg_type == MessageType.AUDIO_SEGMENT: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") segment_index_raw = data.get("segment_index") voice_session_id = _normalize_voice_session_id( data.get("voice_session_id") or _voice_session_id_from_client_segment_id( data.get("client_segment_id") ) ) is_last = bool(data.get("is_last", False)) audio_duration = int(data.get("duration", 0) or 0) if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index_raw is None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 segment_index"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: segment_index = int(segment_index_raw) except (TypeError, ValueError): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 必须为整数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index < 0: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 不能为负数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) should_process = False async with segment_state.lock: already_seen = ( segment_index in segment_state.pending_indices or segment_index in segment_state.processed_indices or segment_index <= segment_state.consumed_index ) if not already_seen: segment_state.pending_indices.add(segment_index) should_process = True if not should_process: logger.info( "收到重复分段,跳过处理: " f"conversation_id={conversation_id}, voice_session_id={voice_session_id}, segment_index={segment_index}" ) continue if is_last: async with segment_state.lock: t = segment_state.listening_feedback_task segment_state.listening_feedback_task = None if t is not None and not t.done(): t.cancel() task = asyncio.create_task( process_audio_segment( conversation_id=conversation_id, user_id=user_id, voice_session_id=voice_session_id, segment_index=segment_index, audio_base64=audio_base64, audio_duration=audio_duration, is_last=is_last, ) ) register_segment_task(conversation_id, voice_session_id, task) elif msg_type == MessageType.AUDIO_MESSAGE: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") audio_duration = data.get("duration", 0) if audio_base64: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue logger.info(f"收到音频消息,时长: {audio_duration}s") try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) transcript_text = await asr.transcribe( audio_bytes, "m4a" ) logger.info("ASR 转写结果: %s", transcript_text) await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": transcript_text, "audio_duration": audio_duration, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, transcript_text=transcript_text, audio_url=f"audio:{audio_duration}s", processed=False, ) db.add(segment) user_message_timestamp = _mark_conversation_active( conversation ) await db.commit() await db.refresh(segment) await background_runner.queue_message( conversation.user_id, segment.id ) if transcript_text and not transcript_text.startswith( "转写失败" ): await process_user_message( conversation_id=conversation_id, user_message=transcript_text, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) else: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音转写失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as e: logger.error(f"处理音频消息失败: {e}", exc_info=True) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": f"处理音频消息失败: {str(e)}" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) elif msg_type == MessageType.TRANSCRIBE_ONLY: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) transcript_text = await asr.transcribe(audio_bytes, "m4a") await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": {"text": transcript_text or ""}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as e: logger.error(f"仅转写失败: {e}", exc_info=True) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": f"转写失败: {str(e)}"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) elif msg_type == MessageType.END_CONVERSATION: conversation.status = "ended" conversation.ended_at = datetime.now(timezone.utc) await db.commit() await process_conversation_segments( conversation_id, db, quota_service ) await manager.send_message( conversation_id, { "type": MessageType.END_CONVERSATION, "conversation_id": conversation_id, "data": {"status": "ended"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) break except RuntimeError as e: error_msg = str(e) if ( "disconnect" in error_msg.lower() or 'Cannot call "receive"' in error_msg or "accept" in error_msg.lower() and "not connected" in error_msg.lower() ): logger.info( f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}" ) break else: logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": str(e)}, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect: logger.info( f"WebSocket 断开连接: conversation_id={conversation_id}" ) break except Exception as e: logger.error(f"处理消息时发生错误: {e}", exc_info=True) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": str(e)}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning(f"发送错误消息失败: {send_error}") break except WebSocketDisconnect: logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}") await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) except Exception as e: logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) finally: await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id)