diff --git a/api/routers/__init__.py b/api/routers/__init__.py new file mode 100644 index 0000000..f40c65b --- /dev/null +++ b/api/routers/__init__.py @@ -0,0 +1,4 @@ +""" +路由模块 +""" + diff --git a/api/routers/__pycache__/__init__.cpython-312.pyc b/api/routers/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000..b5a3faf Binary files /dev/null and b/api/routers/__pycache__/__init__.cpython-312.pyc differ diff --git a/api/routers/__pycache__/books.cpython-312.pyc b/api/routers/__pycache__/books.cpython-312.pyc new file mode 100644 index 0000000..0d1fd3f Binary files /dev/null and b/api/routers/__pycache__/books.cpython-312.pyc differ diff --git a/api/routers/__pycache__/chapters.cpython-312.pyc b/api/routers/__pycache__/chapters.cpython-312.pyc new file mode 100644 index 0000000..0facbd6 Binary files /dev/null and b/api/routers/__pycache__/chapters.cpython-312.pyc differ diff --git a/api/routers/__pycache__/conversations.cpython-312.pyc b/api/routers/__pycache__/conversations.cpython-312.pyc new file mode 100644 index 0000000..7776e6c Binary files /dev/null and b/api/routers/__pycache__/conversations.cpython-312.pyc differ diff --git a/api/routers/__pycache__/websocket.cpython-312.pyc b/api/routers/__pycache__/websocket.cpython-312.pyc new file mode 100644 index 0000000..072e954 Binary files /dev/null and b/api/routers/__pycache__/websocket.cpython-312.pyc differ diff --git a/api/routers/books.py b/api/routers/books.py new file mode 100644 index 0000000..3dbd144 --- /dev/null +++ b/api/routers/books.py @@ -0,0 +1,61 @@ +""" +回忆录相关 API 路由 +""" +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database import get_async_db +from database.models import Book as BookModel +from services.pdf_service import pdf_service + +router = APIRouter(prefix="/api/books", tags=["books"]) + + +@router.get("/current") +async def get_current_book( + user_id: str, + db: AsyncSession = Depends(get_async_db) +): + """获取当前回忆录""" + stmt = select(BookModel).where(BookModel.user_id == user_id).order_by(BookModel.updated_at.desc()) + result = await db.execute(stmt) + book = result.scalar_one_or_none() + + if not book: + return {"message": "No book found"} + + return { + "id": book.id, + "title": book.title, + "total_pages": book.total_pages, + "total_words": book.total_words, + "cover_image_url": book.cover_image_url + } + + +@router.post("/export-pdf") +async def export_pdf( + book_id: str, + user_id: str, + db: AsyncSession = Depends(get_async_db) +): + """导出 PDF""" + book = await db.get(BookModel, book_id) + if not book or book.user_id != user_id: + raise HTTPException(status_code=404, detail="Book not found") + + # 获取所有章节 + from ..database.models import Chapter + stmt = select(Chapter).where(Chapter.user_id == user_id).order_by(Chapter.order_index) + result = await db.execute(stmt) + chapters = result.scalars().all() + + # 生成 PDF + pdf_bytes = await pdf_service.generate_pdf(book, chapters) + + return { + "pdf_base64": pdf_bytes.decode('latin1'), # 简化处理,实际应该用 base64 + "filename": f"{book.title}.pdf" + } + diff --git a/api/routers/chapters.py b/api/routers/chapters.py new file mode 100644 index 0000000..742c051 --- /dev/null +++ b/api/routers/chapters.py @@ -0,0 +1,69 @@ +""" +章节相关 API 路由 +""" +from typing import List + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database import get_async_db +from database.models import Chapter as ChapterModel + +router = APIRouter(prefix="/api/chapters", tags=["chapters"]) + + +@router.get("", response_model=List[dict]) +async def get_chapters( + user_id: str, + db: AsyncSession = Depends(get_async_db) +): + """获取用户所有章节""" + stmt = select(ChapterModel).where(ChapterModel.user_id == user_id).order_by(ChapterModel.order_index) + result = await db.execute(stmt) + chapters = result.scalars().all() + + return [ + { + "id": ch.id, + "title": ch.title, + "content": ch.content, + "order_index": ch.order_index, + "status": ch.status, + "category": ch.category, + "images": ch.images or [] + } + for ch in chapters + ] + + +@router.get("/{chapter_id}", response_model=dict) +async def get_chapter( + chapter_id: str, + db: AsyncSession = Depends(get_async_db) +): + """获取章节详情""" + chapter = await db.get(ChapterModel, chapter_id) + if not chapter: + raise HTTPException(status_code=404, detail="Chapter not found") + + return { + "id": chapter.id, + "title": chapter.title, + "content": chapter.content, + "order_index": chapter.order_index, + "status": chapter.status, + "category": chapter.category, + "images": chapter.images or [] + } + + +@router.post("/{chapter_id}/regenerate") +async def regenerate_chapter( + chapter_id: str, + db: AsyncSession = Depends(get_async_db) +): + """重新整理章节""" + # TODO: 实现重新整理逻辑 + return {"status": "ok", "message": "Chapter regeneration triggered"} + diff --git a/api/routers/conversations.py b/api/routers/conversations.py new file mode 100644 index 0000000..e4f0d0c --- /dev/null +++ b/api/routers/conversations.py @@ -0,0 +1,89 @@ +""" +对话相关 API 路由 +""" +from datetime import datetime, timezone +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select +import uuid + +from database import get_async_db, Conversation, Segment, User +from database.models import Conversation as ConversationModel, Segment as SegmentModel + +router = APIRouter(prefix="/api/conversations", tags=["conversations"]) + + +@router.post("") +async def create_conversation( + user_id: str, + db: AsyncSession = Depends(get_async_db) +): + """创建新对话""" + conversation = ConversationModel( + id=str(uuid.uuid4()), + user_id=user_id, + started_at=datetime.now(timezone.utc), + status="active" + ) + db.add(conversation) + await db.commit() + await db.refresh(conversation) + + return { + "id": conversation.id, + "user_id": conversation.user_id, + "started_at": conversation.started_at.isoformat(), + "status": conversation.status + } + + +@router.get("/{conversation_id}") +async def get_conversation( + conversation_id: str, + db: AsyncSession = Depends(get_async_db) +): + """获取对话详情""" + conversation = await db.get(ConversationModel, conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + + return { + "id": conversation.id, + "user_id": conversation.user_id, + "started_at": conversation.started_at.isoformat(), + "ended_at": conversation.ended_at.isoformat() if conversation.ended_at else None, + "duration_seconds": conversation.duration_seconds, + "summary": conversation.summary, + "status": conversation.status, + "current_topic": conversation.current_topic, + "conversation_stage": conversation.conversation_stage + } + + +@router.post("/{conversation_id}/end") +async def end_conversation( + conversation_id: str, + db: AsyncSession = Depends(get_async_db) +): + """结束对话""" + conversation = await db.get(ConversationModel, conversation_id) + if not conversation: + raise HTTPException(status_code=404, detail="Conversation not found") + + conversation.status = "ended" + conversation.ended_at = datetime.now(timezone.utc) + + if conversation.started_at: + duration = (conversation.ended_at - conversation.started_at).total_seconds() + conversation.duration_seconds = int(duration) + + await db.commit() + + return { + "id": conversation.id, + "status": conversation.status, + "ended_at": conversation.ended_at.isoformat(), + "duration_seconds": conversation.duration_seconds + } + diff --git a/api/routers/websocket.py b/api/routers/websocket.py new file mode 100644 index 0000000..897790e --- /dev/null +++ b/api/routers/websocket.py @@ -0,0 +1,273 @@ +""" +WebSocket 路由:实时对话通信 +""" +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Dict + +from fastapi import WebSocket, WebSocketDisconnect, HTTPException +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from agents import ConversationAgent, MemoryAgent +from agents.prompts import ConversationStage +from database import get_async_db +from database.models import Conversation, Segment +from services.asr_service import asr_service +from services.tts_service import tts_service + + +class MessageType(str, Enum): + """WebSocket 消息类型""" + CONNECT = "connect" + AUDIO_CHUNK = "audio_chunk" + TRANSCRIPT = "transcript" + AGENT_RESPONSE = "agent_response" + TTS_AUDIO = "tts_audio" + END_CONVERSATION = "end_conversation" + ERROR = "error" + + +# 连接管理 +class ConnectionManager: + """WebSocket 连接管理器""" + + def __init__(self): + self.active_connections: Dict[str, WebSocket] = {} + self.conversation_agents: Dict[str, ConversationAgent] = {} + self.memory_agent = MemoryAgent() + + async def connect(self, websocket: WebSocket, conversation_id: str): + """建立连接""" + await websocket.accept() + self.active_connections[conversation_id] = websocket + self.conversation_agents[conversation_id] = ConversationAgent() + + def disconnect(self, conversation_id: str): + """断开连接""" + if conversation_id in self.active_connections: + del self.active_connections[conversation_id] + if conversation_id in self.conversation_agents: + self.conversation_agents[conversation_id].clear_memory(conversation_id) + del self.conversation_agents[conversation_id] + + async def send_message(self, conversation_id: str, message: dict): + """发送消息""" + if conversation_id in self.active_connections: + websocket = self.active_connections[conversation_id] + await websocket.send_json(message) + + async def receive_message(self, conversation_id: str) -> dict: + """接收消息""" + if conversation_id in self.active_connections: + websocket = self.active_connections[conversation_id] + return await websocket.receive_json() + raise HTTPException(status_code=404, detail="Connection not found") + + +manager = ConnectionManager() + + +async def websocket_endpoint(websocket: WebSocket, conversation_id: str): + """ + WebSocket 端点:处理实时对话 + + Args: + websocket: WebSocket 连接 + conversation_id: 对话 ID + """ + await manager.connect(websocket, conversation_id) + + try: + # 发送连接确认 + await manager.send_message(conversation_id, { + "type": MessageType.CONNECT, + "conversation_id": conversation_id, + "data": {"status": "connected"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + # 从数据库获取对话信息 + async for db in get_async_db(): + conversation = await db.get(Conversation, conversation_id) + if not conversation: + # 如果对话不存在,创建新对话 + from database.models import User as UserModel + # 假设用户 ID 从连接参数获取(实际应该从认证获取) + user_id = "default_user" # TODO: 从认证获取实际用户 ID + + conversation = Conversation( + id=conversation_id, + user_id=user_id, + started_at=datetime.now(timezone.utc), + status="active" + ) + db.add(conversation) + await db.commit() + + current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD + + # 主循环:处理消息 + while True: + try: + message = await websocket.receive_json() + msg_type = message.get("type") + + if msg_type == MessageType.AUDIO_CHUNK: + # 处理音频块 + audio_data = message.get("data", {}).get("audio_base64", "") + + # 调用 ASR 服务转文字 + transcript = await asr_service.transcribe(audio_data) + + # 保存段落到数据库 + segment = Segment( + id=str(uuid.uuid4()), + conversation_id=conversation_id, + transcript_text=transcript, + processed=False + ) + db.add(segment) + await db.commit() + + # 发送转写结果 + await manager.send_message(conversation_id, { + "type": MessageType.TRANSCRIPT, + "conversation_id": conversation_id, + "data": {"text": transcript}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + # Agent 生成回应 + agent = manager.conversation_agents.get(conversation_id) + if agent: + # 检测对话阶段 + detected_stage = agent.detect_stage(conversation_id, transcript) + if detected_stage != current_stage: + current_stage = detected_stage + conversation.conversation_stage = current_stage.value + await db.commit() + + # 获取已聊话题 + stmt_segments = select(Segment).where( + Segment.conversation_id == conversation_id + ).order_by(Segment.created_at) + result_segments = await db.execute(stmt_segments) + previous_segments = result_segments.scalars().all() + covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] + + # 生成回应 + response = agent.generate_response( + conversation_id=conversation_id, + user_message=transcript, + current_stage=current_stage, + covered_topics=covered_topics + ) + + # 更新段落的 Agent 回应 + segment.agent_response = response + await db.commit() + + # 发送 Agent 回应 + await manager.send_message(conversation_id, { + "type": MessageType.AGENT_RESPONSE, + "conversation_id": conversation_id, + "data": {"text": response}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + # 调用 TTS 服务生成音频 + tts_audio = await tts_service.synthesize(response) + + # 发送 TTS 音频 + await manager.send_message(conversation_id, { + "type": MessageType.TTS_AUDIO, + "conversation_id": conversation_id, + "data": {"audio_base64": tts_audio}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + elif msg_type == MessageType.END_CONVERSATION: + # 结束对话 + conversation.status = "ended" + conversation.ended_at = datetime.now(timezone.utc) + await db.commit() + + # 触发整理 Agent + await process_conversation_segments(conversation_id, db) + + await manager.send_message(conversation_id, { + "type": MessageType.END_CONVERSATION, + "conversation_id": conversation_id, + "data": {"status": "ended"}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + break + + except Exception as e: + await manager.send_message(conversation_id, { + "type": MessageType.ERROR, + "data": {"message": str(e)}, + "timestamp": datetime.now(timezone.utc).isoformat() + }) + + except WebSocketDisconnect: + manager.disconnect(conversation_id) + except Exception: + manager.disconnect(conversation_id) + raise + + +async def process_conversation_segments(conversation_id: str, db: AsyncSession): + """ + 处理对话段落,生成章节 + + Args: + conversation_id: 对话 ID + db: 数据库会话 + """ + # 获取所有未处理的段落 + stmt = select(Segment).where( + Segment.conversation_id == conversation_id, + Segment.processed == False + ) + result = await db.execute(stmt) + segments = result.scalars().all() + + if not segments: + return + + # 准备段落数据 + segments_data = [ + {"transcript_text": seg.transcript_text} + for seg in segments + ] + + # 调用整理 Agent + memory_agent = manager.memory_agent + chapters_data = memory_agent.process_segments(segments_data) + + # 保存章节到数据库 + from database import Chapter + conversation = await db.get(Conversation, conversation_id) + + for category, chapter_data in chapters_data.items(): + chapter = Chapter( + id=str(uuid.uuid4()), + user_id=conversation.user_id, + title=chapter_data["title"], + content=chapter_data["content"], + order_index=chapter_data.get("order_index", 999), + status="completed", + category=category, + images=chapter_data.get("image_suggestions", []) + ) + db.add(chapter) + + # 标记段落为已处理 + for seg in segments: + seg.processed = True + + await db.commit() +