""" WebSocket 路由:实时对话通信 仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块 """ import asyncio import base64 from datetime import datetime, timezone from fastapi import WebSocket, WebSocketDisconnect, status from starlette.websockets import WebSocketState from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_profile import format_user_profile_context from app.core.agent_logging import log_asr_transcript_result from app.core.config import settings from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider from app.core.logging import get_logger from app.core.security import verify_token from app.features.conversation.service import ConversationService from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.pipeline import ( _delayed_listening_feedback, _voice_session_id_from_client_segment_id, bump_tts_cancel_epoch, chat_orchestrator, cleanup_segment_states, get_or_create_segment_state, handle_tts_request_on_demand, process_audio_segment, process_conversation_segments, process_persisted_user_segment_response, register_segment_task, register_user_response_task, ) from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota from app.features.conversation.ws.topic_chips_push import maybe_send_topic_chips_ws from app.features.memoir.service import MemoirService from app.features.quota.service import QuotaService from app.features.user.service import UserService from app.features.conversation.constants import chat logger = get_logger(__name__) def _idle_hours_since(ts) -> float | None: """计算距 ts 的小时数;ts 为 None 或非 datetime 时返回 None。""" if ts is None: return None if not isinstance(ts, datetime): return None if ts.tzinfo is None: ts = ts.replace(tzinfo=timezone.utc) delta = datetime.now(timezone.utc) - ts return max(0.0, delta.total_seconds() / 3600.0) async def websocket_endpoint( websocket: WebSocket, conversation_id: str, ): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID """ token = websocket.query_params.get("token") if not token: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌" ) return payload = verify_token(token) if not payload or payload.get("type") != "access": await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌" ) return user_id = payload.get("sub") if not user_id: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容" ) return async with AsyncSessionLocal() as db: user_service = UserService(db) user = await user_service.get_by_id(user_id) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" ) return await manager.connect(websocket, conversation_id) logger.info( "WebSocket 已连接 conversation_id={} user_id={}", conversation_id, user_id, ) quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) memoir_service = MemoirService(db=db) try: await manager.send_message( conversation_id, { "type": MessageType.CONNECT, "conversation_id": conversation_id, "data": {"status": "connected"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) conversation, ws_conn_err = await conversation_service.ensure_ws_connection( conversation_id, user_id ) if ws_conn_err == "forbidden": try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "无权访问此对话"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话" ) return if ws_conn_err == "deleted": try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话已删除"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除" ) return # 冷启动对齐 conversation_stage 与 MemoirState.current_stage; # 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。 memoir_state = await memoir_service.get_or_create_memoir_state(user_id) await conversation_service.align_conversation_stage_from_memoir( conversation, memoir_state.current_stage or "" ) await db.refresh(conversation) history = await conversation_service.ensure_redis_history_from_db( conversation_id ) user_language = ( "en" if str(getattr(user, "language_preference", "zh") or "zh").lower() == "en" else "zh" ) async def _stream_ai_only_messages( texts: list[str], log_label: str ) -> None: """统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。""" if not texts: return ai_msg_id = await conversation_service.record_ai_only_turn( conversation_id, texts ) if not ai_msg_id: return total_n = len(texts) for i, text in enumerate(texts): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": total_n, "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < total_n - 1: await asyncio.sleep(0.5) logger.info( "event=ws_auto_ai_sent label={} conversation_id={} segments={}", log_label, conversation_id, total_n, ) async def _maybe_send_topic_chips(reason: str) -> None: await maybe_send_topic_chips_ws( conversation_id, user=user, memoir_state=memoir_state, reason=reason, language=user_language, ) if not history: missing_profile = get_missing_profile_fields(user) if missing_profile: try: greetings = await chat_orchestrator.generate_profile_greeting( conversation_id=conversation_id, missing_fields=missing_profile, nickname=user.nickname or "", language=user_language, ) await _stream_ai_only_messages( greetings, log_label="profile_greeting" ) except Exception as e: logger.exception("发送资料收集开场白失败: {}", e) else: try: user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, language=user_language, ) era_place = (user.grew_up_place or user.birth_place or "") or "" opening_messages = ( await chat_orchestrator.generate_opening_message( conversation_id=conversation_id, memoir_state=memoir_state, user_profile_context=user_profile_context, background_voice=infer_background_voice( user.occupation ), occupation=user.occupation or "", profile_birth_year=user.birth_year, profile_era_place=era_place, language=user_language, ) ) await _stream_ai_only_messages( opening_messages, log_label="opening" ) await _maybe_send_topic_chips(reason="opening") except Exception as e: logger.exception("发送空对话开场白失败: {}", e) else: # 历史非空:判断是否需要回访问候(距上次消息超过阈值) idle_hours = _idle_hours_since(conversation.last_message_at) threshold = float(chat.re_greeting_idle_hours) if ( chat.re_greeting_enabled and not get_missing_profile_fields(user) and idle_hours is not None and idle_hours >= threshold ): try: user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, language=user_language, ) era_place = (user.grew_up_place or user.birth_place or "") or "" re_greetings = ( await chat_orchestrator.generate_re_greeting_message( conversation_id=conversation_id, memoir_state=memoir_state, idle_hours=idle_hours, user_profile_context=user_profile_context, background_voice=infer_background_voice( user.occupation ), occupation=user.occupation or "", profile_birth_year=user.birth_year, profile_era_place=era_place, language=user_language, ) ) await _stream_ai_only_messages( re_greetings, log_label="re_greeting" ) logger.info( "event=ws_re_greeting_emitted conversation_id={} " "idle_hours={:.2f} threshold={:.2f}", conversation_id, idle_hours, threshold, ) await _maybe_send_topic_chips(reason="re_greeting") except Exception as e: logger.exception("发送回访问候失败: {}", e) else: # 不触发回访问候时,仍可下发 chips 以减少冷启动门槛 await _maybe_send_topic_chips(reason="resume") while True: try: if websocket.application_state != WebSocketState.CONNECTED: logger.debug( "WebSocket 已非连接状态,退出循环: conversation_id={}", conversation_id, ) break message = await websocket.receive_json() msg_type = message.get("type") if msg_type == MessageType.AUDIO_SEGMENT: _d = message.get("data") or {} logger.info( "WebSocket 收到消息 type={} conversation_id={} " "segment_index={} is_last={} duration_s={} audio_b64_len={}", msg_type, conversation_id, _d.get("segment_index"), bool(_d.get("is_last")), int(_d.get("duration") or 0), len(_d.get("audio_base64") or ""), ) elif msg_type is not None: logger.info( "WebSocket 收到消息 type={} conversation_id={}", msg_type, conversation_id, ) else: logger.warning( "WebSocket 收到缺少 type 的 JSON conversation_id={}", conversation_id, ) if msg_type == MessageType.TEXT: data = message.get("data") or {} text_message = data.get("text", "") tts_this_turn = bool(data.get("tts_this_turn")) if text_message: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue segment = await conversation_service.create_user_segment( conversation, user_id, text_message, ) task = asyncio.create_task( process_persisted_user_segment_response( conversation_id=conversation_id, user_id=user_id, segment_id=segment.id, tts_this_turn=tts_this_turn, ) ) register_user_response_task(conversation_id, task) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) raw_vs = data.get("voice_session_id") if not raw_vs or not str(raw_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 voice_session_id"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(raw_vs).strip() segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) async with segment_state.lock: if ( segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done() ): continue if segment_state.listening_feedback_sent: continue delayed_task = asyncio.create_task( _delayed_listening_feedback( conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) segment_state.listening_feedback_task = delayed_task elif msg_type == MessageType.AUDIO_SEGMENT: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") segment_index_raw = data.get("segment_index") resolved_vs = data.get("voice_session_id") or ( _voice_session_id_from_client_segment_id( data.get("client_segment_id") ) ) if not resolved_vs or not str(resolved_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "缺少 voice_session_id 或有效的 client_segment_id" }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(resolved_vs).strip() is_last = bool(data.get("is_last", False)) audio_duration = int(data.get("duration", 0) or 0) tts_this_turn_segment = bool(data.get("tts_this_turn")) if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index_raw is None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 segment_index"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: segment_index = int(segment_index_raw) except (TypeError, ValueError): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 必须为整数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index < 0: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 不能为负数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) should_process = False async with segment_state.lock: already_seen = ( segment_index in segment_state.pending_indices or segment_index in segment_state.processed_indices or segment_index <= segment_state.consumed_index ) if not already_seen: segment_state.pending_indices.add(segment_index) should_process = True if not should_process: logger.debug( "收到重复分段,跳过: conversation_id={} voice_session_id={} " "segment_index={} audio_b64_len={} duration={}", conversation_id, voice_session_id, segment_index, len(audio_base64 or ""), audio_duration, ) continue if is_last: async with segment_state.lock: t = segment_state.listening_feedback_task segment_state.listening_feedback_task = None if t is not None and not t.done(): t.cancel() task = asyncio.create_task( process_audio_segment( conversation_id=conversation_id, user_id=user_id, voice_session_id=voice_session_id, segment_index=segment_index, audio_base64=audio_base64, audio_duration=audio_duration, is_last=is_last, tts_this_turn=tts_this_turn_segment, ) ) register_segment_task(conversation_id, voice_session_id, task) elif msg_type == MessageType.AUDIO_MESSAGE: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") audio_duration = data.get("duration", 0) tts_this_turn = bool(data.get("tts_this_turn")) if audio_base64: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue logger.debug( "收到音频消息: conversation_id={} duration_s={}", conversation_id, audio_duration, ) try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") log_asr_transcript_result( logger, text=asr_text or "", conversation_id=conversation_id, duration_s=audio_duration, source="audio_message", ) await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": asr_text, "audio_duration": audio_duration, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) try: ads = int(audio_duration) except (TypeError, ValueError): ads = 0 segment = ( await conversation_service.create_user_segment( conversation, user_id, asr_text, audio_url=f"audio:{audio_duration}s", audio_duration_seconds=ads if ads > 0 else None, ) ) if asr_text and not asr_text.startswith("转写失败"): task = asyncio.create_task( process_persisted_user_segment_response( conversation_id=conversation_id, user_id=user_id, segment_id=segment.id, tts_this_turn=tts_this_turn, ) ) register_user_response_task(conversation_id, task) else: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音转写失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as e: logger.exception("处理音频消息失败: {}", e) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音处理失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) elif msg_type == MessageType.TRANSCRIBE_ONLY: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") log_asr_transcript_result( logger, text=asr_text or "", conversation_id=conversation_id, source="transcribe_only", ) await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": {"text": asr_text or ""}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as e: logger.exception("仅转写失败: {}", e) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "语音转写失败,请重试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) elif msg_type == MessageType.TTS_CANCEL: bump_tts_cancel_epoch(conversation_id) elif msg_type == MessageType.TTS_REQUEST: data = message.get("data") or {} aid = data.get("assistant_message_id") or data.get( "assistantMessageId" ) if not aid or not str(aid).strip(): logger.warning( "ws.TTS_REQUEST 缺少 assistant_message_id " "conversation_id={} user_id={}", conversation_id, user_id, ) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少助手消息 id"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: seg_idx = int( data.get("segment_index", data.get("segmentIndex", 0)) ) except (TypeError, ValueError): seg_idx = 0 st = data.get("segment_text") or data.get("segmentText") st_val: str | None if st is None: st_val = None else: st_val = str(st).strip() or None ok, err_msg = await handle_tts_request_on_demand( conversation_id=conversation_id, user_id=user_id, assistant_message_id=str(aid).strip(), segment_index=seg_idx, segment_text=st_val, db=db, ) if not ok: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": err_msg or "朗读请求失败"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) elif msg_type == MessageType.END_CONVERSATION: await conversation_service.end(conversation_id, user_id) await process_conversation_segments( conversation_id, db, quota_service ) await manager.send_message( conversation_id, { "type": MessageType.END_CONVERSATION, "conversation_id": conversation_id, "data": {"status": "ended"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) break elif msg_type == MessageType.PING: await manager.send_message( conversation_id, { "type": MessageType.PONG, "conversation_id": conversation_id, "data": {}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) else: if msg_type is not None: logger.warning( "WebSocket 未识别的消息 type={} conversation_id={}", msg_type, conversation_id, ) except RuntimeError as e: error_msg = str(e) if ( "disconnect" in error_msg.lower() or 'Cannot call "receive"' in error_msg or "accept" in error_msg.lower() and "not connected" in error_msg.lower() ): logger.debug( "WebSocket 连接已断开或未就绪: conversation_id={} error={}", conversation_id, error_msg, ) break else: logger.exception("处理消息时发生 RuntimeError: {}", e) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "处理失败,请重试"}, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as send_error: logger.warning("发送错误消息失败: {}", send_error) break except WebSocketDisconnect as disc: logger.info( "WebSocket 断开连接(收消息循环): conversation_id={} code={}", conversation_id, getattr(disc, "code", None), ) break except Exception as e: logger.exception("处理消息时发生错误: {}", e) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "处理失败,请重试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning("发送错误消息失败: {}", send_error) break except WebSocketDisconnect as disc: logger.info( "WebSocket 断开连接: conversation_id={} code={}", conversation_id, getattr(disc, "code", None), ) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) except Exception as e: logger.exception("WebSocket 端点发生错误: {}", e) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) finally: await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id)