feat: 添加Redis支持和Celery任务处理
- 新增Redis服务模块用于会话状态存储和缓存 - 集成Celery用于后台任务处理 - 更新Docker Compose配置以支持开发环境 - 优化API以支持异步调用和Redis会话存储 - 更新文档以反映新的开发环境配置和使用方法
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
WebSocket 路由:实时对话通信
|
||||
支持异步 Agent 调用和 Redis 会话存储
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
@@ -19,6 +21,8 @@ from services.auth_service import verify_token
|
||||
from services.memoir_state_service import get_or_create_state
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""WebSocket 消息类型"""
|
||||
@@ -39,7 +43,8 @@ class ConnectionManager:
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.conversation_agents: Dict[str, ConversationAgent] = {}
|
||||
# ConversationAgent 现在是无状态的(会话存储在 Redis),可以复用
|
||||
self.conversation_agent = ConversationAgent()
|
||||
self.memory_agent = MemoryAgent()
|
||||
self.background_runner = BackgroundTaskRunner()
|
||||
|
||||
@@ -47,15 +52,13 @@ class ConnectionManager:
|
||||
"""建立连接"""
|
||||
await websocket.accept()
|
||||
self.active_connections[conversation_id] = websocket
|
||||
self.conversation_agents[conversation_id] = ConversationAgent()
|
||||
|
||||
def disconnect(self, conversation_id: str):
|
||||
async def disconnect(self, conversation_id: str):
|
||||
"""断开连接"""
|
||||
if conversation_id in self.active_connections:
|
||||
del self.active_connections[conversation_id]
|
||||
if conversation_id in self.conversation_agents:
|
||||
self.conversation_agents[conversation_id].clear_memory(conversation_id)
|
||||
del self.conversation_agents[conversation_id]
|
||||
# 清除 Redis 中的会话记忆(可选,也可以保留用于恢复)
|
||||
# await self.conversation_agent.clear_memory(conversation_id)
|
||||
|
||||
async def send_message(self, conversation_id: str, message: dict):
|
||||
"""发送消息"""
|
||||
@@ -198,10 +201,10 @@ async def websocket_endpoint(
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(conversation_id)
|
||||
await manager.disconnect(conversation_id)
|
||||
break
|
||||
except Exception as e:
|
||||
manager.disconnect(conversation_id)
|
||||
await manager.disconnect(conversation_id)
|
||||
raise
|
||||
|
||||
|
||||
@@ -214,7 +217,7 @@ async def process_user_message(
|
||||
manager: ConnectionManager
|
||||
) -> None:
|
||||
"""
|
||||
处理用户消息,生成Agent回应
|
||||
处理用户消息,生成Agent回应(异步版本)
|
||||
|
||||
Args:
|
||||
conversation_id: 对话ID
|
||||
@@ -227,24 +230,26 @@ async def process_user_message(
|
||||
Returns:
|
||||
更新后的对话阶段
|
||||
"""
|
||||
agent = manager.conversation_agents.get(conversation_id)
|
||||
if agent:
|
||||
state = await get_or_create_state(conversation.user_id, db)
|
||||
import asyncio as _asyncio
|
||||
|
||||
agent = manager.conversation_agent
|
||||
state = await get_or_create_state(conversation.user_id, db)
|
||||
|
||||
if conversation.conversation_stage != state.current_stage:
|
||||
conversation.conversation_stage = state.current_stage
|
||||
await db.commit()
|
||||
if conversation.conversation_stage != state.current_stage:
|
||||
conversation.conversation_stage = state.current_stage
|
||||
await db.commit()
|
||||
|
||||
# 获取已聊话题(保留老逻辑用于提示)
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
result_segments = await db.execute(stmt_segments)
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
# 生成回应(可能是多条消息)
|
||||
responses = agent.generate_response_with_state(
|
||||
# 获取已聊话题(保留老逻辑用于提示)
|
||||
stmt_segments = select(Segment).where(
|
||||
Segment.conversation_id == conversation_id
|
||||
).order_by(Segment.created_at)
|
||||
result_segments = await db.execute(stmt_segments)
|
||||
previous_segments = result_segments.scalars().all()
|
||||
covered_topics = [seg.topic_category for seg in previous_segments if seg.topic_category]
|
||||
|
||||
try:
|
||||
# 异步生成回应(可能是多条消息)
|
||||
responses = await agent.generate_response_with_state(
|
||||
conversation_id=conversation_id,
|
||||
user_message=user_message,
|
||||
memoir_state=state
|
||||
@@ -255,7 +260,6 @@ async def process_user_message(
|
||||
await db.commit()
|
||||
|
||||
# 发送 Agent 回应(支持多条消息)
|
||||
import asyncio as _asyncio
|
||||
for i, response_text in enumerate(responses):
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.AGENT_RESPONSE,
|
||||
@@ -266,6 +270,14 @@ async def process_user_message(
|
||||
# 多条消息之间稍作间隔,模拟打字效果
|
||||
if i < len(responses) - 1:
|
||||
await _asyncio.sleep(0.5)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理用户消息失败: {e}")
|
||||
await manager.send_message(conversation_id, {
|
||||
"type": MessageType.ERROR,
|
||||
"data": {"message": f"生成回应失败: {str(e)}"},
|
||||
"timestamp": datetime.now(timezone.utc).isoformat()
|
||||
})
|
||||
|
||||
return
|
||||
|
||||
@@ -274,8 +286,8 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
"""
|
||||
处理对话段落,生成章节(对话结束时调用)
|
||||
|
||||
注意:大部分处理已通过 BackgroundTaskRunner 增量完成
|
||||
这里只处理可能遗漏的最后几条消息
|
||||
注意:大部分处理已通过 Celery 任务增量完成
|
||||
这里立即提交所有待处理的段落到 Celery
|
||||
|
||||
Args:
|
||||
conversation_id: 对话 ID
|
||||
@@ -295,9 +307,19 @@ async def process_conversation_segments(conversation_id: str, db: AsyncSession):
|
||||
segments = result.scalars().all()
|
||||
|
||||
if not segments:
|
||||
# 没有未处理的段落,直接 flush 待处理任务
|
||||
await manager.background_runner.flush_pending(conversation.user_id)
|
||||
return
|
||||
|
||||
# 将未处理的段落加入后台任务队列(不等待完成,避免阻塞)
|
||||
for seg in segments:
|
||||
await manager.background_runner.queue_message(conversation.user_id, seg.id)
|
||||
# 将未处理的段落直接提交到 Celery(不通过去抖)
|
||||
segment_ids = [seg.id for seg in segments]
|
||||
try:
|
||||
from tasks.memoir_tasks import process_memoir_segments
|
||||
process_memoir_segments.delay(conversation.user_id, segment_ids)
|
||||
logger.info(f"对话结束,提交 Celery 任务: conversation_id={conversation_id}, segments={len(segment_ids)}")
|
||||
except Exception as e:
|
||||
logger.error(f"提交 Celery 任务失败: {e}")
|
||||
|
||||
# 同时 flush 任何待处理的任务
|
||||
await manager.background_runner.flush_pending(conversation.user_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user