Files
life-echo/api/app/tasks/story_image_tasks.py
2026-03-23 13:54:41 +08:00

358 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Story 主插图生成 Celery 任务。
从 story_image_intents 原子 claim intent生成图片写入 assets更新 intent。
不读取正文占位符。
"""
import hashlib
import uuid
from datetime import datetime, timedelta, timezone
from celery import shared_task
from PIL import Image
from sqlalchemy import and_, func, or_, select, update
from app.core.db import get_sync_db
from app.core.dependencies import get_image_generator
from app.core.logging import get_logger
from app.core.redis_lock import acquire_redis_lock, release_redis_lock
from app.features.asset.models import Asset
from app.features.memoir.asset_resolver import strip_asset_image_refs_from_markdown
from app.features.memoir.memoir_images.settings import MemoirImageSettings
from app.features.memoir.memoir_images.storage import TencentCosStorageService
from app.features.story.backfill import backfill_image_into_markdown
from app.features.story.models import Story, StoryImageIntent, StoryVersion
from app.ports.image_gen import TaskStatus
logger = get_logger(__name__)
STORY_IMAGE_LOCK_TTL_SECONDS = 1800
STORY_IMAGE_CLAIM_TTL_SECONDS = 1800
def _enqueue_chapter_recompose_for_story(story_id: str) -> None:
"""story 正文因主图回填变更后,标记关联章节 dirty 并异步物化。"""
try:
with get_sync_db() as session:
from app.features.memoir import repo as memoir_repo
memoir_repo.mark_chapters_dirty_for_story_sync(session, story_id)
session.commit()
except Exception as exc:
logger.warning(
"mark_chapters_dirty_for_story_sync failed story=%s: %s", story_id, exc
)
return
try:
from app.tasks.chapter_compose_tasks import recompose_chapters_for_story
recompose_chapters_for_story.delay(story_id)
except Exception as exc:
logger.warning(
"recompose_chapters_for_story.delay failed story=%s: %s", story_id, exc
)
def _build_story_image_cos_key(
user_id: str, story_id: str, intent_id: str, prompt: str
) -> str:
short_hash = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10]
return f"stories/{user_id}/{story_id}/{intent_id}-{short_hash}.png"
def _normalize_image_bytes(image_bytes: bytes) -> bytes:
from io import BytesIO
with Image.open(BytesIO(image_bytes)) as image:
output = BytesIO()
if image.mode in {"RGBA", "LA"}:
normalized = image
elif image.mode == "P":
normalized = image.convert("RGBA")
else:
normalized = image.convert("RGB")
normalized.save(output, format="PNG")
return output.getvalue()
def _build_story_image_prompt(
prompt_brief: str,
story_title: str = "",
story_stage: str | None = None,
style_profile: str | None = None,
) -> str:
"""从 intent.prompt_brief 构建出图 prompt。"""
from app.agents.memoir.prompts import IMAGE_PLACEHOLDER_TEMPLATE
base = IMAGE_PLACEHOLDER_TEMPLATE
if prompt_brief and prompt_brief.strip():
return f"{base}{prompt_brief.strip()}"
fallback = "".join(filter(None, [story_title, story_stage or ""])) or "人生故事"
return f"{base}{fallback}"
def _story_image_claimable_clause(now: datetime):
cutoff = now - timedelta(seconds=STORY_IMAGE_CLAIM_TTL_SECONDS)
return or_(
StoryImageIntent.status.in_(["pending", "failed"]),
and_(
StoryImageIntent.status == "processing",
or_(
StoryImageIntent.claimed_at.is_(None),
StoryImageIntent.claimed_at < cutoff,
),
),
)
def _claim_story_image_intent_sync(db, story_id: str, claim_token: str):
now = datetime.now(timezone.utc)
claimable = _story_image_claimable_clause(now)
candidate_id = db.execute(
select(StoryImageIntent.id)
.where(StoryImageIntent.story_id == story_id)
.where(StoryImageIntent.intent_role == "primary")
.where(claimable)
.order_by(
StoryImageIntent.updated_at.desc(), StoryImageIntent.created_at.desc()
)
.limit(1)
).scalar_one_or_none()
if not candidate_id:
return None
claimed = db.execute(
update(StoryImageIntent)
.where(StoryImageIntent.id == candidate_id)
.where(_story_image_claimable_clause(now))
.values(
status="processing",
claim_token=claim_token,
claimed_at=now,
updated_at=now,
error=None,
attempt_count=func.coalesce(StoryImageIntent.attempt_count, 0) + 1,
)
)
if (claimed.rowcount or 0) != 1:
db.rollback()
return None
row = (
db.execute(
select(StoryImageIntent, Story)
.join(Story, StoryImageIntent.story_id == Story.id)
.where(StoryImageIntent.id == candidate_id)
)
.unique()
.first()
)
db.commit()
return row
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def generate_story_image(self, story_id: str):
"""
为 story 生成主插图。
从 story_image_intents 原子认领 primary intent生成后写入 assets 并更新 intent。
"""
lock_key = f"lock:story-image:{story_id}"
lock_handle = acquire_redis_lock(lock_key, ttl_seconds=STORY_IMAGE_LOCK_TTL_SECONDS)
if lock_handle is None:
logger.debug("generate_story_image: story=%s, reason=locked", story_id)
return {"status": "locked"}
claim_token = uuid.uuid4().hex
intent = None
story = None
try:
with get_sync_db() as db:
row = _claim_story_image_intent_sync(db, story_id, claim_token)
if not row:
logger.debug(
"generate_story_image: story=%s, reason=no_claimable_intent",
story_id,
)
return {"status": "no_intent"}
intent, story = row
img_cfg = MemoirImageSettings.from_env()
min_body = img_cfg.story_image_min_body_chars
if min_body > 0:
plain = strip_asset_image_refs_from_markdown(
story.canonical_markdown or ""
).strip()
if len(plain) < min_body:
with get_sync_db() as db:
intent_db = db.get(StoryImageIntent, intent.id)
if intent_db and (intent_db.status or "").strip() == "processing":
intent_db.status = "skipped"
intent_db.error = f"body_below_min_chars:{len(plain)}"
intent_db.claim_token = None
intent_db.claimed_at = None
intent_db.updated_at = datetime.now(timezone.utc)
db.commit()
logger.info(
"generate_story_image: skipped body too short story=%s len=%s min=%s",
story_id,
len(plain),
min_body,
)
return {"status": "skipped_body_too_short"}
image_generator = get_image_generator()
storage = TencentCosStorageService.from_env()
settings = img_cfg
prompt_final = _build_story_image_prompt(
intent.prompt_brief or "",
story_title=story.title or "",
story_stage=story.stage,
style_profile=intent.style_profile or settings.default_style,
)
result = image_generator.generate(
prompt_final,
settings.default_size,
intent.style_profile or settings.default_style,
)
if result.status != TaskStatus.COMPLETED or not result.image_url:
raise RuntimeError(result.error or "Image generation failed")
image_bytes = _normalize_image_bytes(
image_generator.download_image(result.image_url)
)
cos_key = _build_story_image_cos_key(
story.user_id, story_id, intent.id, prompt_final
)
url = storage.upload_bytes(image_bytes, cos_key, "image/png")
asset_id = str(uuid.uuid4())
with get_sync_db() as db:
intent_db = db.get(StoryImageIntent, intent.id)
if (
not intent_db
or (intent_db.status or "").strip() != "processing"
or (intent_db.claim_token or "").strip() != claim_token
):
logger.debug(
"generate_story_image: skip persist intent=%s status=%s claim=%s",
intent.id,
getattr(intent_db, "status", None),
getattr(intent_db, "claim_token", None),
)
return {"status": "superseded_or_cancelled"}
asset = Asset(
id=asset_id,
asset_type="story_image",
storage_key=cos_key,
url=url,
provider=settings.provider,
style_profile=intent.style_profile or settings.default_style,
prompt_final=prompt_final,
status="completed",
)
db.add(asset)
db.flush()
story_db = db.get(Story, story_id)
target_vid = intent_db.story_version_id or story_db.current_version_id
current_vid = story_db.current_version_id
intent_db.asset_id = asset_id
intent_db.status = "completed"
intent_db.claim_token = None
intent_db.claimed_at = None
intent_db.error = None
intent_db.updated_at = datetime.now(timezone.utc)
db.flush()
# 仅当 intent 仍指向当前版本时回填正文,避免慢任务/重试把图插到新版本上
if not target_vid or target_vid != current_vid:
db.commit()
logger.debug(
"generate_story_image: stale intent skip backfill story=%s "
"intent_ver=%s current=%s url=%s asset=%s",
story_id,
target_vid,
current_vid,
url,
asset_id,
)
return {"status": "success_stale", "asset_id": asset_id}
ver = db.get(StoryVersion, target_vid)
if not ver:
db.commit()
return {"status": "success_no_snapshot", "asset_id": asset_id}
base_md = strip_asset_image_refs_from_markdown(ver.markdown_snapshot or "")
alt_text = (getattr(intent_db, "prompt_brief", None) or "").strip()
if not alt_text:
alt_text = (getattr(intent_db, "caption", None) or "").strip()
backfilled_md = backfill_image_into_markdown(
base_md,
asset_id=asset_id,
image_alt=alt_text or "主插图",
)
max_stmt = select(func.max(StoryVersion.version_no)).where(
StoryVersion.story_id == story_id
)
max_no = db.execute(max_stmt).scalar()
version_no = (max_no or 0) + 1
new_ver = StoryVersion(
id=str(uuid.uuid4()),
story_id=story_id,
version_no=version_no,
markdown_snapshot=backfilled_md,
change_summary="主插图回填",
actor_type="system",
source_type="image_backfill",
parent_version_id=story_db.current_version_id,
)
db.add(new_ver)
db.flush()
story_db.current_version_id = new_ver.id
story_db.canonical_markdown = backfilled_md
db.commit()
_enqueue_chapter_recompose_for_story(story_id)
logger.info(
"generate_story_image: story=%s, asset=%s",
story_id,
asset_id,
)
logger.debug(
"generate_story_image: story=%s asset=%s url=%s cos_key=%s prompt_final=%s",
story_id,
asset_id,
url,
cos_key,
prompt_final,
)
return {"status": "success", "asset_id": asset_id}
except Exception as exc:
if intent is not None:
with get_sync_db() as db:
intent_db = db.get(StoryImageIntent, intent.id)
if (
intent_db
and (intent_db.status or "").strip() != "completed"
and (intent_db.claim_token or "").strip() == claim_token
):
intent_db.status = "failed"
intent_db.claim_token = None
intent_db.claimed_at = None
intent_db.error = str(exc)
intent_db.updated_at = datetime.now(timezone.utc)
db.commit()
logger.warning("generate_story_image failed: story=%s, error=%s", story_id, exc)
raise self.retry(exc=exc) from exc
finally:
release_redis_lock(lock_handle)