Fix memoir image prompt parsing and host allowlist

This commit is contained in:
Kevin
2026-03-11 13:18:20 +08:00
parent 1f98b8bfd6
commit 32954d4b3f
5 changed files with 100 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ from .settings import MemoirImageSettings
logger = logging.getLogger(__name__)
_CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]")
_MARKDOWN_JSON_FENCE_RE = re.compile(r"^\s*```(?:json)?\s*(.*?)\s*```\s*$", re.IGNORECASE | re.DOTALL)
class MemoirImagePromptService:
@@ -54,13 +55,15 @@ class MemoirImagePromptService:
}
if self.llm:
raw_response = None
try:
response = self.llm.invoke(
"Return JSON only with keys prompt, style, size. "
"Convert the memoir scene into an image-generation prompt.\n"
+ json.dumps(llm_input, ensure_ascii=False)
)
parsed = json.loads(response.content)
raw_response = response.content
parsed = json.loads(_extract_json_payload(raw_response))
return {
"prompt": _ensure_style_in_prompt(parsed["prompt"], parsed.get("style", style)),
"style": parsed.get("style", style),
@@ -118,6 +121,23 @@ def _contains_cjk(value: str) -> bool:
return bool(_CJK_RE.search(value or ""))
def _extract_json_payload(raw_response: str | None) -> str:
cleaned = (raw_response or "").strip()
fenced_match = _MARKDOWN_JSON_FENCE_RE.match(cleaned)
if fenced_match:
cleaned = fenced_match.group(1).strip()
if cleaned.startswith("{") and cleaned.endswith("}"):
return cleaned
start = cleaned.find("{")
end = cleaned.rfind("}")
if start != -1 and end != -1 and end > start:
return cleaned[start : end + 1].strip()
return cleaned
def _ensure_style_in_prompt(prompt: str, style: str) -> str:
cleaned_prompt = (prompt or "").strip()
cleaned_style = (style or "").strip()

View File

@@ -227,6 +227,8 @@ def _build_allowed_download_hosts(
default_hosts.add(base_hostname)
if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud":
default_hosts.add("liblibai.cloud")
# Liblib returns generated image downloads from *.liblib.cloud.
default_hosts.add("liblib.cloud")
return tuple(sorted(default_hosts.union(configured_hosts)))

View File

@@ -63,8 +63,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
def test_generate_chapter_images_retries_when_any_item_generation_fails(
self,
_acquire_lock_mock,
_release_lock_mock,
prompt_service_cls,
provider_cls,
storage_cls,
@@ -118,8 +122,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
def test_generate_chapter_images_marks_successful_item_completed(
self,
_acquire_lock_mock,
_release_lock_mock,
prompt_service_cls,
provider_cls,
storage_cls,
@@ -160,7 +168,9 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
"status": "completed",
"image_url": "https://provider.example.com/1.png",
}
provider_inst.download_image.return_value = b"png-bytes"
png_buffer = BytesIO()
Image.new("RGB", (1, 1), color="white").save(png_buffer, format="PNG")
provider_inst.download_image.return_value = png_buffer.getvalue()
storage_inst = storage_cls.from_env.return_value
storage_inst.upload_bytes.return_value = "https://cos.example.com/memoirs/u1/c1/0.png"
@@ -235,8 +245,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
def test_generate_chapter_images_converts_non_png_payload_before_upload(
self,
_acquire_lock_mock,
_release_lock_mock,
prompt_service_cls,
provider_cls,
storage_cls,
@@ -295,8 +309,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
def test_generate_chapter_images_skips_completed_items_for_idempotency(
self,
_acquire_lock_mock,
_release_lock_mock,
prompt_service_cls,
provider_cls,
storage_cls,

View File

@@ -62,6 +62,41 @@ class MemoirImagePromptingTest(unittest.TestCase):
self.assertEqual(result["style"], "watercolor")
self.assertEqual(result["size"], "1024x1024")
def test_prompt_service_parses_markdown_wrapped_json_response(self):
settings = MemoirImageSettings(
enabled=True,
max_per_chapter=2,
provider="liblib",
default_style="watercolor",
default_size="1024x1024",
poll_interval_seconds=3,
max_attempts=20,
liblib_template_uuid="tpl-uuid",
)
llm = Mock()
llm.invoke.return_value.content = """```json
{
"prompt": "A middle-aged teacher stands on the empty stage, realistic, cinematic lighting",
"style": "realistic",
"size": "1280x720"
}
```"""
service = MemoirImagePromptService(llm=llm, settings=settings)
result = service.build_prompt(
chapter_title="二十出头 · 在小镇讲台上种下第一粒种子",
chapter_category="career_early",
description="空荡荡的教室讲台前,一个年轻老师站着",
context_excerpt="第一次站上讲台,心里紧张又兴奋。",
)
self.assertEqual(
result["prompt"],
"A middle-aged teacher stands on the empty stage, realistic, cinematic lighting",
)
self.assertEqual(result["style"], "realistic")
self.assertEqual(result["size"], "1280x720")
@patch("api.services.memoir_images.prompting.logger")
def test_prompt_service_logs_warning_and_falls_back_when_llm_response_is_invalid(
self, logger_mock
@@ -89,4 +124,9 @@ class MemoirImagePromptingTest(unittest.TestCase):
self.assertIn("childhood memory", result["prompt"])
self.assertNotIn("奶奶坐在院子里的藤椅上", result["prompt"])
logger_mock.warning.assert_called_once()
logger_mock.warning.assert_called_once_with(
"图片 prompt 生成回退到默认模板: chapter_category=%s, title=%s, error=%s",
"childhood",
"童年的夏天",
unittest.mock.ANY,
)

View File

@@ -176,6 +176,23 @@ class PollUntilCompleteTest(unittest.TestCase):
class DownloadImageTest(unittest.TestCase):
def test_download_allows_liblib_tmp_image_host_by_default(self):
http_client = Mock()
resp = Mock()
resp.content = b"png-bytes"
resp.raise_for_status = Mock()
http_client.get.return_value = resp
provider = _make_provider(http_client)
payload = provider.download_image(
{"image_url": "https://liblibai-tmp-image.liblib.cloud/img/demo.png"}
)
self.assertEqual(payload, b"png-bytes")
http_client.get.assert_called_once_with(
"https://liblibai-tmp-image.liblib.cloud/img/demo.png"
)
@patch.dict(os.environ, {"MEMOIR_IMAGE_DOWNLOAD_HOSTS": "cdn.example.com"}, clear=False)
def test_download_fetches_binary_payload(self):
http_client = Mock()