Files
life-echo/api/routers/websocket.py

377 lines
15 KiB
Python
Raw Normal View History

2026-01-07 11:56:40 +08:00
"""
WebSocket 路由实时对话通信
支持异步 Agent 调用和 Redis 会话存储
2026-01-07 11:56:40 +08:00
"""
import logging
2026-01-07 11:56:40 +08:00
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
2026-01-07 11:56:40 +08:00
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
2026-01-21 22:31:03 +01:00
from agents.memoir_processor import BackgroundTaskRunner
2026-01-07 11:56:40 +08:00
from database import get_async_db
from database.models import Conversation, Segment
from database.models import User as UserModel
from services.auth_service import verify_token
2026-01-21 22:31:03 +01:00
from services.memoir_state_service import get_or_create_state
from fastapi import HTTPException, status
2026-01-07 11:56:40 +08:00
logger = logging.getLogger(__name__)
2026-01-07 11:56:40 +08:00
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk"
TEXT = "text" # 文本消息
2026-01-07 11:56:40 +08:00
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
2026-01-21 22:31:03 +01:00
MEMOIR_UPDATE = "memoir_update"
2026-01-07 11:56:40 +08:00
ERROR = "error"
# 连接管理
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
# ConversationAgent 现在是无状态的(会话存储在 Redis可以复用
self.conversation_agent = ConversationAgent()
2026-01-07 11:56:40 +08:00
self.memory_agent = MemoryAgent()
2026-01-21 22:31:03 +01:00
self.background_runner = BackgroundTaskRunner()
2026-01-07 11:56:40 +08:00
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
await websocket.accept()
self.active_connections[conversation_id] = websocket
async def disconnect(self, conversation_id: str):
2026-01-07 11:56:40 +08:00
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
# await self.conversation_agent.clear_memory(conversation_id)
2026-01-07 11:56:40 +08:00
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
try:
# 尝试发送消息,如果连接已关闭会抛出异常
await websocket.send_json(message)
except (RuntimeError, Exception) as e:
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
# 如果发送失败,从连接列表中移除
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
2026-01-07 11:56:40 +08:00
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
return await websocket.receive_json()
raise HTTPException(status_code=404, detail="Connection not found")
manager = ConnectionManager()
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str
):
2026-01-07 11:56:40 +08:00
"""
WebSocket 端点处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
# 从查询参数获取token
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
return
# 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
return
2026-01-07 11:56:40 +08:00
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
return
# 验证用户是否存在
async for db in get_async_db():
user = await db.get(UserModel, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
return
2026-01-07 11:56:40 +08:00
await manager.connect(websocket, conversation_id)
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
2026-01-07 11:56:40 +08:00
conversation = await db.get(Conversation, conversation_id)
if not conversation:
# 如果对话不存在,创建新对话
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception:
pass # 如果发送失败,直接关闭连接
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
2026-01-07 11:56:40 +08:00
# 主循环:处理消息
while True:
try:
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.TEXT:
# 处理文本消息
text_message = message.get("data", {}).get("text", "")
2026-01-07 11:56:40 +08:00
if text_message:
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
2026-01-07 11:56:40 +08:00
conversation_id=conversation_id,
transcript_text=text_message,
processed=False
2026-01-07 11:56:40 +08:00
)
db.add(segment)
2026-01-07 11:56:40 +08:00
await db.commit()
2026-01-21 22:31:03 +01:00
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
2026-01-07 11:56:40 +08:00
# Agent 生成回应
2026-01-21 22:31:03 +01:00
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
manager=manager
)
2026-01-07 11:56:40 +08:00
elif msg_type == MessageType.END_CONVERSATION:
# 结束对话
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
# 触发整理 Agent
await process_conversation_segments(conversation_id, db)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
break
except RuntimeError as e:
# 检查是否是断开连接相关的错误
error_msg = str(e)
if "disconnect" in error_msg.lower() or "Cannot call \"receive\"" in error_msg:
logger.info(f"WebSocket 连接已断开: conversation_id={conversation_id}")
break
else:
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
break
2026-01-07 11:56:40 +08:00
except Exception as e:
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
2026-01-07 11:56:40 +08:00
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
await manager.disconnect(conversation_id)
except Exception as e:
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
await manager.disconnect(conversation_id)
finally:
# 确保清理连接
await manager.disconnect(conversation_id)
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
2026-01-21 22:31:03 +01:00
) -> None:
"""
处理用户消息生成Agent回应异步版本
Args:
conversation_id: 对话ID
user_message: 用户消息文本
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
"""
import asyncio as _asyncio
agent = manager.conversation_agent
state = await get_or_create_state(conversation.user_id, db)
2026-01-21 22:31:03 +01:00
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
2026-01-21 22:31:03 +01:00
# 获取已聊话题(保留老逻辑用于提示)
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
try:
# 异步生成回应(可能是多条消息)
responses = await agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
2026-01-21 22:31:03 +01:00
memoir_state=state
)
2026-01-21 22:31:03 +01:00
# 更新段落的 Agent 回应(存储完整内容)
segment.agent_response = "\n\n".join(responses)
await db.commit()
2026-01-21 22:31:03 +01:00
# 发送 Agent 回应(支持多条消息)
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 多条消息之间稍作间隔,模拟打字效果
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
2026-01-21 22:31:03 +01:00
return
2026-01-07 11:56:40 +08:00
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
2026-01-21 22:31:03 +01:00
处理对话段落生成章节对话结束时调用
注意大部分处理已通过 Celery 任务增量完成
这里立即提交所有待处理的段落到 Celery
2026-01-07 11:56:40 +08:00
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
2026-01-21 22:31:03 +01:00
# 获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
2026-01-07 11:56:40 +08:00
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
# 没有未处理的段落,直接 flush 待处理任务
await manager.background_runner.flush_pending(conversation.user_id)
2026-01-07 11:56:40 +08:00
return
# 将未处理的段落直接提交到 Celery不通过去抖
segment_ids = [seg.id for seg in segments]
try:
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(conversation.user_id, segment_ids)
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
# 同时 flush 任何待处理的任务
await manager.background_runner.flush_pending(conversation.user_id)
2026-01-07 11:56:40 +08:00