- 更新books路由以支持用户认证 - 更新chapters路由以支持用户认证 - 更新conversations路由以支持用户认证 - 更新websocket路由以支持用户认证和连接管理
103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
"""
|
|
对话相关 API 路由
|
|
"""
|
|
from datetime import datetime, timezone
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
import uuid
|
|
|
|
from database import get_async_db, Conversation, Segment, User
|
|
from database.models import Conversation as ConversationModel, Segment as SegmentModel
|
|
from middleware.auth import get_current_user
|
|
from database.models import User as UserModel
|
|
|
|
router = APIRouter(prefix="/api/conversations", tags=["conversations"])
|
|
|
|
|
|
@router.post("")
|
|
async def create_conversation(
|
|
current_user: UserModel = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_async_db)
|
|
):
|
|
"""创建新对话(需要认证)"""
|
|
conversation = ConversationModel(
|
|
id=str(uuid.uuid4()),
|
|
user_id=current_user.id,
|
|
started_at=datetime.now(timezone.utc),
|
|
status="active"
|
|
)
|
|
db.add(conversation)
|
|
await db.commit()
|
|
await db.refresh(conversation)
|
|
|
|
return {
|
|
"id": conversation.id,
|
|
"user_id": conversation.user_id,
|
|
"started_at": conversation.started_at.isoformat(),
|
|
"status": conversation.status
|
|
}
|
|
|
|
|
|
@router.get("/{conversation_id}")
|
|
async def get_conversation(
|
|
conversation_id: str,
|
|
current_user: UserModel = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_async_db)
|
|
):
|
|
"""获取对话详情(需要认证,只能访问自己的对话)"""
|
|
conversation = await db.get(ConversationModel, conversation_id)
|
|
if not conversation:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
|
|
# 验证用户权限
|
|
if conversation.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="无权访问此对话")
|
|
|
|
return {
|
|
"id": conversation.id,
|
|
"user_id": conversation.user_id,
|
|
"started_at": conversation.started_at.isoformat(),
|
|
"ended_at": conversation.ended_at.isoformat() if conversation.ended_at else None,
|
|
"duration_seconds": conversation.duration_seconds,
|
|
"summary": conversation.summary,
|
|
"status": conversation.status,
|
|
"current_topic": conversation.current_topic,
|
|
"conversation_stage": conversation.conversation_stage
|
|
}
|
|
|
|
|
|
@router.post("/{conversation_id}/end")
|
|
async def end_conversation(
|
|
conversation_id: str,
|
|
current_user: UserModel = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_async_db)
|
|
):
|
|
"""结束对话(需要认证,只能结束自己的对话)"""
|
|
conversation = await db.get(ConversationModel, conversation_id)
|
|
if not conversation:
|
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
|
|
# 验证用户权限
|
|
if conversation.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="无权操作此对话")
|
|
|
|
conversation.status = "ended"
|
|
conversation.ended_at = datetime.now(timezone.utc)
|
|
|
|
if conversation.started_at:
|
|
duration = (conversation.ended_at - conversation.started_at).total_seconds()
|
|
conversation.duration_seconds = int(duration)
|
|
|
|
await db.commit()
|
|
|
|
return {
|
|
"id": conversation.id,
|
|
"status": conversation.status,
|
|
"ended_at": conversation.ended_at.isoformat(),
|
|
"duration_seconds": conversation.duration_seconds
|
|
}
|
|
|