Files
life-echo/api/tasks/memoir_tasks.py
penghanyuan 3f899aa16c feat: 新增任务状态管理和API支持
- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表
- 实现任务追踪服务,使用Redis存储任务状态
- 更新回忆录处理逻辑,集成Celery任务提交和状态更新
- 增强测试用例,支持任务状态的获取和清除功能
- 优化代码结构,提升可读性和维护性
2026-01-21 23:37:00 +01:00

389 lines
14 KiB
Python

"""
回忆录处理 Celery 任务
"""
import json
import logging
import os
import uuid
from typing import Dict, List
from datetime import datetime, timezone
import redis
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
from database.database import SessionLocal
from database.models import Book, Chapter, Segment, MemoirState
from services.llm_service import llm_service
from agents.state_schema import MemoirStateSchema, SlotData, default_state
from agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
)
logger = logging.getLogger(__name__)
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
"""同步更新任务状态(在 Celery 任务中使用)"""
try:
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = redis.from_url(redis_url, decode_responses=True)
key = f"task:user:{user_id}:tasks"
# 获取现有任务信息
data = r.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
r.hset(key, task_id, json.dumps(task_info))
r.expire(key, 3600) # 1小时过期
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
def _detect_stage(user_message: str, fallback_stage: str) -> str:
"""检测消息所属阶段"""
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
"""将数据库模型转换为 Schema"""
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots if isinstance(model.slots, dict) else default_state().slots,
}
)
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
"""同步获取或创建状态"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()},
)
db.add(state)
db.commit()
db.refresh(state)
return _coerce_state(state)
def _update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
) -> MemoirStateSchema:
"""同步更新 slot"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
_get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = state.current_stage or stage
db.commit()
db.refresh(state)
return _coerce_state(state)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
"""
处理回忆录段落的 Celery 任务
Args:
user_id: 用户 ID
segment_ids: 段落 ID 列表
"""
task_id = self.request.id
logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
# 更新任务状态为 running
_update_task_status_sync(user_id, task_id, "running")
try:
db = SessionLocal()
try:
# 获取段落
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = db.execute(stmt)
segments = result.scalars().all()
if not segments:
logger.warning(f"未找到段落: {segment_ids}")
return {"status": "no_segments"}
# 获取用户状态
state = _get_or_create_state_sync(user_id, db)
llm = llm_service.get_llm()
# 按阶段分组处理
stage_to_segments: Dict[str, List[Segment]] = {}
for segment in segments:
text = segment.transcript_text
detected_stage = _detect_stage(text, state.current_stage)
# 尝试使用 LLM 提取信息
extracted_slots = {}
if llm:
try:
prompt = get_state_extraction_prompt(
user_message=text,
current_stage=state.current_stage,
stage_slots=state.slots.get(detected_stage, {}),
)
response = llm.invoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"LLM 解析失败: {e}")
# 更新 slots
for slot_name, snippet in extracted_slots.items():
state = _update_slot_sync(
user_id=user_id,
stage=detected_stage,
slot_name=slot_name,
snippet=snippet,
segment_ids=[segment.id],
db=db,
)
stage_to_segments.setdefault(detected_stage, []).append(segment)
# 生成章节内容
for stage, stage_segments in stage_to_segments.items():
segment_texts = [seg.transcript_text for seg in stage_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
# 查找或创建章节
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
)
result_chapter = db.execute(stmt_chapter)
chapter = result_chapter.scalar_one_or_none()
# 获取 slot snippets
slot_snippets = {
key: value.snippet
for key, value in (state.slots.get(stage, {}) or {}).items()
if value.snippet
}
# 生成标题和内容
title = chapter.title if chapter else f"{stage} 回忆"
existing_content = chapter.content if chapter else ""
narrative = combined_text
if llm:
try:
if not chapter:
title_prompt = get_creative_title_prompt(
stage=stage,
emotion="neutral",
slots=slot_snippets
)
title_response = llm.invoke(title_prompt)
title = title_response.content.strip().strip('"')
narrative_prompt = get_narrative_prompt(
stage=stage,
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
)
narrative_response = llm.invoke(narrative_prompt)
narrative = narrative_response.content.strip()
except Exception as e:
logger.warning(f"LLM 生成失败: {e}")
if existing_content:
narrative = f"{existing_content}\n\n{combined_text}"
# 更新或创建章节
if chapter:
chapter.content = narrative
chapter.title = title
chapter.is_new = True
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
content=narrative,
order_index=999,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=source_ids,
)
db.add(chapter)
db.flush()
# 更新 Book
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
# 标记段落为已处理
for seg in segments:
seg.processed = True
db.commit()
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
# 更新任务状态为成功
_update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)})
return {"status": "success", "processed": len(segments)}
finally:
db.close()
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 更新任务状态为失败
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
# 重试
raise self.retry(exc=e)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
"""
单独生成章节内容的任务(用于实时更新)
Args:
user_id: 用户 ID
stage: 阶段
new_content: 新内容
"""
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
try:
db = SessionLocal()
try:
llm = llm_service.get_llm()
# 查找章节
stmt = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
)
result = db.execute(stmt)
chapter = result.scalar_one_or_none()
existing_content = chapter.content if chapter else ""
if llm:
prompt = get_narrative_prompt(
stage=stage,
slots={},
new_content=new_content,
existing_content=existing_content,
)
response = llm.invoke(prompt)
narrative = response.content.strip()
else:
narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content
if chapter:
chapter.content = narrative
chapter.is_new = True
else:
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=f"{stage} 回忆",
content=narrative,
order_index=999,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=[],
)
db.add(chapter)
db.commit()
return {"status": "success"}
finally:
db.close()
except Exception as e:
logger.error(f"章节生成失败: {e}")
raise self.retry(exc=e)