refactor: 优化后端路由功能
- 扩展books.py路由,添加新接口 - 优化websocket.py路由,增强WebSocket功能
This commit is contained in:
@@ -8,7 +8,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -64,7 +64,14 @@ class ConnectionManager:
|
||||
"""发送消息"""
|
||||
if conversation_id in self.active_connections:
|
||||
websocket = self.active_connections[conversation_id]
|
||||
await websocket.send_json(message)
|
||||
try:
|
||||
# 尝试发送消息,如果连接已关闭会抛出异常
|
||||
await websocket.send_json(message)
|
||||
except (RuntimeError, Exception) as e:
|
||||
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
|
||||
# 如果发送失败,从连接列表中移除
|
||||
if conversation_id in self.active_connections:
|
||||
del self.active_connections[conversation_id]
|
||||
|
||||
async def receive_message(self, conversation_id: str) -> dict:
|
||||
"""接收消息"""
|
||||
@@ -79,8 +86,7 @@ manager = ConnectionManager()
|
||||
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
token: str = Query(..., description="访问令牌")
|
||||
conversation_id: str
|
||||
):
|
||||
"""
|
||||
WebSocket 端点:处理实时对话
|
||||
@@ -88,8 +94,13 @@ async def websocket_endpoint(
|
||||
Args:
|
||||
websocket: WebSocket 连接
|
||||
conversation_id: 对话 ID
|
||||
token: 访问令牌(从查询参数获取)
|
||||
"""
|
||||
# 从查询参数获取token
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
|
||||
return
|
||||
|
||||
# 验证JWT令牌
|
||||
payload = verify_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
@@ -134,11 +145,14 @@ async def websocket_endpoint(
|
||||
else:
|
||||
# 验证用户权限:只能访问自己的对话
|
||||
if conversation.user_id != user_id:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception:
|
||||
pass # 如果发送失败,直接关闭连接
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
||||
return
|
||||
|
||||
@@ -193,19 +207,51 @@ async def websocket_endpoint(
|
||||
})
|
||||
break
|
||||
|
||||
except RuntimeError as e:
|
||||
# 检查是否是断开连接相关的错误
|
||||
error_msg = str(e)
|
||||
if "disconnect" in error_msg.lower() or "Cannot call \"receive\"" in error_msg:
|
||||
logger.info(f"WebSocket 连接已断开: conversation_id={conversation_id}")
|
||||
break
|
||||
else:
|
||||
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
await manager.disconnect(conversation_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
|
||||
await manager.disconnect(conversation_id)
|
||||
finally:
|
||||
# 确保清理连接
|
||||
await manager.disconnect(conversation_id)
|
||||
raise
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
@@ -272,12 +318,17 @@ async def process_user_message(
|
||||
await _asyncio.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息失败: {e}")
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user