Files
life-echo/api/routers/websocket.py
penghanyuan 44bd478c1e agent init
2026-01-21 22:31:09 +01:00

304 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
WebSocket 路由:实时对话通信
"""
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Dict
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from agents import ConversationAgent, MemoryAgent
from agents.memoir_processor import BackgroundTaskRunner
from database import get_async_db
from database.models import Conversation, Segment
from database.models import User as UserModel
from services.auth_service import verify_token
from services.memoir_state_service import get_or_create_state
from fastapi import HTTPException, status
class MessageType(str, Enum):
"""WebSocket 消息类型"""
CONNECT = "connect"
AUDIO_CHUNK = "audio_chunk"
TEXT = "text" # 文本消息
TRANSCRIPT = "transcript"
AGENT_RESPONSE = "agent_response"
TTS_AUDIO = "tts_audio"
END_CONVERSATION = "end_conversation"
MEMOIR_UPDATE = "memoir_update"
ERROR = "error"
# 连接管理
class ConnectionManager:
"""WebSocket 连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.conversation_agents: Dict[str, ConversationAgent] = {}
self.memory_agent = MemoryAgent()
self.background_runner = BackgroundTaskRunner()
async def connect(self, websocket: WebSocket, conversation_id: str):
"""建立连接"""
await websocket.accept()
self.active_connections[conversation_id] = websocket
self.conversation_agents[conversation_id] = ConversationAgent()
def disconnect(self, conversation_id: str):
"""断开连接"""
if conversation_id in self.active_connections:
del self.active_connections[conversation_id]
if conversation_id in self.conversation_agents:
self.conversation_agents[conversation_id].clear_memory(conversation_id)
del self.conversation_agents[conversation_id]
async def send_message(self, conversation_id: str, message: dict):
"""发送消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
await websocket.send_json(message)
async def receive_message(self, conversation_id: str) -> dict:
"""接收消息"""
if conversation_id in self.active_connections:
websocket = self.active_connections[conversation_id]
return await websocket.receive_json()
raise HTTPException(status_code=404, detail="Connection not found")
manager = ConnectionManager()
async def websocket_endpoint(
websocket: WebSocket,
conversation_id: str,
token: str = Query(..., description="访问令牌")
):
"""
WebSocket 端点:处理实时对话
Args:
websocket: WebSocket 连接
conversation_id: 对话 ID
token: 访问令牌(从查询参数获取)
"""
# 验证JWT令牌
payload = verify_token(token)
if not payload or payload.get("type") != "access":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
return
user_id = payload.get("sub")
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
return
# 验证用户是否存在
async for db in get_async_db():
user = await db.get(UserModel, user_id)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
return
await manager.connect(websocket, conversation_id)
try:
# 发送连接确认
await manager.send_message(conversation_id, {
"type": MessageType.CONNECT,
"conversation_id": conversation_id,
"data": {"status": "connected"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 从数据库获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
# 如果对话不存在,创建新对话
conversation = Conversation(
id=conversation_id,
user_id=user_id,
started_at=datetime.now(timezone.utc),
status="active"
)
db.add(conversation)
await db.commit()
else:
# 验证用户权限:只能访问自己的对话
if conversation.user_id != user_id:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": "无权访问此对话"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
return
# 主循环:处理消息
while True:
try:
message = await websocket.receive_json()
msg_type = message.get("type")
if msg_type == MessageType.TEXT:
# 处理文本消息
text_message = message.get("data", {}).get("text", "")
if text_message:
# 保存段落到数据库
segment = Segment(
id=str(uuid.uuid4()),
conversation_id=conversation_id,
transcript_text=text_message,
processed=False
)
db.add(segment)
await db.commit()
await db.refresh(segment)
await manager.background_runner.queue_message(conversation.user_id, segment.id)
# Agent 生成回应
await process_user_message(
conversation_id=conversation_id,
user_message=text_message,
conversation=conversation,
segment=segment,
db=db,
manager=manager
)
elif msg_type == MessageType.END_CONVERSATION:
# 结束对话
conversation.status = "ended"
conversation.ended_at = datetime.now(timezone.utc)
await db.commit()
# 触发整理 Agent
await process_conversation_segments(conversation_id, db)
await manager.send_message(conversation_id, {
"type": MessageType.END_CONVERSATION,
"conversation_id": conversation_id,
"data": {"status": "ended"},
"timestamp": datetime.now(timezone.utc).isoformat()
})
break
except Exception as e:
await manager.send_message(conversation_id, {
"type": MessageType.ERROR,
"data": {"message": str(e)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
except WebSocketDisconnect:
manager.disconnect(conversation_id)
break
except Exception as e:
manager.disconnect(conversation_id)
raise
async def process_user_message(
conversation_id: str,
user_message: str,
conversation: Conversation,
segment: Segment,
db: AsyncSession,
manager: ConnectionManager
) -> None:
"""
处理用户消息生成Agent回应
Args:
conversation_id: 对话ID
user_message: 用户消息文本
conversation: 对话对象
segment: 段落对象
db: 数据库会话
manager: 连接管理器
Returns:
更新后的对话阶段
"""
agent = manager.conversation_agents.get(conversation_id)
if agent:
state = await get_or_create_state(conversation.user_id, db)
if conversation.conversation_stage != state.current_stage:
conversation.conversation_stage = state.current_stage
await db.commit()
# 获取已聊话题(保留老逻辑用于提示)
stmt_segments = select(Segment).where(
Segment.conversation_id == conversation_id
).order_by(Segment.created_at)
result_segments = await db.execute(stmt_segments)
previous_segments = result_segments.scalars().all()
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
# 生成回应(可能是多条消息)
responses = agent.generate_response_with_state(
conversation_id=conversation_id,
user_message=user_message,
memoir_state=state
)
# 更新段落的 Agent 回应(存储完整内容)
segment.agent_response = "\n\n".join(responses)
await db.commit()
# 发送 Agent 回应(支持多条消息)
import asyncio as _asyncio
for i, response_text in enumerate(responses):
await manager.send_message(conversation_id, {
"type": MessageType.AGENT_RESPONSE,
"conversation_id": conversation_id,
"data": {"text": response_text, "index": i, "total": len(responses)},
"timestamp": datetime.now(timezone.utc).isoformat()
})
# 多条消息之间稍作间隔,模拟打字效果
if i < len(responses) - 1:
await _asyncio.sleep(0.5)
return
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
"""
处理对话段落,生成章节(对话结束时调用)
注意:大部分处理已通过 BackgroundTaskRunner 增量完成
这里只处理可能遗漏的最后几条消息
Args:
conversation_id: 对话 ID
db: 数据库会话
"""
# 获取对话信息
conversation = await db.get(Conversation, conversation_id)
if not conversation:
return
# 获取所有未处理的段落
stmt = select(Segment).where(
Segment.conversation_id == conversation_id,
Segment.processed == False
)
result = await db.execute(stmt)
segments = result.scalars().all()
if not segments:
return
# 将未处理的段落加入后台任务队列(不等待完成,避免阻塞)
for seg in segments:
await manager.background_runner.queue_message(conversation.user_id, seg.id)