refactor: 更新API路由
- 更新books路由以支持用户认证 - 更新chapters路由以支持用户认证 - 更新conversations路由以支持用户认证 - 更新websocket路由以支持用户认证和连接管理
This commit is contained in:
@@ -1,24 +1,27 @@
|
||||
"""
|
||||
回忆录相关 API 路由
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_async_db
|
||||
from database.models import Book as BookModel
|
||||
from database.models import User as UserModel
|
||||
from services.pdf_service import pdf_service
|
||||
from middleware.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/books", tags=["books"])
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_book(
|
||||
user_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""获取当前回忆录"""
|
||||
stmt = select(BookModel).where(BookModel.user_id == user_id).order_by(BookModel.updated_at.desc())
|
||||
"""获取当前回忆录(需要认证)"""
|
||||
stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc())
|
||||
result = await db.execute(stmt)
|
||||
book = result.scalar_one_or_none()
|
||||
|
||||
@@ -34,20 +37,28 @@ async def get_current_book(
|
||||
}
|
||||
|
||||
|
||||
class ExportPdfRequest(BaseModel):
|
||||
book_id: str
|
||||
|
||||
|
||||
@router.post("/export-pdf")
|
||||
async def export_pdf(
|
||||
book_id: str,
|
||||
user_id: str,
|
||||
request: ExportPdfRequest = Body(...),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""导出 PDF"""
|
||||
book = await db.get(BookModel, book_id)
|
||||
if not book or book.user_id != user_id:
|
||||
"""导出 PDF(需要认证,只能导出自己的回忆录)"""
|
||||
book = await db.get(BookModel, request.book_id)
|
||||
if not book:
|
||||
raise HTTPException(status_code=404, detail="Book not found")
|
||||
|
||||
# 验证用户权限
|
||||
if book.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权导出此回忆录")
|
||||
|
||||
# 获取所有章节
|
||||
from ..database.models import Chapter
|
||||
stmt = select(Chapter).where(Chapter.user_id == user_id).order_by(Chapter.order_index)
|
||||
from database.models import Chapter
|
||||
stmt = select(Chapter).where(Chapter.user_id == current_user.id).order_by(Chapter.order_index)
|
||||
result = await db.execute(stmt)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
"""
|
||||
章节相关 API 路由
|
||||
"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database import get_async_db
|
||||
from database.models import Chapter as ChapterModel
|
||||
from database.models import User as UserModel
|
||||
from middleware.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/api/chapters", tags=["chapters"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
async def get_chapters(
|
||||
user_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""获取用户所有章节"""
|
||||
stmt = select(ChapterModel).where(ChapterModel.user_id == user_id).order_by(ChapterModel.order_index)
|
||||
"""获取用户所有章节(需要认证)"""
|
||||
stmt = select(ChapterModel).where(ChapterModel.user_id == current_user.id).order_by(ChapterModel.order_index)
|
||||
result = await db.execute(stmt)
|
||||
chapters = result.scalars().all()
|
||||
|
||||
@@ -40,13 +42,18 @@ async def get_chapters(
|
||||
@router.get("/{chapter_id}", response_model=dict)
|
||||
async def get_chapter(
|
||||
chapter_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""获取章节详情"""
|
||||
"""获取章节详情(需要认证,只能访问自己的章节)"""
|
||||
chapter = await db.get(ChapterModel, chapter_id)
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="Chapter not found")
|
||||
|
||||
# 验证用户权限
|
||||
if chapter.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此章节")
|
||||
|
||||
return {
|
||||
"id": chapter.id,
|
||||
"title": chapter.title,
|
||||
@@ -61,9 +68,18 @@ async def get_chapter(
|
||||
@router.post("/{chapter_id}/regenerate")
|
||||
async def regenerate_chapter(
|
||||
chapter_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""重新整理章节"""
|
||||
"""重新整理章节(需要认证,只能操作自己的章节)"""
|
||||
chapter = await db.get(ChapterModel, chapter_id)
|
||||
if not chapter:
|
||||
raise HTTPException(status_code=404, detail="Chapter not found")
|
||||
|
||||
# 验证用户权限
|
||||
if chapter.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此章节")
|
||||
|
||||
# TODO: 实现重新整理逻辑
|
||||
return {"status": "ok", "message": "Chapter regeneration triggered"}
|
||||
|
||||
|
||||
@@ -3,26 +3,29 @@
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import uuid
|
||||
|
||||
from database import get_async_db, Conversation, Segment, User
|
||||
from database.models import Conversation as ConversationModel, Segment as SegmentModel
|
||||
from middleware.auth import get_current_user
|
||||
from database.models import User as UserModel
|
||||
|
||||
router = APIRouter(prefix="/api/conversations", tags=["conversations"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_conversation(
|
||||
user_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""创建新对话"""
|
||||
"""创建新对话(需要认证)"""
|
||||
conversation = ConversationModel(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
user_id=current_user.id,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
status="active"
|
||||
)
|
||||
@@ -41,13 +44,18 @@ async def create_conversation(
|
||||
@router.get("/{conversation_id}")
|
||||
async def get_conversation(
|
||||
conversation_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""获取对话详情"""
|
||||
"""获取对话详情(需要认证,只能访问自己的对话)"""
|
||||
conversation = await db.get(ConversationModel, conversation_id)
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
# 验证用户权限
|
||||
if conversation.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权访问此对话")
|
||||
|
||||
return {
|
||||
"id": conversation.id,
|
||||
"user_id": conversation.user_id,
|
||||
@@ -64,13 +72,18 @@ async def get_conversation(
|
||||
@router.post("/{conversation_id}/end")
|
||||
async def end_conversation(
|
||||
conversation_id: str,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_async_db)
|
||||
):
|
||||
"""结束对话"""
|
||||
"""结束对话(需要认证,只能结束自己的对话)"""
|
||||
conversation = await db.get(ConversationModel, conversation_id)
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||
|
||||
# 验证用户权限
|
||||
if conversation.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对话")
|
||||
|
||||
conversation.status = "ended"
|
||||
conversation.ended_at = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
|
||||
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -14,14 +14,16 @@ from agents import ConversationAgent, MemoryAgent
|
||||
from agents.prompts import ConversationStage
|
||||
from database import get_async_db
|
||||
from database.models import Conversation, Segment
|
||||
from services.asr_service import asr_service
|
||||
from services.tts_service import tts_service
|
||||
from database.models import User as UserModel
|
||||
from services.auth_service import verify_token
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
CONNECT = "connect"
|
||||
AUDIO_CHUNK = "audio_chunk"
|
||||
TEXT = "text" # 文本消息
|
||||
TRANSCRIPT = "transcript"
|
||||
AGENT_RESPONSE = "agent_response"
|
||||
TTS_AUDIO = "tts_audio"
|
||||
@@ -69,14 +71,37 @@ class ConnectionManager:
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
|
||||
async def websocket_endpoint(
|
||||
websocket: WebSocket,
|
||||
conversation_id: str,
|
||||
token: str = Query(..., description="访问令牌")
|
||||
):
|
||||
"""
|
||||
WebSocket 端点:处理实时对话
|
||||
|
||||
Args:
|
||||
websocket: WebSocket 连接
|
||||
conversation_id: 对话 ID
|
||||
token: 访问令牌(从查询参数获取)
|
||||
"""
|
||||
# 验证JWT令牌
|
||||
payload = verify_token(token)
|
||||
if not payload or payload.get("type") != "access":
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
|
||||
return
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
|
||||
return
|
||||
|
||||
# 验证用户是否存在
|
||||
async for db in get_async_db():
|
||||
user = await db.get(UserModel, user_id)
|
||||
if not user:
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
|
||||
return
|
||||
|
||||
await manager.connect(websocket, conversation_id)
|
||||
|
||||
try:
|
||||
@@ -89,14 +114,9 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
|
||||
})
|
||||
|
||||
# 从数据库获取对话信息
|
||||
async for db in get_async_db():
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
if not conversation:
|
||||
# 如果对话不存在,创建新对话
|
||||
from database.models import User as UserModel
|
||||
# 假设用户 ID 从连接参数获取(实际应该从认证获取)
|
||||
user_id = "default_user" # TODO: 从认证获取实际用户 ID
|
||||
|
||||
conversation = Conversation(
|
||||
id=conversation_id,
|
||||
user_id=user_id,
|
||||
@@ -105,6 +125,17 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
|
||||
)
|
||||
db.add(conversation)
|
||||
await db.commit()
|
||||
else:
|
||||
# 验证用户权限:只能访问自己的对话
|
||||
if conversation.user_id != user_id:
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": "无权访问此对话"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
||||
return
|
||||
|
||||
|
||||
current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD
|
||||
|
||||
@@ -114,80 +145,32 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
|
||||
message = await websocket.receive_json()
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == MessageType.AUDIO_CHUNK:
|
||||
# 处理音频块
|
||||
audio_data = message.get("data", {}).get("audio_base64", "")
|
||||
|
||||
# 调用 ASR 服务转文字
|
||||
transcript = await asr_service.transcribe(audio_data)
|
||||
if msg_type == MessageType.TEXT:
|
||||
# 处理文本消息
|
||||
text_message = message.get("data", {}).get("text", "")
|
||||
|
||||
if text_message:
|
||||
# 保存段落到数据库
|
||||
segment = Segment(
|
||||
id=str(uuid.uuid4()),
|
||||
conversation_id=conversation_id,
|
||||
transcript_text=transcript,
|
||||
transcript_text=text_message,
|
||||
processed=False
|
||||
)
|
||||
db.add(segment)
|
||||
await db.commit()
|
||||
|
||||
# 发送转写结果
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TRANSCRIPT,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": transcript},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# Agent 生成回应
|
||||
agent = manager.conversation_agents.get(conversation_id)
|
||||
if agent:
|
||||
# 检测对话阶段
|
||||
detected_stage = agent.detect_stage(conversation_id, transcript)
|
||||
if detected_stage != current_stage:
|
||||
current_stage = detected_stage
|
||||
conversation.conversation_stage = current_stage.value
|
||||
await db.commit()
|
||||
|
||||
# 获取已聊话题
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
result_segments = await db.execute(stmt_segments)
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
# 生成回应
|
||||
response = agent.generate_response(
|
||||
current_stage = await process_user_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=transcript,
|
||||
user_message=text_message,
|
||||
current_stage=current_stage,
|
||||
covered_topics=covered_topics
|
||||
conversation=conversation,
|
||||
segment=segment,
|
||||
db=db,
|
||||
manager=manager
|
||||
)
|
||||
|
||||
# 更新段落的 Agent 回应
|
||||
segment.agent_response = response
|
||||
await db.commit()
|
||||
|
||||
# 发送 Agent 回应
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
# 调用 TTS 服务生成音频
|
||||
tts_audio = await tts_service.synthesize(response)
|
||||
|
||||
# 发送 TTS 音频
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.TTS_AUDIO,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"audio_base64": tts_audio},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
elif msg_type == MessageType.END_CONVERSATION:
|
||||
# 结束对话
|
||||
conversation.status = "ended"
|
||||
@@ -214,11 +197,76 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(conversation_id)
|
||||
except Exception:
|
||||
break
|
||||
except Exception as e:
|
||||
manager.disconnect(conversation_id)
|
||||
raise
|
||||
|
||||
|
||||
async def process_user_message(
|
||||
conversation_id: str,
|
||||
user_message: str,
|
||||
current_stage: ConversationStage,
|
||||
conversation: Conversation,
|
||||
segment: Segment,
|
||||
db: AsyncSession,
|
||||
manager: ConnectionManager
|
||||
) -> ConversationStage:
|
||||
"""
|
||||
处理用户消息,生成Agent回应
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
user_message: 用户消息文本
|
||||
current_stage: 当前对话阶段
|
||||
conversation: 对话对象
|
||||
segment: 段落对象
|
||||
db: 数据库会话
|
||||
manager: 连接管理器
|
||||
|
||||
Returns:
|
||||
更新后的对话阶段
|
||||
"""
|
||||
agent = manager.conversation_agents.get(conversation_id)
|
||||
if agent:
|
||||
# 检测对话阶段
|
||||
detected_stage = agent.detect_stage(conversation_id, user_message)
|
||||
if detected_stage != current_stage:
|
||||
current_stage = detected_stage
|
||||
conversation.conversation_stage = current_stage.value
|
||||
await db.commit()
|
||||
|
||||
# 获取已聊话题
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
result_segments = await db.execute(stmt_segments)
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
# 生成回应
|
||||
response = agent.generate_response(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
current_stage=current_stage,
|
||||
covered_topics=covered_topics
|
||||
)
|
||||
|
||||
# 更新段落的 Agent 回应
|
||||
segment.agent_response = response
|
||||
await db.commit()
|
||||
|
||||
# 发送 Agent 回应(仅文字,不生成语音)
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
"conversation_id": conversation_id,
|
||||
"data": {"text": response},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
return current_stage
|
||||
|
||||
|
||||
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
"""
|
||||
处理对话段落,生成章节
|
||||
@@ -249,15 +297,18 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
chapters_data = memory_agent.process_segments(segments_data)
|
||||
|
||||
# 保存章节到数据库
|
||||
from database import Chapter
|
||||
from database.models import Chapter as ChapterModel
|
||||
conversation = await db.get(Conversation, conversation_id)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
for category, chapter_data in chapters_data.items():
|
||||
chapter = Chapter(
|
||||
chapter = ChapterModel(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=conversation.user_id,
|
||||
title=chapter_data["title"],
|
||||
content=chapter_data["content"],
|
||||
title=chapter_data.get("title", f"章节-{category}"),
|
||||
content=chapter_data.get("content", ""),
|
||||
order_index=chapter_data.get("order_index", 999),
|
||||
status="completed",
|
||||
category=category,
|
||||
|
||||
Reference in New Issue
Block a user