- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表 - 实现任务追踪服务,使用Redis存储任务状态 - 更新回忆录处理逻辑,集成Celery任务提交和状态更新 - 增强测试用例,支持任务状态的获取和清除功能 - 优化代码结构,提升可读性和维护性
219 lines
7.8 KiB
Python
219 lines
7.8 KiB
Python
"""
|
||
回忆录后台处理器
|
||
|
||
负责:
|
||
- 分析用户对话内容,提取关键信息
|
||
- 更新回忆录状态(slots)
|
||
- 生成/更新章节内容
|
||
- 创建创意章节标题
|
||
|
||
使用 Celery 进行后台任务处理,支持可靠的任务队列和重试机制
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from typing import Dict, List
|
||
|
||
from agents.state_schema import MemoirStateSchema
|
||
from services.llm_service import llm_service
|
||
from .prompts.memory_prompts import (
|
||
get_creative_title_prompt,
|
||
get_narrative_prompt,
|
||
get_state_extraction_prompt,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
STAGE_KEYWORDS = {
|
||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
||
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
||
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class AnalysisResult:
|
||
detected_stage: str
|
||
extracted_slots: Dict[str, str]
|
||
emotion: str
|
||
is_new_chapter: bool
|
||
|
||
|
||
class ContentAnalyzer:
|
||
"""对话内容分析(支持异步)"""
|
||
|
||
def __init__(self) -> None:
|
||
self.llm = llm_service.get_llm()
|
||
|
||
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
|
||
message = user_message.lower()
|
||
for stage, keywords in STAGE_KEYWORDS.items():
|
||
if any(word in message for word in keywords):
|
||
return stage
|
||
return fallback_stage
|
||
|
||
def _fallback_slots(self, state: MemoirStateSchema, stage: str, user_message: str) -> Dict[str, str]:
|
||
stage_slots = state.slots.get(stage, {})
|
||
for key, value in stage_slots.items():
|
||
if not value.snippet:
|
||
return {key: user_message.strip()[:200]}
|
||
return {}
|
||
|
||
async def analyze_message(self, user_message: str, current_state: MemoirStateSchema) -> AnalysisResult:
|
||
detected_stage = self._detect_stage(user_message, current_state.current_stage)
|
||
extracted_slots: Dict[str, str] = {}
|
||
emotion = "neutral"
|
||
is_new_chapter = False
|
||
|
||
if self.llm:
|
||
try:
|
||
prompt = get_state_extraction_prompt(
|
||
user_message=user_message,
|
||
current_stage=current_state.current_stage,
|
||
stage_slots=current_state.slots.get(detected_stage, {}),
|
||
)
|
||
# 使用异步调用
|
||
response = await self.llm.ainvoke(prompt)
|
||
content = response.content.strip()
|
||
parsed = json.loads(content)
|
||
detected_stage = parsed.get("detected_stage", detected_stage)
|
||
extracted_slots = parsed.get("slots", {}) or {}
|
||
emotion = parsed.get("emotion", emotion)
|
||
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
|
||
except json.JSONDecodeError:
|
||
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
||
except Exception as e:
|
||
logger.error(f"分析消息失败: {e}")
|
||
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
||
else:
|
||
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
|
||
|
||
return AnalysisResult(
|
||
detected_stage=detected_stage,
|
||
extracted_slots=extracted_slots,
|
||
emotion=emotion,
|
||
is_new_chapter=is_new_chapter,
|
||
)
|
||
|
||
|
||
class MemoirGenerator:
|
||
"""回忆录生成与更新(支持异步)"""
|
||
|
||
def __init__(self) -> None:
|
||
self.llm = llm_service.get_llm()
|
||
|
||
async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str:
|
||
if not self.llm:
|
||
return f"{stage} 回忆"
|
||
try:
|
||
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
|
||
# 使用异步调用
|
||
response = await self.llm.ainvoke(prompt)
|
||
return response.content.strip().strip('"')
|
||
except Exception as e:
|
||
logger.error(f"生成标题失败: {e}")
|
||
return f"{stage} 回忆"
|
||
|
||
async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str:
|
||
if not self.llm:
|
||
if existing_content:
|
||
return f"{existing_content}\n\n{new_content}"
|
||
return new_content
|
||
try:
|
||
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
|
||
# 使用异步调用
|
||
response = await self.llm.ainvoke(prompt)
|
||
return response.content.strip()
|
||
except Exception as e:
|
||
logger.error(f"生成叙事失败: {e}")
|
||
if existing_content:
|
||
return f"{existing_content}\n\n{new_content}"
|
||
return new_content
|
||
|
||
|
||
class BackgroundTaskRunner:
|
||
"""后台任务调度器(使用 Celery)"""
|
||
|
||
def __init__(self, debounce_seconds: int = 5) -> None:
|
||
self.debounce_seconds = debounce_seconds
|
||
# 内存中的待处理任务(用于去抖)
|
||
self._pending: Dict[str, List[str]] = {}
|
||
self._timers: Dict[str, object] = {}
|
||
self.analyzer = ContentAnalyzer()
|
||
self.generator = MemoirGenerator()
|
||
|
||
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
|
||
"""
|
||
提交 Celery 任务并记录
|
||
|
||
Returns:
|
||
任务 ID,失败返回 None
|
||
"""
|
||
try:
|
||
from tasks.memoir_tasks import process_memoir_segments
|
||
from services.task_tracker import task_tracker
|
||
|
||
# 提交到 Celery
|
||
result = process_memoir_segments.delay(user_id, segment_ids)
|
||
task_id = result.id
|
||
|
||
# 记录任务
|
||
await task_tracker.add_task(user_id, task_id, "memoir")
|
||
|
||
logger.info(f"已提交 Celery 任务: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
|
||
return task_id
|
||
except Exception as e:
|
||
logger.error(f"提交 Celery 任务失败: {e}")
|
||
return None
|
||
|
||
async def queue_message(self, user_id: str, segment_id: str) -> None:
|
||
"""
|
||
将消息加入处理队列
|
||
|
||
使用 Celery 延迟任务实现去抖效果
|
||
"""
|
||
import asyncio
|
||
|
||
# 收集待处理的 segment_ids
|
||
self._pending.setdefault(user_id, []).append(segment_id)
|
||
|
||
# 取消之前的定时器
|
||
if user_id in self._timers:
|
||
self._timers[user_id].cancel()
|
||
|
||
# 创建新的定时器
|
||
async def delayed_submit():
|
||
try:
|
||
await asyncio.sleep(self.debounce_seconds)
|
||
segment_ids = self._pending.pop(user_id, [])
|
||
if segment_ids:
|
||
await self._submit_task(user_id, segment_ids)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"延迟提交任务失败: {e}")
|
||
|
||
self._timers[user_id] = asyncio.create_task(delayed_submit())
|
||
|
||
async def flush_pending(self, user_id: str) -> str | None:
|
||
"""
|
||
立即提交用户的待处理任务(用于对话结束时)
|
||
|
||
Returns:
|
||
任务 ID,无任务或失败返回 None
|
||
"""
|
||
# 取消定时器
|
||
if user_id in self._timers:
|
||
self._timers[user_id].cancel()
|
||
del self._timers[user_id]
|
||
|
||
# 提交待处理任务
|
||
segment_ids = self._pending.pop(user_id, [])
|
||
if segment_ids:
|
||
return await self._submit_task(user_id, segment_ids)
|
||
return None
|