refactor: 更新API路由

- 更新books路由以支持用户认证
- 更新chapters路由以支持用户认证
- 更新conversations路由以支持用户认证
- 更新websocket路由以支持用户认证和连接管理
This commit is contained in:
徐在坤
2026-01-18 15:57:51 +08:00
parent 3a4c9f0838
commit 802f5a3833
4 changed files with 211 additions and 120 deletions

View File

@@ -1,24 +1,27 @@
""" """
回忆录相关 API 路由 回忆录相关 API 路由
""" """
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db from database import get_async_db
from database.models import Book as BookModel from database.models import Book as BookModel
from database.models import User as UserModel
from services.pdf_service import pdf_service from services.pdf_service import pdf_service
from middleware.auth import get_current_user
router = APIRouter(prefix="/api/books", tags=["books"]) router = APIRouter(prefix="/api/books", tags=["books"])
@router.get("/current") @router.get("/current")
async def get_current_book( async def get_current_book(
user_id: str, current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""获取当前回忆录""" """获取当前回忆录(需要认证)"""
stmt = select(BookModel).where(BookModel.user_id == user_id).order_by(BookModel.updated_at.desc()) stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
result = await db.execute(stmt) result = await db.execute(stmt)
book = result.scalar_one_or_none() book = result.scalar_one_or_none()
@@ -34,20 +37,28 @@ async def get_current_book(
} }
class ExportPdfRequest(BaseModel):
book_id: str
@router.post("/export-pdf") @router.post("/export-pdf")
async def export_pdf( async def export_pdf(
book_id: str, request: ExportPdfRequest = Body(...),
user_id: str, current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""导出 PDF""" """导出 PDF(需要认证,只能导出自己的回忆录)"""
book = await db.get(BookModel, book_id) book = await db.get(BookModel, request.book_id)
if not book or book.user_id != user_id: if not book:
raise HTTPException(status_code=404, detail="Book not found") raise HTTPException(status_code=404, detail="Book not found")
# 验证用户权限
if book.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权导出此回忆录")
# 获取所有章节 # 获取所有章节
from ..database.models import Chapter from database.models import Chapter
stmt = select(Chapter).where(Chapter.user_id == user_id).order_by(Chapter.order_index) stmt = select(Chapter).where(Chapter.user_id == current_user.id).order_by(Chapter.order_index)
result = await db.execute(stmt) result = await db.execute(stmt)
chapters = result.scalars().all() chapters = result.scalars().all()

View File

@@ -1,25 +1,27 @@
""" """
章节相关 API 路由 章节相关 API 路由
""" """
from typing import List from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db from database import get_async_db
from database.models import Chapter as ChapterModel from database.models import Chapter as ChapterModel
from database.models import User as UserModel
from middleware.auth import get_current_user
router = APIRouter(prefix="/api/chapters", tags=["chapters"]) router = APIRouter(prefix="/api/chapters", tags=["chapters"])
@router.get("", response_model=List[dict]) @router.get("", response_model=List[dict])
async def get_chapters( async def get_chapters(
user_id: str, current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""获取用户所有章节""" """获取用户所有章节(需要认证)"""
stmt = select(ChapterModel).where(ChapterModel.user_id == user_id).order_by(ChapterModel.order_index) stmt = select(ChapterModel).where(ChapterModel.user_id == current_user.id).order_by(ChapterModel.order_index)
result = await db.execute(stmt) result = await db.execute(stmt)
chapters = result.scalars().all() chapters = result.scalars().all()
@@ -40,13 +42,18 @@ async def get_chapters(
@router.get("/{chapter_id}", response_model=dict) @router.get("/{chapter_id}", response_model=dict)
async def get_chapter( async def get_chapter(
chapter_id: str, chapter_id: str,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""获取章节详情""" """获取章节详情(需要认证,只能访问自己的章节)"""
chapter = await db.get(ChapterModel, chapter_id) chapter = await db.get(ChapterModel, chapter_id)
if not chapter: if not chapter:
raise HTTPException(status_code=404, detail="Chapter not found") raise HTTPException(status_code=404, detail="Chapter not found")
# 验证用户权限
if chapter.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此章节")
return { return {
"id": chapter.id, "id": chapter.id,
"title": chapter.title, "title": chapter.title,
@@ -61,9 +68,18 @@ async def get_chapter(
@router.post("/{chapter_id}/regenerate") @router.post("/{chapter_id}/regenerate")
async def regenerate_chapter( async def regenerate_chapter(
chapter_id: str, chapter_id: str,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""重新整理章节""" """重新整理章节(需要认证,只能操作自己的章节)"""
chapter = await db.get(ChapterModel, chapter_id)
if not chapter:
raise HTTPException(status_code=404, detail="Chapter not found")
# 验证用户权限
if chapter.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权操作此章节")
# TODO: 实现重新整理逻辑 # TODO: 实现重新整理逻辑
return {"status": "ok", "message": "Chapter regeneration triggered"} return {"status": "ok", "message": "Chapter regeneration triggered"}

View File

@@ -3,26 +3,29 @@
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
import uuid import uuid
from database import get_async_db, Conversation, Segment, User from database import get_async_db, Conversation, Segment, User
from database.models import Conversation as ConversationModel, Segment as SegmentModel from database.models import Conversation as ConversationModel, Segment as SegmentModel
from middleware.auth import get_current_user
from database.models import User as UserModel
router = APIRouter(prefix="/api/conversations", tags=["conversations"]) router = APIRouter(prefix="/api/conversations", tags=["conversations"])
@router.post("") @router.post("")
async def create_conversation( async def create_conversation(
user_id: str, current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""创建新对话""" """创建新对话(需要认证)"""
conversation = ConversationModel( conversation = ConversationModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
user_id=user_id, user_id=current_user.id,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
status="active" status="active"
) )
@@ -41,13 +44,18 @@ async def create_conversation(
@router.get("/{conversation_id}") @router.get("/{conversation_id}")
async def get_conversation( async def get_conversation(
conversation_id: str, conversation_id: str,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""获取对话详情""" """获取对话详情(需要认证,只能访问自己的对话)"""
conversation = await db.get(ConversationModel, conversation_id) conversation = await db.get(ConversationModel, conversation_id)
if not conversation: if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found") raise HTTPException(status_code=404, detail="Conversation not found")
# 验证用户权限
if conversation.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权访问此对话")
return { return {
"id": conversation.id, "id": conversation.id,
"user_id": conversation.user_id, "user_id": conversation.user_id,
@@ -64,13 +72,18 @@ async def get_conversation(
@router.post("/{conversation_id}/end") @router.post("/{conversation_id}/end")
async def end_conversation( async def end_conversation(
conversation_id: str, conversation_id: str,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_async_db) db: AsyncSession = Depends(get_async_db)
): ):
"""结束对话""" """结束对话(需要认证,只能结束自己的对话)"""
conversation = await db.get(ConversationModel, conversation_id) conversation = await db.get(ConversationModel, conversation_id)
if not conversation: if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found") raise HTTPException(status_code=404, detail="Conversation not found")
# 验证用户权限
if conversation.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权操作此对话")
conversation.status = "ended" conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc) conversation.ended_at = datetime.now(timezone.utc)

View File

@@ -6,7 +6,7 @@ from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Dict from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -14,14 +14,16 @@ from agents import ConversationAgent, MemoryAgent
from agents.prompts import ConversationStage from agents.prompts import ConversationStage
from database import get_async_db from database import get_async_db
from database.models import Conversation, Segment from database.models import Conversation, Segment
from services.asr_service import asr_service from database.models import User as UserModel
from services.tts_service import tts_service from services.auth_service import verify_token
from fastapi import HTTPException, status
class MessageType(str, Enum): class MessageType(str, Enum):
"""WebSocket 消息类型""" """WebSocket 消息类型"""
CONNECT = "connect" CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk" AUDIO_CHUNK = "audio_chunk"
TEXT = "text" # 文本消息
TRANSCRIPT = "transcript" TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response" AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio" TTS_AUDIO = "tts_audio"
@@ -69,34 +71,52 @@ class ConnectionManager:
manager = ConnectionManager() manager = ConnectionManager()
async def websocket_endpoint(websocket: WebSocket, conversation_id: str): async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
token: str = Query(..., description="访问令牌")
):
""" """
WebSocket 端点:处理实时对话 WebSocket 端点:处理实时对话
Args: Args:
websocket: WebSocket 连接 websocket: WebSocket 连接
conversation_id: 对话 ID conversation_id: 对话 ID
token: 访问令牌(从查询参数获取)
""" """
await manager.connect(websocket, conversation_id) # 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
return
try: user_id = payload.get("sub")
# 发送连接确认 if not user_id:
await manager.send_message(conversation_id, { await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
"type": MessageType.CONNECT, return
"conversation_id": conversation_id,
"data": {"status": "connected"}, # 验证用户是否存在
"timestamp": datetime.now(timezone.utc).isoformat() async for db in get_async_db():
}) user = await db.get(UserModel, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
return
# 从数据库获取对话信息 await manager.connect(websocket, conversation_id)
async for db in get_async_db():
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
conversation = await db.get(Conversation, conversation_id) conversation = await db.get(Conversation, conversation_id)
if not conversation: if not conversation:
# 如果对话不存在,创建新对话 # 如果对话不存在,创建新对话
from database.models import User as UserModel
# 假设用户 ID 从连接参数获取(实际应该从认证获取)
user_id = "default_user" # TODO: 从认证获取实际用户 ID
conversation = Conversation( conversation = Conversation(
id=conversation_id, id=conversation_id,
user_id=user_id, user_id=user_id,
@@ -105,6 +125,17 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
) )
db.add(conversation) db.add(conversation)
await db.commit() await db.commit()
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD
@@ -114,79 +145,31 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
message = await websocket.receive_json() message = await websocket.receive_json()
msg_type = message.get("type") msg_type = message.get("type")
if msg_type == MessageType.AUDIO_CHUNK: if msg_type == MessageType.TEXT:
# 处理音频块 # 处理文本消息
audio_data = message.get("data", {}).get("audio_base64", "") text_message = message.get("data", {}).get("text", "")
# 调用 ASR 服务转文字 if text_message:
transcript = await asr_service.transcribe(audio_data) # 保存段落到数据库
segment = Segment(
# 保存段落到数据库 id=str(uuid.uuid4()),
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript,
processed=False
)
db.add(segment)
await db.commit()
# 发送转写结果
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": transcript},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# Agent 生成回应
agent = manager.conversation_agents.get(conversation_id)
if agent:
# 检测对话阶段
detected_stage = agent.detect_stage(conversation_id, transcript)
if detected_stage != current_stage:
current_stage = detected_stage
conversation.conversation_stage = current_stage.value
await db.commit()
# 获取已聊话题
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 生成回应
response = agent.generate_response(
conversation_id=conversation_id, conversation_id=conversation_id,
user_message=transcript, transcript_text=text_message,
current_stage=current_stage, processed=False
covered_topics=covered_topics
) )
db.add(segment)
# 更新段落的 Agent 回应
segment.agent_response = response
await db.commit() await db.commit()
# 发送 Agent 回应 # Agent 生成回应
await manager.send_message(conversation_id, { current_stage = await process_user_message(
"type": MessageType.AGENT_RESPONSE, conversation_id=conversation_id,
"conversation_id": conversation_id, user_message=text_message,
"data": {"text": response}, current_stage=current_stage,
"timestamp": datetime.now(timezone.utc).isoformat() conversation=conversation,
}) segment=segment,
db=db,
# 调用 TTS 服务生成音频 manager=manager
tts_audio = await tts_service.synthesize(response) )
# 发送 TTS 音频
await manager.send_message(conversation_id, {
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {"audio_base64": tts_audio},
"timestamp": datetime.now(timezone.utc).isoformat()
})
elif msg_type == MessageType.END_CONVERSATION: elif msg_type == MessageType.END_CONVERSATION:
# 结束对话 # 结束对话
@@ -212,11 +195,76 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
"timestamp": datetime.now(timezone.utc).isoformat() "timestamp": datetime.now(timezone.utc).isoformat()
}) })
except WebSocketDisconnect: except WebSocketDisconnect:
manager.disconnect(conversation_id) manager.disconnect(conversation_id)
except Exception: break
manager.disconnect(conversation_id) except Exception as e:
raise manager.disconnect(conversation_id)
raise
async def process_user_message(
conversation_id: str,
user_message: str,
current_stage: ConversationStage,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
) -> ConversationStage:
"""
处理用户消息生成Agent回应
Args:
conversation_id: 对话ID
user_message: 用户消息文本
current_stage: 当前对话阶段
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
"""
agent = manager.conversation_agents.get(conversation_id)
if agent:
# 检测对话阶段
detected_stage = agent.detect_stage(conversation_id, user_message)
if detected_stage != current_stage:
current_stage = detected_stage
conversation.conversation_stage = current_stage.value
await db.commit()
# 获取已聊话题
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 生成回应
response = agent.generate_response(
conversation_id=conversation_id,
user_message=user_message,
current_stage=current_stage,
covered_topics=covered_topics
)
# 更新段落的 Agent 回应
segment.agent_response = response
await db.commit()
# 发送 Agent 回应(仅文字,不生成语音)
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response},
"timestamp": datetime.now(timezone.utc).isoformat()
})
return current_stage
async def process_conversation_segments(conversation_id: str, db: AsyncSession): async def process_conversation_segments(conversation_id: str, db: AsyncSession):
@@ -249,15 +297,18 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
chapters_data = memory_agent.process_segments(segments_data) chapters_data = memory_agent.process_segments(segments_data)
# 保存章节到数据库 # 保存章节到数据库
from database import Chapter from database.models import Chapter as ChapterModel
conversation = await db.get(Conversation, conversation_id) conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
for category, chapter_data in chapters_data.items(): for category, chapter_data in chapters_data.items():
chapter = Chapter( chapter = ChapterModel(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
user_id=conversation.user_id, user_id=conversation.user_id,
title=chapter_data["title"], title=chapter_data.get("title", f"章节-{category}"),
content=chapter_data["content"], content=chapter_data.get("content", ""),
order_index=chapter_data.get("order_index", 999), order_index=chapter_data.get("order_index", 999),
status="completed", status="completed",
category=category, category=category,