diff --git a/api/routers/books.py b/api/routers/books.py index 3dbd144..875ee4c 100644 --- a/api/routers/books.py +++ b/api/routers/books.py @@ -1,24 +1,27 @@ """ 回忆录相关 API 路由 """ -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query, Body +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_async_db from database.models import Book as BookModel +from database.models import User as UserModel from services.pdf_service import pdf_service +from middleware.auth import get_current_user router = APIRouter(prefix="/api/books", tags=["books"]) @router.get("/current") async def get_current_book( - user_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """获取当前回忆录""" - stmt = select(BookModel).where(BookModel.user_id == user_id).order_by(BookModel.updated_at.desc()) + """获取当前回忆录(需要认证)""" + stmt = select(BookModel).where(BookModel.user_id == current_user.id).order_by(BookModel.updated_at.desc()) result = await db.execute(stmt) book = result.scalar_one_or_none() @@ -34,20 +37,28 @@ async def get_current_book( } +class ExportPdfRequest(BaseModel): + book_id: str + + @router.post("/export-pdf") async def export_pdf( - book_id: str, - user_id: str, + request: ExportPdfRequest = Body(...), + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """导出 PDF""" - book = await db.get(BookModel, book_id) - if not book or book.user_id != user_id: + """导出 PDF(需要认证,只能导出自己的回忆录)""" + book = await db.get(BookModel, request.book_id) + if not book: raise HTTPException(status_code=404, detail="Book not found") + # 验证用户权限 + if book.user_id != current_user.id: + raise HTTPException(status_code=403, detail="无权导出此回忆录") + # 获取所有章节 - from ..database.models import Chapter - stmt = select(Chapter).where(Chapter.user_id == user_id).order_by(Chapter.order_index) + from database.models import Chapter + stmt = select(Chapter).where(Chapter.user_id == current_user.id).order_by(Chapter.order_index) result = await db.execute(stmt) chapters = result.scalars().all() diff --git a/api/routers/chapters.py b/api/routers/chapters.py index 742c051..59d04e1 100644 --- a/api/routers/chapters.py +++ b/api/routers/chapters.py @@ -1,25 +1,27 @@ """ 章节相关 API 路由 """ -from typing import List +from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database import get_async_db from database.models import Chapter as ChapterModel +from database.models import User as UserModel +from middleware.auth import get_current_user router = APIRouter(prefix="/api/chapters", tags=["chapters"]) @router.get("", response_model=List[dict]) async def get_chapters( - user_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """获取用户所有章节""" - stmt = select(ChapterModel).where(ChapterModel.user_id == user_id).order_by(ChapterModel.order_index) + """获取用户所有章节(需要认证)""" + stmt = select(ChapterModel).where(ChapterModel.user_id == current_user.id).order_by(ChapterModel.order_index) result = await db.execute(stmt) chapters = result.scalars().all() @@ -40,13 +42,18 @@ async def get_chapters( @router.get("/{chapter_id}", response_model=dict) async def get_chapter( chapter_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """获取章节详情""" + """获取章节详情(需要认证,只能访问自己的章节)""" chapter = await db.get(ChapterModel, chapter_id) if not chapter: raise HTTPException(status_code=404, detail="Chapter not found") + # 验证用户权限 + if chapter.user_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此章节") + return { "id": chapter.id, "title": chapter.title, @@ -61,9 +68,18 @@ async def get_chapter( @router.post("/{chapter_id}/regenerate") async def regenerate_chapter( chapter_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """重新整理章节""" + """重新整理章节(需要认证,只能操作自己的章节)""" + chapter = await db.get(ChapterModel, chapter_id) + if not chapter: + raise HTTPException(status_code=404, detail="Chapter not found") + + # 验证用户权限 + if chapter.user_id != current_user.id: + raise HTTPException(status_code=403, detail="无权操作此章节") + # TODO: 实现重新整理逻辑 return {"status": "ok", "message": "Chapter regeneration triggered"} diff --git a/api/routers/conversations.py b/api/routers/conversations.py index e4f0d0c..c333a9d 100644 --- a/api/routers/conversations.py +++ b/api/routers/conversations.py @@ -3,26 +3,29 @@ """ from datetime import datetime, timezone from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query, Body +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select import uuid from database import get_async_db, Conversation, Segment, User from database.models import Conversation as ConversationModel, Segment as SegmentModel +from middleware.auth import get_current_user +from database.models import User as UserModel router = APIRouter(prefix="/api/conversations", tags=["conversations"]) @router.post("") async def create_conversation( - user_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """创建新对话""" + """创建新对话(需要认证)""" conversation = ConversationModel( id=str(uuid.uuid4()), - user_id=user_id, + user_id=current_user.id, started_at=datetime.now(timezone.utc), status="active" ) @@ -41,13 +44,18 @@ async def create_conversation( @router.get("/{conversation_id}") async def get_conversation( conversation_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """获取对话详情""" + """获取对话详情(需要认证,只能访问自己的对话)""" conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") + # 验证用户权限 + if conversation.user_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此对话") + return { "id": conversation.id, "user_id": conversation.user_id, @@ -64,13 +72,18 @@ async def get_conversation( @router.post("/{conversation_id}/end") async def end_conversation( conversation_id: str, + current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): - """结束对话""" + """结束对话(需要认证,只能结束自己的对话)""" conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") + # 验证用户权限 + if conversation.user_id != current_user.id: + raise HTTPException(status_code=403, detail="无权操作此对话") + conversation.status = "ended" conversation.ended_at = datetime.now(timezone.utc) diff --git a/api/routers/websocket.py b/api/routers/websocket.py index 897790e..9b06c92 100644 --- a/api/routers/websocket.py +++ b/api/routers/websocket.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone from enum import Enum from typing import Dict -from fastapi import WebSocket, WebSocketDisconnect, HTTPException +from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -14,14 +14,16 @@ from agents import ConversationAgent, MemoryAgent from agents.prompts import ConversationStage from database import get_async_db from database.models import Conversation, Segment -from services.asr_service import asr_service -from services.tts_service import tts_service +from database.models import User as UserModel +from services.auth_service import verify_token +from fastapi import HTTPException, status class MessageType(str, Enum): """WebSocket 消息类型""" CONNECT = "connect" AUDIO_CHUNK = "audio_chunk" + TEXT = "text" # 文本消息 TRANSCRIPT = "transcript" AGENT_RESPONSE = "agent_response" TTS_AUDIO = "tts_audio" @@ -69,34 +71,52 @@ class ConnectionManager: manager = ConnectionManager() -async def websocket_endpoint(websocket: WebSocket, conversation_id: str): +async def websocket_endpoint( + websocket: WebSocket, + conversation_id: str, + token: str = Query(..., description="访问令牌") +): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID + token: 访问令牌(从查询参数获取) """ - await manager.connect(websocket, conversation_id) + # 验证JWT令牌 + payload = verify_token(token) + if not payload or payload.get("type") != "access": + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌") + return - try: - # 发送连接确认 - await manager.send_message(conversation_id, { - "type": MessageType.CONNECT, - "conversation_id": conversation_id, - "data": {"status": "connected"}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) + user_id = payload.get("sub") + if not user_id: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容") + return + + # 验证用户是否存在 + async for db in get_async_db(): + user = await db.get(UserModel, user_id) + if not user: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在") + return - # 从数据库获取对话信息 - async for db in get_async_db(): + await manager.connect(websocket, conversation_id) + + try: + # 发送连接确认 + await manager.send_message(conversation_id, { + "type": MessageType.CONNECT, + "conversation_id": conversation_id, + "data": {"status": "connected"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + # 从数据库获取对话信息 conversation = await db.get(Conversation, conversation_id) if not conversation: # 如果对话不存在,创建新对话 - from database.models import User as UserModel - # 假设用户 ID 从连接参数获取(实际应该从认证获取) - user_id = "default_user" # TODO: 从认证获取实际用户 ID - conversation = Conversation( id=conversation_id, user_id=user_id, @@ -105,6 +125,17 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str): ) db.add(conversation) await db.commit() + else: + # 验证用户权限:只能访问自己的对话 + if conversation.user_id != user_id: + await manager.send_message(conversation_id, { + "type": MessageType.ERROR, + "data": {"message": "无权访问此对话"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话") + return + current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD @@ -114,79 +145,31 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str): message = await websocket.receive_json() msg_type = message.get("type") - if msg_type == MessageType.AUDIO_CHUNK: - # 处理音频块 - audio_data = message.get("data", {}).get("audio_base64", "") + if msg_type == MessageType.TEXT: + # 处理文本消息 + text_message = message.get("data", {}).get("text", "") - # 调用 ASR 服务转文字 - transcript = await asr_service.transcribe(audio_data) - - # 保存段落到数据库 - segment = Segment( - id=str(uuid.uuid4()), - conversation_id=conversation_id, - transcript_text=transcript, - processed=False - ) - db.add(segment) - await db.commit() - - # 发送转写结果 - await manager.send_message(conversation_id, { - "type": MessageType.TRANSCRIPT, - "conversation_id": conversation_id, - "data": {"text": transcript}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - # Agent 生成回应 - agent = manager.conversation_agents.get(conversation_id) - if agent: - # 检测对话阶段 - detected_stage = agent.detect_stage(conversation_id, transcript) - if detected_stage != current_stage: - current_stage = detected_stage - conversation.conversation_stage = current_stage.value - await db.commit() - - # 获取已聊话题 - stmt_segments = select(Segment).where( - Segment.conversation_id == conversation_id - ).order_by(Segment.created_at) - result_segments = await db.execute(stmt_segments) - previous_segments = result_segments.scalars().all() - covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] - - # 生成回应 - response = agent.generate_response( + if text_message: + # 保存段落到数据库 + segment = Segment( + id=str(uuid.uuid4()), conversation_id=conversation_id, - user_message=transcript, - current_stage=current_stage, - covered_topics=covered_topics + transcript_text=text_message, + processed=False ) - - # 更新段落的 Agent 回应 - segment.agent_response = response + db.add(segment) await db.commit() - # 发送 Agent 回应 - await manager.send_message(conversation_id, { - "type": MessageType.AGENT_RESPONSE, - "conversation_id": conversation_id, - "data": {"text": response}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) - - # 调用 TTS 服务生成音频 - tts_audio = await tts_service.synthesize(response) - - # 发送 TTS 音频 - await manager.send_message(conversation_id, { - "type": MessageType.TTS_AUDIO, - "conversation_id": conversation_id, - "data": {"audio_base64": tts_audio}, - "timestamp": datetime.now(timezone.utc).isoformat() - }) + # Agent 生成回应 + current_stage = await process_user_message( + conversation_id=conversation_id, + user_message=text_message, + current_stage=current_stage, + conversation=conversation, + segment=segment, + db=db, + manager=manager + ) elif msg_type == MessageType.END_CONVERSATION: # 结束对话 @@ -212,11 +195,76 @@ async def websocket_endpoint(websocket: WebSocket, conversation_id: str): "timestamp": datetime.now(timezone.utc).isoformat() }) - except WebSocketDisconnect: - manager.disconnect(conversation_id) - except Exception: - manager.disconnect(conversation_id) - raise + except WebSocketDisconnect: + manager.disconnect(conversation_id) + break + except Exception as e: + manager.disconnect(conversation_id) + raise + + +async def process_user_message( + conversation_id: str, + user_message: str, + current_stage: ConversationStage, + conversation: Conversation, + segment: Segment, + db: AsyncSession, + manager: ConnectionManager +) -> ConversationStage: + """ + 处理用户消息,生成Agent回应 + + Args: + conversation_id: 对话ID + user_message: 用户消息文本 + current_stage: 当前对话阶段 + conversation: 对话对象 + segment: 段落对象 + db: 数据库会话 + manager: 连接管理器 + + Returns: + 更新后的对话阶段 + """ + agent = manager.conversation_agents.get(conversation_id) + if agent: + # 检测对话阶段 + detected_stage = agent.detect_stage(conversation_id, user_message) + if detected_stage != current_stage: + current_stage = detected_stage + conversation.conversation_stage = current_stage.value + await db.commit() + + # 获取已聊话题 + stmt_segments = select(Segment).where( + Segment.conversation_id == conversation_id + ).order_by(Segment.created_at) + result_segments = await db.execute(stmt_segments) + previous_segments = result_segments.scalars().all() + covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] + + # 生成回应 + response = agent.generate_response( + conversation_id=conversation_id, + user_message=user_message, + current_stage=current_stage, + covered_topics=covered_topics + ) + + # 更新段落的 Agent 回应 + segment.agent_response = response + await db.commit() + + # 发送 Agent 回应(仅文字,不生成语音) + await manager.send_message(conversation_id, { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": {"text": response}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + return current_stage async def process_conversation_segments(conversation_id: str, db: AsyncSession): @@ -249,15 +297,18 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession): chapters_data = memory_agent.process_segments(segments_data) # 保存章节到数据库 - from database import Chapter + from database.models import Chapter as ChapterModel conversation = await db.get(Conversation, conversation_id) + if not conversation: + return + for category, chapter_data in chapters_data.items(): - chapter = Chapter( + chapter = ChapterModel( id=str(uuid.uuid4()), user_id=conversation.user_id, - title=chapter_data["title"], - content=chapter_data["content"], + title=chapter_data.get("title", f"章节-{category}"), + content=chapter_data.get("content", ""), order_index=chapter_data.get("order_index", 999), status="completed", category=category,