""" 对话相关 API 路由 """ from datetime import datetime, timezone from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Body from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select import uuid from database import get_async_db, Conversation, Segment, User from database.models import Conversation as ConversationModel, Segment as SegmentModel from middleware.auth import get_current_user from database.models import User as UserModel router = APIRouter(prefix="/api/conversations", tags=["conversations"]) @router.get("") async def get_conversations( current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """获取当前用户的所有对话列表(需要认证)""" stmt = select(ConversationModel).where( ConversationModel.user_id == current_user.id ).order_by(ConversationModel.started_at.desc()) result = await db.execute(stmt) conversations = result.scalars().all() # 转换为列表项格式 from services.redis_service import redis_service conversation_list = [] for conv in conversations: # 从Redis获取最新消息预览 latest_message = None try: history = await redis_service.get_conversation_history(conv.id) if history: latest_message = history[-1].get("content", "")[:50] # 取前50个字符 except: pass conversation_list.append({ "id": conv.id, "title": conv.summary[:30] if conv.summary else "岁月知己", # 使用summary作为标题,如果没有则使用默认标题 "avatarUrl": None, "latestMessagePreview": latest_message or conv.summary, "latestMessageTime": int(conv.started_at.timestamp() * 1000) if conv.started_at else int(datetime.now(timezone.utc).timestamp() * 1000), "unreadCount": 0, "isDefaultAssistant": conv.summary is None # 如果没有summary,则认为是默认助手 }) return conversation_list @router.post("") async def create_conversation( current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """创建新对话(需要认证)。对话轮数在每次发送消息时校验。""" conversation = ConversationModel( id=str(uuid.uuid4()), user_id=current_user.id, started_at=datetime.now(timezone.utc), status="active" ) db.add(conversation) await db.commit() await db.refresh(conversation) return { "id": conversation.id, "user_id": conversation.user_id, "started_at": conversation.started_at.isoformat(), "status": conversation.status } @router.get("/{conversation_id}") async def get_conversation( conversation_id: str, current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """获取对话详情(需要认证,只能访问自己的对话)""" conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") # 验证用户权限 if conversation.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此对话") return { "id": conversation.id, "user_id": conversation.user_id, "started_at": conversation.started_at.isoformat(), "ended_at": conversation.ended_at.isoformat() if conversation.ended_at else None, "duration_seconds": conversation.duration_seconds, "summary": conversation.summary, "status": conversation.status, "current_topic": conversation.current_topic, "conversation_stage": conversation.conversation_stage } @router.post("/{conversation_id}/end") async def end_conversation( conversation_id: str, current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """结束对话(需要认证,只能结束自己的对话)""" conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") # 验证用户权限 if conversation.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此对话") conversation.status = "ended" conversation.ended_at = datetime.now(timezone.utc) if conversation.started_at: duration = (conversation.ended_at - conversation.started_at).total_seconds() conversation.duration_seconds = int(duration) await db.commit() return { "id": conversation.id, "status": conversation.status, "ended_at": conversation.ended_at.isoformat(), "duration_seconds": conversation.duration_seconds } @router.delete("/{conversation_id}") async def delete_conversation( conversation_id: str, current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """删除对话(需要认证,只能删除自己的对话)""" conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") # 验证用户权限 if conversation.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权删除此对话") # 删除Redis中的对话历史 from services.redis_service import redis_service try: await redis_service.clear_conversation_history(conversation_id) except: pass # 删除数据库中的对话(级联删除segments) await db.delete(conversation) await db.commit() return {"message": "对话已删除"} @router.get("/{conversation_id}/messages") async def get_messages( conversation_id: str, current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """获取对话的消息列表(需要认证,只能访问自己的对话)""" # 验证对话存在且属于当前用户 conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") if conversation.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权访问此对话") # 从Redis获取消息历史 from services.redis_service import redis_service try: history = await redis_service.get_conversation_history(conversation_id) messages = [] for idx, msg in enumerate(history): messages.append({ "id": f"{conversation_id}_msg_{idx}", "conversationId": conversation_id, "content": msg.get("content", ""), "senderType": "user" if msg.get("role") == "human" else "assistant", "timestamp": int(datetime.now(timezone.utc).timestamp() * 1000), # Redis中没有时间戳,使用当前时间 "messageType": "text" }) return messages except Exception as e: # 如果Redis中没有数据,返回空列表 return [] @router.post("/{conversation_id}/organize") async def organize_conversation( conversation_id: str, current_user: UserModel = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ 整理对话内容成章节(需要认证,只能整理自己的对话) 手动触发对话整理,将对话中的内容整理成回忆录章节 """ import logging logger = logging.getLogger(__name__) # 验证对话存在且属于当前用户 conversation = await db.get(ConversationModel, conversation_id) if not conversation: raise HTTPException(status_code=404, detail="Conversation not found") if conversation.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权操作此对话") # 获取所有未处理的段落 stmt = select(SegmentModel).where( SegmentModel.conversation_id == conversation_id, SegmentModel.processed == False ) result = await db.execute(stmt) segments = result.scalars().all() if not segments: # 如果没有未处理的段落,尝试处理所有段落 stmt = select(SegmentModel).where( SegmentModel.conversation_id == conversation_id ) result = await db.execute(stmt) segments = result.scalars().all() if not segments: raise HTTPException(status_code=400, detail="该对话没有可整理的内容") # 免费版仅允许 1 个章节整理,Pro/Pro+ 无限制 from routers.quota import get_chapter_count, check_can_submit_organize chapter_count = await get_chapter_count(current_user.id, db) can_submit, quota_message = check_can_submit_organize( current_user.subscription_type, chapter_count ) if not can_submit: raise HTTPException(status_code=403, detail=quota_message) # 提交到Celery任务处理 try: from routers.websocket import manager from tasks.memoir_tasks import process_memoir_segments segment_ids = [seg.id for seg in segments] process_memoir_segments.delay(conversation.user_id, segment_ids) logger.info(f"手动触发对话整理: conversation_id={conversation_id}, segments={len(segment_ids)}") return { "message": "对话整理任务已提交", "conversation_id": conversation_id, "segments_count": len(segment_ids) } except Exception as e: logger.error(f"提交整理任务失败: {e}") raise HTTPException(status_code=500, detail=f"提交整理任务失败: {str(e)}")