280 lines
9.5 KiB
Python
280 lines
9.5 KiB
Python
import os
|
|
import unittest
|
|
from unittest.mock import Mock, patch
|
|
|
|
from app.features.memoir.memoir_images.provider import LiblibImageProvider
|
|
from app.features.memoir.memoir_images.settings import DEFAULT_LIBLIB_TEMPLATE_UUID
|
|
|
|
|
|
def _make_provider(http_client=None):
|
|
return LiblibImageProvider(
|
|
http_client=http_client or Mock(),
|
|
access_key="test-ak",
|
|
secret_key="test-sk",
|
|
base_url="https://openapi.liblibai.cloud",
|
|
template_uuid="tpl-uuid",
|
|
)
|
|
|
|
|
|
class LiblibSignatureTest(unittest.TestCase):
|
|
def test_sign_returns_auth_params(self):
|
|
provider = _make_provider()
|
|
params = provider._sign("/api/generate/webui/text2img/ultra")
|
|
self.assertEqual(params["AccessKey"], "test-ak")
|
|
self.assertIn("Signature", params)
|
|
self.assertIn("Timestamp", params)
|
|
self.assertIn("SignatureNonce", params)
|
|
|
|
|
|
class SubmitGenerationTest(unittest.TestCase):
|
|
def test_submit_keeps_auth_params_out_of_url_string(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {"generateUuid": "uuid-abc"},
|
|
}
|
|
resp.raise_for_status = Mock()
|
|
http_client.post.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor")
|
|
|
|
call_kwargs = http_client.post.call_args
|
|
url = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs["url"]
|
|
params = call_kwargs.kwargs.get("params")
|
|
|
|
self.assertEqual(
|
|
url, "https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra"
|
|
)
|
|
self.assertNotIn("AccessKey=", url)
|
|
self.assertEqual(params["AccessKey"], "test-ak")
|
|
self.assertIn("Signature", params)
|
|
self.assertIn("Timestamp", params)
|
|
self.assertIn("SignatureNonce", params)
|
|
|
|
def test_submit_returns_processing_with_job_id(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {"generateUuid": "uuid-abc"},
|
|
}
|
|
resp.raise_for_status = Mock()
|
|
http_client.post.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
job = provider.submit_generation(
|
|
prompt="a cat", size="1024x1024", style="watercolor"
|
|
)
|
|
|
|
self.assertEqual(job["status"], "processing")
|
|
self.assertEqual(job["job_id"], "uuid-abc")
|
|
self.assertIsNone(job["image_url"])
|
|
|
|
call_kwargs = http_client.post.call_args
|
|
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
|
self.assertEqual(body["templateUuid"], "tpl-uuid")
|
|
self.assertIn("a cat", body["generateParams"]["prompt"])
|
|
self.assertIn("watercolor", body["generateParams"]["prompt"].lower())
|
|
self.assertEqual(body["generateParams"]["aspectRatio"], "square")
|
|
|
|
def test_submit_applies_style_when_prompt_does_not_include_it(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {"generateUuid": "uuid-abc"},
|
|
}
|
|
resp.raise_for_status = Mock()
|
|
http_client.post.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
provider.submit_generation(
|
|
prompt="a cat under the rain", size="1024x1024", style="watercolor"
|
|
)
|
|
|
|
call_kwargs = http_client.post.call_args
|
|
body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json")
|
|
|
|
self.assertIn("a cat under the rain", body["generateParams"]["prompt"])
|
|
self.assertIn("watercolor", body["generateParams"]["prompt"].lower())
|
|
|
|
def test_submit_raises_on_error_code(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.json.return_value = {"code": 100000, "msg": "param error"}
|
|
resp.raise_for_status = Mock()
|
|
http_client.post.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
with self.assertRaises(RuntimeError):
|
|
provider.submit_generation(
|
|
prompt="a cat", size="1024x1024", style="watercolor"
|
|
)
|
|
|
|
|
|
class PollUntilCompleteTest(unittest.TestCase):
|
|
def test_returns_completed_on_status_5(self):
|
|
http_client = Mock()
|
|
pending_resp = Mock()
|
|
pending_resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {"generateStatus": 2, "images": []},
|
|
}
|
|
pending_resp.raise_for_status = Mock()
|
|
|
|
success_resp = Mock()
|
|
success_resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {
|
|
"generateStatus": 5,
|
|
"images": [
|
|
{"imageUrl": "https://cdn.example.com/1.png", "auditStatus": 3}
|
|
],
|
|
},
|
|
}
|
|
success_resp.raise_for_status = Mock()
|
|
http_client.post.side_effect = [pending_resp, success_resp]
|
|
|
|
provider = _make_provider(http_client)
|
|
job = provider.poll_until_complete(
|
|
{"status": "processing", "job_id": "uuid-abc"},
|
|
poll_interval_seconds=0,
|
|
max_attempts=3,
|
|
)
|
|
|
|
self.assertEqual(job["status"], "completed")
|
|
self.assertEqual(job["image_url"], "https://cdn.example.com/1.png")
|
|
self.assertEqual(job["job_id"], "uuid-abc")
|
|
|
|
def test_raises_on_status_6_failure(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {"generateStatus": 6, "generateMsg": "content violation"},
|
|
}
|
|
resp.raise_for_status = Mock()
|
|
http_client.post.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
with self.assertRaises(RuntimeError, msg="content violation"):
|
|
provider.poll_until_complete(
|
|
{"status": "processing", "job_id": "uuid-abc"},
|
|
poll_interval_seconds=0,
|
|
max_attempts=2,
|
|
)
|
|
|
|
def test_raises_timeout_after_max_attempts(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.json.return_value = {
|
|
"code": 0,
|
|
"data": {"generateStatus": 2, "images": []},
|
|
}
|
|
resp.raise_for_status = Mock()
|
|
http_client.post.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
with self.assertRaises(TimeoutError):
|
|
provider.poll_until_complete(
|
|
{"status": "processing", "job_id": "uuid-abc"},
|
|
poll_interval_seconds=0,
|
|
max_attempts=2,
|
|
)
|
|
|
|
|
|
class DownloadImageTest(unittest.TestCase):
|
|
def test_download_allows_liblib_tmp_image_host_by_default(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.content = b"png-bytes"
|
|
resp.raise_for_status = Mock()
|
|
http_client.get.return_value = resp
|
|
|
|
provider = _make_provider(http_client)
|
|
payload = provider.download_image(
|
|
{"image_url": "https://liblibai-tmp-image.liblib.cloud/img/demo.png"}
|
|
)
|
|
|
|
self.assertEqual(payload, b"png-bytes")
|
|
http_client.get.assert_called_once_with(
|
|
"https://liblibai-tmp-image.liblib.cloud/img/demo.png"
|
|
)
|
|
|
|
def test_download_fetches_binary_payload(self):
|
|
http_client = Mock()
|
|
resp = Mock()
|
|
resp.content = b"png-bytes"
|
|
resp.raise_for_status = Mock()
|
|
http_client.get.return_value = resp
|
|
|
|
provider = LiblibImageProvider(
|
|
http_client=http_client,
|
|
access_key="test-ak",
|
|
secret_key="test-sk",
|
|
base_url="https://openapi.liblibai.cloud",
|
|
template_uuid="tpl-uuid",
|
|
allowed_download_hosts=("cdn.example.com",),
|
|
)
|
|
payload = provider.download_image(
|
|
{"image_url": "https://cdn.example.com/1.png"}
|
|
)
|
|
|
|
self.assertEqual(payload, b"png-bytes")
|
|
|
|
def test_download_rejects_unapproved_host(self):
|
|
http_client = Mock()
|
|
provider = _make_provider(http_client)
|
|
|
|
with self.assertRaises(ValueError):
|
|
provider.download_image({"image_url": "https://evil.example.com/1.png"})
|
|
|
|
http_client.get.assert_not_called()
|
|
|
|
|
|
class ProviderResourceManagementTest(unittest.TestCase):
|
|
@patch("app.adapters.image_gen.liblib_provider.httpx.Client")
|
|
def test_provider_closes_owned_http_client(self, httpx_client_cls):
|
|
http_client = Mock()
|
|
httpx_client_cls.return_value = http_client
|
|
|
|
provider = LiblibImageProvider(
|
|
access_key="test-ak",
|
|
secret_key="test-sk",
|
|
base_url="https://openapi.liblibai.cloud",
|
|
template_uuid="tpl-uuid",
|
|
)
|
|
|
|
provider.close()
|
|
|
|
http_client.close.assert_called_once()
|
|
|
|
def test_provider_does_not_close_injected_http_client(self):
|
|
http_client = Mock()
|
|
provider = LiblibImageProvider(
|
|
http_client=http_client,
|
|
access_key="test-ak",
|
|
secret_key="test-sk",
|
|
base_url="https://openapi.liblibai.cloud",
|
|
template_uuid="tpl-uuid",
|
|
)
|
|
|
|
provider.close()
|
|
|
|
http_client.close.assert_not_called()
|
|
|
|
|
|
class ProviderDefaultsTest(unittest.TestCase):
|
|
@patch.dict(os.environ, {"LIBLIB_TEMPLATE_UUID": ""}, clear=False)
|
|
def test_provider_uses_shared_template_uuid_default(self):
|
|
provider = LiblibImageProvider(
|
|
http_client=Mock(),
|
|
access_key="test-ak",
|
|
secret_key="test-sk",
|
|
base_url="https://openapi.liblibai.cloud",
|
|
)
|
|
|
|
self.assertEqual(provider.template_uuid, DEFAULT_LIBLIB_TEMPLATE_UUID)
|