Files
life-echo/api/routers/websocket.py

325 lines
12 KiB
Python
Raw Normal View History

2026-01-07 11:56:40 +08:00
"""
WebSocket 路由实时对话通信
"""
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
2026-01-07 11:56:40 +08:00
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
from agents.prompts import ConversationStage
from database import get_async_db
from database.models import Conversation, Segment
from database.models import User as UserModel
from services.auth_service import verify_token
from fastapi import HTTPException, status
2026-01-07 11:56:40 +08:00
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk"
TEXT = "text" # 文本消息
2026-01-07 11:56:40 +08:00
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
ERROR = "error"
# 连接管理
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.conversation_agents: Dict[str, ConversationAgent] = {}
self.memory_agent = MemoryAgent()
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
await websocket.accept()
self.active_connections[conversation_id] = websocket
self.conversation_agents[conversation_id] = ConversationAgent()
def disconnect(self, conversation_id: str):
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
if conversation_id in self.conversation_agents:
self.conversation_agents[conversation_id].clear_memory(conversation_id)
del self.conversation_agents[conversation_id]
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
await websocket.send_json(message)
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
return await websocket.receive_json()
raise HTTPException(status_code=404, detail="Connection not found")
manager = ConnectionManager()
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
token: str = Query(..., description="访问令牌")
):
2026-01-07 11:56:40 +08:00
"""
WebSocket 端点处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
token: 访问令牌从查询参数获取
2026-01-07 11:56:40 +08:00
"""
# 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
return
2026-01-07 11:56:40 +08:00
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
return
# 验证用户是否存在
async for db in get_async_db():
user = await db.get(UserModel, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
return
2026-01-07 11:56:40 +08:00
await manager.connect(websocket, conversation_id)
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
2026-01-07 11:56:40 +08:00
conversation = await db.get(Conversation, conversation_id)
if not conversation:
# 如果对话不存在,创建新对话
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
2026-01-07 11:56:40 +08:00
current_stage = ConversationStage(conversation.conversation_stage) if conversation.conversation_stage else ConversationStage.CHILDHOOD
# 主循环:处理消息
while True:
try:
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.TEXT:
# 处理文本消息
text_message = message.get("data", {}).get("text", "")
2026-01-07 11:56:40 +08:00
if text_message:
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
2026-01-07 11:56:40 +08:00
conversation_id=conversation_id,
transcript_text=text_message,
processed=False
2026-01-07 11:56:40 +08:00
)
db.add(segment)
2026-01-07 11:56:40 +08:00
await db.commit()
# Agent 生成回应
current_stage = await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
current_stage=current_stage,
conversation=conversation,
segment=segment,
db=db,
manager=manager
)
2026-01-07 11:56:40 +08:00
elif msg_type == MessageType.END_CONVERSATION:
# 结束对话
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
# 触发整理 Agent
await process_conversation_segments(conversation_id, db)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
break
except Exception as e:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except WebSocketDisconnect:
manager.disconnect(conversation_id)
break
except Exception as e:
manager.disconnect(conversation_id)
raise
async def process_user_message(
conversation_id: str,
user_message: str,
current_stage: ConversationStage,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
) -> ConversationStage:
"""
处理用户消息生成Agent回应
Args:
conversation_id: 对话ID
user_message: 用户消息文本
current_stage: 当前对话阶段
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
"""
agent = manager.conversation_agents.get(conversation_id)
if agent:
# 检测对话阶段
detected_stage = agent.detect_stage(conversation_id, user_message)
if detected_stage != current_stage:
current_stage = detected_stage
conversation.conversation_stage = current_stage.value
await db.commit()
# 获取已聊话题
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 生成回应
response = agent.generate_response(
conversation_id=conversation_id,
user_message=user_message,
current_stage=current_stage,
covered_topics=covered_topics
)
# 更新段落的 Agent 回应
segment.agent_response = response
await db.commit()
# 发送 Agent 回应(仅文字,不生成语音)
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response},
"timestamp": datetime.now(timezone.utc).isoformat()
})
return current_stage
2026-01-07 11:56:40 +08:00
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
处理对话段落生成章节
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
return
# 准备段落数据
segments_data = [
{"transcript_text": seg.transcript_text}
for seg in segments
]
# 调用整理 Agent
memory_agent = manager.memory_agent
chapters_data = memory_agent.process_segments(segments_data)
# 保存章节到数据库
from database.models import Chapter as ChapterModel
2026-01-07 11:56:40 +08:00
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
2026-01-07 11:56:40 +08:00
for category, chapter_data in chapters_data.items():
chapter = ChapterModel(
2026-01-07 11:56:40 +08:00
id=str(uuid.uuid4()),
user_id=conversation.user_id,
title=chapter_data.get("title", f"章节-{category}"),
content=chapter_data.get("content", ""),
2026-01-07 11:56:40 +08:00
order_index=chapter_data.get("order_index", 999),
status="completed",
category=category,
images=chapter_data.get("image_suggestions", [])
)
db.add(chapter)
# 标记段落为已处理
for seg in segments:
seg.processed = True
await db.commit()