Files
life-echo/api/tasks/memoir_tasks.py
Kevin 0970cb7408 fix: 修复 Liblib provider 认证和多个图片生成关键缺陷
- 重写 LiblibImageProvider:Bearer token 改为 HMAC-SHA1 签名认证,
  适配 Liblib 真实 API(Star-3 Alpha 文生图端点)
- 修复 chapter.images JSON 列原地修改不持久化(深拷贝+整列重赋值)
- 修复 generate_chapter_images 在事务提交前派发(改为 commit 后统一 delay)
- 修复 initialize_chapter_images 覆盖已完成图片(新增 merge 去重逻辑)
- 修复 Android failed 图片渲染为错误卡片(改为隐藏,保持正文连续)
- 模型模板 UUID 改为环境变量配置(LIBLIB_TEMPLATE_UUID)
- 更新 .env 凭证格式为 ACCESS_KEY/SECRET_KEY
- 补充 test_memoir_image_bootstrap 缺失的 unittest.mock 导入

Made-with: Cursor
2026-03-10 17:02:50 +08:00

665 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
回忆录处理 Celery 任务
"""
import json
import logging
import os
import uuid
from typing import Dict, List
from datetime import datetime, timezone
import redis
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session
from database.database import SessionLocal
from database.models import Book, Chapter, Segment, MemoirState, User
from services.llm_service import llm_service
from agents.state_schema import MemoirStateSchema, SlotData, default_state
from agents.prompts.memory_prompts import (
get_creative_title_prompt,
get_narrative_prompt,
get_state_extraction_prompt,
get_chapter_classification_prompt,
STAGE_TO_ORDER,
CHAPTER_CATEGORIES,
)
from agents.prompts.profile_prompts import format_user_profile_context
import hashlib
from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders
from services.memoir_images.prompting import MemoirImagePromptService
from services.memoir_images.provider import LiblibImageProvider
from services.memoir_images.settings import MemoirImageSettings
from services.memoir_images.storage import TencentCosStorageService
logger = logging.getLogger(__name__)
def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool:
"""获取章节分布式锁,防止并发写入同一章节"""
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
lock_key = f"lock:chapter:{user_id}:{stage}"
return r.set(lock_key, "1", nx=True, ex=timeout)
def _release_chapter_lock(user_id: str, stage: str):
"""释放章节分布式锁"""
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
lock_key = f"lock:chapter:{user_id}:{stage}"
r.delete(lock_key)
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
"""同步更新任务状态(在 Celery 任务中使用)"""
try:
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = redis.from_url(redis_url, decode_responses=True)
key = f"task:user:{user_id}:tasks"
# 获取现有任务信息
data = r.hget(key, task_id)
if data:
task_info = json.loads(data)
else:
task_info = {"task_id": task_id}
task_info["status"] = status
task_info["updated_at"] = datetime.now(timezone.utc).isoformat()
if result is not None:
task_info["result"] = result
r.hset(key, task_id, json.dumps(task_info))
r.expire(key, 3600) # 1小时过期
logger.info(f"任务状态已更新: task_id={task_id}, status={status}")
except Exception as e:
logger.error(f"更新任务状态失败: {e}")
def _merge_chapter_image_assets(
existing_images: list[dict] | None,
placeholders: list[dict],
provider: str,
style: str,
size: str,
now_iso: str,
) -> list[dict]:
existing_by_placeholder = {
item.get("placeholder"): dict(item)
for item in (existing_images or [])
if item.get("placeholder")
}
merged_assets: list[dict] = []
for item in placeholders:
existing = existing_by_placeholder.get(item["placeholder"])
if existing:
merged_item = dict(existing)
merged_item["index"] = item["index"]
merged_item["placeholder"] = item["placeholder"]
merged_item["description"] = item["description"]
merged_item["provider"] = merged_item.get("provider") or provider
merged_item["style"] = merged_item.get("style") or style
merged_item["size"] = merged_item.get("size") or size
merged_item["created_at"] = merged_item.get("created_at") or now_iso
merged_item["updated_at"] = merged_item.get("updated_at") or now_iso
if merged_item.get("status") == "completed" and not merged_item.get("url"):
merged_item["status"] = "failed"
merged_item["error"] = merged_item.get("error") or "missing image url"
else:
merged_item = build_initial_image_assets(
placeholders=[item],
provider=provider,
style=style,
size=size,
now_iso=now_iso,
)[0]
merged_assets.append(merged_item)
return merged_assets
def chapter_has_images_to_generate(images: list[dict] | None) -> bool:
return any(item.get("status") in {"pending", "failed"} for item in (images or []))
def initialize_chapter_images(chapter) -> list[dict]:
"""Parse IMAGE placeholders from chapter content and build pending image assets."""
settings = MemoirImageSettings.from_env()
if not settings.enabled:
chapter.images = []
return chapter.images
prompt_service = MemoirImagePromptService(llm=None, settings=settings)
placeholders = parse_image_placeholders(chapter.content, settings.max_per_chapter)
style = prompt_service.CATEGORY_STYLE_MAP.get(chapter.category, settings.default_style)
chapter.images = _merge_chapter_image_assets(
existing_images=chapter.images,
placeholders=placeholders,
provider=settings.provider,
style=style,
size=settings.default_size,
now_iso=datetime.now(timezone.utc).isoformat(),
)
return chapter.images
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
"career": ["工作", "职业", "事业", "公司", "同事", "创业"],
"family": ["伴侣", "孩子", "家庭", "家人", "结婚", "父母"],
"belief": ["信念", "价值观", "座右铭", "坚持", "原则"],
}
# 5-stage → 默认 8-category 映射LLM 分类失败时的兜底)
_STAGE_TO_DEFAULT_CATEGORY = {
"childhood": "childhood",
"education": "education",
"career": "career_early",
"family": "family",
"belief": "beliefs",
}
def _detect_stage(user_message: str, fallback_stage: str) -> str:
"""检测消息所属的 5-stage 阶段(用于状态跟踪)"""
message = user_message.lower()
for stage, keywords in STAGE_KEYWORDS.items():
if any(word in message for word in keywords):
return stage
return fallback_stage
def _classify_chapter_category(text: str, fallback_stage: str, llm=None) -> str | None:
"""
将内容分类到 8 个章节类别之一。
优先使用 LLM失败则按 5-stage 关键词映射到默认类别。
如果 LLM 判定内容无实质回忆录价值,返回 None。
"""
if llm:
try:
prompt = get_chapter_classification_prompt(text)
response = llm.invoke(prompt)
category = response.content.strip().lower()
if category == "none":
logger.info(f"LLM 判定内容无回忆录价值,跳过: {text[:80]}...")
return None
if category in CHAPTER_CATEGORIES:
return category
except Exception as e:
logger.warning(f"LLM 章节分类失败: {e}")
stage = _detect_stage(text, fallback_stage)
return _STAGE_TO_DEFAULT_CATEGORY.get(stage, _STAGE_TO_DEFAULT_CATEGORY.get(fallback_stage, "childhood"))
def _coerce_state(model: MemoirState) -> MemoirStateSchema:
"""将数据库模型转换为 Schema"""
return MemoirStateSchema.model_validate(
{
"stage_order": model.stage_order or default_state().stage_order,
"current_stage": model.current_stage,
"covered_stages": model.covered_stages or [],
"slots": model.slots if isinstance(model.slots, dict) else default_state().slots,
}
)
def _get_or_create_state_sync(user_id: str, db: Session) -> MemoirStateSchema:
"""同步获取或创建状态"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if state:
return _coerce_state(state)
default = default_state()
state = MemoirState(
id=str(uuid.uuid4()),
user_id=user_id,
stage_order=default.stage_order,
current_stage=default.current_stage,
covered_stages=default.covered_stages,
slots={k: {sk: sv.model_dump() for sk, sv in v.items()} for k, v in default.slots.items()},
)
db.add(state)
db.commit()
db.refresh(state)
return _coerce_state(state)
def _update_slot_sync(
user_id: str,
stage: str,
slot_name: str,
snippet: str,
segment_ids: List[str],
db: Session,
) -> MemoirStateSchema:
"""同步更新 slot"""
stmt = select(MemoirState).where(MemoirState.user_id == user_id)
result = db.execute(stmt)
state = result.scalar_one_or_none()
if not state:
_get_or_create_state_sync(user_id, db)
result = db.execute(stmt)
state = result.scalar_one()
slots: Dict[str, Dict] = state.slots or {}
stage_slots = slots.get(stage, {})
existing = stage_slots.get(slot_name, {})
merged_segment_ids = list({*(existing.get("segment_ids") or []), *segment_ids})
stage_slots[slot_name] = SlotData(snippet=snippet, segment_ids=merged_segment_ids).model_dump()
slots[stage] = stage_slots
state.slots = slots
state.current_stage = state.current_stage or stage
db.commit()
db.refresh(state)
return _coerce_state(state)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
"""
处理回忆录段落的 Celery 任务
Args:
user_id: 用户 ID
segment_ids: 段落 ID 列表
"""
task_id = self.request.id
logger.info(f"开始处理回忆录段落: user_id={user_id}, task_id={task_id}, segments={len(segment_ids)}")
# 更新任务状态为 running
_update_task_status_sync(user_id, task_id, "running")
try:
db = SessionLocal()
try:
chapters_to_enqueue: set[str] = set()
# 获取段落
stmt = select(Segment).where(Segment.id.in_(segment_ids))
result = db.execute(stmt)
segments = result.scalars().all()
if not segments:
logger.warning(f"未找到段落: {segment_ids}")
return {"status": "no_segments"}
# 获取用户状态和资料
state = _get_or_create_state_sync(user_id, db)
llm = llm_service.get_llm()
user_obj = db.get(User, user_id)
user_profile = ""
user_birth_year = None
if user_obj:
user_birth_year = user_obj.birth_year
user_profile = format_user_profile_context(
birth_year=user_obj.birth_year,
birth_place=user_obj.birth_place,
grew_up_place=user_obj.grew_up_place,
occupation=user_obj.occupation,
)
# 分两步处理:
# 1) 5-stage 状态跟踪slots
# 2) 8-category 章节分类chapter creation
category_to_segments: Dict[str, List[Segment]] = {}
for segment in segments:
text = segment.transcript_text
detected_stage = _detect_stage(text, state.current_stage)
# 提取 slots5-stage 状态跟踪)
extracted_slots = {}
if llm:
try:
prompt = get_state_extraction_prompt(
user_message=text,
current_stage=state.current_stage,
stage_slots=state.slots.get(detected_stage, {}),
)
response = llm.invoke(prompt)
content = response.content.strip()
parsed = json.loads(content)
detected_stage = parsed.get("detected_stage", detected_stage)
extracted_slots = parsed.get("slots", {}) or {}
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"LLM 解析失败: {e}")
for slot_name, snippet in extracted_slots.items():
state = _update_slot_sync(
user_id=user_id,
stage=detected_stage,
slot_name=slot_name,
snippet=snippet,
segment_ids=[segment.id],
db=db,
)
# 8-category 章节分类
chapter_category = _classify_chapter_category(text, detected_stage, llm)
if chapter_category is None:
logger.info(f"段落无回忆录价值,跳过: segment_id={segment.id}")
continue
category_to_segments.setdefault(chapter_category, []).append(segment)
# 按 8 分类生成章节内容
for stage, stage_segments in category_to_segments.items():
if not _acquire_chapter_lock(user_id, stage):
logger.warning(f"章节锁竞争: user={user_id}, stage={stage}, 延迟重试")
raise self.retry(countdown=10)
try:
segment_texts = [seg.transcript_text for seg in stage_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
# 查找 active 章节(被清除的章节不继续更新,而是创建新的)
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
Chapter.is_active == True,
)
result_chapter = db.execute(stmt_chapter)
chapter = result_chapter.scalar_one_or_none()
# 获取 slot snippets
slot_snippets = {
key: value.snippet
for key, value in (state.slots.get(stage, {}) or {}).items()
if value.snippet
}
# 生成标题和内容
title = chapter.title if chapter else f"{stage} 回忆"
existing_content = chapter.content if chapter else ""
narrative = combined_text
if llm:
try:
if not chapter:
title_prompt = get_creative_title_prompt(
stage=stage,
emotion="neutral",
slots=slot_snippets,
user_profile=user_profile,
birth_year=user_birth_year,
)
title_response = llm.invoke(title_prompt)
title = title_response.content.strip().strip('"')
narrative_prompt = get_narrative_prompt(
stage=stage,
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
user_profile=user_profile,
birth_year=user_birth_year,
)
narrative_response = llm.invoke(narrative_prompt)
new_narrative = narrative_response.content.strip()
# 追加而非替换
if existing_content:
narrative = f"{existing_content}\n\n{new_narrative}"
else:
narrative = new_narrative
except Exception as e:
logger.warning(f"LLM 生成失败: {e}")
if existing_content:
narrative = f"{existing_content}\n\n{combined_text}"
# 安全检查:新内容不应比旧内容短
if existing_content and len(narrative) < len(existing_content) * 0.8:
logger.warning(
f"内容长度异常: existing={len(existing_content)}, "
f"new={len(narrative)}, stage={stage}. 回退为追加模式"
)
narrative = f"{existing_content}\n\n{combined_text}"
# 更新或创建章节
if chapter:
chapter.content = narrative
chapter.title = title
chapter.is_new = True
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
# 根据 stage 计算正确的排序索引
calculated_order_index = STAGE_TO_ORDER.get(stage, 999)
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=title,
content=narrative,
order_index=calculated_order_index,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=source_ids,
)
db.add(chapter)
db.flush()
initialize_chapter_images(chapter)
if chapter_has_images_to_generate(chapter.images):
chapters_to_enqueue.add(chapter.id)
# 更新 Book
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
result_book = db.execute(stmt_book)
book = result_book.scalar_one_or_none()
if not book:
book = Book(
id=str(uuid.uuid4()),
user_id=user_id,
title="我的回忆录",
total_pages=0,
total_words=0,
cover_image_url=None,
)
db.add(book)
book.has_update = True
book.last_update_chapter_id = chapter.id
finally:
_release_chapter_lock(user_id, stage)
# 标记段落为已处理
for seg in segments:
seg.processed = True
db.commit()
for chapter_id in sorted(chapters_to_enqueue):
try:
generate_chapter_images.delay(chapter_id)
except Exception as exc:
logger.warning(f"补图任务派发失败: chapter={chapter_id}, error={exc}")
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
# 更新任务状态为成功
_update_task_status_sync(user_id, task_id, "success", {"processed": len(segments)})
return {"status": "success", "processed": len(segments)}
finally:
db.close()
except Exception as e:
logger.error(f"回忆录处理失败: {e}")
# 更新任务状态为失败
_update_task_status_sync(user_id, task_id, "failure", {"error": str(e)})
# 重试
raise self.retry(exc=e)
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_content(self, user_id: str, stage: str, new_content: str):
"""
单独生成章节内容的任务(用于实时更新)
Args:
user_id: 用户 ID
stage: 阶段
new_content: 新内容
"""
logger.info(f"生成章节内容: user_id={user_id}, stage={stage}")
try:
db = SessionLocal()
try:
llm = llm_service.get_llm()
# 查找 active 章节(被清除的章节不继续更新,而是创建新的)
stmt = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
Chapter.is_active == True,
)
result = db.execute(stmt)
chapter = result.scalar_one_or_none()
existing_content = chapter.content if chapter else ""
if llm:
prompt = get_narrative_prompt(
stage=stage,
slots={},
new_content=new_content,
existing_content=existing_content,
)
response = llm.invoke(prompt)
new_narrative = response.content.strip()
# 追加而非替换
if existing_content:
narrative = f"{existing_content}\n\n{new_narrative}"
else:
narrative = new_narrative
else:
narrative = f"{existing_content}\n\n{new_content}" if existing_content else new_content
# 安全检查:新内容不应比旧内容短
if existing_content and len(narrative) < len(existing_content) * 0.8:
logger.warning(
f"内容长度异常: existing={len(existing_content)}, "
f"new={len(narrative)}, stage={stage}. 回退为追加模式"
)
narrative = f"{existing_content}\n\n{new_content}"
if chapter:
chapter.content = narrative
chapter.is_new = True
else:
# 根据 stage 计算正确的排序索引
calculated_order_index = STAGE_TO_ORDER.get(stage, 999)
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
title=f"{stage} 回忆",
content=narrative,
order_index=calculated_order_index,
status="completed",
category=stage,
images=[],
is_new=True,
source_segments=[],
)
db.add(chapter)
db.commit()
return {"status": "success"}
finally:
db.close()
except Exception as e:
logger.error(f"章节生成失败: {e}")
raise self.retry(exc=e)
def build_cos_key(user_id: str, chapter_id: str, index: int, prompt: str) -> str:
short_hash = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10]
return f"memoirs/{user_id}/{chapter_id}/{index}-{short_hash}.png"
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_chapter_images(self, chapter_id: str):
"""Async task to generate images for a chapter's pending image assets."""
db = SessionLocal()
try:
chapter = db.get(Chapter, chapter_id)
if not chapter or not chapter.images:
return {"status": "no_images"}
settings = MemoirImageSettings.from_env()
prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings)
provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid)
storage = TencentCosStorageService.from_env()
images = [dict(item) for item in (chapter.images or [])]
for index, item in enumerate(images):
if item.get("status") == "completed" and item.get("url"):
continue
if item.get("status") not in {"pending", "failed"}:
continue
current_item = dict(item)
current_item["status"] = "processing"
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
images[index] = current_item
chapter.images = images
db.commit()
try:
context_lines = (chapter.content or "").split("\n")
context_excerpt = " ".join(context_lines[:5])[:200]
prompt_data = prompt_service.build_prompt(
chapter_title=chapter.title,
chapter_category=chapter.category or "",
description=item.get("description", ""),
context_excerpt=context_excerpt,
)
job = provider.submit_generation(
prompt=prompt_data["prompt"],
size=prompt_data["size"],
style=prompt_data["style"],
)
if job["status"] != "completed":
job = provider.poll_until_complete(
job,
poll_interval_seconds=settings.poll_interval_seconds,
max_attempts=settings.max_attempts,
)
image_bytes = provider.download_image(job)
key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"])
current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png")
current_item["prompt"] = prompt_data["prompt"]
current_item["style"] = prompt_data["style"]
current_item["size"] = prompt_data["size"]
current_item["status"] = "completed"
current_item["error"] = None
except Exception as exc:
current_item["status"] = "failed"
current_item["error"] = str(exc)
logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}")
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
images[index] = current_item
chapter.images = images
db.commit()
return {"status": "success"}
finally:
db.close()