""" WebSocket 路由:实时对话通信 仅包含 websocket_endpoint 生命周期函数,业务逻辑委托给 pipeline 等子模块 """ import asyncio import base64 from datetime import datetime, timezone from fastapi import WebSocket, WebSocketDisconnect, status from starlette.websockets import WebSocketState from app.agents.chat.background_voice import infer_background_voice from app.agents.chat.prompts_conversation import build_topic_chips from app.agents.chat.prompts_profile import format_user_profile_context from app.agents.stage_constants import STAGE_TO_ORDER from app.agents.state_schema import ( interview_control_state, narrative_coverage_state, ) from app.core.config import settings from app.core.db import AsyncSessionLocal from app.core.dependencies import get_asr_provider from app.core.logging import get_logger from app.core.security import verify_token from app.features.conversation.history_store import ConversationHistoryStore from app.features.conversation.service import ConversationService from app.features.conversation.ws.connection_manager import manager from app.features.conversation.ws.message_types import MessageType from app.features.conversation.ws.pipeline import ( _delayed_listening_feedback, _voice_session_id_from_client_segment_id, bump_tts_cancel_epoch, chat_orchestrator, cleanup_segment_states, get_or_create_segment_state, memoir_ingest_scheduler, process_audio_segment, process_conversation_segments, process_user_message, register_segment_task, ) from app.features.conversation.ws.profile_collector import get_missing_profile_fields from app.features.conversation.ws.quota_guard import check_ws_quota from app.features.memoir.state_service import get_or_create_state from app.features.quota.service import QuotaService from app.features.user.models import User logger = get_logger(__name__) def _idle_hours_since(ts) -> float | None: """计算距 ts 的小时数;ts 为 None 或非 datetime 时返回 None。""" if ts is None: return None if not isinstance(ts, datetime): return None if ts.tzinfo is None: ts = ts.replace(tzinfo=timezone.utc) delta = datetime.now(timezone.utc) - ts return max(0.0, delta.total_seconds() / 3600.0) async def websocket_endpoint( websocket: WebSocket, conversation_id: str, ): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID """ token = websocket.query_params.get("token") if not token: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌" ) return payload = verify_token(token) if not payload or payload.get("type") != "access": await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌" ) return user_id = payload.get("sub") if not user_id: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容" ) return async with AsyncSessionLocal() as db: user = await db.get(User, user_id) if not user: await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在" ) return await manager.connect(websocket, conversation_id) logger.info( "WebSocket 已连接 conversation_id={} user_id={}", conversation_id, user_id, ) quota_service = QuotaService(db=db) conversation_service = ConversationService(db=db, quota_service=quota_service) try: await manager.send_message( conversation_id, { "type": MessageType.CONNECT, "conversation_id": conversation_id, "data": {"status": "connected"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) conversation, ws_conn_err = await conversation_service.ensure_ws_connection( conversation_id, user_id ) if ws_conn_err == "forbidden": try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "无权访问此对话"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话" ) return if ws_conn_err == "deleted": try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "对话已删除"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception: pass await websocket.close( code=status.WS_1008_POLICY_VIOLATION, reason="对话已删除" ) return # 冷启动对齐 conversation_stage 与 MemoirState.current_stage; # 若对话行已有更靠前的人生阶段(STAGE_TO_ORDER 更大),不覆盖以免回退。 memoir_state = await get_or_create_state(user_id, db) ms = (memoir_state.current_stage or "").strip() cs = (conversation.conversation_stage or "").strip() if ms: if not cs: conversation.conversation_stage = ms elif STAGE_TO_ORDER.get(ms, -1) >= STAGE_TO_ORDER.get(cs, -1): conversation.conversation_stage = ms await db.commit() await db.refresh(conversation) history = await conversation_service.ensure_redis_history_from_db( conversation_id ) async def _stream_ai_only_messages( texts: list[str], log_label: str ) -> None: """统一:把一组 AI 消息落库并按 [SPLIT] 分段下发。""" if not texts: return ai_msg_id = await ConversationHistoryStore(db).record_ai_only_turn( conversation_id, texts ) if not ai_msg_id: return total_n = len(texts) for i, text in enumerate(texts): await manager.send_message( conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": { "text": text, "index": i, "total": total_n, "assistant_message_id": ai_msg_id, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) if i < total_n - 1: await asyncio.sleep(0.5) logger.info( "event=ws_auto_ai_sent label={} conversation_id={} segments={}", log_label, conversation_id, total_n, ) async def _maybe_send_topic_chips(reason: str) -> None: """根据当前阶段空 slot 生成 quick-start 话题 chips;失败静默。""" if not settings.chat_topic_chips_enabled: return # 资料未齐时不送 chips:profile 收集走另一条流程,chips 反而噪音 if get_missing_profile_fields(user): return try: narrative_state = narrative_coverage_state(memoir_state) control_state = interview_control_state(memoir_state) empty_slots = control_state.prompt_empty_slots_for_stage( narrative_state, memoir_state.current_stage ) chips = build_topic_chips( memoir_state.current_stage, empty_slots, max_chips=settings.chat_topic_chips_max, ) if not chips: return await manager.send_message( conversation_id, { "type": MessageType.TOPIC_SUGGESTIONS, "conversation_id": conversation_id, "data": { "reason": reason, "stage": memoir_state.current_stage, "suggestions": chips, }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) logger.info( "event=ws_topic_chips_sent reason={} conversation_id={} " "stage={} count={}", reason, conversation_id, memoir_state.current_stage, len(chips), ) except Exception as e: logger.warning("发送话题 chips 失败: {}", e) if not history: missing_profile = get_missing_profile_fields(user) if missing_profile: try: greetings = await chat_orchestrator.generate_profile_greeting( conversation_id=conversation_id, missing_fields=missing_profile, nickname=user.nickname or "", ) await _stream_ai_only_messages( greetings, log_label="profile_greeting" ) except Exception as e: logger.exception("发送资料收集开场白失败: {}", e) else: try: user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, ) era_place = (user.grew_up_place or user.birth_place or "") or "" opening_messages = ( await chat_orchestrator.generate_opening_message( conversation_id=conversation_id, memoir_state=memoir_state, user_profile_context=user_profile_context, background_voice=infer_background_voice( user.occupation ), occupation=user.occupation or "", profile_birth_year=user.birth_year, profile_era_place=era_place, ) ) await _stream_ai_only_messages( opening_messages, log_label="opening" ) await _maybe_send_topic_chips(reason="opening") except Exception as e: logger.exception("发送空对话开场白失败: {}", e) else: # 历史非空:判断是否需要回访问候(距上次消息超过阈值) idle_hours = _idle_hours_since(conversation.last_message_at) threshold = float(settings.chat_re_greeting_idle_hours) if ( settings.chat_re_greeting_enabled and not get_missing_profile_fields(user) and idle_hours is not None and idle_hours >= threshold ): try: user_profile_context = format_user_profile_context( birth_year=user.birth_year, birth_place=user.birth_place, grew_up_place=user.grew_up_place, occupation=user.occupation, ) era_place = (user.grew_up_place or user.birth_place or "") or "" re_greetings = ( await chat_orchestrator.generate_re_greeting_message( conversation_id=conversation_id, memoir_state=memoir_state, idle_hours=idle_hours, user_profile_context=user_profile_context, background_voice=infer_background_voice( user.occupation ), occupation=user.occupation or "", profile_birth_year=user.birth_year, profile_era_place=era_place, ) ) await _stream_ai_only_messages( re_greetings, log_label="re_greeting" ) logger.info( "event=ws_re_greeting_emitted conversation_id={} " "idle_hours={:.2f} threshold={:.2f}", conversation_id, idle_hours, threshold, ) await _maybe_send_topic_chips(reason="re_greeting") except Exception as e: logger.exception("发送回访问候失败: {}", e) else: # 不触发回访问候时,仍可下发 chips 以减少冷启动门槛 await _maybe_send_topic_chips(reason="resume") while True: try: if websocket.application_state != WebSocketState.CONNECTED: logger.debug( "WebSocket 已非连接状态,退出循环: conversation_id={}", conversation_id, ) break message = await websocket.receive_json() msg_type = message.get("type") if msg_type == MessageType.AUDIO_SEGMENT: _d = message.get("data") or {} logger.info( "WebSocket 收到消息 type={} conversation_id={} " "segment_index={} is_last={} duration_s={} audio_b64_len={}", msg_type, conversation_id, _d.get("segment_index"), bool(_d.get("is_last")), int(_d.get("duration") or 0), len(_d.get("audio_base64") or ""), ) elif msg_type is not None: logger.info( "WebSocket 收到消息 type={} conversation_id={}", msg_type, conversation_id, ) else: logger.warning( "WebSocket 收到缺少 type 的 JSON conversation_id={}", conversation_id, ) if msg_type == MessageType.TEXT: text_message = message.get("data", {}).get("text", "") if text_message: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue segment = await conversation_service.create_user_segment( conversation, user_id, text_message, ) user_message_timestamp = conversation.last_message_at await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, text_char_count=len(text_message.strip()), ) await process_user_message( conversation_id=conversation_id, user_message=text_message, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) elif msg_type == MessageType.RECORDING_STARTED: data = message.get("data", {}) raw_vs = data.get("voice_session_id") if not raw_vs or not str(raw_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 voice_session_id"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(raw_vs).strip() segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) async with segment_state.lock: if ( segment_state.listening_feedback_task is not None and not segment_state.listening_feedback_task.done() ): continue if segment_state.listening_feedback_sent: continue delayed_task = asyncio.create_task( _delayed_listening_feedback( conversation_id=conversation_id, voice_session_id=voice_session_id, ) ) segment_state.listening_feedback_task = delayed_task elif msg_type == MessageType.AUDIO_SEGMENT: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") segment_index_raw = data.get("segment_index") resolved_vs = data.get("voice_session_id") or ( _voice_session_id_from_client_segment_id( data.get("client_segment_id") ) ) if not resolved_vs or not str(resolved_vs).strip(): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "缺少 voice_session_id 或有效的 client_segment_id" }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue voice_session_id = str(resolved_vs).strip() is_last = bool(data.get("is_last", False)) audio_duration = int(data.get("duration", 0) or 0) if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index_raw is None: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 segment_index"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: segment_index = int(segment_index_raw) except (TypeError, ValueError): await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 必须为整数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue if segment_index < 0: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "segment_index 不能为负数"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue segment_state = get_or_create_segment_state( conversation_id, voice_session_id, ) should_process = False async with segment_state.lock: already_seen = ( segment_index in segment_state.pending_indices or segment_index in segment_state.processed_indices or segment_index <= segment_state.consumed_index ) if not already_seen: segment_state.pending_indices.add(segment_index) should_process = True if not should_process: logger.debug( "收到重复分段,跳过: conversation_id={} voice_session_id={} " "segment_index={} audio_b64_len={} duration={}", conversation_id, voice_session_id, segment_index, len(audio_base64 or ""), audio_duration, ) continue if is_last: async with segment_state.lock: t = segment_state.listening_feedback_task segment_state.listening_feedback_task = None if t is not None and not t.done(): t.cancel() task = asyncio.create_task( process_audio_segment( conversation_id=conversation_id, user_id=user_id, voice_session_id=voice_session_id, segment_index=segment_index, audio_base64=audio_base64, audio_duration=audio_duration, is_last=is_last, ) ) register_segment_task(conversation_id, voice_session_id, task) elif msg_type == MessageType.AUDIO_MESSAGE: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") audio_duration = data.get("duration", 0) if audio_base64: can_send, quota_msg = await check_ws_quota( quota_service, user_id, user.subscription_type ) if not can_send: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": quota_msg, "code": "QUOTA_EXCEEDED", }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) continue logger.debug( "收到音频消息: conversation_id={} duration_s={}", conversation_id, audio_duration, ) try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") logger.debug( "ASR 转写完成: conversation_id={} chars={}", conversation_id, len(asr_text or ""), ) logger.debug( "ASR 转写全文: conversation_id={} text={}", conversation_id, asr_text, ) await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": { "text": asr_text, "audio_duration": audio_duration, }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) try: ads = int(audio_duration) except (TypeError, ValueError): ads = 0 segment = ( await conversation_service.create_user_segment( conversation, user_id, asr_text, audio_url=f"audio:{audio_duration}s", audio_duration_seconds=ads if ads > 0 else None, ) ) user_message_timestamp = conversation.last_message_at await memoir_ingest_scheduler.queue_segment( conversation.user_id, segment.id, text_char_count=len((asr_text or "").strip()), ) if asr_text and not asr_text.startswith("转写失败"): await process_user_message( conversation_id=conversation_id, user_message=asr_text, conversation=conversation, segment=segment, db=db, user=user, user_message_timestamp=segment.created_at or user_message_timestamp, ) else: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音转写失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as e: logger.exception("处理音频消息失败: {}", e) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": { "message": "语音处理失败,请重试或使用文字输入" }, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) elif msg_type == MessageType.TRANSCRIBE_ONLY: data = message.get("data", {}) audio_base64 = data.get("audio_base64", "") if not audio_base64: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "缺少 audio_base64"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) continue try: asr = get_asr_provider() audio_bytes = base64.b64decode(audio_base64) asr_text = await asr.transcribe(audio_bytes, "m4a") await manager.send_message( conversation_id, { "type": MessageType.TRANSCRIPT, "conversation_id": conversation_id, "data": {"text": asr_text or ""}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as e: logger.exception("仅转写失败: {}", e) await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "语音转写失败,请重试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) elif msg_type == MessageType.TTS_CANCEL: bump_tts_cancel_epoch(conversation_id) elif msg_type == MessageType.END_CONVERSATION: await conversation_service.end(conversation_id, user_id) await process_conversation_segments( conversation_id, db, quota_service ) await manager.send_message( conversation_id, { "type": MessageType.END_CONVERSATION, "conversation_id": conversation_id, "data": {"status": "ended"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) break elif msg_type == MessageType.PING: await manager.send_message( conversation_id, { "type": MessageType.PONG, "conversation_id": conversation_id, "data": {}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) else: if msg_type is not None: logger.warning( "WebSocket 未识别的消息 type={} conversation_id={}", msg_type, conversation_id, ) except RuntimeError as e: error_msg = str(e) if ( "disconnect" in error_msg.lower() or 'Cannot call "receive"' in error_msg or "accept" in error_msg.lower() and "not connected" in error_msg.lower() ): logger.debug( "WebSocket 连接已断开或未就绪: conversation_id={} error={}", conversation_id, error_msg, ) break else: logger.exception("处理消息时发生 RuntimeError: {}", e) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "处理失败,请重试"}, "timestamp": datetime.now( timezone.utc ).isoformat(), }, ) except Exception as send_error: logger.warning("发送错误消息失败: {}", send_error) break except WebSocketDisconnect as disc: logger.info( "WebSocket 断开连接(收消息循环): conversation_id={} code={}", conversation_id, getattr(disc, "code", None), ) break except Exception as e: logger.exception("处理消息时发生错误: {}", e) if conversation_id in manager.active_connections: try: await manager.send_message( conversation_id, { "type": MessageType.ERROR, "data": {"message": "处理失败,请重试"}, "timestamp": datetime.now(timezone.utc).isoformat(), }, ) except Exception as send_error: logger.warning("发送错误消息失败: {}", send_error) break except WebSocketDisconnect as disc: logger.info( "WebSocket 断开连接: conversation_id={} code={}", conversation_id, getattr(disc, "code", None), ) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) except Exception as e: logger.exception("WebSocket 端点发生错误: {}", e) await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id) finally: await manager.disconnect(conversation_id) cleanup_segment_states(conversation_id)