diff --git a/api/routers/chapters.py b/api/routers/chapters.py index f249e26..129b298 100644 --- a/api/routers/chapters.py +++ b/api/routers/chapters.py @@ -14,6 +14,12 @@ from database.models import Chapter as ChapterModel from database.models import User as UserModel from middleware.auth import get_current_user from agents.prompts.memory_prompts import CHAPTER_CATEGORIES, CHAPTER_ORDER, STAGE_TO_ORDER +from services.memoir_images.schema import ( + completed_image_assets, + IMAGE_STATUS_COMPLETED, + normalize_image_assets, +) +from services.memoir_images.settings import MemoirImageSettings from services.memoir_images.storage import ( TencentCosStorageService, normalize_cos_url, @@ -29,9 +35,13 @@ def _normalize_image_assets(images: list[dict] | None) -> list[dict]: region = os.getenv("TENCENT_COS_REGION", "") base_url = os.getenv("TENCENT_COS_BASE_URL", "") storage = TencentCosStorageService.from_env() + settings = MemoirImageSettings.from_env() + source_assets = normalize_image_assets(images) + if not settings.enabled: + source_assets = completed_image_assets(source_assets) normalized_assets: list[dict] = [] - for item in (images or []): + for item in source_assets: asset = dict(item) normalized_url = normalize_cos_url( asset.get("url"), @@ -40,12 +50,13 @@ def _normalize_image_assets(images: list[dict] | None) -> list[dict]: base_url=base_url, ) storage_key = resolve_image_storage_key(asset) - if asset.get("status") == "completed" and storage_key: + if asset.get("status") == IMAGE_STATUS_COMPLETED and storage_key: try: asset["url"] = storage.get_download_url(storage_key) except Exception as exc: logger.warning("章节图片签名失败: key=%s, error=%s", storage_key, exc) asset["url"] = normalized_url + asset["error"] = asset.get("error") or "image delivery unavailable" else: asset["url"] = normalized_url asset.pop("storage_key", None) diff --git a/api/services/memoir_images/prompting.py b/api/services/memoir_images/prompting.py index c0c0472..2fc494b 100644 --- a/api/services/memoir_images/prompting.py +++ b/api/services/memoir_images/prompting.py @@ -1,8 +1,13 @@ import json +import logging +import re from typing import Any, Optional from .settings import MemoirImageSettings +logger = logging.getLogger(__name__) +_CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]") + class MemoirImagePromptService: CATEGORY_STYLE_MAP = { @@ -14,6 +19,16 @@ class MemoirImagePromptService: "beliefs": "editorial illustration", "summary": "editorial illustration", } + CATEGORY_FALLBACK_SUBJECT_MAP = { + "childhood": "childhood memory", + "education": "school memory", + "career_early": "early career memory", + "career_achievement": "career achievement memory", + "career_challenge": "career challenge memory", + "family": "family memory", + "beliefs": "reflective life memory", + "summary": "memoir summary scene", + } def __init__(self, llm: Optional[Any], settings: MemoirImageSettings): self.llm = llm @@ -47,17 +62,69 @@ class MemoirImagePromptService: ) parsed = json.loads(response.content) return { - "prompt": parsed["prompt"], + "prompt": _ensure_style_in_prompt(parsed["prompt"], parsed.get("style", style)), "style": parsed.get("style", style), "size": parsed.get("size", self.settings.default_size), "prompt_context": prompt_context, } - except Exception: - pass + except Exception as exc: + logger.warning( + "图片 prompt 生成回退到默认模板: chapter_category=%s, title=%s, error=%s", + chapter_category, + chapter_title, + exc, + ) return { - "prompt": f"{description}\n\nScene context: {context_excerpt}", + "prompt": _ensure_style_in_prompt( + self._build_fallback_prompt( + chapter_category=chapter_category, + description=description, + context_excerpt=context_excerpt, + style=style, + ), + style, + ), "style": style, "size": self.settings.default_size, "prompt_context": prompt_context, } + + def _build_fallback_prompt( + self, + chapter_category: str, + description: str, + context_excerpt: str, + style: str, + ) -> str: + subject = self.CATEGORY_FALLBACK_SUBJECT_MAP.get(chapter_category, "memoir scene") + if _contains_cjk(description) or _contains_cjk(context_excerpt): + return ( + f"A {style} illustration of a {subject}, emotionally resonant, cinematic composition, " + "authentic everyday details, natural lighting, expressive environment, no text overlay." + ) + + details = ". ".join(part.strip() for part in (description, context_excerpt) if part.strip()) + if not details: + details = "A personal life story scene with authentic emotional detail" + return ( + f"A {style} illustration of a {subject}. " + f"Scene details: {details}. " + "Cinematic composition, authentic emotions, natural lighting, no text overlay." + ) + + +def _contains_cjk(value: str) -> bool: + return bool(_CJK_RE.search(value or "")) + + +def _ensure_style_in_prompt(prompt: str, style: str) -> str: + cleaned_prompt = (prompt or "").strip() + cleaned_style = (style or "").strip() + if not cleaned_style: + return cleaned_prompt + if cleaned_style.lower() in cleaned_prompt.lower(): + return cleaned_prompt + if not cleaned_prompt: + return cleaned_style + return f"{cleaned_style}, {cleaned_prompt}" diff --git a/api/services/memoir_images/provider.py b/api/services/memoir_images/provider.py index 51e0be5..a0b52cc 100644 --- a/api/services/memoir_images/provider.py +++ b/api/services/memoir_images/provider.py @@ -2,14 +2,23 @@ import base64 import hmac import logging import os +import re import time import uuid from hashlib import sha1 +from urllib.parse import urlparse import httpx +from .settings import DEFAULT_LIBLIB_TEMPLATE_UUID + logger = logging.getLogger(__name__) +_SENSITIVE_QUERY_PARAMS = ("AccessKey", "Signature", "Timestamp", "SignatureNonce") +_SENSITIVE_QUERY_RE = re.compile( + r"([?&])(" + "|".join(_SENSITIVE_QUERY_PARAMS) + r")=([^&\s]+)" +) + _SIZE_TO_ASPECT_RATIO = { "1024x1024": "square", "768x1024": "portrait", @@ -29,33 +38,40 @@ class LiblibImageProvider: secret_key: str | None = None, base_url: str | None = None, template_uuid: str | None = None, + allowed_download_hosts: tuple[str, ...] | None = None, ): + _install_http_log_redaction() + self._owns_http_client = http_client is None self.http_client = http_client or httpx.Client(timeout=120) self.access_key = access_key or os.getenv("LIBLIB_ACCESS_KEY", "") self.secret_key = secret_key or os.getenv("LIBLIB_SECRET_KEY", "") self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "https://openapi.liblibai.cloud")).rstrip("/") - self.template_uuid = template_uuid or os.getenv( - "LIBLIB_TEMPLATE_UUID", "5d7e67009b344550bc1aa6ccbfa1d7f4" + self.template_uuid = template_uuid or os.getenv("LIBLIB_TEMPLATE_UUID") or DEFAULT_LIBLIB_TEMPLATE_UUID + self.allowed_download_hosts = _build_allowed_download_hosts( + self.base_url, + allowed_download_hosts=allowed_download_hosts, ) # ------------------------------------------------------------------ # Signature helpers # ------------------------------------------------------------------ - def _sign(self, uri: str) -> str: - """Build a full URL with Liblib HMAC-SHA1 query-string auth.""" + def _build_url(self, uri: str) -> str: + return f"{self.base_url}{uri}" + + def _sign(self, uri: str) -> dict[str, str]: + """Build Liblib HMAC-SHA1 query-string auth params.""" timestamp = str(int(time.time() * 1000)) nonce = str(uuid.uuid4()) content = "&".join((uri, timestamp, nonce)) digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest() signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() - return ( - f"{self.base_url}{uri}" - f"?AccessKey={self.access_key}" - f"&Signature={signature}" - f"&Timestamp={timestamp}" - f"&SignatureNonce={nonce}" - ) + return { + "AccessKey": self.access_key, + "Signature": signature, + "Timestamp": timestamp, + "SignatureNonce": nonce, + } # ------------------------------------------------------------------ # Public API @@ -63,14 +79,16 @@ class LiblibImageProvider: def submit_generation(self, prompt: str, size: str, style: str) -> dict: uri = "/api/generate/webui/text2img/ultra" - url = self._sign(uri) + url = self._build_url(uri) + params = self._sign(uri) + styled_prompt = _apply_style_to_prompt(prompt, style) aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size, "square") body = { "templateUuid": self.template_uuid, "generateParams": { - "prompt": prompt, + "prompt": styled_prompt, "aspectRatio": aspect_ratio, "imgCount": 1, "steps": 30, @@ -78,6 +96,7 @@ class LiblibImageProvider: } response = self.http_client.post( url, + params=params, headers={"Content-Type": "application/json"}, json=body, ) @@ -96,9 +115,11 @@ class LiblibImageProvider: uri = "/api/generate/webui/status" for attempt in range(max_attempts): - url = self._sign(uri) + url = self._build_url(uri) + params = self._sign(uri) response = self.http_client.post( url, + params=params, headers={"Content-Type": "application/json"}, json={"generateUuid": job["job_id"]}, ) @@ -136,6 +157,104 @@ class LiblibImageProvider: raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}") def download_image(self, job: dict) -> bytes: - response = self.http_client.get(job["image_url"]) + image_url = job["image_url"] + _validate_download_url(image_url, self.allowed_download_hosts) + response = self.http_client.get(image_url) response.raise_for_status() return response.content + + def close(self) -> None: + if self._owns_http_client: + self.http_client.close() + + def __enter__(self) -> "LiblibImageProvider": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + + +class _LiblibAuthRedactionFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + record.msg = _redact_sensitive_query_values(record.msg) + if record.args: + if isinstance(record.args, dict): + record.args = { + key: _redact_sensitive_query_values(value) + for key, value in record.args.items() + } + else: + record.args = tuple(_redact_sensitive_query_values(value) for value in record.args) + return True + + +def _redact_sensitive_query_values(value): + if isinstance(value, str): + return _SENSITIVE_QUERY_RE.sub(r"\1\2=[REDACTED]", value) + return value + + +def _install_http_log_redaction() -> None: + for logger_name in ( + "httpx", + "httpcore", + "httpcore.connection", + "httpcore.http11", + "httpcore.proxy", + ): + target_logger = logging.getLogger(logger_name) + if getattr(target_logger, "_liblib_auth_redaction_installed", False): + continue + target_logger.addFilter(_LiblibAuthRedactionFilter()) + target_logger._liblib_auth_redaction_installed = True + + +def _build_allowed_download_hosts( + base_url: str, + allowed_download_hosts: tuple[str, ...] | None = None, +) -> tuple[str, ...]: + configured_hosts = allowed_download_hosts + if configured_hosts is None: + configured_hosts = tuple( + host.strip().lower() + for host in os.getenv("MEMOIR_IMAGE_DOWNLOAD_HOSTS", "").split(",") + if host.strip() + ) + + base_hostname = (urlparse(base_url).hostname or "").lower() + default_hosts: set[str] = set() + if base_hostname: + default_hosts.add(base_hostname) + if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud": + default_hosts.add("liblibai.cloud") + + return tuple(sorted(default_hosts.union(configured_hosts))) + + +def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> None: + parsed = urlparse(image_url) + hostname = (parsed.hostname or "").lower() + if parsed.scheme != "https" or not hostname: + raise ValueError(f"Unsupported image download URL: {image_url}") + + if not any(_hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts): + raise ValueError(f"Image download host is not allowed: {hostname}") + + +def _hostname_matches(hostname: str, allowed_host: str) -> bool: + normalized_allowed = allowed_host.strip().lower() + if not normalized_allowed: + return False + return hostname == normalized_allowed or hostname.endswith(f".{normalized_allowed}") + + +def _apply_style_to_prompt(prompt: str, style: str) -> str: + cleaned_prompt = (prompt or "").strip() + cleaned_style = (style or "").strip() + if not cleaned_style: + return cleaned_prompt + if cleaned_style.lower() in cleaned_prompt.lower(): + return cleaned_prompt + if not cleaned_prompt: + return cleaned_style + return f"{cleaned_style}, {cleaned_prompt}" diff --git a/api/services/memoir_images/schema.py b/api/services/memoir_images/schema.py new file mode 100644 index 0000000..c2fe592 --- /dev/null +++ b/api/services/memoir_images/schema.py @@ -0,0 +1,108 @@ +import re +from typing import Any + +IMAGE_STATUS_PENDING = "pending" +IMAGE_STATUS_PROCESSING = "processing" +IMAGE_STATUS_COMPLETED = "completed" +IMAGE_STATUS_FAILED = "failed" + +VALID_IMAGE_STATUSES = { + IMAGE_STATUS_PENDING, + IMAGE_STATUS_PROCESSING, + IMAGE_STATUS_COMPLETED, + IMAGE_STATUS_FAILED, +} + +_PLACEHOLDER_DESCRIPTION_RE = re.compile(r"\{\{\{\{IMAGE:(.*?)\}\}\}\}|\{\{IMAGE:(.*?)\}\}") + + +def normalize_image_asset(asset: dict[str, Any] | None) -> dict[str, Any] | None: + if not isinstance(asset, dict): + return None + + placeholder = _as_non_empty_string(asset.get("placeholder")) + description = _as_non_empty_string(asset.get("description")) or _extract_description_from_placeholder( + placeholder + ) + if not placeholder or not description: + return None + + normalized = dict(asset) + normalized["index"] = _coerce_int(asset.get("index"), default=0) + normalized["placeholder"] = placeholder + normalized["description"] = description + + status = _as_non_empty_string(asset.get("status")) or IMAGE_STATUS_PENDING + if status not in VALID_IMAGE_STATUSES: + normalized["status"] = IMAGE_STATUS_FAILED + normalized["error"] = asset.get("error") or f"invalid image status: {status}" + return normalized + + normalized["status"] = status + normalized["prompt"] = _as_optional_string(asset.get("prompt")) + normalized["url"] = _as_optional_string(asset.get("url")) + normalized["storage_key"] = _as_optional_string(asset.get("storage_key")) + normalized["provider"] = _as_optional_string(asset.get("provider")) + normalized["style"] = _as_optional_string(asset.get("style")) + normalized["size"] = _as_optional_string(asset.get("size")) + normalized["error"] = _as_optional_string(asset.get("error")) + normalized["created_at"] = _as_optional_string(asset.get("created_at")) + normalized["updated_at"] = _as_optional_string(asset.get("updated_at")) + + if normalized["status"] == IMAGE_STATUS_COMPLETED and not ( + normalized["url"] or normalized["storage_key"] + ): + normalized["status"] = IMAGE_STATUS_FAILED + normalized["error"] = normalized["error"] or "missing image url" + + return normalized + + +def normalize_image_assets(images: list[dict[str, Any]] | None) -> list[dict[str, Any]]: + normalized_assets: list[dict[str, Any]] = [] + for item in images or []: + normalized = normalize_image_asset(item) + if normalized: + normalized_assets.append(normalized) + return normalized_assets + + +def completed_image_assets(images: list[dict[str, Any]] | None) -> list[dict[str, Any]]: + return [ + asset + for asset in normalize_image_assets(images) + if asset.get("status") == IMAGE_STATUS_COMPLETED + and (asset.get("storage_key") or asset.get("url")) + ] + + +def _as_non_empty_string(value: Any) -> str | None: + if isinstance(value, str): + stripped = value.strip() + return stripped or None + return None + + +def _as_optional_string(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + return str(value) + + +def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _extract_description_from_placeholder(placeholder: str | None) -> str | None: + if not placeholder: + return None + match = _PLACEHOLDER_DESCRIPTION_RE.fullmatch(placeholder.strip()) + if not match: + return None + description = (match.group(1) or match.group(2) or "").strip() + return description or None diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py index cc382fc..181fb23 100644 --- a/api/tasks/memoir_tasks.py +++ b/api/tasks/memoir_tasks.py @@ -5,11 +5,13 @@ import json import logging import os import uuid +from io import BytesIO from typing import Dict, List from datetime import datetime, timezone import redis from celery import shared_task +from PIL import Image from sqlalchemy import select from sqlalchemy.orm import Session @@ -31,31 +33,64 @@ import hashlib from services.memoir_images.parser import build_initial_image_assets, parse_image_placeholders from services.memoir_images.prompting import MemoirImagePromptService from services.memoir_images.provider import LiblibImageProvider +from services.memoir_images.schema import ( + completed_image_assets, + IMAGE_STATUS_COMPLETED, + IMAGE_STATUS_FAILED, + IMAGE_STATUS_PENDING, + IMAGE_STATUS_PROCESSING, + normalize_image_assets, +) from services.memoir_images.settings import MemoirImageSettings from services.memoir_images.storage import TencentCosStorageService logger = logging.getLogger(__name__) +_REDIS_CLIENTS: dict[bool, redis.Redis] = {} + + +def _get_redis_client(*, decode_responses: bool = False) -> redis.Redis: + client = _REDIS_CLIENTS.get(decode_responses) + if client is None: + client = redis.from_url( + os.getenv("REDIS_URL", "redis://localhost:6379/0"), + decode_responses=decode_responses, + ) + _REDIS_CLIENTS[decode_responses] = client + return client def _acquire_chapter_lock(user_id: str, stage: str, timeout: int = 120) -> bool: """获取章节分布式锁,防止并发写入同一章节""" - r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0")) + r = _get_redis_client() lock_key = f"lock:chapter:{user_id}:{stage}" return r.set(lock_key, "1", nx=True, ex=timeout) def _release_chapter_lock(user_id: str, stage: str): """释放章节分布式锁""" - r = redis.from_url(os.getenv("REDIS_URL", "redis://localhost:6379/0")) + r = _get_redis_client() lock_key = f"lock:chapter:{user_id}:{stage}" r.delete(lock_key) +def _acquire_chapter_image_lock(chapter_id: str, timeout: int = 600) -> bool: + """获取章节补图分布式锁,避免同一章节重复补图。""" + r = _get_redis_client() + lock_key = f"lock:chapter-images:{chapter_id}" + return r.set(lock_key, "1", nx=True, ex=timeout) + + +def _release_chapter_image_lock(chapter_id: str): + """释放章节补图分布式锁。""" + r = _get_redis_client() + lock_key = f"lock:chapter-images:{chapter_id}" + r.delete(lock_key) + + def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Dict = None): """同步更新任务状态(在 Celery 任务中使用)""" try: - redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") - r = redis.from_url(redis_url, decode_responses=True) + r = _get_redis_client(decode_responses=True) key = f"task:user:{user_id}:tasks" @@ -87,9 +122,10 @@ def _merge_chapter_image_assets( size: str, now_iso: str, ) -> list[dict]: + normalized_existing_images = normalize_image_assets(existing_images) existing_by_placeholder = { item.get("placeholder"): dict(item) - for item in (existing_images or []) + for item in normalized_existing_images if item.get("placeholder") } merged_assets: list[dict] = [] @@ -106,10 +142,10 @@ def _merge_chapter_image_assets( merged_item["size"] = merged_item.get("size") or size merged_item["created_at"] = merged_item.get("created_at") or now_iso merged_item["updated_at"] = merged_item.get("updated_at") or now_iso - if merged_item.get("status") == "completed" and not ( + if merged_item.get("status") == IMAGE_STATUS_COMPLETED and not ( merged_item.get("storage_key") or merged_item.get("url") ): - merged_item["status"] = "failed" + merged_item["status"] = IMAGE_STATUS_FAILED merged_item["error"] = merged_item.get("error") or "missing image url" else: merged_item = build_initial_image_assets( @@ -125,14 +161,17 @@ def _merge_chapter_image_assets( def chapter_has_images_to_generate(images: list[dict] | None) -> bool: - return any(item.get("status") in {"pending", "failed"} for item in (images or [])) + return any( + item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED} + for item in normalize_image_assets(images) + ) def initialize_chapter_images(chapter) -> list[dict]: """Parse IMAGE placeholders from chapter content and build pending image assets.""" settings = MemoirImageSettings.from_env() if not settings.enabled: - chapter.images = [] + chapter.images = completed_image_assets(chapter.images) logger.info(f"章节图片初始化跳过: chapter={chapter.id}, enabled=false") return chapter.images @@ -157,6 +196,19 @@ def initialize_chapter_images(chapter) -> list[dict]: return chapter.images +def _normalize_image_bytes_for_storage(image_bytes: bytes) -> bytes: + with Image.open(BytesIO(image_bytes)) as image: + output = BytesIO() + if image.mode in {"RGBA", "LA"}: + normalized = image + elif image.mode == "P": + normalized = image.convert("RGBA") + else: + normalized = image.convert("RGB") + normalized.save(output, format="PNG") + return output.getvalue() + + STAGE_KEYWORDS = { "childhood": ["童年", "小时候", "出生", "家乡", "小镇"], "education": ["上学", "学校", "老师", "同学", "教育", "大学"], @@ -304,6 +356,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): # 获取用户状态和资料 state = _get_or_create_state_sync(user_id, db) llm = llm_service.get_llm() + image_settings = MemoirImageSettings.from_env() user_obj = db.get(User, user_id) user_profile = "" @@ -361,19 +414,19 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): category_to_segments.setdefault(chapter_category, []).append(segment) # 按 8 分类生成章节内容 - for stage, stage_segments in category_to_segments.items(): - if not _acquire_chapter_lock(user_id, stage): - logger.warning(f"章节锁竞争: user={user_id}, stage={stage}, 延迟重试") + for chapter_category, category_segments in category_to_segments.items(): + if not _acquire_chapter_lock(user_id, chapter_category): + logger.warning(f"章节锁竞争: user={user_id}, category={chapter_category}, 延迟重试") raise self.retry(countdown=10) try: - segment_texts = [seg.transcript_text for seg in stage_segments] + segment_texts = [seg.transcript_text for seg in category_segments] combined_text = "\n\n".join(segment_texts) - source_ids = [seg.id for seg in stage_segments] + source_ids = [seg.id for seg in category_segments] # 查找 active 章节(被清除的章节不继续更新,而是创建新的) stmt_chapter = select(Chapter).where( Chapter.user_id == user_id, - Chapter.category == stage, + Chapter.category == chapter_category, Chapter.is_active == True, ) result_chapter = db.execute(stmt_chapter) @@ -382,12 +435,12 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): # 获取 slot snippets slot_snippets = { key: value.snippet - for key, value in (state.slots.get(stage, {}) or {}).items() + for key, value in (state.slots.get(chapter_category, {}) or {}).items() if value.snippet } # 生成标题和内容 - title = chapter.title if chapter else f"{stage} 回忆" + title = chapter.title if chapter else f"{chapter_category} 回忆" existing_content = chapter.content if chapter else "" narrative = combined_text @@ -395,7 +448,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): try: if not chapter: title_prompt = get_creative_title_prompt( - stage=stage, + stage=chapter_category, emotion="neutral", slots=slot_snippets, user_profile=user_profile, @@ -405,7 +458,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): title = title_response.content.strip().strip('"') narrative_prompt = get_narrative_prompt( - stage=stage, + stage=chapter_category, slots=slot_snippets, new_content=combined_text, existing_content=existing_content, @@ -429,7 +482,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): if existing_content and len(narrative) < len(existing_content) * 0.8: logger.warning( f"内容长度异常: existing={len(existing_content)}, " - f"new={len(narrative)}, stage={stage}. 回退为追加模式" + f"new={len(narrative)}, category={chapter_category}. 回退为追加模式" ) narrative = f"{existing_content}\n\n{combined_text}" @@ -441,7 +494,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): chapter.source_segments = list({*(chapter.source_segments or []), *source_ids}) else: # 根据 stage 计算正确的排序索引 - calculated_order_index = STAGE_TO_ORDER.get(stage, 999) + calculated_order_index = STAGE_TO_ORDER.get(chapter_category, 999) chapter = Chapter( id=str(uuid.uuid4()), user_id=user_id, @@ -449,7 +502,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): content=narrative, order_index=calculated_order_index, status="completed", - category=stage, + category=chapter_category, images=[], is_new=True, source_segments=source_ids, @@ -459,7 +512,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): db.flush() initialize_chapter_images(chapter) - if chapter_has_images_to_generate(chapter.images): + if image_settings.enabled and chapter_has_images_to_generate(chapter.images): chapters_to_enqueue.add(chapter.id) # 更新 Book @@ -479,7 +532,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): book.has_update = True book.last_update_chapter_id = chapter.id finally: - _release_chapter_lock(user_id, stage) + _release_chapter_lock(user_id, chapter_category) # 标记段落为已处理 for seg in segments: @@ -607,6 +660,8 @@ def build_cos_key(user_id: str, chapter_id: str, index: int, prompt: str) -> str def generate_chapter_images(self, chapter_id: str): """Async task to generate images for a chapter's pending image assets.""" db = SessionLocal() + lock_acquired = False + provider = None try: chapter = db.get(Chapter, chapter_id) if not chapter or not chapter.images: @@ -614,26 +669,40 @@ def generate_chapter_images(self, chapter_id: str): return {"status": "no_images"} settings = MemoirImageSettings.from_env() + if not settings.enabled: + chapter.images = completed_image_assets(chapter.images) + db.commit() + logger.info("章节补图跳过: chapter=%s, reason=disabled", chapter_id) + return {"status": "disabled"} + + lock_acquired = _acquire_chapter_image_lock(chapter_id) + if not lock_acquired: + logger.info("章节补图跳过: chapter=%s, reason=locked", chapter_id) + return {"status": "locked"} + prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings) provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid) storage = TencentCosStorageService.from_env() - images = [dict(item) for item in (chapter.images or [])] - pending_count = sum(1 for item in images if item.get("status") in {"pending", "failed"}) + images = normalize_image_assets(chapter.images) + pending_count = sum( + 1 for item in images if item.get("status") in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED} + ) logger.info( "章节补图开始: chapter=%s, total_images=%d, pending_images=%d", chapter_id, len(images), pending_count, ) + failures: list[str] = [] for index, item in enumerate(images): - if item.get("status") == "completed" and (item.get("storage_key") or item.get("url")): + if item.get("status") == IMAGE_STATUS_COMPLETED and (item.get("storage_key") or item.get("url")): continue - if item.get("status") not in {"pending", "failed"}: + if item.get("status") not in {IMAGE_STATUS_PENDING, IMAGE_STATUS_FAILED}: continue current_item = dict(item) - current_item["status"] = "processing" + current_item["status"] = IMAGE_STATUS_PROCESSING current_item["updated_at"] = datetime.now(timezone.utc).isoformat() images[index] = current_item chapter.images = images @@ -660,14 +729,14 @@ def generate_chapter_images(self, chapter_id: str): poll_interval_seconds=settings.poll_interval_seconds, max_attempts=settings.max_attempts, ) - image_bytes = provider.download_image(job) + image_bytes = _normalize_image_bytes_for_storage(provider.download_image(job)) key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"]) current_item["storage_key"] = key current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png") current_item["prompt"] = prompt_data["prompt"] current_item["style"] = prompt_data["style"] current_item["size"] = prompt_data["size"] - current_item["status"] = "completed" + current_item["status"] = IMAGE_STATUS_COMPLETED current_item["error"] = None logger.info( "章节补图成功: chapter=%s, index=%s, url=%s", @@ -676,8 +745,9 @@ def generate_chapter_images(self, chapter_id: str): current_item["url"], ) except Exception as exc: - current_item["status"] = "failed" + current_item["status"] = IMAGE_STATUS_FAILED current_item["error"] = str(exc) + failures.append(f"index={current_item.get('index')}, error={exc}") logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}") current_item["updated_at"] = datetime.now(timezone.utc).isoformat() @@ -685,6 +755,18 @@ def generate_chapter_images(self, chapter_id: str): chapter.images = images db.commit() + if failures: + raise RuntimeError( + f"章节补图存在失败项: chapter={chapter_id}, failures={'; '.join(failures)}" + ) + return {"status": "success"} + except Exception as exc: + logger.error("章节补图任务失败: chapter=%s, error=%s", chapter_id, exc) + raise self.retry(exc=exc) finally: + if provider: + provider.close() + if lock_acquired: + _release_chapter_image_lock(chapter_id) db.close() diff --git a/api/tests/test_chapters_router_images.py b/api/tests/test_chapters_router_images.py index 62c697a..cbc12c6 100644 --- a/api/tests/test_chapters_router_images.py +++ b/api/tests/test_chapters_router_images.py @@ -56,3 +56,130 @@ class ChaptersRouterImagesTest(unittest.TestCase): ) self.assertEqual(payload["images"][0]["prompt"], "A serene southern China town") self.assertNotIn("storage_key", payload["images"][0]) + + @patch("api.routers.chapters.TencentCosStorageService") + @patch.dict( + os.environ, + { + "TENCENT_COS_BUCKET": "life-echo-dev-1319381411", + "TENCENT_COS_REGION": "ap-shanghai", + "TENCENT_COS_BASE_URL": "https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com", + }, + clear=False, + ) + def test_chapter_to_dict_preserves_completed_asset_when_signing_fails(self, storage_cls): + storage = Mock() + storage.get_download_url.side_effect = RuntimeError("cos unavailable") + storage_cls.from_env.return_value = storage + + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "title": "童年的夏天", + "content": "{{IMAGE:南方小镇的青石板路}}", + "order_index": 0, + "status": "completed", + "category": "childhood", + "images": [ + { + "index": 0, + "placeholder": "{{IMAGE:南方小镇的青石板路}}", + "description": "南方小镇的青石板路", + "status": "completed", + "prompt": "A serene southern China town", + "url": "https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com/memoirs/u1/c1/0-demo.png", + "storage_key": "memoirs/u1/c1/0-demo.png", + } + ], + "updated_at": None, + "is_new": False, + "source_segments": [], + }, + )() + + payload = _chapter_to_dict(chapter) + + self.assertEqual(payload["images"][0]["status"], "completed") + self.assertEqual( + payload["images"][0]["url"], + "https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com/memoirs/u1/c1/0-demo.png", + ) + self.assertEqual(payload["images"][0]["prompt"], "A serene southern China town") + self.assertNotIn("storage_key", payload["images"][0]) + + @patch("api.routers.chapters.TencentCosStorageService") + def test_chapter_to_dict_drops_malformed_image_assets(self, storage_cls): + storage_cls.from_env.return_value = Mock() + + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "title": "童年的夏天", + "content": "{{IMAGE:南方小镇的青石板路}}", + "order_index": 0, + "status": "completed", + "category": "childhood", + "images": [ + { + "index": 0, + "status": "completed", + } + ], + "updated_at": None, + "is_new": False, + "source_segments": [], + }, + )() + + payload = _chapter_to_dict(chapter) + + self.assertEqual(payload["images"], []) + + @patch("api.routers.chapters.TencentCosStorageService") + @patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "false"}, clear=False) + def test_chapter_to_dict_hides_non_completed_assets_when_feature_disabled(self, storage_cls): + storage = Mock() + storage.get_download_url.return_value = "https://signed.example.com/0.png?sig=123" + storage_cls.from_env.return_value = storage + + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "title": "童年的夏天", + "content": "{{IMAGE:南方小镇的青石板路}}", + "order_index": 0, + "status": "completed", + "category": "childhood", + "images": [ + { + "index": 0, + "placeholder": "{{IMAGE:南方小镇的青石板路}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + }, + { + "index": 1, + "placeholder": "{{IMAGE:奶奶坐在院子里的藤椅上}}", + "description": "奶奶坐在院子里的藤椅上", + "status": "completed", + "url": "https://life-echo-dev-1319381411.cos.ap-shanghai.myqcloud.com/memoirs/u1/c1/1-demo.png", + "storage_key": "memoirs/u1/c1/1-demo.png", + }, + ], + "updated_at": None, + "is_new": False, + "source_segments": [], + }, + )() + + payload = _chapter_to_dict(chapter) + + self.assertEqual(len(payload["images"]), 1) + self.assertEqual(payload["images"][0]["status"], "completed") diff --git a/api/tests/test_generate_chapter_images_task.py b/api/tests/test_generate_chapter_images_task.py index bc723ce..52ac047 100644 --- a/api/tests/test_generate_chapter_images_task.py +++ b/api/tests/test_generate_chapter_images_task.py @@ -1,10 +1,119 @@ import unittest +from io import BytesIO +from types import SimpleNamespace from unittest.mock import Mock, patch +from PIL import Image + +from api.tasks import memoir_tasks from api.tasks.memoir_tasks import generate_chapter_images class GenerateChapterImagesTaskTest(unittest.TestCase): + def setUp(self): + memoir_tasks._REDIS_CLIENTS.clear() + + @patch("api.tasks.memoir_tasks.redis.from_url") + @patch("api.tasks.memoir_tasks.SessionLocal") + @patch("api.tasks.memoir_tasks.TencentCosStorageService") + @patch("api.tasks.memoir_tasks.LiblibImageProvider") + @patch("api.tasks.memoir_tasks.MemoirImagePromptService") + def test_generate_chapter_images_skips_when_lock_is_already_held( + self, + prompt_service_cls, + provider_cls, + storage_cls, + session_local_cls, + redis_from_url, + ): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "user_id": "user-1", + "title": "童年的夏天", + "category": "childhood", + "content": "那条路我一直记得。", + "images": [ + { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + } + ], + }, + )() + + db = Mock() + db.get.return_value = chapter + session_local_cls.return_value = db + redis_from_url.return_value.set.return_value = False + + result = generate_chapter_images.run("chapter-1") + + self.assertEqual(result, {"status": "locked"}) + provider_cls.return_value.submit_generation.assert_not_called() + storage_cls.from_env.return_value.upload_bytes.assert_not_called() + db.commit.assert_not_called() + + @patch("api.tasks.memoir_tasks.SessionLocal") + @patch("api.tasks.memoir_tasks.TencentCosStorageService") + @patch("api.tasks.memoir_tasks.LiblibImageProvider") + @patch("api.tasks.memoir_tasks.MemoirImagePromptService") + def test_generate_chapter_images_retries_when_any_item_generation_fails( + self, + prompt_service_cls, + provider_cls, + storage_cls, + session_local_cls, + ): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "user_id": "user-1", + "title": "童年的夏天", + "category": "childhood", + "content": "那条路我一直记得。", + "images": [ + { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + } + ], + }, + )() + + db = Mock() + db.get.return_value = chapter + session_local_cls.return_value = db + prompt_service_cls.return_value.build_prompt.return_value = { + "prompt": "A serene southern China town", + "style": "watercolor", + "size": "1024x1024", + "prompt_context": "childhood: 童年的夏天", + } + provider_cls.return_value.submit_generation.side_effect = RuntimeError("transient provider error") + + retry_error = RuntimeError("retry requested") + task_self = SimpleNamespace(request=SimpleNamespace(id="task-1"), retry=Mock(side_effect=retry_error)) + + with self.assertRaises(RuntimeError) as ctx: + generate_chapter_images.run.__func__(task_self, "chapter-1") + + self.assertIs(ctx.exception, retry_error) + self.assertEqual(chapter.images[0]["status"], "failed") + self.assertEqual(chapter.images[0]["error"], "transient provider error") + task_self.retry.assert_called_once() + storage_cls.from_env.return_value.upload_bytes.assert_not_called() + @patch("api.tasks.memoir_tasks.SessionLocal") @patch("api.tasks.memoir_tasks.TencentCosStorageService") @patch("api.tasks.memoir_tasks.LiblibImageProvider") @@ -61,8 +170,127 @@ class GenerateChapterImagesTaskTest(unittest.TestCase): self.assertEqual(chapter.images[0]["storage_key"], "memoirs/user-1/chapter-1/0-7e1f860790.png") self.assertEqual(chapter.images[0]["url"], "https://cos.example.com/memoirs/u1/c1/0.png") self.assertEqual(chapter.images[0]["prompt"], "A serene southern China town") + provider_inst.close.assert_called_once() db.commit.assert_called() + @patch("api.tasks.memoir_tasks.SessionLocal") + @patch("api.tasks.memoir_tasks.TencentCosStorageService") + @patch("api.tasks.memoir_tasks.LiblibImageProvider") + @patch("api.tasks.memoir_tasks.MemoirImagePromptService") + @patch("api.tasks.memoir_tasks.MemoirImageSettings.from_env") + def test_generate_chapter_images_returns_disabled_when_feature_flag_is_off( + self, + settings_from_env, + prompt_service_cls, + provider_cls, + storage_cls, + session_local_cls, + ): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "user_id": "user-1", + "title": "童年的夏天", + "category": "childhood", + "content": "那条路我一直记得。", + "images": [ + { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + } + ], + }, + )() + + settings_from_env.return_value = SimpleNamespace( + enabled=False, + max_per_chapter=2, + provider="liblib", + default_style="watercolor", + default_size="1024x1024", + poll_interval_seconds=3, + max_attempts=20, + liblib_template_uuid="tpl-uuid", + ) + + db = Mock() + db.get.return_value = chapter + session_local_cls.return_value = db + + result = generate_chapter_images.run("chapter-1") + + self.assertEqual(result, {"status": "disabled"}) + self.assertEqual(chapter.images, []) + prompt_service_cls.assert_not_called() + provider_cls.assert_not_called() + storage_cls.from_env.return_value.upload_bytes.assert_not_called() + db.commit.assert_called_once() + + @patch("api.tasks.memoir_tasks.SessionLocal") + @patch("api.tasks.memoir_tasks.TencentCosStorageService") + @patch("api.tasks.memoir_tasks.LiblibImageProvider") + @patch("api.tasks.memoir_tasks.MemoirImagePromptService") + def test_generate_chapter_images_converts_non_png_payload_before_upload( + self, + prompt_service_cls, + provider_cls, + storage_cls, + session_local_cls, + ): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "user_id": "user-1", + "title": "童年的夏天", + "category": "childhood", + "content": "那条路我一直记得。\n\n{{{{IMAGE:南方小镇的青石板路}}}}", + "images": [ + { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "status": "pending", + "url": None, + } + ], + }, + )() + + image_buffer = BytesIO() + Image.new("RGB", (2, 1), color="white").save(image_buffer, format="JPEG") + jpeg_bytes = image_buffer.getvalue() + + db = Mock() + db.get.return_value = chapter + session_local_cls.return_value = db + prompt_service_cls.return_value.build_prompt.return_value = { + "prompt": "A serene southern China town", + "style": "watercolor", + "size": "1024x1024", + "prompt_context": "childhood: 童年的夏天", + } + provider_inst = provider_cls.return_value + provider_inst.submit_generation.return_value = { + "status": "completed", + "image_url": "https://provider.example.com/1.jpg", + } + provider_inst.download_image.return_value = jpeg_bytes + storage_inst = storage_cls.from_env.return_value + storage_inst.upload_bytes.return_value = "https://cos.example.com/memoirs/u1/c1/0.png" + + generate_chapter_images.run("chapter-1") + + upload_args = storage_inst.upload_bytes.call_args.args + self.assertTrue(upload_args[0].startswith(b"\x89PNG\r\n\x1a\n")) + self.assertEqual(upload_args[2], "image/png") + @patch("api.tasks.memoir_tasks.SessionLocal") @patch("api.tasks.memoir_tasks.TencentCosStorageService") @patch("api.tasks.memoir_tasks.LiblibImageProvider") diff --git a/api/tests/test_memoir_image_bootstrap.py b/api/tests/test_memoir_image_bootstrap.py index d076d5d..3b26805 100644 --- a/api/tests/test_memoir_image_bootstrap.py +++ b/api/tests/test_memoir_image_bootstrap.py @@ -6,6 +6,42 @@ from api.tasks.memoir_tasks import initialize_chapter_images class MemoirImageBootstrapTest(unittest.TestCase): + def test_initialize_chapter_images_keeps_only_completed_assets_when_disabled(self): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "title": "童年的夏天", + "category": "childhood", + "content": "那条路我一直记得。", + "images": [ + { + "index": 0, + "placeholder": "{{IMAGE:南方小镇的青石板路}}", + "description": "南方小镇的青石板路", + "status": "completed", + "url": "https://cos.example.com/existing.png", + }, + { + "index": 1, + "placeholder": "{{IMAGE:奶奶坐在院子里的藤椅上}}", + "description": "奶奶坐在院子里的藤椅上", + "status": "pending", + "url": None, + }, + ], + }, + )() + + with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "false"}, clear=False): + assets = initialize_chapter_images(chapter) + + self.assertEqual(len(assets), 1) + self.assertEqual(assets[0]["status"], "completed") + self.assertEqual(assets[0]["url"], "https://cos.example.com/existing.png") + self.assertEqual(chapter.images, assets) + def test_initialize_chapter_images_sets_pending_assets_when_enabled(self): chapter = type( "ChapterStub", @@ -86,3 +122,30 @@ class MemoirImageBootstrapTest(unittest.TestCase): self.assertEqual(len(assets), 1) self.assertEqual(assets[0]["status"], "pending") self.assertEqual(assets[0]["placeholder"], "{{IMAGE:1938年初的上海弄堂口,冬日萧瑟}}") + + def test_initialize_chapter_images_normalizes_invalid_existing_asset_status(self): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "title": "童年的夏天", + "category": "childhood", + "content": "开头。\n\n{{IMAGE:南方小镇的青石板路}}\n\n结尾。", + "images": [ + { + "index": 0, + "placeholder": "{{IMAGE:南方小镇的青石板路}}", + "description": "南方小镇的青石板路", + "status": "mystery", + } + ], + }, + )() + + with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + assets = initialize_chapter_images(chapter) + + self.assertEqual(len(assets), 1) + self.assertEqual(assets[0]["status"], "failed") + self.assertEqual(assets[0]["error"], "invalid image status: mystery") diff --git a/api/tests/test_memoir_image_provider.py b/api/tests/test_memoir_image_provider.py index fa9b817..39a77b9 100644 --- a/api/tests/test_memoir_image_provider.py +++ b/api/tests/test_memoir_image_provider.py @@ -1,7 +1,9 @@ +import os import unittest from unittest.mock import Mock, patch from api.services.memoir_images.provider import LiblibImageProvider +from api.services.memoir_images.settings import DEFAULT_LIBLIB_TEMPLATE_UUID def _make_provider(http_client=None): @@ -15,17 +17,40 @@ def _make_provider(http_client=None): class LiblibSignatureTest(unittest.TestCase): - def test_sign_returns_url_with_auth_params(self): + def test_sign_returns_auth_params(self): provider = _make_provider() - url = provider._sign("/api/generate/webui/text2img/ultra") - self.assertIn("AccessKey=test-ak", url) - self.assertIn("Signature=", url) - self.assertIn("Timestamp=", url) - self.assertIn("SignatureNonce=", url) - self.assertTrue(url.startswith("https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra?")) + params = provider._sign("/api/generate/webui/text2img/ultra") + self.assertEqual(params["AccessKey"], "test-ak") + self.assertIn("Signature", params) + self.assertIn("Timestamp", params) + self.assertIn("SignatureNonce", params) class SubmitGenerationTest(unittest.TestCase): + def test_submit_keeps_auth_params_out_of_url_string(self): + http_client = Mock() + resp = Mock() + resp.json.return_value = { + "code": 0, + "data": {"generateUuid": "uuid-abc"}, + } + resp.raise_for_status = Mock() + http_client.post.return_value = resp + + provider = _make_provider(http_client) + provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor") + + call_kwargs = http_client.post.call_args + url = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs["url"] + params = call_kwargs.kwargs.get("params") + + self.assertEqual(url, "https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra") + self.assertNotIn("AccessKey=", url) + self.assertEqual(params["AccessKey"], "test-ak") + self.assertIn("Signature", params) + self.assertIn("Timestamp", params) + self.assertIn("SignatureNonce", params) + def test_submit_returns_processing_with_job_id(self): http_client = Mock() resp = Mock() @@ -46,9 +71,29 @@ class SubmitGenerationTest(unittest.TestCase): call_kwargs = http_client.post.call_args body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") self.assertEqual(body["templateUuid"], "tpl-uuid") - self.assertEqual(body["generateParams"]["prompt"], "a cat") + self.assertIn("a cat", body["generateParams"]["prompt"]) + self.assertIn("watercolor", body["generateParams"]["prompt"].lower()) self.assertEqual(body["generateParams"]["aspectRatio"], "square") + def test_submit_applies_style_when_prompt_does_not_include_it(self): + http_client = Mock() + resp = Mock() + resp.json.return_value = { + "code": 0, + "data": {"generateUuid": "uuid-abc"}, + } + resp.raise_for_status = Mock() + http_client.post.return_value = resp + + provider = _make_provider(http_client) + provider.submit_generation(prompt="a cat under the rain", size="1024x1024", style="watercolor") + + call_kwargs = http_client.post.call_args + body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + + self.assertIn("a cat under the rain", body["generateParams"]["prompt"]) + self.assertIn("watercolor", body["generateParams"]["prompt"].lower()) + def test_submit_raises_on_error_code(self): http_client = Mock() resp = Mock() @@ -131,6 +176,7 @@ class PollUntilCompleteTest(unittest.TestCase): class DownloadImageTest(unittest.TestCase): + @patch.dict(os.environ, {"MEMOIR_IMAGE_DOWNLOAD_HOSTS": "cdn.example.com"}, clear=False) def test_download_fetches_binary_payload(self): http_client = Mock() resp = Mock() @@ -142,3 +188,57 @@ class DownloadImageTest(unittest.TestCase): payload = provider.download_image({"image_url": "https://cdn.example.com/1.png"}) self.assertEqual(payload, b"png-bytes") + + def test_download_rejects_unapproved_host(self): + http_client = Mock() + provider = _make_provider(http_client) + + with self.assertRaises(ValueError): + provider.download_image({"image_url": "https://evil.example.com/1.png"}) + + http_client.get.assert_not_called() + + +class ProviderResourceManagementTest(unittest.TestCase): + @patch("api.services.memoir_images.provider.httpx.Client") + def test_provider_closes_owned_http_client(self, httpx_client_cls): + http_client = Mock() + httpx_client_cls.return_value = http_client + + provider = LiblibImageProvider( + access_key="test-ak", + secret_key="test-sk", + base_url="https://openapi.liblibai.cloud", + template_uuid="tpl-uuid", + ) + + provider.close() + + http_client.close.assert_called_once() + + def test_provider_does_not_close_injected_http_client(self): + http_client = Mock() + provider = LiblibImageProvider( + http_client=http_client, + access_key="test-ak", + secret_key="test-sk", + base_url="https://openapi.liblibai.cloud", + template_uuid="tpl-uuid", + ) + + provider.close() + + http_client.close.assert_not_called() + + +class ProviderDefaultsTest(unittest.TestCase): + @patch.dict(os.environ, {"LIBLIB_TEMPLATE_UUID": ""}, clear=False) + def test_provider_uses_shared_template_uuid_default(self): + provider = LiblibImageProvider( + http_client=Mock(), + access_key="test-ak", + secret_key="test-sk", + base_url="https://openapi.liblibai.cloud", + ) + + self.assertEqual(provider.template_uuid, DEFAULT_LIBLIB_TEMPLATE_UUID) diff --git a/api/tests/test_process_memoir_segments_image_enqueue.py b/api/tests/test_process_memoir_segments_image_enqueue.py index 0e17efe..7bbf1cd 100644 --- a/api/tests/test_process_memoir_segments_image_enqueue.py +++ b/api/tests/test_process_memoir_segments_image_enqueue.py @@ -73,3 +73,64 @@ class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): delay_events = [event for event in events if event.startswith("delay:")] self.assertEqual(len(delay_events), 1) self.assertGreater(events.index(delay_events[0]), events.index("commit")) + + @patch("api.tasks.memoir_tasks._update_task_status_sync") + @patch("api.tasks.memoir_tasks._release_chapter_lock") + @patch("api.tasks.memoir_tasks._acquire_chapter_lock", return_value=True) + @patch("api.tasks.memoir_tasks._classify_chapter_category", return_value="childhood") + @patch("api.tasks.memoir_tasks._get_or_create_state_sync") + @patch("api.tasks.memoir_tasks.llm_service.get_llm", return_value=None) + @patch("api.tasks.memoir_tasks.generate_chapter_images.delay") + @patch("api.tasks.memoir_tasks.SessionLocal") + @patch("api.tasks.memoir_tasks.MemoirImageSettings.from_env") + def test_process_memoir_segments_does_not_enqueue_image_jobs_when_feature_disabled( + self, + settings_from_env, + session_local_cls, + delay_mock, + _get_llm, + get_state_mock, + _classify_mock, + _acquire_lock_mock, + _release_lock_mock, + _update_status_mock, + ): + settings_from_env.return_value = MemoirImageSettings( + enabled=False, + max_per_chapter=2, + provider="liblib", + default_style="watercolor", + default_size="1024x1024", + poll_interval_seconds=3, + max_attempts=20, + liblib_template_uuid="tpl-uuid", + ) + get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) + + segment = SimpleNamespace( + id="segment-1", + transcript_text="那条路我一直记得。\n\n{{{{IMAGE:南方小镇的青石板路}}}}", + processed=False, + ) + + segments_result = Mock() + segments_result.scalars.return_value.all.return_value = [segment] + + chapter_result = Mock() + chapter_result.scalar_one_or_none.return_value = None + + book_result = Mock() + book_result.scalar_one_or_none.return_value = None + + db = Mock() + db.execute.side_effect = [segments_result, chapter_result, book_result] + db.get.return_value = None + session_local_cls.return_value = db + + task_self = SimpleNamespace( + request=SimpleNamespace(id="task-1"), + retry=Mock(side_effect=AssertionError("retry should not be called")), + ) + process_memoir_segments.run.__func__(task_self, "user-1", ["segment-1"]) + + delay_mock.assert_not_called() diff --git a/app-android/app/src/main/java/com/huaga/life_echo/network/models/MemoirImageStatus.kt b/app-android/app/src/main/java/com/huaga/life_echo/network/models/MemoirImageStatus.kt new file mode 100644 index 0000000..5711b72 --- /dev/null +++ b/app-android/app/src/main/java/com/huaga/life_echo/network/models/MemoirImageStatus.kt @@ -0,0 +1,6 @@ +package com.huaga.life_echo.network.models + +const val MEMOIR_IMAGE_STATUS_PENDING = "pending" +const val MEMOIR_IMAGE_STATUS_PROCESSING = "processing" +const val MEMOIR_IMAGE_STATUS_COMPLETED = "completed" +const val MEMOIR_IMAGE_STATUS_FAILED = "failed" diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MemoirImagePolling.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MemoirImagePolling.kt index 2fad380..a8e4f3c 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MemoirImagePolling.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MemoirImagePolling.kt @@ -1,13 +1,41 @@ package com.huaga.life_echo.ui.screens import com.huaga.life_echo.network.models.ChapterDto +import com.huaga.life_echo.network.models.MEMOIR_IMAGE_STATUS_PENDING +import com.huaga.life_echo.network.models.MEMOIR_IMAGE_STATUS_PROCESSING internal const val MEMOIR_IMAGE_POLL_INTERVAL_MS = 3_000L +internal const val MEMOIR_IMAGE_PROVIDER_MAX_ATTEMPTS = 60L +internal const val MEMOIR_IMAGE_POLL_GRACE_MS = 30_000L internal fun hasPendingMemoirImages(chapters: List): Boolean { return chapters.any { chapter -> chapter.images.any { image -> - image.status == "pending" || image.status == "processing" + image.status == MEMOIR_IMAGE_STATUS_PENDING || image.status == MEMOIR_IMAGE_STATUS_PROCESSING } } } + +internal fun longestPendingMemoirImageSequence(chapters: List): Long { + return chapters.maxOfOrNull { chapter -> + chapter.images.count { image -> + image.status == MEMOIR_IMAGE_STATUS_PENDING || image.status == MEMOIR_IMAGE_STATUS_PROCESSING + }.toLong() + } ?: 0L +} + +internal fun memoirImagePollTimeoutMs(maxObservedPendingImages: Long): Long { + val pendingImages = maxOf(1L, maxObservedPendingImages) + return pendingImages * MEMOIR_IMAGE_POLL_INTERVAL_MS * MEMOIR_IMAGE_PROVIDER_MAX_ATTEMPTS + + MEMOIR_IMAGE_POLL_GRACE_MS +} + +internal fun shouldContinueMemoirImagePolling( + chapters: List, + pollStartedAtMs: Long, + nowMs: Long, + maxObservedPendingImages: Long, +): Boolean { + if (!hasPendingMemoirImages(chapters)) return false + return nowMs - pollStartedAtMs < memoirImagePollTimeoutMs(maxObservedPendingImages) +} diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MyMemoirScreen.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MyMemoirScreen.kt index 3b2dd0c..789f0eb 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MyMemoirScreen.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/screens/MyMemoirScreen.kt @@ -168,17 +168,31 @@ fun MyMemoirScreen( val shouldPollChapterImages = remember(chapterDtos) { hasPendingMemoirImages(chapterDtos) } + val latestChapterDtos by rememberUpdatedState(chapterDtos) val latestIsLoading by rememberUpdatedState(isLoading) val latestIsRefreshing by rememberUpdatedState(isRefreshing) LaunchedEffect(shouldPollChapterImages) { if (!shouldPollChapterImages) return@LaunchedEffect + val pollStartedAtMs = System.currentTimeMillis() + var maxObservedPendingImages = longestPendingMemoirImageSequence(latestChapterDtos) - while (true) { + while ( + shouldContinueMemoirImagePolling( + chapters = latestChapterDtos, + pollStartedAtMs = pollStartedAtMs, + nowMs = System.currentTimeMillis(), + maxObservedPendingImages = maxObservedPendingImages, + ) + ) { delay(MEMOIR_IMAGE_POLL_INTERVAL_MS) if (!latestIsLoading && !latestIsRefreshing) { viewModel.refreshChapters() } + maxObservedPendingImages = maxOf( + maxObservedPendingImages, + longestPendingMemoirImageSequence(latestChapterDtos), + ) } } diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/screens/MemoirImagePollingTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/screens/MemoirImagePollingTest.kt index eb7a53b..e42f7fc 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/ui/screens/MemoirImagePollingTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/screens/MemoirImagePollingTest.kt @@ -2,6 +2,7 @@ package com.huaga.life_echo.ui.screens import com.huaga.life_echo.network.models.ChapterDto import com.huaga.life_echo.network.models.ChapterImageDto +import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue import org.junit.Test @@ -28,7 +29,85 @@ class MemoirImagePollingTest { assertFalse(hasPendingMemoirImages(chapters)) } - private fun chapterWithImages(status: String): ChapterDto { + @Test + fun memoirImagePollTimeoutMs_scalesWithPendingImagesInSameChapter() { + val chapters = listOf( + chapterWithImages("processing", imageCount = 2), + ) + + assertTrue(memoirImagePollTimeoutMs(2L) > MEMOIR_IMAGE_POLL_GRACE_MS) + assertEquals( + 390_000L, + memoirImagePollTimeoutMs(2L), + ) + assertEquals(2L, longestPendingMemoirImageSequence(chapters)) + } + + @Test + fun shouldContinueMemoirImagePolling_returnsFalse_afterDynamicTimeout() { + val chapters = listOf( + chapterWithImages("processing", imageCount = 2), + ) + + assertFalse( + shouldContinueMemoirImagePolling( + chapters = chapters, + pollStartedAtMs = 0L, + nowMs = memoirImagePollTimeoutMs(2L) + 1L, + maxObservedPendingImages = 2L, + ) + ) + } + + @Test + fun shouldContinueMemoirImagePolling_returnsTrue_beforeDynamicTimeout_whenPending() { + val chapters = listOf( + chapterWithImages("processing", imageCount = 2), + ) + + assertTrue( + shouldContinueMemoirImagePolling( + chapters = chapters, + pollStartedAtMs = 0L, + nowMs = memoirImagePollTimeoutMs(2L) - 1L, + maxObservedPendingImages = 2L, + ) + ) + } + + @Test + fun shouldContinueMemoirImagePolling_keepsOriginalWindow_whenPendingCountDropsMidSession() { + val chapters = listOf( + chapterWithImages("processing", imageCount = 1), + ) + + assertTrue( + shouldContinueMemoirImagePolling( + chapters = chapters, + pollStartedAtMs = 0L, + nowMs = 300_000L, + maxObservedPendingImages = 2L, + ) + ) + } + + @Test + fun shouldContinueMemoirImagePolling_returnsFalse_whenLatestChapterStateHasNoPendingImages() { + val chapters = listOf( + chapterWithImages("completed", imageCount = 1), + ) + + assertFalse( + shouldContinueMemoirImagePolling( + chapters = chapters, + pollStartedAtMs = 0L, + nowMs = 1L, + maxObservedPendingImages = 1L, + ) + ) + } + + private fun chapterWithImages(status: String, imageCount: Int = 1): ChapterDto { return ChapterDto( id = "chapter-$status", title = "title-$status", @@ -36,13 +115,13 @@ class MemoirImagePollingTest { order_index = 0, status = "partial", category = "childhood", - images = listOf( + images = (0 until imageCount).map { index -> ChapterImageDto( - index = 0, - placeholder = "{{IMAGE:$status}}", + index = index, + placeholder = "{{IMAGE:$status-$index}}", description = status, prompt = null, - url = if (status == "completed") "https://cos.example.com/$status.png" else null, + url = if (status == "completed") "https://cos.example.com/$status-$index.png" else null, status = status, provider = "liblib", style = "memoir", @@ -51,7 +130,7 @@ class MemoirImagePollingTest { created_at = null, updated_at = null, ) - ), + }, updated_at = null, is_new = false, source_segments = emptyList(),