Files
life-echo/api/routers/websocket.py
iammm0 be3532d4b1 feat: 扩展后端WebSocket处理
- 优化api/routers/websocket.py

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-10 17:09:48 +08:00

509 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket 路由:实时对话通信
支持异步 Agent 调用和 Redis 会话存储
"""
import logging
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
from agents.memoir_processor import BackgroundTaskRunner
from database import get_async_db
from database.models import Conversation, Segment
from database.models import User as UserModel
from services.auth_service import verify_token
from services.memoir_state_service import get_or_create_state
from services.asr_service import asr_service
from fastapi import HTTPException, status
logger = logging.getLogger(__name__)
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk"
AUDIO_MESSAGE = "audio_message" # 完整音频消息(类似微信语音)
TRANSCRIBE_ONLY = "transcribe_only" # 仅转写,不落库、不触发 Agent用于「转文字」发送
TEXT = "text" # 文本消息
TRANSCRIPT = "transcript" # 语音转文字结果
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"
# 连接管理
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
# ConversationAgent 现在是无状态的(会话存储在 Redis可以复用
self.conversation_agent = ConversationAgent()
self.memory_agent = MemoryAgent()
self.background_runner = BackgroundTaskRunner()
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
await websocket.accept()
self.active_connections[conversation_id] = websocket
async def disconnect(self, conversation_id: str):
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
# await self.conversation_agent.clear_memory(conversation_id)
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
try:
# 尝试发送消息,如果连接已关闭会抛出异常
await websocket.send_json(message)
except (RuntimeError, Exception) as e:
logger.warning(f"发送消息失败 (conversation_id={conversation_id}): {e}")
# 如果发送失败,从连接列表中移除
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
return await websocket.receive_json()
raise HTTPException(status_code=404, detail="Connection not found")
manager = ConnectionManager()
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
"""
# 从查询参数获取token
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="缺少访问令牌")
return
# 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
return
# 验证用户是否存在
async for db in get_async_db():
user = await db.get(UserModel, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
return
await manager.connect(websocket, conversation_id)
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
# 如果对话不存在,创建新对话
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception:
pass # 如果发送失败,直接关闭连接
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
# 主循环:处理消息
while True:
try:
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.TEXT:
# 处理文本消息
text_message = message.get("data", {}).get("text", "")
if text_message:
# 校验对话轮数配额
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=text_message,
processed=False
)
db.add(segment)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
# Agent 生成回应
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
manager=manager
)
elif msg_type == MessageType.AUDIO_MESSAGE:
# 处理完整音频消息(类似微信语音)
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
audio_duration = data.get("duration", 0)
if audio_base64:
# 校验对话轮数配额
from routers.quota import get_segment_count, check_can_send_message
seg_count = await get_segment_count(user_id, db)
can_send, quota_msg = check_can_send_message(user.subscription_type, seg_count)
if not can_send:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": quota_msg, "code": "QUOTA_EXCEEDED"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
logger.info(f"收到音频消息,时长: {audio_duration}s")
try:
# 1. ASR 转写
transcript_text = await asr_service.transcribe(audio_base64)
logger.info(f"ASR 转写结果: {transcript_text}")
# 2. 发送转写结果给客户端
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {
"text": transcript_text,
"audio_duration": audio_duration
},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 3. 保存段落到数据库(包含转写文本和音频信息)
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=transcript_text,
audio_url=f"audio:{audio_duration}s", # 简化存储,标记为音频消息
processed=False
)
db.add(segment)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
# 4. Agent 生成回应(基于转写文本)
if transcript_text and not transcript_text.startswith("转写失败"):
await process_user_message(
conversation_id=conversation_id,
user_message=transcript_text,
conversation=conversation,
segment=segment,
db=db,
manager=manager
)
else:
# 转写失败,发送错误消息
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "语音转写失败,请重试或使用文字输入"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as e:
logger.error(f"处理音频消息失败: {e}", exc_info=True)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"处理音频消息失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
elif msg_type == MessageType.TRANSCRIBE_ONLY:
# 仅转写:不落库、不触发 Agent用于客户端「转文字」后发文本
data = message.get("data", {})
audio_base64 = data.get("audio_base64", "")
if not audio_base64:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "缺少 audio_base64"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
continue
try:
transcript_text = await asr_service.transcribe(audio_base64)
await manager.send_message(conversation_id, {
"type": MessageType.TRANSCRIPT,
"conversation_id": conversation_id,
"data": {"text": transcript_text or ""},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as e:
logger.error(f"仅转写失败: {e}", exc_info=True)
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"转写失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
elif msg_type == MessageType.END_CONVERSATION:
# 结束对话
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
# 触发整理 Agent
await process_conversation_segments(conversation_id, db)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
break
except RuntimeError as e:
# 检查是否是断开连接相关的错误
error_msg = str(e)
if "disconnect" in error_msg.lower() or "Cannot call \"receive\"" in error_msg:
logger.info(f"WebSocket 连接已断开: conversation_id={conversation_id}")
break
else:
logger.error(f"处理消息时发生 RuntimeError: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
break
except Exception as e:
logger.error(f"处理消息时发生错误: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
break
except WebSocketDisconnect:
logger.info(f"WebSocket 断开连接: conversation_id={conversation_id}")
await manager.disconnect(conversation_id)
except Exception as e:
logger.error(f"WebSocket 端点发生错误: {e}", exc_info=True)
await manager.disconnect(conversation_id)
finally:
# 确保清理连接
await manager.disconnect(conversation_id)
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
) -> None:
"""
处理用户消息生成Agent回应异步版本
Args:
conversation_id: 对话ID
user_message: 用户消息文本
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
"""
import asyncio as _asyncio
agent = manager.conversation_agent
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
# 获取已聊话题(保留老逻辑用于提示)
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
try:
# 异步生成回应(可能是多条消息)
responses = await agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state
)
# 更新段落的 Agent 回应(存储完整内容)
segment.agent_response = "\n\n".join(responses)
await db.commit()
# 发送 Agent 回应(支持多条消息)
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 多条消息之间稍作间隔,模拟打字效果
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
except Exception as e:
logger.error(f"处理用户消息失败: {e}", exc_info=True)
# 只在连接仍然活跃时发送错误消息
if conversation_id in manager.active_connections:
try:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": f"生成回应失败: {str(e)}"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except Exception as send_error:
logger.warning(f"发送错误消息失败: {send_error}")
return
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
处理对话段落,生成章节(对话结束时调用)
注意:大部分处理已通过 Celery 任务增量完成
这里立即提交所有待处理的段落到 Celery
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
# 获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
# 没有未处理的段落,直接 flush 待处理任务
await manager.background_runner.flush_pending(conversation.user_id)
return
# 免费版仅允许 1 个章节整理,提交前校验
from database.models import User as UserModel
from routers.quota import get_chapter_count, check_can_submit_organize
user = await db.get(UserModel, conversation.user_id)
if user:
chapter_count = await get_chapter_count(user.id, db)
can_submit, _ = check_can_submit_organize(user.subscription_type, chapter_count)
if not can_submit:
logger.info(
f"用户 {user.id} 章节配额已用尽,跳过提交整理任务: conversation_id={conversation_id}"
)
await manager.background_runner.flush_pending(conversation.user_id)
return
# 将未处理的段落直接提交到 Celery不通过去抖
segment_ids = [seg.id for seg in segments]
try:
from tasks.memoir_tasks import process_memoir_segments
process_memoir_segments.delay(conversation.user_id, segment_ids)
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
# 同时 flush 任何待处理的任务
await manager.background_runner.flush_pending(conversation.user_id)