""" WebSocket 路由:实时对话通信 """ import uuid from datetime import datetime, timezone from enum import Enum from typing import Dict from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from agents import ConversationAgent, MemoryAgent from agents.prompts import ConversationStage from database import get_async_db from database.models import Conversation, Segment from database.models import User as UserModel from services.auth_service import verify_token from fastapi import HTTPException, status class MessageType(str, Enum): """WebSocket 消息类型""" CONNECT = "connect" AUDIO_CHUNK = "audio_chunk" TEXT = "text" # 文本消息 TRANSCRIPT = "transcript" AGENT_RESPONSE = "agent_response" TTS_AUDIO = "tts_audio" END_CONVERSATION = "end_conversation" ERROR = "error" # 连接管理 class ConnectionManager: """WebSocket 连接管理器""" def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.conversation_agents: Dict[str, ConversationAgent] = {} self.memory_agent = MemoryAgent() async def connect(self, websocket: WebSocket, conversation_id: str): """建立连接""" await websocket.accept() self.active_connections[conversation_id] = websocket self.conversation_agents[conversation_id] = ConversationAgent() def disconnect(self, conversation_id: str): """断开连接""" if conversation_id in self.active_connections: del self.active_connections[conversation_id] if conversation_id in self.conversation_agents: self.conversation_agents[conversation_id].clear_memory(conversation_id) del self.conversation_agents[conversation_id] async def send_message(self, conversation_id: str, message: dict): """发送消息""" if conversation_id in self.active_connections: websocket = self.active_connections[conversation_id] await websocket.send_json(message) async def receive_message(self, conversation_id: str) -> dict: """接收消息""" if conversation_id in self.active_connections: websocket = self.active_connections[conversation_id] return await websocket.receive_json() raise HTTPException(status_code=404, detail="Connection not found") manager = ConnectionManager() async def websocket_endpoint( websocket: WebSocket, conversation_id: str, token: str = Query(..., description="访问令牌") ): """ WebSocket 端点:处理实时对话 Args: websocket: WebSocket 连接 conversation_id: 对话 ID token: 访问令牌(从查询参数获取) """ # 验证JWT令牌 payload = verify_token(token) if not payload or payload.get("type") != "access": await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌") return user_id = payload.get("sub") if not user_id: await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容") return # 验证用户是否存在 async for db in get_async_db(): user = await db.get(UserModel, user_id) if not user: await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在") return await manager.connect(websocket, conversation_id) try: # 发送连接确认 await manager.send_message(conversation_id, { "type": MessageType.CONNECT, "conversation_id": conversation_id, "data": {"status": "connected"}, "timestamp": datetime.now(timezone.utc).isoformat() }) # 从数据库获取对话信息 conversation = await db.get(Conversation, conversation_id) if not conversation: # 如果对话不存在,创建新对话 conversation = Conversation( id=conversation_id, user_id=user_id, started_at=datetime.now(timezone.utc), status="active" ) db.add(conversation) await db.commit() else: # 验证用户权限:只能访问自己的对话 if conversation.user_id != user_id: await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": {"message": "无权访问此对话"}, "timestamp": datetime.now(timezone.utc).isoformat() }) await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话") return current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD # 主循环:处理消息 while True: try: message = await websocket.receive_json() msg_type = message.get("type") if msg_type == MessageType.TEXT: # 处理文本消息 text_message = message.get("data", {}).get("text", "") if text_message: # 保存段落到数据库 segment = Segment( id=str(uuid.uuid4()), conversation_id=conversation_id, transcript_text=text_message, processed=False ) db.add(segment) await db.commit() # Agent 生成回应 current_stage = await process_user_message( conversation_id=conversation_id, user_message=text_message, current_stage=current_stage, conversation=conversation, segment=segment, db=db, manager=manager ) elif msg_type == MessageType.END_CONVERSATION: # 结束对话 conversation.status = "ended" conversation.ended_at = datetime.now(timezone.utc) await db.commit() # 触发整理 Agent await process_conversation_segments(conversation_id, db) await manager.send_message(conversation_id, { "type": MessageType.END_CONVERSATION, "conversation_id": conversation_id, "data": {"status": "ended"}, "timestamp": datetime.now(timezone.utc).isoformat() }) break except Exception as e: await manager.send_message(conversation_id, { "type": MessageType.ERROR, "data": {"message": str(e)}, "timestamp": datetime.now(timezone.utc).isoformat() }) except WebSocketDisconnect: manager.disconnect(conversation_id) break except Exception as e: manager.disconnect(conversation_id) raise async def process_user_message( conversation_id: str, user_message: str, current_stage: ConversationStage, conversation: Conversation, segment: Segment, db: AsyncSession, manager: ConnectionManager ) -> ConversationStage: """ 处理用户消息,生成Agent回应 Args: conversation_id: 对话ID user_message: 用户消息文本 current_stage: 当前对话阶段 conversation: 对话对象 segment: 段落对象 db: 数据库会话 manager: 连接管理器 Returns: 更新后的对话阶段 """ agent = manager.conversation_agents.get(conversation_id) if agent: # 检测对话阶段 detected_stage = agent.detect_stage(conversation_id, user_message) if detected_stage != current_stage: current_stage = detected_stage conversation.conversation_stage = current_stage.value await db.commit() # 获取已聊话题 stmt_segments = select(Segment).where( Segment.conversation_id == conversation_id ).order_by(Segment.created_at) result_segments = await db.execute(stmt_segments) previous_segments = result_segments.scalars().all() covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category] # 生成回应 response = agent.generate_response( conversation_id=conversation_id, user_message=user_message, current_stage=current_stage, covered_topics=covered_topics ) # 更新段落的 Agent 回应 segment.agent_response = response await db.commit() # 发送 Agent 回应(仅文字,不生成语音) await manager.send_message(conversation_id, { "type": MessageType.AGENT_RESPONSE, "conversation_id": conversation_id, "data": {"text": response}, "timestamp": datetime.now(timezone.utc).isoformat() }) return current_stage async def process_conversation_segments(conversation_id: str, db: AsyncSession): """ 处理对话段落,生成章节 Args: conversation_id: 对话 ID db: 数据库会话 """ # 获取所有未处理的段落 stmt = select(Segment).where( Segment.conversation_id == conversation_id, Segment.processed == False ) result = await db.execute(stmt) segments = result.scalars().all() if not segments: return # 准备段落数据 segments_data = [ {"transcript_text": seg.transcript_text} for seg in segments ] # 调用整理 Agent memory_agent = manager.memory_agent chapters_data = memory_agent.process_segments(segments_data) # 保存章节到数据库 from database.models import Chapter as ChapterModel conversation = await db.get(Conversation, conversation_id) if not conversation: return for category, chapter_data in chapters_data.items(): chapter = ChapterModel( id=str(uuid.uuid4()), user_id=conversation.user_id, title=chapter_data.get("title", f"章节-{category}"), content=chapter_data.get("content", ""), order_index=chapter_data.get("order_index", 999), status="completed", category=category, images=chapter_data.get("image_suggestions", []) ) db.add(chapter) # 标记段落为已处理 for seg in segments: seg.processed = True await db.commit()