Fix memoir image prompt parsing and host allowlist
This commit is contained in:
@@ -7,6 +7,7 @@ from .settings import MemoirImageSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]")
|
||||
_MARKDOWN_JSON_FENCE_RE = re.compile(r"^\s*```(?:json)?\s*(.*?)\s*```\s*$", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
class MemoirImagePromptService:
|
||||
@@ -54,13 +55,15 @@ class MemoirImagePromptService:
|
||||
}
|
||||
|
||||
if self.llm:
|
||||
raw_response = None
|
||||
try:
|
||||
response = self.llm.invoke(
|
||||
"Return JSON only with keys prompt, style, size. "
|
||||
"Convert the memoir scene into an image-generation prompt.\n"
|
||||
+ json.dumps(llm_input, ensure_ascii=False)
|
||||
)
|
||||
parsed = json.loads(response.content)
|
||||
raw_response = response.content
|
||||
parsed = json.loads(_extract_json_payload(raw_response))
|
||||
return {
|
||||
"prompt": _ensure_style_in_prompt(parsed["prompt"], parsed.get("style", style)),
|
||||
"style": parsed.get("style", style),
|
||||
@@ -118,6 +121,23 @@ def _contains_cjk(value: str) -> bool:
|
||||
return bool(_CJK_RE.search(value or ""))
|
||||
|
||||
|
||||
def _extract_json_payload(raw_response: str | None) -> str:
|
||||
cleaned = (raw_response or "").strip()
|
||||
fenced_match = _MARKDOWN_JSON_FENCE_RE.match(cleaned)
|
||||
if fenced_match:
|
||||
cleaned = fenced_match.group(1).strip()
|
||||
|
||||
if cleaned.startswith("{") and cleaned.endswith("}"):
|
||||
return cleaned
|
||||
|
||||
start = cleaned.find("{")
|
||||
end = cleaned.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return cleaned[start : end + 1].strip()
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def _ensure_style_in_prompt(prompt: str, style: str) -> str:
|
||||
cleaned_prompt = (prompt or "").strip()
|
||||
cleaned_style = (style or "").strip()
|
||||
|
||||
@@ -227,6 +227,8 @@ def _build_allowed_download_hosts(
|
||||
default_hosts.add(base_hostname)
|
||||
if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud":
|
||||
default_hosts.add("liblibai.cloud")
|
||||
# Liblib returns generated image downloads from *.liblib.cloud.
|
||||
default_hosts.add("liblib.cloud")
|
||||
|
||||
return tuple(sorted(default_hosts.union(configured_hosts)))
|
||||
|
||||
|
||||
@@ -63,8 +63,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
|
||||
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
|
||||
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
|
||||
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
|
||||
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
|
||||
def test_generate_chapter_images_retries_when_any_item_generation_fails(
|
||||
self,
|
||||
_acquire_lock_mock,
|
||||
_release_lock_mock,
|
||||
prompt_service_cls,
|
||||
provider_cls,
|
||||
storage_cls,
|
||||
@@ -118,8 +122,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
|
||||
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
|
||||
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
|
||||
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
|
||||
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
|
||||
def test_generate_chapter_images_marks_successful_item_completed(
|
||||
self,
|
||||
_acquire_lock_mock,
|
||||
_release_lock_mock,
|
||||
prompt_service_cls,
|
||||
provider_cls,
|
||||
storage_cls,
|
||||
@@ -160,7 +168,9 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
|
||||
"status": "completed",
|
||||
"image_url": "https://provider.example.com/1.png",
|
||||
}
|
||||
provider_inst.download_image.return_value = b"png-bytes"
|
||||
png_buffer = BytesIO()
|
||||
Image.new("RGB", (1, 1), color="white").save(png_buffer, format="PNG")
|
||||
provider_inst.download_image.return_value = png_buffer.getvalue()
|
||||
storage_inst = storage_cls.from_env.return_value
|
||||
storage_inst.upload_bytes.return_value = "https://cos.example.com/memoirs/u1/c1/0.png"
|
||||
|
||||
@@ -235,8 +245,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
|
||||
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
|
||||
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
|
||||
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
|
||||
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
|
||||
def test_generate_chapter_images_converts_non_png_payload_before_upload(
|
||||
self,
|
||||
_acquire_lock_mock,
|
||||
_release_lock_mock,
|
||||
prompt_service_cls,
|
||||
provider_cls,
|
||||
storage_cls,
|
||||
@@ -295,8 +309,12 @@ class GenerateChapterImagesTaskTest(unittest.TestCase):
|
||||
@patch("api.tasks.memoir_tasks.TencentCosStorageService")
|
||||
@patch("api.tasks.memoir_tasks.LiblibImageProvider")
|
||||
@patch("api.tasks.memoir_tasks.MemoirImagePromptService")
|
||||
@patch("api.tasks.memoir_tasks._release_chapter_image_lock")
|
||||
@patch("api.tasks.memoir_tasks._acquire_chapter_image_lock", return_value=True)
|
||||
def test_generate_chapter_images_skips_completed_items_for_idempotency(
|
||||
self,
|
||||
_acquire_lock_mock,
|
||||
_release_lock_mock,
|
||||
prompt_service_cls,
|
||||
provider_cls,
|
||||
storage_cls,
|
||||
|
||||
@@ -62,6 +62,41 @@ class MemoirImagePromptingTest(unittest.TestCase):
|
||||
self.assertEqual(result["style"], "watercolor")
|
||||
self.assertEqual(result["size"], "1024x1024")
|
||||
|
||||
def test_prompt_service_parses_markdown_wrapped_json_response(self):
|
||||
settings = MemoirImageSettings(
|
||||
enabled=True,
|
||||
max_per_chapter=2,
|
||||
provider="liblib",
|
||||
default_style="watercolor",
|
||||
default_size="1024x1024",
|
||||
poll_interval_seconds=3,
|
||||
max_attempts=20,
|
||||
liblib_template_uuid="tpl-uuid",
|
||||
)
|
||||
llm = Mock()
|
||||
llm.invoke.return_value.content = """```json
|
||||
{
|
||||
"prompt": "A middle-aged teacher stands on the empty stage, realistic, cinematic lighting",
|
||||
"style": "realistic",
|
||||
"size": "1280x720"
|
||||
}
|
||||
```"""
|
||||
service = MemoirImagePromptService(llm=llm, settings=settings)
|
||||
|
||||
result = service.build_prompt(
|
||||
chapter_title="二十出头 · 在小镇讲台上种下第一粒种子",
|
||||
chapter_category="career_early",
|
||||
description="空荡荡的教室讲台前,一个年轻老师站着",
|
||||
context_excerpt="第一次站上讲台,心里紧张又兴奋。",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result["prompt"],
|
||||
"A middle-aged teacher stands on the empty stage, realistic, cinematic lighting",
|
||||
)
|
||||
self.assertEqual(result["style"], "realistic")
|
||||
self.assertEqual(result["size"], "1280x720")
|
||||
|
||||
@patch("api.services.memoir_images.prompting.logger")
|
||||
def test_prompt_service_logs_warning_and_falls_back_when_llm_response_is_invalid(
|
||||
self, logger_mock
|
||||
@@ -89,4 +124,9 @@ class MemoirImagePromptingTest(unittest.TestCase):
|
||||
|
||||
self.assertIn("childhood memory", result["prompt"])
|
||||
self.assertNotIn("奶奶坐在院子里的藤椅上", result["prompt"])
|
||||
logger_mock.warning.assert_called_once()
|
||||
logger_mock.warning.assert_called_once_with(
|
||||
"图片 prompt 生成回退到默认模板: chapter_category=%s, title=%s, error=%s",
|
||||
"childhood",
|
||||
"童年的夏天",
|
||||
unittest.mock.ANY,
|
||||
)
|
||||
|
||||
@@ -176,6 +176,23 @@ class PollUntilCompleteTest(unittest.TestCase):
|
||||
|
||||
|
||||
class DownloadImageTest(unittest.TestCase):
|
||||
def test_download_allows_liblib_tmp_image_host_by_default(self):
|
||||
http_client = Mock()
|
||||
resp = Mock()
|
||||
resp.content = b"png-bytes"
|
||||
resp.raise_for_status = Mock()
|
||||
http_client.get.return_value = resp
|
||||
|
||||
provider = _make_provider(http_client)
|
||||
payload = provider.download_image(
|
||||
{"image_url": "https://liblibai-tmp-image.liblib.cloud/img/demo.png"}
|
||||
)
|
||||
|
||||
self.assertEqual(payload, b"png-bytes")
|
||||
http_client.get.assert_called_once_with(
|
||||
"https://liblibai-tmp-image.liblib.cloud/img/demo.png"
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"MEMOIR_IMAGE_DOWNLOAD_HOSTS": "cdn.example.com"}, clear=False)
|
||||
def test_download_fetches_binary_payload(self):
|
||||
http_client = Mock()
|
||||
|
||||
Reference in New Issue
Block a user