fix: harden memoir image generation flow
This commit is contained in:
@@ -5,11 +5,13 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from PIL import Image
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -31,31 +33,64 @@ import hashlib
|
||||
from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders
|
||||
from services.memoir_images.prompting import MemoirImagePromptService
|
||||
from services.memoir_images.provider import LiblibImageProvider
|
||||
from services.memoir_images.schema import (
|
||||
completed_image_assets,
|
||||
IMAGE_STATUS_COMPLETED,
|
||||
IMAGE_STATUS_FAILED,
|
||||
IMAGE_STATUS_PENDING,
|
||||
IMAGE_STATUS_PROCESSING,
|
||||
normalize_image_assets,
|
||||
)
|
||||
from services.memoir_images.settings import MemoirImageSettings
|
||||
from services.memoir_images.storage import TencentCosStorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_REDIS_CLIENTS: dict[bool, redis.Redis] = {}
|
||||
|
||||
|
||||
def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis:
|
||||
client = _REDIS_CLIENTS.get(decode_responses)
|
||||
if client is None:
|
||||
client = redis.from_url(
|
||||
os.getenv("REDIS_URL", "redis://localhost:6379/0"),
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
_REDIS_CLIENTS[decode_responses] = client
|
||||
return client
|
||||
|
||||
|
||||
def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool:
|
||||
"""获取章节分布式锁,防止并发写入同一章节"""
|
||||
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
|
||||
r = _get_redis_client()
|
||||
lock_key = f"lock:chapter:{user_id}:{stage}"
|
||||
return r.set(lock_key, "1", nx=True, ex=timeout)
|
||||
|
||||
|
||||
def _release_chapter_lock(user_id: str, stage: str):
|
||||
"""释放章节分布式锁"""
|
||||
r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0"))
|
||||
r = _get_redis_client()
|
||||
lock_key = f"lock:chapter:{user_id}:{stage}"
|
||||
r.delete(lock_key)
|
||||
|
||||
|
||||
def _acquire_chapter_image_lock(chapter_id: str, timeout: int = 600) -> bool:
|
||||
"""获取章节补图分布式锁,避免同一章节重复补图。"""
|
||||
r = _get_redis_client()
|
||||
lock_key = f"lock:chapter-images:{chapter_id}"
|
||||
return r.set(lock_key, "1", nx=True, ex=timeout)
|
||||
|
||||
|
||||
def _release_chapter_image_lock(chapter_id: str):
|
||||
"""释放章节补图分布式锁。"""
|
||||
r = _get_redis_client()
|
||||
lock_key = f"lock:chapter-images:{chapter_id}"
|
||||
r.delete(lock_key)
|
||||
|
||||
|
||||
def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None):
|
||||
"""同步更新任务状态(在 Celery 任务中使用)"""
|
||||
try:
|
||||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||||
r = redis.from_url(redis_url, decode_responses=True)
|
||||
r = _get_redis_client(decode_responses=True)
|
||||
|
||||
key = f"task:user:{user_id}:tasks"
|
||||
|
||||
@@ -87,9 +122,10 @@ def _merge_chapter_image_assets(
|
||||
size: str,
|
||||
now_iso: str,
|
||||
) -> list[dict]:
|
||||
normalized_existing_images = normalize_image_assets(existing_images)
|
||||
existing_by_placeholder = {
|
||||
item.get("placeholder"): dict(item)
|
||||
for item in (existing_images or [])
|
||||
for item in normalized_existing_images
|
||||
if item.get("placeholder")
|
||||
}
|
||||
merged_assets: list[dict] = []
|
||||
@@ -106,10 +142,10 @@ def _merge_chapter_image_assets(
|
||||
merged_item["size"] = merged_item.get("size") or size
|
||||
merged_item["created_at"] = merged_item.get("created_at") or now_iso
|
||||
merged_item["updated_at"] = merged_item.get("updated_at") or now_iso
|
||||
if merged_item.get("status") == "completed" and not (
|
||||
if merged_item.get("status") == IMAGE_STATUS_COMPLETED and not (
|
||||
merged_item.get("storage_key") or merged_item.get("url")
|
||||
):
|
||||
merged_item["status"] = "failed"
|
||||
merged_item["status"] = IMAGE_STATUS_FAILED
|
||||
merged_item["error"] = merged_item.get("error") or "missing image url"
|
||||
else:
|
||||
merged_item = build_initial_image_assets(
|
||||
@@ -125,14 +161,17 @@ def _merge_chapter_image_assets(
|
||||
|
||||
|
||||
def chapter_has_images_to_generate(images: list[dict] | None) -> bool:
|
||||
return any(item.get("status") in {"pending", "failed"} for item in (images or []))
|
||||
return any(
|
||||
item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}
|
||||
for item in normalize_image_assets(images)
|
||||
)
|
||||
|
||||
|
||||
def initialize_chapter_images(chapter) -> list[dict]:
|
||||
"""Parse IMAGE placeholders from chapter content and build pending image assets."""
|
||||
settings = MemoirImageSettings.from_env()
|
||||
if not settings.enabled:
|
||||
chapter.images = []
|
||||
chapter.images = completed_image_assets(chapter.images)
|
||||
logger.info(f"章节图片初始化跳过: chapter={chapter.id}, enabled=false")
|
||||
return chapter.images
|
||||
|
||||
@@ -157,6 +196,19 @@ def initialize_chapter_images(chapter) -> list[dict]:
|
||||
return chapter.images
|
||||
|
||||
|
||||
def _normalize_image_bytes_for_storage(image_bytes: bytes) -> bytes:
|
||||
with Image.open(BytesIO(image_bytes)) as image:
|
||||
output = BytesIO()
|
||||
if image.mode in {"RGBA", "LA"}:
|
||||
normalized = image
|
||||
elif image.mode == "P":
|
||||
normalized = image.convert("RGBA")
|
||||
else:
|
||||
normalized = image.convert("RGB")
|
||||
normalized.save(output, format="PNG")
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
STAGE_KEYWORDS = {
|
||||
"childhood": ["童年", "小时候", "出生", "家乡", "小镇"],
|
||||
"education": ["上学", "学校", "老师", "同学", "教育", "大学"],
|
||||
@@ -304,6 +356,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
# 获取用户状态和资料
|
||||
state = _get_or_create_state_sync(user_id, db)
|
||||
llm = llm_service.get_llm()
|
||||
image_settings = MemoirImageSettings.from_env()
|
||||
|
||||
user_obj = db.get(User, user_id)
|
||||
user_profile = ""
|
||||
@@ -361,19 +414,19 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
category_to_segments.setdefault(chapter_category, []).append(segment)
|
||||
|
||||
# 按 8 分类生成章节内容
|
||||
for stage, stage_segments in category_to_segments.items():
|
||||
if not _acquire_chapter_lock(user_id, stage):
|
||||
logger.warning(f"章节锁竞争: user={user_id}, stage={stage}, 延迟重试")
|
||||
for chapter_category, category_segments in category_to_segments.items():
|
||||
if not _acquire_chapter_lock(user_id, chapter_category):
|
||||
logger.warning(f"章节锁竞争: user={user_id}, category={chapter_category}, 延迟重试")
|
||||
raise self.retry(countdown=10)
|
||||
try:
|
||||
segment_texts = [seg.transcript_text for seg in stage_segments]
|
||||
segment_texts = [seg.transcript_text for seg in category_segments]
|
||||
combined_text = "\n\n".join(segment_texts)
|
||||
source_ids = [seg.id for seg in stage_segments]
|
||||
source_ids = [seg.id for seg in category_segments]
|
||||
|
||||
# 查找 active 章节(被清除的章节不继续更新,而是创建新的)
|
||||
stmt_chapter = select(Chapter).where(
|
||||
Chapter.user_id == user_id,
|
||||
Chapter.category == stage,
|
||||
Chapter.category == chapter_category,
|
||||
Chapter.is_active == True,
|
||||
)
|
||||
result_chapter = db.execute(stmt_chapter)
|
||||
@@ -382,12 +435,12 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
# 获取 slot snippets
|
||||
slot_snippets = {
|
||||
key: value.snippet
|
||||
for key, value in (state.slots.get(stage, {}) or {}).items()
|
||||
for key, value in (state.slots.get(chapter_category, {}) or {}).items()
|
||||
if value.snippet
|
||||
}
|
||||
|
||||
# 生成标题和内容
|
||||
title = chapter.title if chapter else f"{stage} 回忆"
|
||||
title = chapter.title if chapter else f"{chapter_category} 回忆"
|
||||
existing_content = chapter.content if chapter else ""
|
||||
narrative = combined_text
|
||||
|
||||
@@ -395,7 +448,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
try:
|
||||
if not chapter:
|
||||
title_prompt = get_creative_title_prompt(
|
||||
stage=stage,
|
||||
stage=chapter_category,
|
||||
emotion="neutral",
|
||||
slots=slot_snippets,
|
||||
user_profile=user_profile,
|
||||
@@ -405,7 +458,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
title = title_response.content.strip().strip('"')
|
||||
|
||||
narrative_prompt = get_narrative_prompt(
|
||||
stage=stage,
|
||||
stage=chapter_category,
|
||||
slots=slot_snippets,
|
||||
new_content=combined_text,
|
||||
existing_content=existing_content,
|
||||
@@ -429,7 +482,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
if existing_content and len(narrative) < len(existing_content) * 0.8:
|
||||
logger.warning(
|
||||
f"内容长度异常: existing={len(existing_content)}, "
|
||||
f"new={len(narrative)}, stage={stage}. 回退为追加模式"
|
||||
f"new={len(narrative)}, category={chapter_category}. 回退为追加模式"
|
||||
)
|
||||
narrative = f"{existing_content}\n\n{combined_text}"
|
||||
|
||||
@@ -441,7 +494,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
chapter.source_segments = list({*(chapter.source_segments or []), *source_ids})
|
||||
else:
|
||||
# 根据 stage 计算正确的排序索引
|
||||
calculated_order_index = STAGE_TO_ORDER.get(stage, 999)
|
||||
calculated_order_index = STAGE_TO_ORDER.get(chapter_category, 999)
|
||||
chapter = Chapter(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
@@ -449,7 +502,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
content=narrative,
|
||||
order_index=calculated_order_index,
|
||||
status="completed",
|
||||
category=stage,
|
||||
category=chapter_category,
|
||||
images=[],
|
||||
is_new=True,
|
||||
source_segments=source_ids,
|
||||
@@ -459,7 +512,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
db.flush()
|
||||
|
||||
initialize_chapter_images(chapter)
|
||||
if chapter_has_images_to_generate(chapter.images):
|
||||
if image_settings.enabled and chapter_has_images_to_generate(chapter.images):
|
||||
chapters_to_enqueue.add(chapter.id)
|
||||
|
||||
# 更新 Book
|
||||
@@ -479,7 +532,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
book.has_update = True
|
||||
book.last_update_chapter_id = chapter.id
|
||||
finally:
|
||||
_release_chapter_lock(user_id, stage)
|
||||
_release_chapter_lock(user_id, chapter_category)
|
||||
|
||||
# 标记段落为已处理
|
||||
for seg in segments:
|
||||
@@ -607,6 +660,8 @@ def build_cos_key(user_id: str, chapter_id: str, index: int, prompt: str) -> str
|
||||
def generate_chapter_images(self, chapter_id: str):
|
||||
"""Async task to generate images for a chapter's pending image assets."""
|
||||
db = SessionLocal()
|
||||
lock_acquired = False
|
||||
provider = None
|
||||
try:
|
||||
chapter = db.get(Chapter, chapter_id)
|
||||
if not chapter or not chapter.images:
|
||||
@@ -614,26 +669,40 @@ def generate_chapter_images(self, chapter_id: str):
|
||||
return {"status": "no_images"}
|
||||
|
||||
settings = MemoirImageSettings.from_env()
|
||||
if not settings.enabled:
|
||||
chapter.images = completed_image_assets(chapter.images)
|
||||
db.commit()
|
||||
logger.info("章节补图跳过: chapter=%s, reason=disabled", chapter_id)
|
||||
return {"status": "disabled"}
|
||||
|
||||
lock_acquired = _acquire_chapter_image_lock(chapter_id)
|
||||
if not lock_acquired:
|
||||
logger.info("章节补图跳过: chapter=%s, reason=locked", chapter_id)
|
||||
return {"status": "locked"}
|
||||
|
||||
prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings)
|
||||
provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid)
|
||||
storage = TencentCosStorageService.from_env()
|
||||
images = [dict(item) for item in (chapter.images or [])]
|
||||
pending_count = sum(1 for item in images if item.get("status") in {"pending", "failed"})
|
||||
images = normalize_image_assets(chapter.images)
|
||||
pending_count = sum(
|
||||
1 for item in images if item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}
|
||||
)
|
||||
logger.info(
|
||||
"章节补图开始: chapter=%s, total_images=%d, pending_images=%d",
|
||||
chapter_id,
|
||||
len(images),
|
||||
pending_count,
|
||||
)
|
||||
failures: list[str] = []
|
||||
|
||||
for index, item in enumerate(images):
|
||||
if item.get("status") == "completed" and (item.get("storage_key") or item.get("url")):
|
||||
if item.get("status") == IMAGE_STATUS_COMPLETED and (item.get("storage_key") or item.get("url")):
|
||||
continue
|
||||
if item.get("status") not in {"pending", "failed"}:
|
||||
if item.get("status") not in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}:
|
||||
continue
|
||||
|
||||
current_item = dict(item)
|
||||
current_item["status"] = "processing"
|
||||
current_item["status"] = IMAGE_STATUS_PROCESSING
|
||||
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
images[index] = current_item
|
||||
chapter.images = images
|
||||
@@ -660,14 +729,14 @@ def generate_chapter_images(self, chapter_id: str):
|
||||
poll_interval_seconds=settings.poll_interval_seconds,
|
||||
max_attempts=settings.max_attempts,
|
||||
)
|
||||
image_bytes = provider.download_image(job)
|
||||
image_bytes = _normalize_image_bytes_for_storage(provider.download_image(job))
|
||||
key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"])
|
||||
current_item["storage_key"] = key
|
||||
current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png")
|
||||
current_item["prompt"] = prompt_data["prompt"]
|
||||
current_item["style"] = prompt_data["style"]
|
||||
current_item["size"] = prompt_data["size"]
|
||||
current_item["status"] = "completed"
|
||||
current_item["status"] = IMAGE_STATUS_COMPLETED
|
||||
current_item["error"] = None
|
||||
logger.info(
|
||||
"章节补图成功: chapter=%s, index=%s, url=%s",
|
||||
@@ -676,8 +745,9 @@ def generate_chapter_images(self, chapter_id: str):
|
||||
current_item["url"],
|
||||
)
|
||||
except Exception as exc:
|
||||
current_item["status"] = "failed"
|
||||
current_item["status"] = IMAGE_STATUS_FAILED
|
||||
current_item["error"] = str(exc)
|
||||
failures.append(f"index={current_item.get('index')}, error={exc}")
|
||||
logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}")
|
||||
|
||||
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
@@ -685,6 +755,18 @@ def generate_chapter_images(self, chapter_id: str):
|
||||
chapter.images = images
|
||||
db.commit()
|
||||
|
||||
if failures:
|
||||
raise RuntimeError(
|
||||
f"章节补图存在失败项: chapter={chapter_id}, failures={'; '.join(failures)}"
|
||||
)
|
||||
|
||||
return {"status": "success"}
|
||||
except Exception as exc:
|
||||
logger.error("章节补图任务失败: chapter=%s, error=%s", chapter_id, exc)
|
||||
raise self.retry(exc=exc)
|
||||
finally:
|
||||
if provider:
|
||||
provider.close()
|
||||
if lock_acquired:
|
||||
_release_chapter_image_lock(chapter_id)
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user