fix: harden memoir image generation flow

This commit is contained in:
Kevin
2026-03-11 11:26:42 +08:00
parent a76cf8da18
commit 00092d34c9
14 changed files with 1162 additions and 69 deletions

View File

@@ -5,11 +5,13 @@ import json
import logging
import os
import uuid
from io import BytesIO
from typing import Dict, List
from datetime import datetime, timezone
import redis
from celery import shared_task
from PIL import Image
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -31,31 +33,64 @@ import hashlib
from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders
from services.memoir_images.prompting import MemoirImagePromptService
from services.memoir_images.provider import LiblibImageProvider
from services.memoir_images.schema import (
completed_image_assets,
IMAGE_STATUS_COMPLETED,
IMAGE_STATUS_FAILED,
IMAGE_STATUS_PENDING,
IMAGE_STATUS_PROCESSING,
normalize_image_assets,
)
from services.memoir_images.settings import MemoirImageSettings
from services.memoir_images.storage import TencentCosStorageService
logger = logging.getLogger(__name__)
_REDIS_CLIENTS: dict[bool, redis.Redis] = {}
def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis:
client = _REDIS_CLIENTS.get(decode_responses)
if client is None:
client = redis.from_url(
os.getenv("REDIS_URL", "redis://localhost:6379/0"),
decode_responses=decode_responses,
)
_REDIS_CLIENTS[decode_responses] = client
return client
def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool:
"""获取章节分布式锁,防止并发写入同一章节"""
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
r = _get_redis_client()
lock_key = f"lock:chapter:{user_id}:{stage}"
return r.set(lock_key, "1", nx=True, ex=timeout)
def _release_chapter_lock(user_id: str, stage: str):
"""释放章节分布式锁"""
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
r = _get_redis_client()
lock_key = f"lock:chapter:{user_id}:{stage}"
r.delete(lock_key)
def _acquire_chapter_image_lock(chapter_id: str, timeout: int = 600) -> bool:
"""获取章节补图分布式锁,避免同一章节重复补图。"""
r = _get_redis_client()
lock_key = f"lock:chapter-images:{chapter_id}"
return r.set(lock_key, "1", nx=True, ex=timeout)
def _release_chapter_image_lock(chapter_id: str):
"""释放章节补图分布式锁。"""
r = _get_redis_client()
lock_key = f"lock:chapter-images:{chapter_id}"
r.delete(lock_key)
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
"""同步更新任务状态(在 Celery 任务中使用)"""
try:
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
r = redis.from_url(redis_url, decode_responses=True)
r = _get_redis_client(decode_responses=True)
key = f"task:user:{user_id}:tasks"
@@ -87,9 +122,10 @@ def _merge_chapter_image_assets(
size: str,
now_iso: str,
) -> list[dict]:
normalized_existing_images = normalize_image_assets(existing_images)
existing_by_placeholder = {
item.get("placeholder"): dict(item)
for item in (existing_images or [])
for item in normalized_existing_images
if item.get("placeholder")
}
merged_assets: list[dict] = []
@@ -106,10 +142,10 @@ def _merge_chapter_image_assets(
merged_item["size"] = merged_item.get("size") or size
merged_item["created_at"] = merged_item.get("created_at") or now_iso
merged_item["updated_at"] = merged_item.get("updated_at") or now_iso
if merged_item.get("status") == "completed" and not (
if merged_item.get("status") == IMAGE_STATUS_COMPLETED and not (
merged_item.get("storage_key") or merged_item.get("url")
):
merged_item["status"] = "failed"
merged_item["status"] = IMAGE_STATUS_FAILED
merged_item["error"] = merged_item.get("error") or "missing image url"
else:
merged_item = build_initial_image_assets(
@@ -125,14 +161,17 @@ def _merge_chapter_image_assets(
def chapter_has_images_to_generate(images: list[dict] | None) -> bool:
return any(item.get("status") in {"pending", "failed"} for item in (images or []))
return any(
item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}
for item in normalize_image_assets(images)
)
def initialize_chapter_images(chapter) -> list[dict]:
"""Parse IMAGE placeholders from chapter content and build pending image assets."""
settings = MemoirImageSettings.from_env()
if not settings.enabled:
chapter.images = []
chapter.images = completed_image_assets(chapter.images)
logger.info(f"章节图片初始化跳过: chapter={chapter.id}, enabled=false")
return chapter.images
@@ -157,6 +196,19 @@ def initialize_chapter_images(chapter) -> list[dict]:
return chapter.images
def _normalize_image_bytes_for_storage(image_bytes: bytes) -> bytes:
with Image.open(BytesIO(image_bytes)) as image:
output = BytesIO()
if image.mode in {"RGBA", "LA"}:
normalized = image
elif image.mode == "P":
normalized = image.convert("RGBA")
else:
normalized = image.convert("RGB")
normalized.save(output, format="PNG")
return output.getvalue()
STAGE_KEYWORDS = {
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
@@ -304,6 +356,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
# 获取用户状态和资料
state = _get_or_create_state_sync(user_id, db)
llm = llm_service.get_llm()
image_settings = MemoirImageSettings.from_env()
user_obj = db.get(User, user_id)
user_profile = ""
@@ -361,19 +414,19 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
category_to_segments.setdefault(chapter_category, []).append(segment)
# 按 8 分类生成章节内容
for stage, stage_segments in category_to_segments.items():
if not _acquire_chapter_lock(user_id, stage):
logger.warning(f"章节锁竞争: user={user_id}, stage={stage}, 延迟重试")
for chapter_category, category_segments in category_to_segments.items():
if not _acquire_chapter_lock(user_id, chapter_category):
logger.warning(f"章节锁竞争: user={user_id}, category={chapter_category}, 延迟重试")
raise self.retry(countdown=10)
try:
segment_texts = [seg.transcript_text for seg in stage_segments]
segment_texts = [seg.transcript_text for seg in category_segments]
combined_text = "\n\n".join(segment_texts)
source_ids = [seg.id for seg in stage_segments]
source_ids = [seg.id for seg in category_segments]
# 查找 active 章节(被清除的章节不继续更新,而是创建新的)
stmt_chapter = select(Chapter).where(
Chapter.user_id == user_id,
Chapter.category == stage,
Chapter.category == chapter_category,
Chapter.is_active == True,
)
result_chapter = db.execute(stmt_chapter)
@@ -382,12 +435,12 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
# 获取 slot snippets
slot_snippets = {
key: value.snippet
for key, value in (state.slots.get(stage, {}) or {}).items()
for key, value in (state.slots.get(chapter_category, {}) or {}).items()
if value.snippet
}
# 生成标题和内容
title = chapter.title if chapter else f"{stage} 回忆"
title = chapter.title if chapter else f"{chapter_category} 回忆"
existing_content = chapter.content if chapter else ""
narrative = combined_text
@@ -395,7 +448,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
try:
if not chapter:
title_prompt = get_creative_title_prompt(
stage=stage,
stage=chapter_category,
emotion="neutral",
slots=slot_snippets,
user_profile=user_profile,
@@ -405,7 +458,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
title = title_response.content.strip().strip('"')
narrative_prompt = get_narrative_prompt(
stage=stage,
stage=chapter_category,
slots=slot_snippets,
new_content=combined_text,
existing_content=existing_content,
@@ -429,7 +482,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
if existing_content and len(narrative) < len(existing_content) * 0.8:
logger.warning(
f"内容长度异常: existing={len(existing_content)}, "
f"new={len(narrative)}, stage={stage}. 回退为追加模式"
f"new={len(narrative)}, category={chapter_category}. 回退为追加模式"
)
narrative = f"{existing_content}\n\n{combined_text}"
@@ -441,7 +494,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
else:
# 根据 stage 计算正确的排序索引
calculated_order_index = STAGE_TO_ORDER.get(stage, 999)
calculated_order_index = STAGE_TO_ORDER.get(chapter_category, 999)
chapter = Chapter(
id=str(uuid.uuid4()),
user_id=user_id,
@@ -449,7 +502,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
content=narrative,
order_index=calculated_order_index,
status="completed",
category=stage,
category=chapter_category,
images=[],
is_new=True,
source_segments=source_ids,
@@ -459,7 +512,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
db.flush()
initialize_chapter_images(chapter)
if chapter_has_images_to_generate(chapter.images):
if image_settings.enabled and chapter_has_images_to_generate(chapter.images):
chapters_to_enqueue.add(chapter.id)
# 更新 Book
@@ -479,7 +532,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
book.has_update = True
book.last_update_chapter_id = chapter.id
finally:
_release_chapter_lock(user_id, stage)
_release_chapter_lock(user_id, chapter_category)
# 标记段落为已处理
for seg in segments:
@@ -607,6 +660,8 @@ def build_cos_key(user_id: str, chapter_id: str, index: int, prompt: str) -> str
def generate_chapter_images(self, chapter_id: str):
"""Async task to generate images for a chapter's pending image assets."""
db = SessionLocal()
lock_acquired = False
provider = None
try:
chapter = db.get(Chapter, chapter_id)
if not chapter or not chapter.images:
@@ -614,26 +669,40 @@ def generate_chapter_images(self, chapter_id: str):
return {"status": "no_images"}
settings = MemoirImageSettings.from_env()
if not settings.enabled:
chapter.images = completed_image_assets(chapter.images)
db.commit()
logger.info("章节补图跳过: chapter=%s, reason=disabled", chapter_id)
return {"status": "disabled"}
lock_acquired = _acquire_chapter_image_lock(chapter_id)
if not lock_acquired:
logger.info("章节补图跳过: chapter=%s, reason=locked", chapter_id)
return {"status": "locked"}
prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings)
provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid)
storage = TencentCosStorageService.from_env()
images = [dict(item) for item in (chapter.images or [])]
pending_count = sum(1 for item in images if item.get("status") in {"pending", "failed"})
images = normalize_image_assets(chapter.images)
pending_count = sum(
1 for item in images if item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}
)
logger.info(
"章节补图开始: chapter=%s, total_images=%d, pending_images=%d",
chapter_id,
len(images),
pending_count,
)
failures: list[str] = []
for index, item in enumerate(images):
if item.get("status") == "completed" and (item.get("storage_key") or item.get("url")):
if item.get("status") == IMAGE_STATUS_COMPLETED and (item.get("storage_key") or item.get("url")):
continue
if item.get("status") not in {"pending", "failed"}:
if item.get("status") not in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}:
continue
current_item = dict(item)
current_item["status"] = "processing"
current_item["status"] = IMAGE_STATUS_PROCESSING
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
images[index] = current_item
chapter.images = images
@@ -660,14 +729,14 @@ def generate_chapter_images(self, chapter_id: str):
poll_interval_seconds=settings.poll_interval_seconds,
max_attempts=settings.max_attempts,
)
image_bytes = provider.download_image(job)
image_bytes = _normalize_image_bytes_for_storage(provider.download_image(job))
key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"])
current_item["storage_key"] = key
current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png")
current_item["prompt"] = prompt_data["prompt"]
current_item["style"] = prompt_data["style"]
current_item["size"] = prompt_data["size"]
current_item["status"] = "completed"
current_item["status"] = IMAGE_STATUS_COMPLETED
current_item["error"] = None
logger.info(
"章节补图成功: chapter=%s, index=%s, url=%s",
@@ -676,8 +745,9 @@ def generate_chapter_images(self, chapter_id: str):
current_item["url"],
)
except Exception as exc:
current_item["status"] = "failed"
current_item["status"] = IMAGE_STATUS_FAILED
current_item["error"] = str(exc)
failures.append(f"index={current_item.get('index')}, error={exc}")
logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}")
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
@@ -685,6 +755,18 @@ def generate_chapter_images(self, chapter_id: str):
chapter.images = images
db.commit()
if failures:
raise RuntimeError(
f"章节补图存在失败项: chapter={chapter_id}, failures={'; '.join(failures)}"
)
return {"status": "success"}
except Exception as exc:
logger.error("章节补图任务失败: chapter=%s, error=%s", chapter_id, exc)
raise self.retry(exc=exc)
finally:
if provider:
provider.close()
if lock_acquired:
_release_chapter_image_lock(chapter_id)
db.close()