358 lines
13 KiB
Python
358 lines
13 KiB
Python
"""
|
||
Story 主插图生成 Celery 任务。
|
||
|
||
从 story_image_intents 原子 claim intent,生成图片,写入 assets,更新 intent。
|
||
不读取正文占位符。
|
||
"""
|
||
|
||
import hashlib
|
||
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from celery import shared_task
|
||
from PIL import Image
|
||
from sqlalchemy import and_, func, or_, select, update
|
||
|
||
from app.core.db import get_sync_db
|
||
from app.core.dependencies import get_image_generator
|
||
from app.core.logging import get_logger
|
||
from app.core.redis_lock import acquire_redis_lock, release_redis_lock
|
||
from app.features.asset.models import Asset
|
||
from app.features.memoir.asset_resolver import strip_asset_image_refs_from_markdown
|
||
from app.features.memoir.memoir_images.settings import MemoirImageSettings
|
||
from app.features.memoir.memoir_images.storage import TencentCosStorageService
|
||
from app.features.story.backfill import backfill_image_into_markdown
|
||
from app.features.story.models import Story, StoryImageIntent, StoryVersion
|
||
from app.ports.image_gen import TaskStatus
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
STORY_IMAGE_LOCK_TTL_SECONDS = 1800
|
||
STORY_IMAGE_CLAIM_TTL_SECONDS = 1800
|
||
|
||
|
||
def _enqueue_chapter_recompose_for_story(story_id: str) -> None:
|
||
"""story 正文因主图回填变更后,标记关联章节 dirty 并异步物化。"""
|
||
try:
|
||
with get_sync_db() as session:
|
||
from app.features.memoir import repo as memoir_repo
|
||
|
||
memoir_repo.mark_chapters_dirty_for_story_sync(session, story_id)
|
||
session.commit()
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"mark_chapters_dirty_for_story_sync failed story=%s: %s", story_id, exc
|
||
)
|
||
return
|
||
try:
|
||
from app.tasks.chapter_compose_tasks import recompose_chapters_for_story
|
||
|
||
recompose_chapters_for_story.delay(story_id)
|
||
except Exception as exc:
|
||
logger.warning(
|
||
"recompose_chapters_for_story.delay failed story=%s: %s", story_id, exc
|
||
)
|
||
|
||
|
||
def _build_story_image_cos_key(
|
||
user_id: str, story_id: str, intent_id: str, prompt: str
|
||
) -> str:
|
||
short_hash = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10]
|
||
return f"stories/{user_id}/{story_id}/{intent_id}-{short_hash}.png"
|
||
|
||
|
||
def _normalize_image_bytes(image_bytes: bytes) -> bytes:
|
||
from io import BytesIO
|
||
|
||
with Image.open(BytesIO(image_bytes)) as image:
|
||
output = BytesIO()
|
||
if image.mode in {"RGBA", "LA"}:
|
||
normalized = image
|
||
elif image.mode == "P":
|
||
normalized = image.convert("RGBA")
|
||
else:
|
||
normalized = image.convert("RGB")
|
||
normalized.save(output, format="PNG")
|
||
return output.getvalue()
|
||
|
||
|
||
def _build_story_image_prompt(
|
||
prompt_brief: str,
|
||
story_title: str = "",
|
||
story_stage: str | None = None,
|
||
style_profile: str | None = None,
|
||
) -> str:
|
||
"""从 intent.prompt_brief 构建出图 prompt。"""
|
||
from app.agents.memoir.prompts import IMAGE_PLACEHOLDER_TEMPLATE
|
||
|
||
base = IMAGE_PLACEHOLDER_TEMPLATE
|
||
if prompt_brief and prompt_brief.strip():
|
||
return f"{base}。{prompt_brief.strip()}"
|
||
fallback = ",".join(filter(None, [story_title, story_stage or ""])) or "人生故事"
|
||
return f"{base}。{fallback}"
|
||
|
||
|
||
def _story_image_claimable_clause(now: datetime):
|
||
cutoff = now - timedelta(seconds=STORY_IMAGE_CLAIM_TTL_SECONDS)
|
||
return or_(
|
||
StoryImageIntent.status.in_(["pending", "failed"]),
|
||
and_(
|
||
StoryImageIntent.status == "processing",
|
||
or_(
|
||
StoryImageIntent.claimed_at.is_(None),
|
||
StoryImageIntent.claimed_at < cutoff,
|
||
),
|
||
),
|
||
)
|
||
|
||
|
||
def _claim_story_image_intent_sync(db, story_id: str, claim_token: str):
|
||
now = datetime.now(timezone.utc)
|
||
claimable = _story_image_claimable_clause(now)
|
||
candidate_id = db.execute(
|
||
select(StoryImageIntent.id)
|
||
.where(StoryImageIntent.story_id == story_id)
|
||
.where(StoryImageIntent.intent_role == "primary")
|
||
.where(claimable)
|
||
.order_by(
|
||
StoryImageIntent.updated_at.desc(), StoryImageIntent.created_at.desc()
|
||
)
|
||
.limit(1)
|
||
).scalar_one_or_none()
|
||
if not candidate_id:
|
||
return None
|
||
|
||
claimed = db.execute(
|
||
update(StoryImageIntent)
|
||
.where(StoryImageIntent.id == candidate_id)
|
||
.where(_story_image_claimable_clause(now))
|
||
.values(
|
||
status="processing",
|
||
claim_token=claim_token,
|
||
claimed_at=now,
|
||
updated_at=now,
|
||
error=None,
|
||
attempt_count=func.coalesce(StoryImageIntent.attempt_count, 0) + 1,
|
||
)
|
||
)
|
||
if (claimed.rowcount or 0) != 1:
|
||
db.rollback()
|
||
return None
|
||
|
||
row = (
|
||
db.execute(
|
||
select(StoryImageIntent, Story)
|
||
.join(Story, StoryImageIntent.story_id == Story.id)
|
||
.where(StoryImageIntent.id == candidate_id)
|
||
)
|
||
.unique()
|
||
.first()
|
||
)
|
||
db.commit()
|
||
return row
|
||
|
||
|
||
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
|
||
def generate_story_image(self, story_id: str):
|
||
"""
|
||
为 story 生成主插图。
|
||
从 story_image_intents 原子认领 primary intent,生成后写入 assets 并更新 intent。
|
||
"""
|
||
lock_key = f"lock:story-image:{story_id}"
|
||
lock_handle = acquire_redis_lock(lock_key, ttl_seconds=STORY_IMAGE_LOCK_TTL_SECONDS)
|
||
if lock_handle is None:
|
||
logger.debug("generate_story_image: story=%s, reason=locked", story_id)
|
||
return {"status": "locked"}
|
||
|
||
claim_token = uuid.uuid4().hex
|
||
intent = None
|
||
story = None
|
||
try:
|
||
with get_sync_db() as db:
|
||
row = _claim_story_image_intent_sync(db, story_id, claim_token)
|
||
if not row:
|
||
logger.debug(
|
||
"generate_story_image: story=%s, reason=no_claimable_intent",
|
||
story_id,
|
||
)
|
||
return {"status": "no_intent"}
|
||
|
||
intent, story = row
|
||
|
||
img_cfg = MemoirImageSettings.from_env()
|
||
min_body = img_cfg.story_image_min_body_chars
|
||
if min_body > 0:
|
||
plain = strip_asset_image_refs_from_markdown(
|
||
story.canonical_markdown or ""
|
||
).strip()
|
||
if len(plain) < min_body:
|
||
with get_sync_db() as db:
|
||
intent_db = db.get(StoryImageIntent, intent.id)
|
||
if intent_db and (intent_db.status or "").strip() == "processing":
|
||
intent_db.status = "skipped"
|
||
intent_db.error = f"body_below_min_chars:{len(plain)}"
|
||
intent_db.claim_token = None
|
||
intent_db.claimed_at = None
|
||
intent_db.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
logger.info(
|
||
"generate_story_image: skipped body too short story=%s len=%s min=%s",
|
||
story_id,
|
||
len(plain),
|
||
min_body,
|
||
)
|
||
return {"status": "skipped_body_too_short"}
|
||
|
||
image_generator = get_image_generator()
|
||
storage = TencentCosStorageService.from_env()
|
||
|
||
settings = img_cfg
|
||
prompt_final = _build_story_image_prompt(
|
||
intent.prompt_brief or "",
|
||
story_title=story.title or "",
|
||
story_stage=story.stage,
|
||
style_profile=intent.style_profile or settings.default_style,
|
||
)
|
||
result = image_generator.generate(
|
||
prompt_final,
|
||
settings.default_size,
|
||
intent.style_profile or settings.default_style,
|
||
)
|
||
if result.status != TaskStatus.COMPLETED or not result.image_url:
|
||
raise RuntimeError(result.error or "Image generation failed")
|
||
|
||
image_bytes = _normalize_image_bytes(
|
||
image_generator.download_image(result.image_url)
|
||
)
|
||
cos_key = _build_story_image_cos_key(
|
||
story.user_id, story_id, intent.id, prompt_final
|
||
)
|
||
url = storage.upload_bytes(image_bytes, cos_key, "image/png")
|
||
|
||
asset_id = str(uuid.uuid4())
|
||
with get_sync_db() as db:
|
||
intent_db = db.get(StoryImageIntent, intent.id)
|
||
if (
|
||
not intent_db
|
||
or (intent_db.status or "").strip() != "processing"
|
||
or (intent_db.claim_token or "").strip() != claim_token
|
||
):
|
||
logger.debug(
|
||
"generate_story_image: skip persist intent=%s status=%s claim=%s",
|
||
intent.id,
|
||
getattr(intent_db, "status", None),
|
||
getattr(intent_db, "claim_token", None),
|
||
)
|
||
return {"status": "superseded_or_cancelled"}
|
||
|
||
asset = Asset(
|
||
id=asset_id,
|
||
asset_type="story_image",
|
||
storage_key=cos_key,
|
||
url=url,
|
||
provider=settings.provider,
|
||
style_profile=intent.style_profile or settings.default_style,
|
||
prompt_final=prompt_final,
|
||
status="completed",
|
||
)
|
||
db.add(asset)
|
||
db.flush()
|
||
|
||
story_db = db.get(Story, story_id)
|
||
target_vid = intent_db.story_version_id or story_db.current_version_id
|
||
current_vid = story_db.current_version_id
|
||
|
||
intent_db.asset_id = asset_id
|
||
intent_db.status = "completed"
|
||
intent_db.claim_token = None
|
||
intent_db.claimed_at = None
|
||
intent_db.error = None
|
||
intent_db.updated_at = datetime.now(timezone.utc)
|
||
db.flush()
|
||
|
||
# 仅当 intent 仍指向当前版本时回填正文,避免慢任务/重试把图插到新版本上
|
||
if not target_vid or target_vid != current_vid:
|
||
db.commit()
|
||
logger.debug(
|
||
"generate_story_image: stale intent skip backfill story=%s "
|
||
"intent_ver=%s current=%s url=%s asset=%s",
|
||
story_id,
|
||
target_vid,
|
||
current_vid,
|
||
url,
|
||
asset_id,
|
||
)
|
||
return {"status": "success_stale", "asset_id": asset_id}
|
||
|
||
ver = db.get(StoryVersion, target_vid)
|
||
if not ver:
|
||
db.commit()
|
||
return {"status": "success_no_snapshot", "asset_id": asset_id}
|
||
|
||
base_md = strip_asset_image_refs_from_markdown(ver.markdown_snapshot or "")
|
||
alt_text = (getattr(intent_db, "prompt_brief", None) or "").strip()
|
||
if not alt_text:
|
||
alt_text = (getattr(intent_db, "caption", None) or "").strip()
|
||
backfilled_md = backfill_image_into_markdown(
|
||
base_md,
|
||
asset_id=asset_id,
|
||
image_alt=alt_text or "主插图",
|
||
)
|
||
max_stmt = select(func.max(StoryVersion.version_no)).where(
|
||
StoryVersion.story_id == story_id
|
||
)
|
||
max_no = db.execute(max_stmt).scalar()
|
||
version_no = (max_no or 0) + 1
|
||
new_ver = StoryVersion(
|
||
id=str(uuid.uuid4()),
|
||
story_id=story_id,
|
||
version_no=version_no,
|
||
markdown_snapshot=backfilled_md,
|
||
change_summary="主插图回填",
|
||
actor_type="system",
|
||
source_type="image_backfill",
|
||
parent_version_id=story_db.current_version_id,
|
||
)
|
||
db.add(new_ver)
|
||
db.flush()
|
||
story_db.current_version_id = new_ver.id
|
||
story_db.canonical_markdown = backfilled_md
|
||
|
||
db.commit()
|
||
|
||
_enqueue_chapter_recompose_for_story(story_id)
|
||
|
||
logger.info(
|
||
"generate_story_image: story=%s, asset=%s",
|
||
story_id,
|
||
asset_id,
|
||
)
|
||
logger.debug(
|
||
"generate_story_image: story=%s asset=%s url=%s cos_key=%s prompt_final=%s",
|
||
story_id,
|
||
asset_id,
|
||
url,
|
||
cos_key,
|
||
prompt_final,
|
||
)
|
||
return {"status": "success", "asset_id": asset_id}
|
||
except Exception as exc:
|
||
if intent is not None:
|
||
with get_sync_db() as db:
|
||
intent_db = db.get(StoryImageIntent, intent.id)
|
||
if (
|
||
intent_db
|
||
and (intent_db.status or "").strip() != "completed"
|
||
and (intent_db.claim_token or "").strip() == claim_token
|
||
):
|
||
intent_db.status = "failed"
|
||
intent_db.claim_token = None
|
||
intent_db.claimed_at = None
|
||
intent_db.error = str(exc)
|
||
intent_db.updated_at = datetime.now(timezone.utc)
|
||
db.commit()
|
||
logger.warning("generate_story_image failed: story=%s, error=%s", story_id, exc)
|
||
raise self.retry(exc=exc) from exc
|
||
finally:
|
||
release_redis_lock(lock_handle)
|