fix: harden memoir image generation flow

This commit is contained in:
Kevin
2026-03-11 11:26:42 +08:00
parent a76cf8da18
commit 00092d34c9
14 changed files with 1162 additions and 69 deletions

View File

@@ -1,7 +1,9 @@
import os
import unittest
from unittest.mock import Mock, patch
from api.services.memoir_images.provider import LiblibImageProvider
from api.services.memoir_images.settings import DEFAULT_LIBLIB_TEMPLATE_UUID
def _make_provider(http_client=None):
@@ -15,17 +17,40 @@ def _make_provider(http_client=None):
class LiblibSignatureTest(unittest.TestCase):
def test_sign_returns_url_with_auth_params(self):
def test_sign_returns_auth_params(self):
provider = _make_provider()
url = provider._sign("/api/generate/webui/text2img/ultra")
self.assertIn("AccessKey=test-ak", url)
self.assertIn("Signature=", url)
self.assertIn("Timestamp=", url)
self.assertIn("SignatureNonce=", url)
self.assertTrue(url.startswith("https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra?"))
params = provider._sign("/api/generate/webui/text2img/ultra")
self.assertEqual(params["AccessKey"], "test-ak")
self.assertIn("Signature", params)
self.assertIn("Timestamp", params)
self.assertIn("SignatureNonce", params)
class SubmitGenerationTest(unittest.TestCase):
def test_submit_keeps_auth_params_out_of_url_string(self):
http_client = Mock()
resp = Mock()
resp.json.return_value = {
"code": 0,
"data": {"generateUuid": "uuid-abc"},
}
resp.raise_for_status = Mock()
http_client.post.return_value = resp
provider = _make_provider(http_client)
provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor")
call_kwargs = http_client.post.call_args
url = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs["url"]
params = call_kwargs.kwargs.get("params")
self.assertEqual(url, "https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra")
self.assertNotIn("AccessKey=", url)
self.assertEqual(params["AccessKey"], "test-ak")
self.assertIn("Signature", params)
self.assertIn("Timestamp", params)
self.assertIn("SignatureNonce", params)
def test_submit_returns_processing_with_job_id(self):
http_client = Mock()
resp = Mock()
@@ -46,9 +71,29 @@ class SubmitGenerationTest(unittest.TestCase):
call_kwargs = http_client.post.call_args
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
self.assertEqual(body["templateUuid"], "tpl-uuid")
self.assertEqual(body["generateParams"]["prompt"], "a cat")
self.assertIn("a cat", body["generateParams"]["prompt"])
self.assertIn("watercolor", body["generateParams"]["prompt"].lower())
self.assertEqual(body["generateParams"]["aspectRatio"], "square")
def test_submit_applies_style_when_prompt_does_not_include_it(self):
http_client = Mock()
resp = Mock()
resp.json.return_value = {
"code": 0,
"data": {"generateUuid": "uuid-abc"},
}
resp.raise_for_status = Mock()
http_client.post.return_value = resp
provider = _make_provider(http_client)
provider.submit_generation(prompt="a cat under the rain", size="1024x1024", style="watercolor")
call_kwargs = http_client.post.call_args
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
self.assertIn("a cat under the rain", body["generateParams"]["prompt"])
self.assertIn("watercolor", body["generateParams"]["prompt"].lower())
def test_submit_raises_on_error_code(self):
http_client = Mock()
resp = Mock()
@@ -131,6 +176,7 @@ class PollUntilCompleteTest(unittest.TestCase):
class DownloadImageTest(unittest.TestCase):
@patch.dict(os.environ, {"MEMOIR_IMAGE_DOWNLOAD_HOSTS": "cdn.example.com"}, clear=False)
def test_download_fetches_binary_payload(self):
http_client = Mock()
resp = Mock()
@@ -142,3 +188,57 @@ class DownloadImageTest(unittest.TestCase):
payload = provider.download_image({"image_url": "https://cdn.example.com/1.png"})
self.assertEqual(payload, b"png-bytes")
def test_download_rejects_unapproved_host(self):
http_client = Mock()
provider = _make_provider(http_client)
with self.assertRaises(ValueError):
provider.download_image({"image_url": "https://evil.example.com/1.png"})
http_client.get.assert_not_called()
class ProviderResourceManagementTest(unittest.TestCase):
@patch("api.services.memoir_images.provider.httpx.Client")
def test_provider_closes_owned_http_client(self, httpx_client_cls):
http_client = Mock()
httpx_client_cls.return_value = http_client
provider = LiblibImageProvider(
access_key="test-ak",
secret_key="test-sk",
base_url="https://openapi.liblibai.cloud",
template_uuid="tpl-uuid",
)
provider.close()
http_client.close.assert_called_once()
def test_provider_does_not_close_injected_http_client(self):
http_client = Mock()
provider = LiblibImageProvider(
http_client=http_client,
access_key="test-ak",
secret_key="test-sk",
base_url="https://openapi.liblibai.cloud",
template_uuid="tpl-uuid",
)
provider.close()
http_client.close.assert_not_called()
class ProviderDefaultsTest(unittest.TestCase):
@patch.dict(os.environ, {"LIBLIB_TEMPLATE_UUID": ""}, clear=False)
def test_provider_uses_shared_template_uuid_default(self):
provider = LiblibImageProvider(
http_client=Mock(),
access_key="test-ak",
secret_key="test-sk",
base_url="https://openapi.liblibai.cloud",
)
self.assertEqual(provider.template_uuid, DEFAULT_LIBLIB_TEMPLATE_UUID)