Files
life-echo/api/agents/memoir_processor.py
penghanyuan 3f899aa16c feat: 新增任务状态管理和API支持
- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表
- 实现任务追踪服务,使用Redis存储任务状态
- 更新回忆录处理逻辑,集成Celery任务提交和状态更新
- 增强测试用例,支持任务状态的获取和清除功能
- 优化代码结构,提升可读性和维护性
2026-01-21 23:37:00 +01:00

219 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录后台处理器
负责:
- 分析用户对话内容,提取关键信息
- 更新回忆录状态slots
- 生成/更新章节内容
- 创建创意章节标题
使用 Celery 进行后台任务处理,支持可靠的任务队列和重试机制
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from typing import Dict, List
from agents.state_schema import MemoirStateSchema
from services.llm_service import llm_service
from .prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = logging.getLogger(__name__)
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
@dataclass
class AnalysisResult:
detected_stage: str
extracted_slots: Dict[str, str]
emotion: str
is_new_chapter: bool
class ContentAnalyzer:
"""对话内容分析(支持异步)"""
def __init__(self) -> None:
self.llm = llm_service.get_llm()
def _detect_stage(self, user_message: str, fallback_stage: str) -> str:
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _fallback_slots(self, state: MemoirStateSchema, stage: str, user_message: str) -> Dict[str, str]:
stage_slots = state.slots.get(stage, {})
for key, value in stage_slots.items():
if not value.snippet:
return {key: user_message.strip()[:200]}
return {}
async def analyze_message(self, user_message: str, current_state: MemoirStateSchema) -> AnalysisResult:
detected_stage = self._detect_stage(user_message, current_state.current_stage)
extracted_slots: Dict[str, str] = {}
emotion = "neutral"
is_new_chapter = False
if self.llm:
try:
prompt = get_state_extraction_prompt(
user_message=user_message,
current_stage=current_state.current_stage,
stage_slots=current_state.slots.get(detected_stage, {}),
)
# 使用异步调用
response = await self.llm.ainvoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
emotion = parsed.get("emotion", emotion)
is_new_chapter = bool(parsed.get("is_new_chapter", is_new_chapter))
except json.JSONDecodeError:
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
except Exception as e:
logger.error(f"分析消息失败: {e}")
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
else:
extracted_slots = self._fallback_slots(current_state, detected_stage, user_message)
return AnalysisResult(
detected_stage=detected_stage,
extracted_slots=extracted_slots,
emotion=emotion,
is_new_chapter=is_new_chapter,
)
class MemoirGenerator:
"""回忆录生成与更新(支持异步)"""
def __init__(self) -> None:
self.llm = llm_service.get_llm()
async def generate_chapter_title(self, stage: str, slots: Dict[str, str], emotion: str) -> str:
if not self.llm:
return f"{stage} 回忆"
try:
prompt = get_creative_title_prompt(stage=stage, emotion=emotion, slots=slots)
# 使用异步调用
response = await self.llm.ainvoke(prompt)
return response.content.strip().strip('"')
except Exception as e:
logger.error(f"生成标题失败: {e}")
return f"{stage} 回忆"
async def generate_narrative(self, stage: str, slots: Dict[str, str], new_content: str, existing_content: str) -> str:
if not self.llm:
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
try:
prompt = get_narrative_prompt(stage=stage, slots=slots, new_content=new_content, existing_content=existing_content)
# 使用异步调用
response = await self.llm.ainvoke(prompt)
return response.content.strip()
except Exception as e:
logger.error(f"生成叙事失败: {e}")
if existing_content:
return f"{existing_content}\n\n{new_content}"
return new_content
class BackgroundTaskRunner:
"""后台任务调度器(使用 Celery"""
def __init__(self, debounce_seconds: int = 5) -> None:
self.debounce_seconds = debounce_seconds
# 内存中的待处理任务(用于去抖)
self._pending: Dict[str, List[str]] = {}
self._timers: Dict[str, object] = {}
self.analyzer = ContentAnalyzer()
self.generator = MemoirGenerator()
async def _submit_task(self, user_id: str, segment_ids: List[str]) -> str | None:
"""
提交 Celery 任务并记录
Returns:
任务 ID失败返回 None
"""
try:
from tasks.memoir_tasks import process_memoir_segments
from services.task_tracker import task_tracker
# 提交到 Celery
result = process_memoir_segments.delay(user_id, segment_ids)
task_id = result.id
# 记录任务
await task_tracker.add_task(user_id, task_id, "memoir")
logger.info(f"已提交 Celery 任务: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
return task_id
except Exception as e:
logger.error(f"提交 Celery 任务失败: {e}")
return None
async def queue_message(self, user_id: str, segment_id: str) -> None:
"""
将消息加入处理队列
使用 Celery 延迟任务实现去抖效果
"""
import asyncio
# 收集待处理的 segment_ids
self._pending.setdefault(user_id, []).append(segment_id)
# 取消之前的定时器
if user_id in self._timers:
self._timers[user_id].cancel()
# 创建新的定时器
async def delayed_submit():
try:
await asyncio.sleep(self.debounce_seconds)
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
await self._submit_task(user_id, segment_ids)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"延迟提交任务失败: {e}")
self._timers[user_id] = asyncio.create_task(delayed_submit())
async def flush_pending(self, user_id: str) -> str | None:
"""
立即提交用户的待处理任务(用于对话结束时)
Returns:
任务 ID无任务或失败返回 None
"""
# 取消定时器
if user_id in self._timers:
self._timers[user_id].cancel()
del self._timers[user_id]
# 提交待处理任务
segment_ids = self._pending.pop(user_id, [])
if segment_ids:
return await self._submit_task(user_id, segment_ids)
return None