2026-01-07 11:56:40 +08:00
|
|
|
|
"""
|
|
|
|
|
|
WebSocket 路由:实时对话通信
|
|
|
|
|
|
"""
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect, HTTPException, Query
|
2026-01-07 11:56:40 +08:00
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
|
|
|
|
|
|
|
from agents import ConversationAgent, MemoryAgent
|
2026-01-21 22:31:03 +01:00
|
|
|
|
from agents.memoir_processor import BackgroundTaskRunner
|
2026-01-07 11:56:40 +08:00
|
|
|
|
from database import get_async_db
|
|
|
|
|
|
from database.models import Conversation, Segment
|
2026-01-18 15:57:51 +08:00
|
|
|
|
from database.models import User as UserModel
|
|
|
|
|
|
from services.auth_service import verify_token
|
2026-01-21 22:31:03 +01:00
|
|
|
|
from services.memoir_state_service import get_or_create_state
|
2026-01-18 15:57:51 +08:00
|
|
|
|
from fastapi import HTTPException, status
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MessageType(str, Enum):
|
|
|
|
|
|
"""WebSocket 消息类型"""
|
|
|
|
|
|
CONNECT = "connect"
|
|
|
|
|
|
AUDIO_CHUNK = "audio_chunk"
|
2026-01-18 15:57:51 +08:00
|
|
|
|
TEXT = "text" # 文本消息
|
2026-01-07 11:56:40 +08:00
|
|
|
|
TRANSCRIPT = "transcript"
|
|
|
|
|
|
AGENT_RESPONSE = "agent_response"
|
|
|
|
|
|
TTS_AUDIO = "tts_audio"
|
|
|
|
|
|
END_CONVERSATION = "end_conversation"
|
2026-01-21 22:31:03 +01:00
|
|
|
|
MEMOIR_UPDATE = "memoir_update"
|
2026-01-07 11:56:40 +08:00
|
|
|
|
ERROR = "error"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 连接管理
|
|
|
|
|
|
class ConnectionManager:
|
|
|
|
|
|
"""WebSocket 连接管理器"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.active_connections: Dict[str, WebSocket] = {}
|
|
|
|
|
|
self.conversation_agents: Dict[str, ConversationAgent] = {}
|
|
|
|
|
|
self.memory_agent = MemoryAgent()
|
2026-01-21 22:31:03 +01:00
|
|
|
|
self.background_runner = BackgroundTaskRunner()
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
|
|
|
|
|
async def connect(self, websocket: WebSocket, conversation_id: str):
|
|
|
|
|
|
"""建立连接"""
|
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
self.active_connections[conversation_id] = websocket
|
|
|
|
|
|
self.conversation_agents[conversation_id] = ConversationAgent()
|
|
|
|
|
|
|
|
|
|
|
|
def disconnect(self, conversation_id: str):
|
|
|
|
|
|
"""断开连接"""
|
|
|
|
|
|
if conversation_id in self.active_connections:
|
|
|
|
|
|
del self.active_connections[conversation_id]
|
|
|
|
|
|
if conversation_id in self.conversation_agents:
|
|
|
|
|
|
self.conversation_agents[conversation_id].clear_memory(conversation_id)
|
|
|
|
|
|
del self.conversation_agents[conversation_id]
|
|
|
|
|
|
|
|
|
|
|
|
async def send_message(self, conversation_id: str, message: dict):
|
|
|
|
|
|
"""发送消息"""
|
|
|
|
|
|
if conversation_id in self.active_connections:
|
|
|
|
|
|
websocket = self.active_connections[conversation_id]
|
|
|
|
|
|
await websocket.send_json(message)
|
|
|
|
|
|
|
|
|
|
|
|
async def receive_message(self, conversation_id: str) -> dict:
|
|
|
|
|
|
"""接收消息"""
|
|
|
|
|
|
if conversation_id in self.active_connections:
|
|
|
|
|
|
websocket = self.active_connections[conversation_id]
|
|
|
|
|
|
return await websocket.receive_json()
|
|
|
|
|
|
raise HTTPException(status_code=404, detail="Connection not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
async def websocket_endpoint(
|
|
|
|
|
|
websocket: WebSocket,
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
token: str = Query(..., description="访问令牌")
|
|
|
|
|
|
):
|
2026-01-07 11:56:40 +08:00
|
|
|
|
"""
|
|
|
|
|
|
WebSocket 端点:处理实时对话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
websocket: WebSocket 连接
|
|
|
|
|
|
conversation_id: 对话 ID
|
2026-01-18 15:57:51 +08:00
|
|
|
|
token: 访问令牌(从查询参数获取)
|
2026-01-07 11:56:40 +08:00
|
|
|
|
"""
|
2026-01-18 15:57:51 +08:00
|
|
|
|
# 验证JWT令牌
|
|
|
|
|
|
payload = verify_token(token)
|
|
|
|
|
|
if not payload or payload.get("type") != "access":
|
|
|
|
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的认证令牌")
|
|
|
|
|
|
return
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
user_id = payload.get("sub")
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无效的令牌内容")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 验证用户是否存在
|
|
|
|
|
|
async for db in get_async_db():
|
|
|
|
|
|
user = await db.get(UserModel, user_id)
|
|
|
|
|
|
if not user:
|
|
|
|
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="用户不存在")
|
|
|
|
|
|
return
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
await manager.connect(websocket, conversation_id)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 发送连接确认
|
|
|
|
|
|
await manager.send_message(conversation_id, {
|
|
|
|
|
|
"type": MessageType.CONNECT,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
|
|
"data": {"status": "connected"},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 从数据库获取对话信息
|
2026-01-07 11:56:40 +08:00
|
|
|
|
conversation = await db.get(Conversation, conversation_id)
|
|
|
|
|
|
if not conversation:
|
|
|
|
|
|
# 如果对话不存在,创建新对话
|
|
|
|
|
|
conversation = Conversation(
|
|
|
|
|
|
id=conversation_id,
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
started_at=datetime.now(timezone.utc),
|
|
|
|
|
|
status="active"
|
|
|
|
|
|
)
|
|
|
|
|
|
db.add(conversation)
|
|
|
|
|
|
await db.commit()
|
2026-01-18 15:57:51 +08:00
|
|
|
|
else:
|
|
|
|
|
|
# 验证用户权限:只能访问自己的对话
|
|
|
|
|
|
if conversation.user_id != user_id:
|
|
|
|
|
|
await manager.send_message(conversation_id, {
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {"message": "无权访问此对话"},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="无权访问此对话")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
|
|
|
|
|
# 主循环:处理消息
|
|
|
|
|
|
while True:
|
|
|
|
|
|
try:
|
|
|
|
|
|
message = await websocket.receive_json()
|
|
|
|
|
|
msg_type = message.get("type")
|
|
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
if msg_type == MessageType.TEXT:
|
|
|
|
|
|
# 处理文本消息
|
|
|
|
|
|
text_message = message.get("data", {}).get("text", "")
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
if text_message:
|
|
|
|
|
|
# 保存段落到数据库
|
|
|
|
|
|
segment = Segment(
|
|
|
|
|
|
id=str(uuid.uuid4()),
|
2026-01-07 11:56:40 +08:00
|
|
|
|
conversation_id=conversation_id,
|
2026-01-18 15:57:51 +08:00
|
|
|
|
transcript_text=text_message,
|
|
|
|
|
|
processed=False
|
2026-01-07 11:56:40 +08:00
|
|
|
|
)
|
2026-01-18 15:57:51 +08:00
|
|
|
|
db.add(segment)
|
2026-01-07 11:56:40 +08:00
|
|
|
|
await db.commit()
|
2026-01-21 22:31:03 +01:00
|
|
|
|
await db.refresh(segment)
|
|
|
|
|
|
await manager.background_runner.queue_message(conversation.user_id, segment.id)
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
# Agent 生成回应
|
2026-01-21 22:31:03 +01:00
|
|
|
|
await process_user_message(
|
2026-01-18 15:57:51 +08:00
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
user_message=text_message,
|
|
|
|
|
|
conversation=conversation,
|
|
|
|
|
|
segment=segment,
|
|
|
|
|
|
db=db,
|
|
|
|
|
|
manager=manager
|
|
|
|
|
|
)
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
|
|
|
|
|
elif msg_type == MessageType.END_CONVERSATION:
|
|
|
|
|
|
# 结束对话
|
|
|
|
|
|
conversation.status = "ended"
|
|
|
|
|
|
conversation.ended_at = datetime.now(timezone.utc)
|
|
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
|
|
|
|
|
# 触发整理 Agent
|
|
|
|
|
|
await process_conversation_segments(conversation_id, db)
|
|
|
|
|
|
|
|
|
|
|
|
await manager.send_message(conversation_id, {
|
|
|
|
|
|
"type": MessageType.END_CONVERSATION,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
|
|
"data": {"status": "ended"},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
await manager.send_message(conversation_id, {
|
|
|
|
|
|
"type": MessageType.ERROR,
|
|
|
|
|
|
"data": {"message": str(e)},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2026-01-18 15:57:51 +08:00
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
|
manager.disconnect(conversation_id)
|
|
|
|
|
|
break
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
manager.disconnect(conversation_id)
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_user_message(
|
|
|
|
|
|
conversation_id: str,
|
|
|
|
|
|
user_message: str,
|
|
|
|
|
|
conversation: Conversation,
|
|
|
|
|
|
segment: Segment,
|
|
|
|
|
|
db: AsyncSession,
|
|
|
|
|
|
manager: ConnectionManager
|
2026-01-21 22:31:03 +01:00
|
|
|
|
) -> None:
|
2026-01-18 15:57:51 +08:00
|
|
|
|
"""
|
|
|
|
|
|
处理用户消息,生成Agent回应
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
conversation_id: 对话ID
|
|
|
|
|
|
user_message: 用户消息文本
|
|
|
|
|
|
conversation: 对话对象
|
|
|
|
|
|
segment: 段落对象
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
manager: 连接管理器
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
更新后的对话阶段
|
|
|
|
|
|
"""
|
|
|
|
|
|
agent = manager.conversation_agents.get(conversation_id)
|
|
|
|
|
|
if agent:
|
2026-01-21 22:31:03 +01:00
|
|
|
|
state = await get_or_create_state(conversation.user_id, db)
|
|
|
|
|
|
|
|
|
|
|
|
if conversation.conversation_stage != state.current_stage:
|
|
|
|
|
|
conversation.conversation_stage = state.current_stage
|
2026-01-18 15:57:51 +08:00
|
|
|
|
await db.commit()
|
2026-01-21 22:31:03 +01:00
|
|
|
|
|
|
|
|
|
|
# 获取已聊话题(保留老逻辑用于提示)
|
2026-01-18 15:57:51 +08:00
|
|
|
|
stmt_segments = select(Segment).where(
|
|
|
|
|
|
Segment.conversation_id == conversation_id
|
|
|
|
|
|
).order_by(Segment.created_at)
|
|
|
|
|
|
result_segments = await db.execute(stmt_segments)
|
|
|
|
|
|
previous_segments = result_segments.scalars().all()
|
|
|
|
|
|
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
|
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
# 生成回应(可能是多条消息)
|
|
|
|
|
|
responses = agent.generate_response_with_state(
|
2026-01-18 15:57:51 +08:00
|
|
|
|
conversation_id=conversation_id,
|
|
|
|
|
|
user_message=user_message,
|
2026-01-21 22:31:03 +01:00
|
|
|
|
memoir_state=state
|
2026-01-18 15:57:51 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
# 更新段落的 Agent 回应(存储完整内容)
|
|
|
|
|
|
segment.agent_response = "\n\n".join(responses)
|
2026-01-18 15:57:51 +08:00
|
|
|
|
await db.commit()
|
|
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
# 发送 Agent 回应(支持多条消息)
|
|
|
|
|
|
import asyncio as _asyncio
|
|
|
|
|
|
for i, response_text in enumerate(responses):
|
|
|
|
|
|
await manager.send_message(conversation_id, {
|
|
|
|
|
|
"type": MessageType.AGENT_RESPONSE,
|
|
|
|
|
|
"conversation_id": conversation_id,
|
|
|
|
|
|
"data": {"text": response_text, "index": i, "total": len(responses)},
|
|
|
|
|
|
"timestamp": datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
})
|
|
|
|
|
|
# 多条消息之间稍作间隔,模拟打字效果
|
|
|
|
|
|
if i < len(responses) - 1:
|
|
|
|
|
|
await _asyncio.sleep(0.5)
|
2026-01-18 15:57:51 +08:00
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
return
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
|
|
|
|
|
"""
|
2026-01-21 22:31:03 +01:00
|
|
|
|
处理对话段落,生成章节(对话结束时调用)
|
|
|
|
|
|
|
|
|
|
|
|
注意:大部分处理已通过 BackgroundTaskRunner 增量完成
|
|
|
|
|
|
这里只处理可能遗漏的最后几条消息
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
conversation_id: 对话 ID
|
|
|
|
|
|
db: 数据库会话
|
|
|
|
|
|
"""
|
2026-01-21 22:31:03 +01:00
|
|
|
|
# 获取对话信息
|
|
|
|
|
|
conversation = await db.get(Conversation, conversation_id)
|
|
|
|
|
|
if not conversation:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2026-01-07 11:56:40 +08:00
|
|
|
|
# 获取所有未处理的段落
|
|
|
|
|
|
stmt = select(Segment).where(
|
|
|
|
|
|
Segment.conversation_id == conversation_id,
|
|
|
|
|
|
Segment.processed == False
|
|
|
|
|
|
)
|
|
|
|
|
|
result = await db.execute(stmt)
|
|
|
|
|
|
segments = result.scalars().all()
|
|
|
|
|
|
|
|
|
|
|
|
if not segments:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
2026-01-21 22:31:03 +01:00
|
|
|
|
# 将未处理的段落加入后台任务队列(不等待完成,避免阻塞)
|
2026-01-07 11:56:40 +08:00
|
|
|
|
for seg in segments:
|
2026-01-21 22:31:03 +01:00
|
|
|
|
await manager.background_runner.queue_message(conversation.user_id, seg.id)
|
2026-01-07 11:56:40 +08:00
|
|
|
|
|