fix: 修复 Liblib provider 认证和多个图片生成关键缺陷
- 重写 LiblibImageProvider:Bearer token 改为 HMAC-SHA1 签名认证, 适配 Liblib 真实 API(Star-3 Alpha 文生图端点) - 修复 chapter.images JSON 列原地修改不持久化(深拷贝+整列重赋值) - 修复 generate_chapter_images 在事务提交前派发(改为 commit 后统一 delay) - 修复 initialize_chapter_images 覆盖已完成图片(新增 merge 去重逻辑) - 修复 Android failed 图片渲染为错误卡片(改为隐藏,保持正文连续) - 模型模板 UUID 改为环境变量配置(LIBLIB_TEMPLATE_UUID) - 更新 .env 凭证格式为 ACCESS_KEY/SECRET_KEY - 补充 test_memoir_image_bootstrap 缺失的 unittest.mock 导入 Made-with: Cursor
This commit is contained in:
@@ -48,7 +48,7 @@ tencentcloud-sdk-python>=3.0.1000
|
||||
openai
|
||||
|
||||
# Tencent COS for memoir image storage
|
||||
cos-python-sdk-v5>=1.9.30
|
||||
cos-python-sdk-v5>=1.9.40
|
||||
|
||||
# Payment - WeChat Pay & Alipay
|
||||
wechatpayv3>=0.3.0
|
||||
|
||||
@@ -1,43 +1,141 @@
|
||||
import base64
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from hashlib import sha1
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SIZE_TO_ASPECT_RATIO = {
|
||||
"1024x1024": "square",
|
||||
"768x1024": "portrait",
|
||||
"1024x768": "landscape",
|
||||
"1280x720": "landscape",
|
||||
"720x1280": "portrait",
|
||||
}
|
||||
|
||||
|
||||
class LiblibImageProvider:
|
||||
"""Liblib (https://openapi.liblibai.cloud) image generation adapter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
http_client=None,
|
||||
api_key: str | None = None,
|
||||
http_client: httpx.Client | None = None,
|
||||
access_key: str | None = None,
|
||||
secret_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
template_uuid: str | None = None,
|
||||
):
|
||||
self.http_client = http_client or httpx.Client(timeout=60)
|
||||
self.api_key = api_key or os.getenv("LIBLIB_API_KEY", "")
|
||||
self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "")).rstrip("/")
|
||||
self.http_client = http_client or httpx.Client(timeout=120)
|
||||
self.access_key = access_key or os.getenv("LIBLIB_ACCESS_KEY", "")
|
||||
self.secret_key = secret_key or os.getenv("LIBLIB_SECRET_KEY", "")
|
||||
self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "https://openapi.liblibai.cloud")).rstrip("/")
|
||||
self.template_uuid = template_uuid or os.getenv(
|
||||
"LIBLIB_TEMPLATE_UUID", "5d7e67009b344550bc1aa6ccbfa1d7f4"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Signature helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _sign(self, uri: str) -> str:
|
||||
"""Build a full URL with Liblib HMAC-SHA1 query-string auth."""
|
||||
timestamp = str(int(time.time() * 1000))
|
||||
nonce = str(uuid.uuid4())
|
||||
content = "&".join((uri, timestamp, nonce))
|
||||
digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest()
|
||||
signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return (
|
||||
f"{self.base_url}{uri}"
|
||||
f"?AccessKey={self.access_key}"
|
||||
f"&Signature={signature}"
|
||||
f"&Timestamp={timestamp}"
|
||||
f"&SignatureNonce={nonce}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def submit_generation(self, prompt: str, size: str, style: str) -> dict:
|
||||
response = self.http_client.post(
|
||||
f"{self.base_url}/v1/images/generations",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={"prompt": prompt, "size": size, "style": style},
|
||||
)
|
||||
data = response.json()
|
||||
if data.get("image_url"):
|
||||
return {"status": "completed", "image_url": data["image_url"], "job_id": None}
|
||||
return {"status": "processing", "job_id": data.get("task_id"), "image_url": None}
|
||||
uri = "/api/generate/webui/text2img/ultra"
|
||||
url = self._sign(uri)
|
||||
|
||||
def poll_until_complete(self, job: dict, poll_interval_seconds: int, max_attempts: int) -> dict:
|
||||
for _ in range(max_attempts):
|
||||
response = self.http_client.get(
|
||||
f"{self.base_url}/v1/images/generations/{job['job_id']}",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size, "square")
|
||||
|
||||
body = {
|
||||
"templateUuid": self.template_uuid,
|
||||
"generateParams": {
|
||||
"prompt": prompt,
|
||||
"aspectRatio": aspect_ratio,
|
||||
"imgCount": 1,
|
||||
"steps": 30,
|
||||
},
|
||||
}
|
||||
response = self.http_client.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("code") != 0:
|
||||
raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}")
|
||||
|
||||
generate_uuid = data["data"]["generateUuid"]
|
||||
return {"status": "processing", "job_id": generate_uuid, "image_url": None}
|
||||
|
||||
def poll_until_complete(
|
||||
self, job: dict, poll_interval_seconds: int, max_attempts: int
|
||||
) -> dict:
|
||||
uri = "/api/generate/webui/status"
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
url = self._sign(uri)
|
||||
response = self.http_client.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"generateUuid": job["job_id"]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("image_url"):
|
||||
return {"status": "completed", "image_url": data["image_url"], "job_id": job["job_id"]}
|
||||
|
||||
if data.get("code") != 0:
|
||||
raise RuntimeError(f"Liblib status query failed: {data.get('msg', data)}")
|
||||
|
||||
result = data.get("data", {})
|
||||
status = result.get("generateStatus")
|
||||
|
||||
if status == 5: # success
|
||||
images = result.get("images") or []
|
||||
if images:
|
||||
return {
|
||||
"status": "completed",
|
||||
"image_url": images[0]["imageUrl"],
|
||||
"job_id": job["job_id"],
|
||||
}
|
||||
raise RuntimeError(f"Liblib returned success but no images for {job['job_id']}")
|
||||
|
||||
if status == 6: # failed
|
||||
raise RuntimeError(f"Liblib generation failed: {result.get('generateMsg', 'unknown')}")
|
||||
|
||||
if status == 7: # timeout on Liblib side
|
||||
raise TimeoutError(f"Liblib generation timed out on server side: {job['job_id']}")
|
||||
|
||||
logger.debug(
|
||||
"Liblib poll attempt %d/%d, status=%s, job=%s",
|
||||
attempt + 1, max_attempts, status, job["job_id"],
|
||||
)
|
||||
time.sleep(poll_interval_seconds)
|
||||
raise TimeoutError(f"Liblib image generation timed out for {job['job_id']}")
|
||||
|
||||
raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}")
|
||||
|
||||
def download_image(self, job: dict) -> bytes:
|
||||
response = self.http_client.get(job["image_url"])
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -11,6 +11,7 @@ class MemoirImageSettings:
|
||||
default_size: str
|
||||
poll_interval_seconds: int
|
||||
max_attempts: int
|
||||
liblib_template_uuid: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "MemoirImageSettings":
|
||||
@@ -21,5 +22,8 @@ class MemoirImageSettings:
|
||||
default_style=os.getenv("MEMOIR_IMAGE_STYLE_DEFAULT", "watercolor"),
|
||||
default_size=os.getenv("MEMOIR_IMAGE_SIZE_DEFAULT", "1024x1024"),
|
||||
poll_interval_seconds=int(os.getenv("MEMOIR_IMAGE_POLL_INTERVAL", "3")),
|
||||
max_attempts=int(os.getenv("MEMOIR_IMAGE_MAX_ATTEMPTS", "20")),
|
||||
max_attempts=int(os.getenv("MEMOIR_IMAGE_MAX_ATTEMPTS", "60")),
|
||||
liblib_template_uuid=os.getenv(
|
||||
"LIBLIB_TEMPLATE_UUID", "5d7e67009b344550bc1aa6ccbfa1d7f4"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -78,6 +78,54 @@ def _update_task_status_sync(user_id: str, task_id: str, status: str, result: Di
|
||||
except Exception as e:
|
||||
logger.error(f"更新任务状态失败: {e}")
|
||||
|
||||
|
||||
def _merge_chapter_image_assets(
|
||||
existing_images: list[dict] | None,
|
||||
placeholders: list[dict],
|
||||
provider: str,
|
||||
style: str,
|
||||
size: str,
|
||||
now_iso: str,
|
||||
) -> list[dict]:
|
||||
existing_by_placeholder = {
|
||||
item.get("placeholder"): dict(item)
|
||||
for item in (existing_images or [])
|
||||
if item.get("placeholder")
|
||||
}
|
||||
merged_assets: list[dict] = []
|
||||
|
||||
for item in placeholders:
|
||||
existing = existing_by_placeholder.get(item["placeholder"])
|
||||
if existing:
|
||||
merged_item = dict(existing)
|
||||
merged_item["index"] = item["index"]
|
||||
merged_item["placeholder"] = item["placeholder"]
|
||||
merged_item["description"] = item["description"]
|
||||
merged_item["provider"] = merged_item.get("provider") or provider
|
||||
merged_item["style"] = merged_item.get("style") or style
|
||||
merged_item["size"] = merged_item.get("size") or size
|
||||
merged_item["created_at"] = merged_item.get("created_at") or now_iso
|
||||
merged_item["updated_at"] = merged_item.get("updated_at") or now_iso
|
||||
if merged_item.get("status") == "completed" and not merged_item.get("url"):
|
||||
merged_item["status"] = "failed"
|
||||
merged_item["error"] = merged_item.get("error") or "missing image url"
|
||||
else:
|
||||
merged_item = build_initial_image_assets(
|
||||
placeholders=[item],
|
||||
provider=provider,
|
||||
style=style,
|
||||
size=size,
|
||||
now_iso=now_iso,
|
||||
)[0]
|
||||
merged_assets.append(merged_item)
|
||||
|
||||
return merged_assets
|
||||
|
||||
|
||||
def chapter_has_images_to_generate(images: list[dict] | None) -> bool:
|
||||
return any(item.get("status") in {"pending", "failed"} for item in (images or []))
|
||||
|
||||
|
||||
def initialize_chapter_images(chapter) -> list[dict]:
|
||||
"""Parse IMAGE placeholders from chapter content and build pending image assets."""
|
||||
settings = MemoirImageSettings.from_env()
|
||||
@@ -88,15 +136,14 @@ def initialize_chapter_images(chapter) -> list[dict]:
|
||||
prompt_service = MemoirImagePromptService(llm=None, settings=settings)
|
||||
placeholders = parse_image_placeholders(chapter.content, settings.max_per_chapter)
|
||||
style = prompt_service.CATEGORY_STYLE_MAP.get(chapter.category, settings.default_style)
|
||||
chapter.images = build_initial_image_assets(
|
||||
chapter.images = _merge_chapter_image_assets(
|
||||
existing_images=chapter.images,
|
||||
placeholders=placeholders,
|
||||
provider=settings.provider,
|
||||
style=style,
|
||||
size=settings.default_size,
|
||||
now_iso=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
if chapter.images:
|
||||
generate_chapter_images.delay(chapter.id)
|
||||
return chapter.images
|
||||
|
||||
|
||||
@@ -234,6 +281,7 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
chapters_to_enqueue: set[str] = set()
|
||||
# 获取段落
|
||||
stmt = select(Segment).where(Segment.id.in_(segment_ids))
|
||||
result = db.execute(stmt)
|
||||
@@ -401,6 +449,8 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
db.flush()
|
||||
|
||||
initialize_chapter_images(chapter)
|
||||
if chapter_has_images_to_generate(chapter.images):
|
||||
chapters_to_enqueue.add(chapter.id)
|
||||
|
||||
# 更新 Book
|
||||
stmt_book = select(Book).where(Book.user_id == user_id).order_by(Book.updated_at.desc())
|
||||
@@ -426,6 +476,13 @@ def process_memoir_segments(self, user_id: str, segment_ids: List[str]):
|
||||
seg.processed = True
|
||||
|
||||
db.commit()
|
||||
|
||||
for chapter_id in sorted(chapters_to_enqueue):
|
||||
try:
|
||||
generate_chapter_images.delay(chapter_id)
|
||||
except Exception as exc:
|
||||
logger.warning(f"补图任务派发失败: chapter={chapter_id}, error={exc}")
|
||||
|
||||
logger.info(f"回忆录处理完成: user_id={user_id}, task_id={task_id}")
|
||||
|
||||
# 更新任务状态为成功
|
||||
@@ -546,16 +603,21 @@ def generate_chapter_images(self, chapter_id: str):
|
||||
|
||||
settings = MemoirImageSettings.from_env()
|
||||
prompt_service = MemoirImagePromptService(llm_service.get_llm(), settings)
|
||||
provider = LiblibImageProvider()
|
||||
provider = LiblibImageProvider(template_uuid=settings.liblib_template_uuid)
|
||||
storage = TencentCosStorageService.from_env()
|
||||
images = [dict(item) for item in (chapter.images or [])]
|
||||
|
||||
for item in chapter.images:
|
||||
for index, item in enumerate(images):
|
||||
if item.get("status") == "completed" and item.get("url"):
|
||||
continue
|
||||
if item.get("status") not in {"pending", "failed"}:
|
||||
continue
|
||||
|
||||
item["status"] = "processing"
|
||||
current_item = dict(item)
|
||||
current_item["status"] = "processing"
|
||||
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
images[index] = current_item
|
||||
chapter.images = images
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
@@ -580,21 +642,23 @@ def generate_chapter_images(self, chapter_id: str):
|
||||
max_attempts=settings.max_attempts,
|
||||
)
|
||||
image_bytes = provider.download_image(job)
|
||||
key = build_cos_key(chapter.user_id, chapter.id, item["index"], prompt_data["prompt"])
|
||||
item["url"] = storage.upload_bytes(image_bytes, key, "image/png")
|
||||
item["prompt"] = prompt_data["prompt"]
|
||||
item["style"] = prompt_data["style"]
|
||||
item["size"] = prompt_data["size"]
|
||||
item["status"] = "completed"
|
||||
item["error"] = None
|
||||
key = build_cos_key(chapter.user_id, chapter.id, current_item["index"], prompt_data["prompt"])
|
||||
current_item["url"] = storage.upload_bytes(image_bytes, key, "image/png")
|
||||
current_item["prompt"] = prompt_data["prompt"]
|
||||
current_item["style"] = prompt_data["style"]
|
||||
current_item["size"] = prompt_data["size"]
|
||||
current_item["status"] = "completed"
|
||||
current_item["error"] = None
|
||||
except Exception as exc:
|
||||
item["status"] = "failed"
|
||||
item["error"] = str(exc)
|
||||
logger.warning(f"图片生成失败: chapter={chapter_id}, index={item.get('index')}, error={exc}")
|
||||
current_item["status"] = "failed"
|
||||
current_item["error"] = str(exc)
|
||||
logger.warning(f"图片生成失败: chapter={chapter_id}, index={current_item.get('index')}, error={exc}")
|
||||
|
||||
item["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
current_item["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
images[index] = current_item
|
||||
chapter.images = images
|
||||
db.commit()
|
||||
|
||||
db.commit()
|
||||
return {"status": "success"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import unittest.mock
|
||||
|
||||
from api.tasks.memoir_tasks import initialize_chapter_images
|
||||
|
||||
|
||||
class MemoirImageBootstrapTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks.generate_chapter_images.delay")
|
||||
def test_initialize_chapter_images_sets_pending_assets_and_enqueues_task(self, delay_mock):
|
||||
def test_initialize_chapter_images_sets_pending_assets_when_enabled(self):
|
||||
chapter = type(
|
||||
"ChapterStub",
|
||||
(),
|
||||
@@ -19,8 +19,50 @@ class MemoirImageBootstrapTest(unittest.TestCase):
|
||||
},
|
||||
)()
|
||||
|
||||
assets = initialize_chapter_images(chapter)
|
||||
with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False):
|
||||
assets = initialize_chapter_images(chapter)
|
||||
|
||||
self.assertEqual(len(assets), 1)
|
||||
self.assertEqual(assets[0]["status"], "pending")
|
||||
delay_mock.assert_called_once_with("chapter-1")
|
||||
|
||||
def test_initialize_chapter_images_preserves_completed_assets_and_adds_only_new_placeholders(self):
|
||||
chapter = type(
|
||||
"ChapterStub",
|
||||
(),
|
||||
{
|
||||
"id": "chapter-1",
|
||||
"title": "童年的夏天",
|
||||
"category": "childhood",
|
||||
"content": (
|
||||
"那条路我一直记得。\n\n"
|
||||
"{{{{IMAGE:南方小镇的青石板路}}}}\n\n"
|
||||
"奶奶总坐在门口。\n\n"
|
||||
"{{{{IMAGE:奶奶坐在院子里的藤椅上}}}}"
|
||||
),
|
||||
"images": [
|
||||
{
|
||||
"index": 0,
|
||||
"placeholder": "{{{{IMAGE:南方小镇的青石板路}}}}",
|
||||
"description": "南方小镇的青石板路",
|
||||
"prompt": "A serene southern China town",
|
||||
"url": "https://cos.example.com/existing.png",
|
||||
"status": "completed",
|
||||
"provider": "liblib",
|
||||
"style": "watercolor",
|
||||
"size": "1024x1024",
|
||||
"error": None,
|
||||
"created_at": "2026-03-10T10:00:00Z",
|
||||
"updated_at": "2026-03-10T10:00:00Z",
|
||||
}
|
||||
],
|
||||
},
|
||||
)()
|
||||
|
||||
with unittest.mock.patch.dict(os.environ, {"MEMOIR_IMAGE_ENABLED": "true"}, clear=False):
|
||||
assets = initialize_chapter_images(chapter)
|
||||
|
||||
self.assertEqual(len(assets), 2)
|
||||
self.assertEqual(assets[0]["status"], "completed")
|
||||
self.assertEqual(assets[0]["url"], "https://cos.example.com/existing.png")
|
||||
self.assertEqual(assets[1]["status"], "pending")
|
||||
self.assertEqual(assets[1]["description"], "奶奶坐在院子里的藤椅上")
|
||||
|
||||
@@ -1,58 +1,144 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from api.services.memoir_images.provider import LiblibImageProvider
|
||||
|
||||
|
||||
class MemoirImageProviderTest(unittest.TestCase):
|
||||
def test_submit_generation_handles_sync_provider_response(self):
|
||||
def _make_provider(http_client=None):
|
||||
return LiblibImageProvider(
|
||||
http_client=http_client or Mock(),
|
||||
access_key="test-ak",
|
||||
secret_key="test-sk",
|
||||
base_url="https://openapi.liblibai.cloud",
|
||||
template_uuid="tpl-uuid",
|
||||
)
|
||||
|
||||
|
||||
class LiblibSignatureTest(unittest.TestCase):
|
||||
def test_sign_returns_url_with_auth_params(self):
|
||||
provider = _make_provider()
|
||||
url = provider._sign("/api/generate/webui/text2img/ultra")
|
||||
self.assertIn("AccessKey=test-ak", url)
|
||||
self.assertIn("Signature=", url)
|
||||
self.assertIn("Timestamp=", url)
|
||||
self.assertIn("SignatureNonce=", url)
|
||||
self.assertTrue(url.startswith("https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra?"))
|
||||
|
||||
|
||||
class SubmitGenerationTest(unittest.TestCase):
|
||||
def test_submit_returns_processing_with_job_id(self):
|
||||
http_client = Mock()
|
||||
http_client.post.return_value.json.return_value = {
|
||||
"status": "succeeded",
|
||||
"image_url": "https://provider.example.com/1.png",
|
||||
resp = Mock()
|
||||
resp.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {"generateUuid": "uuid-abc"},
|
||||
}
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
resp.raise_for_status = Mock()
|
||||
http_client.post.return_value = resp
|
||||
|
||||
job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor")
|
||||
|
||||
self.assertEqual(job["status"], "completed")
|
||||
self.assertEqual(job["image_url"], "https://provider.example.com/1.png")
|
||||
|
||||
def test_submit_generation_handles_async_provider_response(self):
|
||||
http_client = Mock()
|
||||
http_client.post.return_value.json.return_value = {
|
||||
"task_id": "job-123",
|
||||
"status": "queued",
|
||||
}
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
|
||||
job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor")
|
||||
provider = _make_provider(http_client)
|
||||
job = provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor")
|
||||
|
||||
self.assertEqual(job["status"], "processing")
|
||||
self.assertEqual(job["job_id"], "job-123")
|
||||
self.assertEqual(job["job_id"], "uuid-abc")
|
||||
self.assertIsNone(job["image_url"])
|
||||
|
||||
def test_poll_until_complete_returns_completed_job(self):
|
||||
call_kwargs = http_client.post.call_args
|
||||
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
||||
self.assertEqual(body["templateUuid"], "tpl-uuid")
|
||||
self.assertEqual(body["generateParams"]["prompt"], "a cat")
|
||||
self.assertEqual(body["generateParams"]["aspectRatio"], "square")
|
||||
|
||||
def test_submit_raises_on_error_code(self):
|
||||
http_client = Mock()
|
||||
http_client.get.return_value.json.side_effect = [
|
||||
{"status": "queued"},
|
||||
{"status": "succeeded", "image_url": "https://provider.example.com/1.png"},
|
||||
]
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
resp = Mock()
|
||||
resp.json.return_value = {"code": 100000, "msg": "param error"}
|
||||
resp.raise_for_status = Mock()
|
||||
http_client.post.return_value = resp
|
||||
|
||||
provider = _make_provider(http_client)
|
||||
with self.assertRaises(RuntimeError):
|
||||
provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor")
|
||||
|
||||
|
||||
class PollUntilCompleteTest(unittest.TestCase):
|
||||
def test_returns_completed_on_status_5(self):
|
||||
http_client = Mock()
|
||||
pending_resp = Mock()
|
||||
pending_resp.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {"generateStatus": 2, "images": []},
|
||||
}
|
||||
pending_resp.raise_for_status = Mock()
|
||||
|
||||
success_resp = Mock()
|
||||
success_resp.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {
|
||||
"generateStatus": 5,
|
||||
"images": [{"imageUrl": "https://cdn.example.com/1.png", "auditStatus": 3}],
|
||||
},
|
||||
}
|
||||
success_resp.raise_for_status = Mock()
|
||||
http_client.post.side_effect = [pending_resp, success_resp]
|
||||
|
||||
provider = _make_provider(http_client)
|
||||
job = provider.poll_until_complete(
|
||||
{"status": "processing", "job_id": "job-123"},
|
||||
{"status": "processing", "job_id": "uuid-abc"},
|
||||
poll_interval_seconds=0,
|
||||
max_attempts=2,
|
||||
max_attempts=3,
|
||||
)
|
||||
|
||||
self.assertEqual(job["status"], "completed")
|
||||
self.assertEqual(job["image_url"], "https://provider.example.com/1.png")
|
||||
self.assertEqual(job["image_url"], "https://cdn.example.com/1.png")
|
||||
self.assertEqual(job["job_id"], "uuid-abc")
|
||||
|
||||
def test_download_image_fetches_binary_payload(self):
|
||||
def test_raises_on_status_6_failure(self):
|
||||
http_client = Mock()
|
||||
http_client.get.return_value.content = b"png-bytes"
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
resp = Mock()
|
||||
resp.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {"generateStatus": 6, "generateMsg": "content violation"},
|
||||
}
|
||||
resp.raise_for_status = Mock()
|
||||
http_client.post.return_value = resp
|
||||
|
||||
payload = provider.download_image({"image_url": "https://provider.example.com/1.png"})
|
||||
provider = _make_provider(http_client)
|
||||
with self.assertRaises(RuntimeError, msg="content violation"):
|
||||
provider.poll_until_complete(
|
||||
{"status": "processing", "job_id": "uuid-abc"},
|
||||
poll_interval_seconds=0,
|
||||
max_attempts=2,
|
||||
)
|
||||
|
||||
def test_raises_timeout_after_max_attempts(self):
|
||||
http_client = Mock()
|
||||
resp = Mock()
|
||||
resp.json.return_value = {
|
||||
"code": 0,
|
||||
"data": {"generateStatus": 2, "images": []},
|
||||
}
|
||||
resp.raise_for_status = Mock()
|
||||
http_client.post.return_value = resp
|
||||
|
||||
provider = _make_provider(http_client)
|
||||
with self.assertRaises(TimeoutError):
|
||||
provider.poll_until_complete(
|
||||
{"status": "processing", "job_id": "uuid-abc"},
|
||||
poll_interval_seconds=0,
|
||||
max_attempts=2,
|
||||
)
|
||||
|
||||
|
||||
class DownloadImageTest(unittest.TestCase):
|
||||
def test_download_fetches_binary_payload(self):
|
||||
http_client = Mock()
|
||||
resp = Mock()
|
||||
resp.content = b"png-bytes"
|
||||
resp.raise_for_status = Mock()
|
||||
http_client.get.return_value = resp
|
||||
|
||||
provider = _make_provider(http_client)
|
||||
payload = provider.download_image({"image_url": "https://cdn.example.com/1.png"})
|
||||
|
||||
self.assertEqual(payload, b"png-bytes")
|
||||
|
||||
74
api/tests/test_process_memoir_segments_image_enqueue.py
Normal file
74
api/tests/test_process_memoir_segments_image_enqueue.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from api.tasks.memoir_tasks import MemoirImageSettings, process_memoir_segments
|
||||
|
||||
|
||||
class ProcessMemoirSegmentsImageEnqueueTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks._update_task_status_sync")
|
||||
@patch("api.tasks.memoir_tasks._release_chapter_lock")
|
||||
@patch("api.tasks.memoir_tasks._acquire_chapter_lock", return_value=True)
|
||||
@patch("api.tasks.memoir_tasks._classify_chapter_category", return_value="childhood")
|
||||
@patch("api.tasks.memoir_tasks._get_or_create_state_sync")
|
||||
@patch("api.tasks.memoir_tasks.llm_service.get_llm", return_value=None)
|
||||
@patch("api.tasks.memoir_tasks.generate_chapter_images.delay")
|
||||
@patch("api.tasks.memoir_tasks.SessionLocal")
|
||||
@patch("api.tasks.memoir_tasks.MemoirImageSettings.from_env")
|
||||
def test_process_memoir_segments_enqueues_image_jobs_after_commit(
|
||||
self,
|
||||
settings_from_env,
|
||||
session_local_cls,
|
||||
delay_mock,
|
||||
_get_llm,
|
||||
get_state_mock,
|
||||
_classify_mock,
|
||||
_acquire_lock_mock,
|
||||
_release_lock_mock,
|
||||
_update_status_mock,
|
||||
):
|
||||
settings_from_env.return_value = MemoirImageSettings(
|
||||
enabled=True,
|
||||
max_per_chapter=2,
|
||||
provider="liblib",
|
||||
default_style="watercolor",
|
||||
default_size="1024x1024",
|
||||
poll_interval_seconds=3,
|
||||
max_attempts=20,
|
||||
)
|
||||
get_state_mock.return_value = SimpleNamespace(current_stage="childhood", slots={})
|
||||
|
||||
segment = SimpleNamespace(
|
||||
id="segment-1",
|
||||
transcript_text="那条路我一直记得。\n\n{{{{IMAGE:南方小镇的青石板路}}}}",
|
||||
processed=False,
|
||||
)
|
||||
|
||||
segments_result = Mock()
|
||||
segments_result.scalars.return_value.all.return_value = [segment]
|
||||
|
||||
chapter_result = Mock()
|
||||
chapter_result.scalar_one_or_none.return_value = None
|
||||
|
||||
book_result = Mock()
|
||||
book_result.scalar_one_or_none.return_value = None
|
||||
|
||||
db = Mock()
|
||||
db.execute.side_effect = [segments_result, chapter_result, book_result]
|
||||
db.get.return_value = None
|
||||
session_local_cls.return_value = db
|
||||
|
||||
events: list[str] = []
|
||||
db.commit.side_effect = lambda: events.append("commit")
|
||||
delay_mock.side_effect = lambda chapter_id: events.append(f"delay:{chapter_id}")
|
||||
|
||||
task_self = SimpleNamespace(
|
||||
request=SimpleNamespace(id="task-1"),
|
||||
retry=Mock(side_effect=AssertionError("retry should not be called")),
|
||||
)
|
||||
process_memoir_segments.run.__func__(task_self, "user-1", ["segment-1"])
|
||||
|
||||
self.assertIn("commit", events)
|
||||
delay_events = [event for event in events if event.startswith("delay:")]
|
||||
self.assertEqual(len(delay_events), 1)
|
||||
self.assertGreater(events.index(delay_events[0]), events.index("commit"))
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.huaga.life_echo.ui.components.memoir
|
||||
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.compose.ui.test.assertDoesNotExist
|
||||
import androidx.compose.ui.test.assertIsDisplayed
|
||||
import androidx.compose.ui.test.junit4.createAndroidComposeRule
|
||||
import androidx.compose.ui.test.onNodeWithTag
|
||||
@@ -50,7 +51,7 @@ class ChapterReadingImageBlocksTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun chapterReadingView_showsFailureCard_forFailedImage_withoutRawPlaceholderText() {
|
||||
fun chapterReadingView_hidesFailedImageBlock() {
|
||||
val chapter = ChapterContentDto(
|
||||
id = "chapter-1",
|
||||
title = "童年的夏天",
|
||||
@@ -81,6 +82,6 @@ class ChapterReadingImageBlocksTest {
|
||||
|
||||
composeRule.setContent { ChapterReadingView(chapter = chapter) }
|
||||
|
||||
composeRule.onNodeWithTag("memoir-image-error-0").assertIsDisplayed()
|
||||
composeRule.onNodeWithTag("memoir-image-error-0").assertDoesNotExist()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ fun splitMemoirContent(content: String, images: List<ChapterImageDto>): List<Mem
|
||||
if (before.isNotBlank()) blocks += MemoirContentBlock.Text(before)
|
||||
if (image.status == "completed" && !image.url.isNullOrBlank()) {
|
||||
blocks += MemoirContentBlock.Image(image)
|
||||
} else if (image.status in listOf("pending", "processing", "failed")) {
|
||||
} else if (image.status == "pending" || image.status == "processing") {
|
||||
blocks += MemoirContentBlock.Image(image)
|
||||
}
|
||||
remaining = parts.getOrElse(1) { "" }
|
||||
|
||||
@@ -8,10 +8,8 @@ import androidx.compose.animation.core.tween
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.clickable
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.shape.RoundedCornerShape
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.runtime.Composable
|
||||
@@ -21,7 +19,6 @@ import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.draw.clip
|
||||
import androidx.compose.ui.layout.ContentScale
|
||||
import androidx.compose.ui.platform.testTag
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.unit.dp
|
||||
import coil.compose.AsyncImage
|
||||
import com.huaga.life_echo.network.models.ChapterImageDto
|
||||
@@ -73,29 +70,6 @@ fun MemoirInlineImage(
|
||||
)
|
||||
}
|
||||
}
|
||||
"failed" -> Column(
|
||||
modifier = modifier
|
||||
.fillMaxWidth()
|
||||
.clip(RoundedCornerShape(16.dp))
|
||||
.background(LightPurple.copy(alpha = 0.10f))
|
||||
.padding(16.dp)
|
||||
.testTag("memoir-image-error-${image.index}")
|
||||
) {
|
||||
Text(
|
||||
text = "图片生成失败",
|
||||
fontSize = AppTypography.bodyMedium,
|
||||
fontWeight = FontWeight.Medium,
|
||||
color = SlatePurple,
|
||||
)
|
||||
if (image.description.isNotBlank()) {
|
||||
Text(
|
||||
text = image.description,
|
||||
fontSize = AppTypography.captionMedium,
|
||||
color = SlatePurple.copy(alpha = 0.6f),
|
||||
modifier = Modifier.padding(top = 4.dp),
|
||||
)
|
||||
}
|
||||
}
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.huaga.life_echo.ui.components.memoir
|
||||
|
||||
import com.huaga.life_echo.network.models.ChapterImageDto
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Assert.assertFalse
|
||||
import org.junit.Assert.assertTrue
|
||||
import org.junit.Test
|
||||
|
||||
@@ -56,4 +57,33 @@ class MemoirContentBlocksTest {
|
||||
val text = (blocks[0] as MemoirContentBlock.Text).content
|
||||
assertTrue(!text.contains("IMAGE:"))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun splitMemoirContent_skipsFailedImages_andRemovesTheirPlaceholders() {
|
||||
val blocks = splitMemoirContent(
|
||||
content = "开头。\n\n{{{{IMAGE:生成失败的图}}}}\n\n结尾。",
|
||||
images = listOf(
|
||||
ChapterImageDto(
|
||||
index = 0,
|
||||
placeholder = "{{{{IMAGE:生成失败的图}}}}",
|
||||
description = "生成失败的图",
|
||||
prompt = null,
|
||||
url = null,
|
||||
status = "failed",
|
||||
provider = "liblib",
|
||||
style = "watercolor",
|
||||
size = "1024x1024",
|
||||
error = "provider timeout",
|
||||
created_at = null,
|
||||
updated_at = null,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
assertFalse(blocks.any { it is MemoirContentBlock.Image })
|
||||
val combinedText = blocks.filterIsInstance<MemoirContentBlock.Text>().joinToString("\n") { it.content }
|
||||
assertFalse(combinedText.contains("IMAGE:"))
|
||||
assertTrue(combinedText.contains("开头"))
|
||||
assertTrue(combinedText.contains("结尾"))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user