- 添加任务状态API路由,支持获取当前用户的任务状态和待处理任务列表 - 实现任务追踪服务,使用Redis存储任务状态 - 更新回忆录处理逻辑,集成Celery任务提交和状态更新 - 增强测试用例,支持任务状态的获取和清除功能 - 优化代码结构,提升可读性和维护性
389 lines
14 KiB
Python
389 lines
14 KiB
Python
"""
|
|
回忆录处理 Celery 任务
|
|
"""
|
|
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from typing import Dict, List
|
|
from datetime import datetime, timezone
|
|
|
|
import redis
|
|
from celery import shared_task
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from database.database import SessionLocal
|
|
from database.models import Book, Chapter, Segment, MemoirState
|
|
from services.llm_service import llm_service
|
|
from agents.state_schema import MemoirStateSchema, SlotData, default_state
|
|
from agents.prompts.memory_prompts import (
|
|
get_creative_title_prompt,
|
|
get_narrative_prompt,
|
|
get_state_extraction_prompt,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
|
|
"""同步更新任务状态(在 Celery 任务中使用)"""
|
|
try:
|
|
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
|
r = redis.from_url(redis_url, decode_responses=True)
|
|
|
|
key = f"task:user:{user_id}:tasks"
|
|
|
|
# 获取现有任务信息
|
|
data = r.hget(key, task_id)
|
|
if data:
|
|
task_info = json.loads(data)
|
|
else:
|
|
task_info = {"task_id": task_id}
|
|
|
|
task_info["status"] = status
|
|
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
|
|
if result is not None:
|
|
task_info["result"] = result
|
|
|
|
r.hset(key, task_id, json.dumps(task_info))
|
|
r.expire(key, 3600) # 1小时过期
|
|
|
|
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
|
|
except Exception as e:
|
|
logger.error(f"更新任务状态失败: {e}")
|
|
|
|
STAGE_KEYWORDS = {
|
|
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
|
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
|
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
|
|
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
|
|
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
|
|
}
|
|
|
|
|
|
def _detect_stage(user_message: str, fallback_stage: str) -> str:
|
|
"""检测消息所属阶段"""
|
|
message = user_message.lower()
|
|
for stage, keywords in STAGE_KEYWORDS.items():
|
|
if any(word in message for word in keywords):
|
|
return stage
|
|
return fallback_stage
|
|
|
|
|
|
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
|
|
"""将数据库模型转换为 Schema"""
|
|
return MemoirStateSchema.model_validate(
|
|
{
|
|
"stage_order": model.stage_order or default_state().stage_order,
|
|
"current_stage": model.current_stage,
|
|
"covered_stages": model.covered_stages or [],
|
|
"slots": model.slots if isinstance(model.slots, dict) else default_state().slots,
|
|
}
|
|
)
|
|
|
|
|
|
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
|
|
"""同步获取或创建状态"""
|
|
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
|
|
result = db.execute(stmt)
|
|
state = result.scalar_one_or_none()
|
|
if state:
|
|
return _coerce_state(state)
|
|
|
|
default = default_state()
|
|
state = MemoirState(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
stage_order=default.stage_order,
|
|
current_stage=default.current_stage,
|
|
covered_stages=default.covered_stages,
|
|
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()},
|
|
)
|
|
db.add(state)
|
|
db.commit()
|
|
db.refresh(state)
|
|
return _coerce_state(state)
|
|
|
|
|
|
def _update_slot_sync(
|
|
user_id: str,
|
|
stage: str,
|
|
slot_name: str,
|
|
snippet: str,
|
|
segment_ids: List[str],
|
|
db: Session,
|
|
) -> MemoirStateSchema:
|
|
"""同步更新 slot"""
|
|
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
|
|
result = db.execute(stmt)
|
|
state = result.scalar_one_or_none()
|
|
if not state:
|
|
_get_or_create_state_sync(user_id, db)
|
|
result = db.execute(stmt)
|
|
state = result.scalar_one()
|
|
|
|
slots: Dict[str, Dict] = state.slots or {}
|
|
stage_slots = slots.get(stage, {})
|
|
existing = stage_slots.get(slot_name, {})
|
|
|
|
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
|
|
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump()
|
|
slots[stage] = stage_slots
|
|
state.slots = slots
|
|
state.current_stage = state.current_stage or stage
|
|
db.commit()
|
|
db.refresh(state)
|
|
return _coerce_state(state)
|
|
|
|
|
|
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
|
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
|
"""
|
|
处理回忆录段落的 Celery 任务
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
segment_ids: 段落 ID 列表
|
|
"""
|
|
task_id = self.request.id
|
|
logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
|
|
|
|
# 更新任务状态为 running
|
|
_update_task_status_sync(user_id, task_id, "running")
|
|
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
# 获取段落
|
|
stmt = select(Segment).where(Segment.id.in_(segment_ids))
|
|
result = db.execute(stmt)
|
|
segments = result.scalars().all()
|
|
|
|
if not segments:
|
|
logger.warning(f"未找到段落: {segment_ids}")
|
|
return {"status": "no_segments"}
|
|
|
|
# 获取用户状态
|
|
state = _get_or_create_state_sync(user_id, db)
|
|
llm = llm_service.get_llm()
|
|
|
|
# 按阶段分组处理
|
|
stage_to_segments: Dict[str, List[Segment]] = {}
|
|
|
|
for segment in segments:
|
|
text = segment.transcript_text
|
|
detected_stage = _detect_stage(text, state.current_stage)
|
|
|
|
# 尝试使用 LLM 提取信息
|
|
extracted_slots = {}
|
|
if llm:
|
|
try:
|
|
prompt = get_state_extraction_prompt(
|
|
user_message=text,
|
|
current_stage=state.current_stage,
|
|
stage_slots=state.slots.get(detected_stage, {}),
|
|
)
|
|
response = llm.invoke(prompt)
|
|
content = response.content.strip()
|
|
parsed = json.loads(content)
|
|
detected_stage = parsed.get("detected_stage", detected_stage)
|
|
extracted_slots = parsed.get("slots", {}) or {}
|
|
except (json.JSONDecodeError, Exception) as e:
|
|
logger.warning(f"LLM 解析失败: {e}")
|
|
|
|
# 更新 slots
|
|
for slot_name, snippet in extracted_slots.items():
|
|
state = _update_slot_sync(
|
|
user_id=user_id,
|
|
stage=detected_stage,
|
|
slot_name=slot_name,
|
|
snippet=snippet,
|
|
segment_ids=[segment.id],
|
|
db=db,
|
|
)
|
|
|
|
stage_to_segments.setdefault(detected_stage, []).append(segment)
|
|
|
|
# 生成章节内容
|
|
for stage, stage_segments in stage_to_segments.items():
|
|
segment_texts = [seg.transcript_text for seg in stage_segments]
|
|
combined_text = "\n\n".join(segment_texts)
|
|
source_ids = [seg.id for seg in stage_segments]
|
|
|
|
# 查找或创建章节
|
|
stmt_chapter = select(Chapter).where(
|
|
Chapter.user_id == user_id,
|
|
Chapter.category == stage,
|
|
)
|
|
result_chapter = db.execute(stmt_chapter)
|
|
chapter = result_chapter.scalar_one_or_none()
|
|
|
|
# 获取 slot snippets
|
|
slot_snippets = {
|
|
key: value.snippet
|
|
for key, value in (state.slots.get(stage, {}) or {}).items()
|
|
if value.snippet
|
|
}
|
|
|
|
# 生成标题和内容
|
|
title = chapter.title if chapter else f"{stage} 回忆"
|
|
existing_content = chapter.content if chapter else ""
|
|
narrative = combined_text
|
|
|
|
if llm:
|
|
try:
|
|
if not chapter:
|
|
title_prompt = get_creative_title_prompt(
|
|
stage=stage,
|
|
emotion="neutral",
|
|
slots=slot_snippets
|
|
)
|
|
title_response = llm.invoke(title_prompt)
|
|
title = title_response.content.strip().strip('"')
|
|
|
|
narrative_prompt = get_narrative_prompt(
|
|
stage=stage,
|
|
slots=slot_snippets,
|
|
new_content=combined_text,
|
|
existing_content=existing_content,
|
|
)
|
|
narrative_response = llm.invoke(narrative_prompt)
|
|
narrative = narrative_response.content.strip()
|
|
except Exception as e:
|
|
logger.warning(f"LLM 生成失败: {e}")
|
|
if existing_content:
|
|
narrative = f"{existing_content}\n\n{combined_text}"
|
|
|
|
# 更新或创建章节
|
|
if chapter:
|
|
chapter.content = narrative
|
|
chapter.title = title
|
|
chapter.is_new = True
|
|
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
|
|
else:
|
|
chapter = Chapter(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
title=title,
|
|
content=narrative,
|
|
order_index=999,
|
|
status="completed",
|
|
category=stage,
|
|
images=[],
|
|
is_new=True,
|
|
source_segments=source_ids,
|
|
)
|
|
db.add(chapter)
|
|
|
|
db.flush()
|
|
|
|
# 更新 Book
|
|
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
|
result_book = db.execute(stmt_book)
|
|
book = result_book.scalar_one_or_none()
|
|
if not book:
|
|
book = Book(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
title="我的回忆录",
|
|
total_pages=0,
|
|
total_words=0,
|
|
cover_image_url=None,
|
|
)
|
|
db.add(book)
|
|
book.has_update = True
|
|
book.last_update_chapter_id = chapter.id
|
|
|
|
# 标记段落为已处理
|
|
for seg in segments:
|
|
seg.processed = True
|
|
|
|
db.commit()
|
|
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
|
|
|
|
# 更新任务状态为成功
|
|
_update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)})
|
|
|
|
return {"status": "success", "processed": len(segments)}
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"回忆录处理失败: {e}")
|
|
|
|
# 更新任务状态为失败
|
|
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
|
|
|
|
# 重试
|
|
raise self.retry(exc=e)
|
|
|
|
|
|
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
|
|
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
|
|
"""
|
|
单独生成章节内容的任务(用于实时更新)
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
stage: 阶段
|
|
new_content: 新内容
|
|
"""
|
|
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
|
|
|
|
try:
|
|
db = SessionLocal()
|
|
try:
|
|
llm = llm_service.get_llm()
|
|
|
|
# 查找章节
|
|
stmt = select(Chapter).where(
|
|
Chapter.user_id == user_id,
|
|
Chapter.category == stage,
|
|
)
|
|
result = db.execute(stmt)
|
|
chapter = result.scalar_one_or_none()
|
|
|
|
existing_content = chapter.content if chapter else ""
|
|
|
|
if llm:
|
|
prompt = get_narrative_prompt(
|
|
stage=stage,
|
|
slots={},
|
|
new_content=new_content,
|
|
existing_content=existing_content,
|
|
)
|
|
response = llm.invoke(prompt)
|
|
narrative = response.content.strip()
|
|
else:
|
|
narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content
|
|
|
|
if chapter:
|
|
chapter.content = narrative
|
|
chapter.is_new = True
|
|
else:
|
|
chapter = Chapter(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
title=f"{stage} 回忆",
|
|
content=narrative,
|
|
order_index=999,
|
|
status="completed",
|
|
category=stage,
|
|
images=[],
|
|
is_new=True,
|
|
source_segments=[],
|
|
)
|
|
db.add(chapter)
|
|
|
|
db.commit()
|
|
return {"status": "success"}
|
|
|
|
finally:
|
|
db.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"章节生成失败: {e}")
|
|
raise self.retry(exc=e)
|