添加API路由模块

This commit is contained in:
iammm0
2026-01-07 11:56:40 +08:00
parent 56dffc300b
commit 9ca3a3a89a
10 changed files with 496 additions and 0 deletions

4
api/routers/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
"""
路由模块
"""

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

61
api/routers/books.py Normal file
View File

@@ -0,0 +1,61 @@
"""
回忆录相关 API 路由
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db
from database.models import Book as BookModel
from services.pdf_service import pdf_service
router = APIRouter(prefix="/api/books", tags=["books"])
@router.get("/current")
async def get_current_book(
user_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""获取当前回忆录"""
stmt = select(BookModel).where(BookModel.user_id == user_id).order_by(BookModel.updated_at.desc())
result = await db.execute(stmt)
book = result.scalar_one_or_none()
if not book:
return {"message": "No book found"}
return {
"id": book.id,
"title": book.title,
"total_pages": book.total_pages,
"total_words": book.total_words,
"cover_image_url": book.cover_image_url
}
@router.post("/export-pdf")
async def export_pdf(
book_id: str,
user_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""导出 PDF"""
book = await db.get(BookModel, book_id)
if not book or book.user_id != user_id:
raise HTTPException(status_code=404, detail="Book not found")
# 获取所有章节
from ..database.models import Chapter
stmt = select(Chapter).where(Chapter.user_id == user_id).order_by(Chapter.order_index)
result = await db.execute(stmt)
chapters = result.scalars().all()
# 生成 PDF
pdf_bytes = await pdf_service.generate_pdf(book, chapters)
return {
"pdf_base64": pdf_bytes.decode('latin1'), # 简化处理,实际应该用 base64
"filename": f"{book.title}.pdf"
}

69
api/routers/chapters.py Normal file
View File

@@ -0,0 +1,69 @@
"""
章节相关 API 路由
"""
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database import get_async_db
from database.models import Chapter as ChapterModel
router = APIRouter(prefix="/api/chapters", tags=["chapters"])
@router.get("", response_model=List[dict])
async def get_chapters(
user_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""获取用户所有章节"""
stmt = select(ChapterModel).where(ChapterModel.user_id == user_id).order_by(ChapterModel.order_index)
result = await db.execute(stmt)
chapters = result.scalars().all()
return [
{
"id": ch.id,
"title": ch.title,
"content": ch.content,
"order_index": ch.order_index,
"status": ch.status,
"category": ch.category,
"images": ch.images or []
}
for ch in chapters
]
@router.get("/{chapter_id}", response_model=dict)
async def get_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""获取章节详情"""
chapter = await db.get(ChapterModel, chapter_id)
if not chapter:
raise HTTPException(status_code=404, detail="Chapter not found")
return {
"id": chapter.id,
"title": chapter.title,
"content": chapter.content,
"order_index": chapter.order_index,
"status": chapter.status,
"category": chapter.category,
"images": chapter.images or []
}
@router.post("/{chapter_id}/regenerate")
async def regenerate_chapter(
chapter_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""重新整理章节"""
# TODO: 实现重新整理逻辑
return {"status": "ok", "message": "Chapter regeneration triggered"}

View File

@@ -0,0 +1,89 @@
"""
对话相关 API 路由
"""
from datetime import datetime, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import uuid
from database import get_async_db, Conversation, Segment, User
from database.models import Conversation as ConversationModel, Segment as SegmentModel
router = APIRouter(prefix="/api/conversations", tags=["conversations"])
@router.post("")
async def create_conversation(
user_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""创建新对话"""
conversation = ConversationModel(
id=str(uuid.uuid4()),
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
await db.refresh(conversation)
return {
"id": conversation.id,
"user_id": conversation.user_id,
"started_at": conversation.started_at.isoformat(),
"status": conversation.status
}
@router.get("/{conversation_id}")
async def get_conversation(
conversation_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""获取对话详情"""
conversation = await db.get(ConversationModel, conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
return {
"id": conversation.id,
"user_id": conversation.user_id,
"started_at": conversation.started_at.isoformat(),
"ended_at": conversation.ended_at.isoformat() if conversation.ended_at else None,
"duration_seconds": conversation.duration_seconds,
"summary": conversation.summary,
"status": conversation.status,
"current_topic": conversation.current_topic,
"conversation_stage": conversation.conversation_stage
}
@router.post("/{conversation_id}/end")
async def end_conversation(
conversation_id: str,
db: AsyncSession = Depends(get_async_db)
):
"""结束对话"""
conversation = await db.get(ConversationModel, conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
if conversation.started_at:
duration = (conversation.ended_at - conversation.started_at).total_seconds()
conversation.duration_seconds = int(duration)
await db.commit()
return {
"id": conversation.id,
"status": conversation.status,
"ended_at": conversation.ended_at.isoformat(),
"duration_seconds": conversation.duration_seconds
}

273
api/routers/websocket.py Normal file
View File

@@ -0,0 +1,273 @@
"""
WebSocket 路由:实时对话通信
"""
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
from agents.prompts import ConversationStage
from database import get_async_db
from database.models import Conversation, Segment
from services.asr_service import asr_service
from services.tts_service import tts_service
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk"
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
ERROR = "error"
# 连接管理
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.conversation_agents: Dict[str, ConversationAgent] = {}
self.memory_agent = MemoryAgent()
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
await websocket.accept()
self.active_connections[conversation_id] = websocket
self.conversation_agents[conversation_id] = ConversationAgent()
def disconnect(self, conversation_id: str):
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
if conversation_id in self.conversation_agents:
self.conversation_agents[conversation_id].clear_memory(conversation_id)
del self.conversation_agents[conversation_id]
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
await websocket.send_json(message)
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
return await websocket.receive_json()
raise HTTPException(status_code=404, detail="Connection not found")
manager = ConnectionManager()
async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
await manager.connect(websocket, conversation_id)
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
async for db in get_async_db():
conversation = await db.get(Conversation, conversation_id)
if not conversation:
# 如果对话不存在,创建新对话
from database.models import User as UserModel
# 假设用户 ID 从连接参数获取(实际应该从认证获取)
user_id = "default_user" # TODO: 从认证获取实际用户 ID
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD
# 主循环:处理消息
while True:
try:
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.AUDIO_CHUNK:
# 处理音频块
audio_data = message.get("data", {}).get("audio_base64", "")
# 调用 ASR 服务转文字
transcript = await asr_service.transcribe(audio_data)
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript,
processed=False
)
db.add(segment)
await db.commit()
# 发送转写结果
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": transcript},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# Agent 生成回应
agent = manager.conversation_agents.get(conversation_id)
if agent:
# 检测对话阶段
detected_stage = agent.detect_stage(conversation_id, transcript)
if detected_stage != current_stage:
current_stage = detected_stage
conversation.conversation_stage = current_stage.value
await db.commit()
# 获取已聊话题
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 生成回应
response = agent.generate_response(
conversation_id=conversation_id,
user_message=transcript,
current_stage=current_stage,
covered_topics=covered_topics
)
# 更新段落的 Agent 回应
segment.agent_response = response
await db.commit()
# 发送 Agent 回应
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 调用 TTS 服务生成音频
tts_audio = await tts_service.synthesize(response)
# 发送 TTS 音频
await manager.send_message(conversation_id, {
"type": MessageType.TTS_AUDIO,
"conversation_id": conversation_id,
"data": {"audio_base64": tts_audio},
"timestamp": datetime.now(timezone.utc).isoformat()
})
elif msg_type == MessageType.END_CONVERSATION:
# 结束对话
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
# 触发整理 Agent
await process_conversation_segments(conversation_id, db)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
break
except Exception as e:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except WebSocketDisconnect:
manager.disconnect(conversation_id)
except Exception:
manager.disconnect(conversation_id)
raise
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
处理对话段落,生成章节
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
return
# 准备段落数据
segments_data = [
{"transcript_text": seg.transcript_text}
for seg in segments
]
# 调用整理 Agent
memory_agent = manager.memory_agent
chapters_data = memory_agent.process_segments(segments_data)
# 保存章节到数据库
from database import Chapter
conversation = await db.get(Conversation, conversation_id)
for category, chapter_data in chapters_data.items():
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=conversation.user_id,
title=chapter_data["title"],
content=chapter_data["content"],
order_index=chapter_data.get("order_index", 999),
status="completed",
category=category,
images=chapter_data.get("image_suggestions", [])
)
db.add(chapter)
# 标记段落为已处理
for seg in segments:
seg.processed = True
await db.commit()