diff --git a/api/requirements.txt b/api/requirements.txt index bda6e32..5082a07 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -48,7 +48,7 @@ tencentcloud-sdk-python>=3.0.1000 openai # Tencent COS for memoir image storage -cos-python-sdk-v5>=1.9.30 +cos-python-sdk-v5>=1.9.40 # Payment - WeChat Pay & Alipay wechatpayv3>=0.3.0 diff --git a/api/services/memoir_images/provider.py b/api/services/memoir_images/provider.py index dee32c0..51e0be5 100644 --- a/api/services/memoir_images/provider.py +++ b/api/services/memoir_images/provider.py @@ -1,43 +1,141 @@ +import base64 +import hmac +import logging import os import time +import uuid +from hashlib import sha1 import httpx +logger = logging.getLogger(__name__) + +_SIZE_TO_ASPECT_RATIO = { + "1024x1024": "square", + "768x1024": "portrait", + "1024x768": "landscape", + "1280x720": "landscape", + "720x1280": "portrait", +} + class LiblibImageProvider: + """Liblib (https://openapi.liblibai.cloud) image generation adapter.""" + def __init__( self, - http_client=None, - api_key: str | None = None, + http_client: httpx.Client | None = None, + access_key: str | None = None, + secret_key: str | None = None, base_url: str | None = None, + template_uuid: str | None = None, ): - self.http_client = http_client or httpx.Client(timeout=60) - self.api_key = api_key or os.getenv("LIBLIB_API_KEY", "") - self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "")).rstrip("/") + self.http_client = http_client or httpx.Client(timeout=120) + self.access_key = access_key or os.getenv("LIBLIB_ACCESS_KEY", "") + self.secret_key = secret_key or os.getenv("LIBLIB_SECRET_KEY", "") + self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "https://openapi.liblibai.cloud")).rstrip("/") + self.template_uuid = template_uuid or os.getenv( + "LIBLIB_TEMPLATE_UUID", "5d7e67009b344550bc1aa6ccbfa1d7f4" + ) + + # ------------------------------------------------------------------ + # Signature helpers + # ------------------------------------------------------------------ + + def _sign(self, uri: str) -> str: + """Build a full URL with Liblib HMAC-SHA1 query-string auth.""" + timestamp = str(int(time.time() * 1000)) + nonce = str(uuid.uuid4()) + content = "&".join((uri, timestamp, nonce)) + digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest() + signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return ( + f"{self.base_url}{uri}" + f"?AccessKey={self.access_key}" + f"&Signature={signature}" + f"&Timestamp={timestamp}" + f"&SignatureNonce={nonce}" + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ def submit_generation(self, prompt: str, size: str, style: str) -> dict: - response = self.http_client.post( - f"{self.base_url}/v1/images/generations", - headers={"Authorization": f"Bearer {self.api_key}"}, - json={"prompt": prompt, "size": size, "style": style}, - ) - data = response.json() - if data.get("image_url"): - return {"status": "completed", "image_url": data["image_url"], "job_id": None} - return {"status": "processing", "job_id": data.get("task_id"), "image_url": None} + uri = "/api/generate/webui/text2img/ultra" + url = self._sign(uri) - def poll_until_complete(self, job: dict, poll_interval_seconds: int, max_attempts: int) -> dict: - for _ in range(max_attempts): - response = self.http_client.get( - f"{self.base_url}/v1/images/generations/{job['job_id']}", - headers={"Authorization": f"Bearer {self.api_key}"}, + aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size, "square") + + body = { + "templateUuid": self.template_uuid, + "generateParams": { + "prompt": prompt, + "aspectRatio": aspect_ratio, + "imgCount": 1, + "steps": 30, + }, + } + response = self.http_client.post( + url, + headers={"Content-Type": "application/json"}, + json=body, + ) + response.raise_for_status() + data = response.json() + + if data.get("code") != 0: + raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}") + + generate_uuid = data["data"]["generateUuid"] + return {"status": "processing", "job_id": generate_uuid, "image_url": None} + + def poll_until_complete( + self, job: dict, poll_interval_seconds: int, max_attempts: int + ) -> dict: + uri = "/api/generate/webui/status" + + for attempt in range(max_attempts): + url = self._sign(uri) + response = self.http_client.post( + url, + headers={"Content-Type": "application/json"}, + json={"generateUuid": job["job_id"]}, ) + response.raise_for_status() data = response.json() - if data.get("image_url"): - return {"status": "completed", "image_url": data["image_url"], "job_id": job["job_id"]} + + if data.get("code") != 0: + raise RuntimeError(f"Liblib status query failed: {data.get('msg', data)}") + + result = data.get("data", {}) + status = result.get("generateStatus") + + if status == 5: # success + images = result.get("images") or [] + if images: + return { + "status": "completed", + "image_url": images[0]["imageUrl"], + "job_id": job["job_id"], + } + raise RuntimeError(f"Liblib returned success but no images for {job['job_id']}") + + if status == 6: # failed + raise RuntimeError(f"Liblib generation failed: {result.get('generateMsg', 'unknown')}") + + if status == 7: # timeout on Liblib side + raise TimeoutError(f"Liblib generation timed out on server side: {job['job_id']}") + + logger.debug( + "Liblib poll attempt %d/%d, status=%s, job=%s", + attempt + 1, max_attempts, status, job["job_id"], + ) time.sleep(poll_interval_seconds) - raise TimeoutError(f"Liblib image generation timed out for {job['job_id']}") + + raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}") def download_image(self, job: dict) -> bytes: response = self.http_client.get(job["image_url"]) + response.raise_for_status() return response.content diff --git a/api/services/memoir_images/settings.py b/api/services/memoir_images/settings.py index bb181ef..ec1e9df 100644 --- a/api/services/memoir_images/settings.py +++ b/api/services/memoir_images/settings.py @@ -11,6 +11,7 @@ class MemoirImageSettings: default_size: str poll_interval_seconds: int max_attempts: int + liblib_template_uuid: str @classmethod def from_env(cls) -> "MemoirImageSettings": @@ -21,5 +22,8 @@ class MemoirImageSettings: default_style=os.getenv("MEMOIR_IMAGE_STYLE_DEFAULT", "watercolor"), default_size=os.getenv("MEMOIR_IMAGE_SIZE_DEFAULT", "1024x1024"), poll_interval_seconds=int(os.getenv("MEMOIR_IMAGE_POLL_INTERVAL", "3")), - max_attempts=int(os.getenv("MEMOIR_IMAGE_MAX_ATTEMPTS", "20")), + max_attempts=int(os.getenv("MEMOIR_IMAGE_MAX_ATTEMPTS", "60")), + liblib_template_uuid=os.getenv( + "LIBLIB_TEMPLATE_UUID", "5d7e67009b344550bc1aa6ccbfa1d7f4" + ), ) diff --git a/api/tasks/memoir_tasks.py b/api/tasks/memoir_tasks.py index 3e5421e..5517fcc 100644 --- a/api/tasks/memoir_tasks.py +++ b/api/tasks/memoir_tasks.py @@ -78,6 +78,54 @@ def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Di except Exception as e: logger.error(f"更新任务状态失败: {e}") + +def _merge_chapter_image_assets( + existing_images: list[dict] | None, + placeholders: list[dict], + provider: str, + style: str, + size: str, + now_iso: str, +) -> list[dict]: + existing_by_placeholder = { + item.get("placeholder"): dict(item) + for item in (existing_images or []) + if item.get("placeholder") + } + merged_assets: list[dict] = [] + + for item in placeholders: + existing = existing_by_placeholder.get(item["placeholder"]) + if existing: + merged_item = dict(existing) + merged_item["index"] = item["index"] + merged_item["placeholder"] = item["placeholder"] + merged_item["description"] = item["description"] + merged_item["provider"] = merged_item.get("provider") or provider + merged_item["style"] = merged_item.get("style") or style + merged_item["size"] = merged_item.get("size") or size + merged_item["created_at"] = merged_item.get("created_at") or now_iso + merged_item["updated_at"] = merged_item.get("updated_at") or now_iso + if merged_item.get("status") == "completed" and not merged_item.get("url"): + merged_item["status"] = "failed" + merged_item["error"] = merged_item.get("error") or "missing image url" + else: + merged_item = build_initial_image_assets( + placeholders=[item], + provider=provider, + style=style, + size=size, + now_iso=now_iso, + )[0] + merged_assets.append(merged_item) + + return merged_assets + + +def chapter_has_images_to_generate(images: list[dict] | None) -> bool: + return any(item.get("status") in {"pending", "failed"} for item in (images or [])) + + def initialize_chapter_images(chapter) -> list[dict]: """Parse IMAGE placeholders from chapter content and build pending image assets.""" settings = MemoirImageSettings.from_env() @@ -88,15 +136,14 @@ def initialize_chapter_images(chapter) -> list[dict]: prompt_service = MemoirImagePromptService(llm=None, settings=settings) placeholders = parse_image_placeholders(chapter.content, settings.max_per_chapter) style = prompt_service.CATEGORY_STYLE_MAP.get(chapter.category, settings.default_style) - chapter.images = build_initial_image_assets( + chapter.images = _merge_chapter_image_assets( + existing_images=chapter.images, placeholders=placeholders, provider=settings.provider, style=style, size=settings.default_size, now_iso=datetime.now(timezone.utc).isoformat(), ) - if chapter.images: - generate_chapter_images.delay(chapter.id) return chapter.images @@ -234,6 +281,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): try: db = SessionLocal() try: + chapters_to_enqueue: set[str] = set() # 获取段落 stmt = select(Segment).where(Segment.id.in_(segment_ids)) result = db.execute(stmt) @@ -401,6 +449,8 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): db.flush() initialize_chapter_images(chapter) + if chapter_has_images_to_generate(chapter.images): + chapters_to_enqueue.add(chapter.id) # 更新 Book stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc()) @@ -426,6 +476,13 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]): seg.processed = True db.commit() + + for chapter_id in sorted(chapters_to_enqueue): + try: + generate_chapter_images.delay(chapter_id) + except Exception as exc: + logger.warning(f"补图任务派发失败: chapter={chapter_id}, error={exc}") + logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}") # 更新任务状态为成功 @@ -546,16 +603,21 @@ def generate_chapter_images(self, chapter_id: str): settings = MemoirImageSettings.from_env() prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings) - provider = LiblibImageProvider() + provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid) storage = TencentCosStorageService.from_env() + images = [dict(item) for item in (chapter.images or [])] - for item in chapter.images: + for index, item in enumerate(images): if item.get("status") == "completed" and item.get("url"): continue if item.get("status") not in {"pending", "failed"}: continue - item["status"] = "processing" + current_item = dict(item) + current_item["status"] = "processing" + current_item["updated_at"] = datetime.now(timezone.utc).isoformat() + images[index] = current_item + chapter.images = images db.commit() try: @@ -580,21 +642,23 @@ def generate_chapter_images(self, chapter_id: str): max_attempts=settings.max_attempts, ) image_bytes = provider.download_image(job) - key = build_cos_key(chapter.user_id, chapter.id, item["index"], prompt_data["prompt"]) - item["url"] = storage.upload_bytes(image_bytes, key, "image/png") - item["prompt"] = prompt_data["prompt"] - item["style"] = prompt_data["style"] - item["size"] = prompt_data["size"] - item["status"] = "completed" - item["error"] = None + key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"]) + current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png") + current_item["prompt"] = prompt_data["prompt"] + current_item["style"] = prompt_data["style"] + current_item["size"] = prompt_data["size"] + current_item["status"] = "completed" + current_item["error"] = None except Exception as exc: - item["status"] = "failed" - item["error"] = str(exc) - logger.warning(f"图片生成失败: chapter={chapter_id}, index={item.get('index')}, error={exc}") + current_item["status"] = "failed" + current_item["error"] = str(exc) + logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}") - item["updated_at"] = datetime.now(timezone.utc).isoformat() + current_item["updated_at"] = datetime.now(timezone.utc).isoformat() + images[index] = current_item + chapter.images = images + db.commit() - db.commit() return {"status": "success"} finally: db.close() diff --git a/api/tests/test_memoir_image_bootstrap.py b/api/tests/test_memoir_image_bootstrap.py index 64fb09d..2270eda 100644 --- a/api/tests/test_memoir_image_bootstrap.py +++ b/api/tests/test_memoir_image_bootstrap.py @@ -1,12 +1,12 @@ +import os import unittest -from unittest.mock import patch +import unittest.mock from api.tasks.memoir_tasks import initialize_chapter_images class MemoirImageBootstrapTest(unittest.TestCase): - @patch("api.tasks.memoir_tasks.generate_chapter_images.delay") - def test_initialize_chapter_images_sets_pending_assets_and_enqueues_task(self, delay_mock): + def test_initialize_chapter_images_sets_pending_assets_when_enabled(self): chapter = type( "ChapterStub", (), @@ -19,8 +19,50 @@ class MemoirImageBootstrapTest(unittest.TestCase): }, )() - assets = initialize_chapter_images(chapter) + with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + assets = initialize_chapter_images(chapter) self.assertEqual(len(assets), 1) self.assertEqual(assets[0]["status"], "pending") - delay_mock.assert_called_once_with("chapter-1") + + def test_initialize_chapter_images_preserves_completed_assets_and_adds_only_new_placeholders(self): + chapter = type( + "ChapterStub", + (), + { + "id": "chapter-1", + "title": "童年的夏天", + "category": "childhood", + "content": ( + "那条路我一直记得。\n\n" + "{{{{IMAGE:南方小镇的青石板路}}}}\n\n" + "奶奶总坐在门口。\n\n" + "{{{{IMAGE:奶奶坐在院子里的藤椅上}}}}" + ), + "images": [ + { + "index": 0, + "placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}", + "description": "南方小镇的青石板路", + "prompt": "A serene southern China town", + "url": "https://cos.example.com/existing.png", + "status": "completed", + "provider": "liblib", + "style": "watercolor", + "size": "1024x1024", + "error": None, + "created_at": "2026-03-10T10:00:00Z", + "updated_at": "2026-03-10T10:00:00Z", + } + ], + }, + )() + + with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False): + assets = initialize_chapter_images(chapter) + + self.assertEqual(len(assets), 2) + self.assertEqual(assets[0]["status"], "completed") + self.assertEqual(assets[0]["url"], "https://cos.example.com/existing.png") + self.assertEqual(assets[1]["status"], "pending") + self.assertEqual(assets[1]["description"], "奶奶坐在院子里的藤椅上") diff --git a/api/tests/test_memoir_image_provider.py b/api/tests/test_memoir_image_provider.py index e0c8de5..fa9b817 100644 --- a/api/tests/test_memoir_image_provider.py +++ b/api/tests/test_memoir_image_provider.py @@ -1,58 +1,144 @@ import unittest -from unittest.mock import Mock +from unittest.mock import Mock, patch from api.services.memoir_images.provider import LiblibImageProvider -class MemoirImageProviderTest(unittest.TestCase): - def test_submit_generation_handles_sync_provider_response(self): +def _make_provider(http_client=None): + return LiblibImageProvider( + http_client=http_client or Mock(), + access_key="test-ak", + secret_key="test-sk", + base_url="https://openapi.liblibai.cloud", + template_uuid="tpl-uuid", + ) + + +class LiblibSignatureTest(unittest.TestCase): + def test_sign_returns_url_with_auth_params(self): + provider = _make_provider() + url = provider._sign("/api/generate/webui/text2img/ultra") + self.assertIn("AccessKey=test-ak", url) + self.assertIn("Signature=", url) + self.assertIn("Timestamp=", url) + self.assertIn("SignatureNonce=", url) + self.assertTrue(url.startswith("https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra?")) + + +class SubmitGenerationTest(unittest.TestCase): + def test_submit_returns_processing_with_job_id(self): http_client = Mock() - http_client.post.return_value.json.return_value = { - "status": "succeeded", - "image_url": "https://provider.example.com/1.png", + resp = Mock() + resp.json.return_value = { + "code": 0, + "data": {"generateUuid": "uuid-abc"}, } - provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + resp.raise_for_status = Mock() + http_client.post.return_value = resp - job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor") - - self.assertEqual(job["status"], "completed") - self.assertEqual(job["image_url"], "https://provider.example.com/1.png") - - def test_submit_generation_handles_async_provider_response(self): - http_client = Mock() - http_client.post.return_value.json.return_value = { - "task_id": "job-123", - "status": "queued", - } - provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") - - job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor") + provider = _make_provider(http_client) + job = provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor") self.assertEqual(job["status"], "processing") - self.assertEqual(job["job_id"], "job-123") + self.assertEqual(job["job_id"], "uuid-abc") + self.assertIsNone(job["image_url"]) - def test_poll_until_complete_returns_completed_job(self): + call_kwargs = http_client.post.call_args + body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + self.assertEqual(body["templateUuid"], "tpl-uuid") + self.assertEqual(body["generateParams"]["prompt"], "a cat") + self.assertEqual(body["generateParams"]["aspectRatio"], "square") + + def test_submit_raises_on_error_code(self): http_client = Mock() - http_client.get.return_value.json.side_effect = [ - {"status": "queued"}, - {"status": "succeeded", "image_url": "https://provider.example.com/1.png"}, - ] - provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + resp = Mock() + resp.json.return_value = {"code": 100000, "msg": "param error"} + resp.raise_for_status = Mock() + http_client.post.return_value = resp + provider = _make_provider(http_client) + with self.assertRaises(RuntimeError): + provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor") + + +class PollUntilCompleteTest(unittest.TestCase): + def test_returns_completed_on_status_5(self): + http_client = Mock() + pending_resp = Mock() + pending_resp.json.return_value = { + "code": 0, + "data": {"generateStatus": 2, "images": []}, + } + pending_resp.raise_for_status = Mock() + + success_resp = Mock() + success_resp.json.return_value = { + "code": 0, + "data": { + "generateStatus": 5, + "images": [{"imageUrl": "https://cdn.example.com/1.png", "auditStatus": 3}], + }, + } + success_resp.raise_for_status = Mock() + http_client.post.side_effect = [pending_resp, success_resp] + + provider = _make_provider(http_client) job = provider.poll_until_complete( - {"status": "processing", "job_id": "job-123"}, + {"status": "processing", "job_id": "uuid-abc"}, poll_interval_seconds=0, - max_attempts=2, + max_attempts=3, ) self.assertEqual(job["status"], "completed") - self.assertEqual(job["image_url"], "https://provider.example.com/1.png") + self.assertEqual(job["image_url"], "https://cdn.example.com/1.png") + self.assertEqual(job["job_id"], "uuid-abc") - def test_download_image_fetches_binary_payload(self): + def test_raises_on_status_6_failure(self): http_client = Mock() - http_client.get.return_value.content = b"png-bytes" - provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + resp = Mock() + resp.json.return_value = { + "code": 0, + "data": {"generateStatus": 6, "generateMsg": "content violation"}, + } + resp.raise_for_status = Mock() + http_client.post.return_value = resp - payload = provider.download_image({"image_url": "https://provider.example.com/1.png"}) + provider = _make_provider(http_client) + with self.assertRaises(RuntimeError, msg="content violation"): + provider.poll_until_complete( + {"status": "processing", "job_id": "uuid-abc"}, + poll_interval_seconds=0, + max_attempts=2, + ) + + def test_raises_timeout_after_max_attempts(self): + http_client = Mock() + resp = Mock() + resp.json.return_value = { + "code": 0, + "data": {"generateStatus": 2, "images": []}, + } + resp.raise_for_status = Mock() + http_client.post.return_value = resp + + provider = _make_provider(http_client) + with self.assertRaises(TimeoutError): + provider.poll_until_complete( + {"status": "processing", "job_id": "uuid-abc"}, + poll_interval_seconds=0, + max_attempts=2, + ) + + +class DownloadImageTest(unittest.TestCase): + def test_download_fetches_binary_payload(self): + http_client = Mock() + resp = Mock() + resp.content = b"png-bytes" + resp.raise_for_status = Mock() + http_client.get.return_value = resp + + provider = _make_provider(http_client) + payload = provider.download_image({"image_url": "https://cdn.example.com/1.png"}) self.assertEqual(payload, b"png-bytes") diff --git a/api/tests/test_process_memoir_segments_image_enqueue.py b/api/tests/test_process_memoir_segments_image_enqueue.py new file mode 100644 index 0000000..f6b5de8 --- /dev/null +++ b/api/tests/test_process_memoir_segments_image_enqueue.py @@ -0,0 +1,74 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from api.tasks.memoir_tasks import MemoirImageSettings, process_memoir_segments + + +class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase): + @patch("api.tasks.memoir_tasks._update_task_status_sync") + @patch("api.tasks.memoir_tasks._release_chapter_lock") + @patch("api.tasks.memoir_tasks._acquire_chapter_lock", return_value=True) + @patch("api.tasks.memoir_tasks._classify_chapter_category", return_value="childhood") + @patch("api.tasks.memoir_tasks._get_or_create_state_sync") + @patch("api.tasks.memoir_tasks.llm_service.get_llm", return_value=None) + @patch("api.tasks.memoir_tasks.generate_chapter_images.delay") + @patch("api.tasks.memoir_tasks.SessionLocal") + @patch("api.tasks.memoir_tasks.MemoirImageSettings.from_env") + def test_process_memoir_segments_enqueues_image_jobs_after_commit( + self, + settings_from_env, + session_local_cls, + delay_mock, + _get_llm, + get_state_mock, + _classify_mock, + _acquire_lock_mock, + _release_lock_mock, + _update_status_mock, + ): + settings_from_env.return_value = MemoirImageSettings( + enabled=True, + max_per_chapter=2, + provider="liblib", + default_style="watercolor", + default_size="1024x1024", + poll_interval_seconds=3, + max_attempts=20, + ) + get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={}) + + segment = SimpleNamespace( + id="segment-1", + transcript_text="那条路我一直记得。\n\n{{{{IMAGE:南方小镇的青石板路}}}}", + processed=False, + ) + + segments_result = Mock() + segments_result.scalars.return_value.all.return_value = [segment] + + chapter_result = Mock() + chapter_result.scalar_one_or_none.return_value = None + + book_result = Mock() + book_result.scalar_one_or_none.return_value = None + + db = Mock() + db.execute.side_effect = [segments_result, chapter_result, book_result] + db.get.return_value = None + session_local_cls.return_value = db + + events: list[str] = [] + db.commit.side_effect = lambda: events.append("commit") + delay_mock.side_effect = lambda chapter_id: events.append(f"delay:{chapter_id}") + + task_self = SimpleNamespace( + request=SimpleNamespace(id="task-1"), + retry=Mock(side_effect=AssertionError("retry should not be called")), + ) + process_memoir_segments.run.__func__(task_self, "user-1", ["segment-1"]) + + self.assertIn("commit", events) + delay_events = [event for event in events if event.startswith("delay:")] + self.assertEqual(len(delay_events), 1) + self.assertGreater(events.index(delay_events[0]), events.index("commit")) diff --git a/app-android/app/src/androidTest/java/com/huaga/life_echo/ui/components/memoir/ChapterReadingImageBlocksTest.kt b/app-android/app/src/androidTest/java/com/huaga/life_echo/ui/components/memoir/ChapterReadingImageBlocksTest.kt index e0036d1..39a40c4 100644 --- a/app-android/app/src/androidTest/java/com/huaga/life_echo/ui/components/memoir/ChapterReadingImageBlocksTest.kt +++ b/app-android/app/src/androidTest/java/com/huaga/life_echo/ui/components/memoir/ChapterReadingImageBlocksTest.kt @@ -1,6 +1,7 @@ package com.huaga.life_echo.ui.components.memoir import androidx.activity.ComponentActivity +import androidx.compose.ui.test.assertDoesNotExist import androidx.compose.ui.test.assertIsDisplayed import androidx.compose.ui.test.junit4.createAndroidComposeRule import androidx.compose.ui.test.onNodeWithTag @@ -50,7 +51,7 @@ class ChapterReadingImageBlocksTest { } @Test - fun chapterReadingView_showsFailureCard_forFailedImage_withoutRawPlaceholderText() { + fun chapterReadingView_hidesFailedImageBlock() { val chapter = ChapterContentDto( id = "chapter-1", title = "童年的夏天", @@ -81,6 +82,6 @@ class ChapterReadingImageBlocksTest { composeRule.setContent { ChapterReadingView(chapter = chapter) } - composeRule.onNodeWithTag("memoir-image-error-0").assertIsDisplayed() + composeRule.onNodeWithTag("memoir-image-error-0").assertDoesNotExist() } } diff --git a/app-android/app/src/main/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocks.kt b/app-android/app/src/main/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocks.kt index d36265a..4a0b6fe 100644 --- a/app-android/app/src/main/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocks.kt +++ b/app-android/app/src/main/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocks.kt @@ -24,7 +24,7 @@ fun splitMemoirContent(content: String, images: List): List Column( - modifier = modifier - .fillMaxWidth() - .clip(RoundedCornerShape(16.dp)) - .background(LightPurple.copy(alpha = 0.10f)) - .padding(16.dp) - .testTag("memoir-image-error-${image.index}") - ) { - Text( - text = "图片生成失败", - fontSize = AppTypography.bodyMedium, - fontWeight = FontWeight.Medium, - color = SlatePurple, - ) - if (image.description.isNotBlank()) { - Text( - text = image.description, - fontSize = AppTypography.captionMedium, - color = SlatePurple.copy(alpha = 0.6f), - modifier = Modifier.padding(top = 4.dp), - ) - } - } else -> Unit } } diff --git a/app-android/app/src/test/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocksTest.kt b/app-android/app/src/test/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocksTest.kt index d840d24..d5c5b34 100644 --- a/app-android/app/src/test/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocksTest.kt +++ b/app-android/app/src/test/java/com/huaga/life_echo/ui/components/memoir/MemoirContentBlocksTest.kt @@ -2,6 +2,7 @@ package com.huaga.life_echo.ui.components.memoir import com.huaga.life_echo.network.models.ChapterImageDto import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue import org.junit.Test @@ -56,4 +57,33 @@ class MemoirContentBlocksTest { val text = (blocks[0] as MemoirContentBlock.Text).content assertTrue(!text.contains("IMAGE:")) } + + @Test + fun splitMemoirContent_skipsFailedImages_andRemovesTheirPlaceholders() { + val blocks = splitMemoirContent( + content = "开头。\n\n{{{{IMAGE:生成失败的图}}}}\n\n结尾。", + images = listOf( + ChapterImageDto( + index = 0, + placeholder = "{{{{IMAGE:生成失败的图}}}}", + description = "生成失败的图", + prompt = null, + url = null, + status = "failed", + provider = "liblib", + style = "watercolor", + size = "1024x1024", + error = "provider timeout", + created_at = null, + updated_at = null, + ) + ) + ) + + assertFalse(blocks.any { it is MemoirContentBlock.Image }) + val combinedText = blocks.filterIsInstance().joinToString("\n") { it.content } + assertFalse(combinedText.contains("IMAGE:")) + assertTrue(combinedText.contains("开头")) + assertTrue(combinedText.contains("结尾")) + } }