refactor: 优化后端路由功能
- 扩展books.py路由,添加新接口 - 优化websocket.py路由,增强WebSocket功能
This commit is contained in:
@@ -55,6 +55,44 @@ async def clear_book_update(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class UpdateBookRequest(BaseModel):
|
||||
title: str
|
||||
subtitle: str | None = None # 目前数据库不支持subtitle,但保留字段以便将来扩展
|
||||
|
||||
|
||||
@router.put("/{book_id}")
|
||||
async def update_book(
|
||||
book_id: str,
|
||||
request: UpdateBookRequest = Body(...),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""更新书籍标题(需要认证,只能更新自己的回忆录)"""
|
||||
book = await db.get(BookModel, book_id)
|
||||
if not book:
|
||||
raise HTTPException(status_code=404, detail="Book not found")
|
||||
|
||||
# 验证用户权限
|
||||
if book.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权更新此回忆录")
|
||||
|
||||
# 更新标题
|
||||
book.title = request.title
|
||||
# subtitle字段目前数据库不支持,暂时忽略
|
||||
await db.commit()
|
||||
await db.refresh(book)
|
||||
|
||||
return {
|
||||
"id": book.id,
|
||||
"title": book.title,
|
||||
"total_pages": book.total_pages,
|
||||
"total_words": book.total_words,
|
||||
"cover_image_url": book.cover_image_url,
|
||||
"has_update": book.has_update,
|
||||
"last_update_chapter_id": book.last_update_chapter_id,
|
||||
}
|
||||
|
||||
|
||||
class ExportPdfRequest(BaseModel):
|
||||
book_id: str
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -64,7 +64,14 @@ class ConnectionManager:
|
||||
"""发送消息"""
|
||||
if conversation_id in self.active_connections:
|
||||
websocket = self.active_connections[conversation_id]
|
||||
await websocket.send_json(message)
|
||||
try:
|
||||
# 尝试发送消息,如果连接已关闭会抛出异常
|
||||
await websocket.send_json(message)
|
||||
except (RuntimeError, Exception) as e:
|
||||
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
|
||||
# 如果发送失败,从连接列表中移除
|
||||
if conversation_id in self.active_connections:
|
||||
del self.active_connections[conversation_id]
|
||||
|
||||
async def receive_message(self, conversation_id: str) -> dict:
|
||||
"""接收消息"""
|
||||
@@ -79,8 +86,7 @@ manager = ConnectionManager()
|
||||
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
token: str = Query(..., description="访问令牌")
|
||||
conversation_id: str
|
||||
):
|
||||
"""
|
||||
WebSocket 端点:处理实时对话
|
||||
@@ -88,8 +94,13 @@ async def websocket_endpoint(
|
||||
Args:
|
||||
websocket: WebSocket 连接
|
||||
conversation_id: 对话 ID
|
||||
token: 访问令牌(从查询参数获取)
|
||||
"""
|
||||
# 从查询参数获取token
|
||||
token = websocket.query_params.get("token")
|
||||
if not token:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
|
||||
return
|
||||
|
||||
# 验证JWT令牌
|
||||
payload = verify_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
@@ -134,11 +145,14 @@ async def websocket_endpoint(
|
||||
else:
|
||||
# 验证用户权限:只能访问自己的对话
|
||||
if conversation.user_id != user_id:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception:
|
||||
pass # 如果发送失败,直接关闭连接
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
||||
return
|
||||
|
||||
@@ -193,19 +207,51 @@ async def websocket_endpoint(
|
||||
})
|
||||
break
|
||||
|
||||
except RuntimeError as e:
|
||||
# 检查是否是断开连接相关的错误
|
||||
error_msg = str(e)
|
||||
if "disconnect" in error_msg.lower() or "Cannot call \"receive\"" in error_msg:
|
||||
logger.info(f"WebSocket 连接已断开: conversation_id={conversation_id}")
|
||||
break
|
||||
else:
|
||||
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": str(e)},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
|
||||
await manager.disconnect(conversation_id)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
|
||||
await manager.disconnect(conversation_id)
|
||||
finally:
|
||||
# 确保清理连接
|
||||
await manager.disconnect(conversation_id)
|
||||
raise
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
@@ -272,12 +318,17 @@ async def process_user_message(
|
||||
await _asyncio.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息失败: {e}")
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
logger.error(f"处理用户消息失败: {e}", exc_info=True)
|
||||
# 只在连接仍然活跃时发送错误消息
|
||||
if conversation_id in manager.active_connections:
|
||||
try:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
except Exception as send_error:
|
||||
logger.warning(f"发送错误消息失败: {send_error}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user