diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 33d1593..27b9a78 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -9,6 +9,7 @@ from enum import Enum from typing import Dict from fastapi import WebSocket, WebSocketDisconnect, HTTPException +from starlette.websockets import WebSocketState from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -163,6 +164,9 @@ async def websocket_endpoint( # 主循环:处理消息 while True: try: + if websocket.application_state != WebSocketState.CONNECTED: + logger.info(f"WebSocket 已非连接状态,退出循环: conversation_id={conversation_id}") + break message = await websocket.receive_json() msg_type = message.get("type") @@ -326,10 +330,14 @@ async def websocket_endpoint( break except RuntimeError as e: - # 检查是否是断开连接相关的错误 + # 检查是否是断开连接或未连接状态(如 accept 前/后连接被关闭) error_msg = str(e) - if "disconnect" in error_msg.lower() or "Cannot call \"receive\"" in error_msg: - logger.info(f"WebSocket 连接已断开: conversation_id={conversation_id}") + if ( + "disconnect" in error_msg.lower() + or "Cannot call \"receive\"" in error_msg + or "accept" in error_msg.lower() and "not connected" in error_msg.lower() + ): + logger.info(f"WebSocket 连接已断开或未就绪: conversation_id={conversation_id}, error={error_msg}") break else: logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)