refactor: 优化后端路由功能

- 扩展books.py路由,添加新接口
- 优化websocket.py路由,增强WebSocket功能
This commit is contained in:
iammm0
2026-01-26 11:54:05 +08:00
parent 5314077f3b
commit dae4a176fd
2 changed files with 112 additions and 23 deletions

View File

@@ -55,6 +55,44 @@ async def clear_book_update(
return {"status": "ok"}
class UpdateBookRequest(BaseModel):
title: str
subtitle: str | None = None # 目前数据库不支持subtitle但保留字段以便将来扩展
@router.put("/{book_id}")
async def update_book(
book_id: str,
request: UpdateBookRequest = Body(...),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db)
):
"""更新书籍标题(需要认证,只能更新自己的回忆录)"""
book = await db.get(BookModel, book_id)
if not book:
raise HTTPException(status_code=404, detail="Book not found")
# 验证用户权限
if book.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此回忆录")
# 更新标题
book.title = request.title
# subtitle字段目前数据库不支持暂时忽略
await db.commit()
await db.refresh(book)
return {
"id": book.id,
"title": book.title,
"total_pages": book.total_pages,
"total_words": book.total_words,
"cover_image_url": book.cover_image_url,
"has_update": book.has_update,
"last_update_chapter_id": book.last_update_chapter_id,
}
class ExportPdfRequest(BaseModel):
book_id: str

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -64,7 +64,14 @@ class ConnectionManager:
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
await websocket.send_json(message)
try:
# 尝试发送消息,如果连接已关闭会抛出异常
await websocket.send_json(message)
except (RuntimeError, Exception) as e:
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
# 如果发送失败,从连接列表中移除
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
@@ -79,8 +86,7 @@ manager = ConnectionManager()
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
token: str = Query(..., description="访问令牌")
conversation_id: str
):
"""
WebSocket 端点:处理实时对话
@@ -88,8 +94,13 @@ async def websocket_endpoint(
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
token: 访问令牌(从查询参数获取)
"""
# 从查询参数获取token
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
return
# 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
@@ -134,11 +145,14 @@ async def websocket_endpoint(
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception:
pass # 如果发送失败,直接关闭连接
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
@@ -193,19 +207,51 @@ async def websocket_endpoint(
})
break
except RuntimeError as e:
# 检查是否是断开连接相关的错误
error_msg = str(e)
if "disconnect" in error_msg.lower() or "Cannot call \"receive\"" in error_msg:
logger.info(f"WebSocket 连接已断开: conversation_id={conversation_id}")
break
else:
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
break
except Exception as e:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
await manager.disconnect(conversation_id)
break
except Exception as e:
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
await manager.disconnect(conversation_id)
finally:
# 确保清理连接
await manager.disconnect(conversation_id)
raise
async def process_user_message(
@@ -272,12 +318,17 @@ async def process_user_message(
await _asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}")
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
logger.error(f"处理用户消息失败: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
return