import os import unittest from unittest.mock import Mock, patch from app.features.memoir.memoir_images.provider import LiblibImageProvider from app.features.memoir.memoir_images.settings import DEFAULT_LIBLIB_TEMPLATE_UUID def _make_provider(http_client=None): return LiblibImageProvider( http_client=http_client or Mock(), access_key="test-ak", secret_key="test-sk", base_url="https://openapi.liblibai.cloud", template_uuid="tpl-uuid", ) class LiblibSignatureTest(unittest.TestCase): def test_sign_returns_auth_params(self): provider = _make_provider() params = provider._sign("/api/generate/webui/text2img/ultra") self.assertEqual(params["AccessKey"], "test-ak") self.assertIn("Signature", params) self.assertIn("Timestamp", params) self.assertIn("SignatureNonce", params) class SubmitGenerationTest(unittest.TestCase): def test_submit_keeps_auth_params_out_of_url_string(self): http_client = Mock() resp = Mock() resp.json.return_value = { "code": 0, "data": {"generateUuid": "uuid-abc"}, } resp.raise_for_status = Mock() http_client.post.return_value = resp provider = _make_provider(http_client) provider.submit_generation(prompt="a cat", size="1024x1024", style="watercolor") call_kwargs = http_client.post.call_args url = call_kwargs.args[0] if call_kwargs.args else call_kwargs.kwargs["url"] params = call_kwargs.kwargs.get("params") self.assertEqual( url, "https://openapi.liblibai.cloud/api/generate/webui/text2img/ultra" ) self.assertNotIn("AccessKey=", url) self.assertEqual(params["AccessKey"], "test-ak") self.assertIn("Signature", params) self.assertIn("Timestamp", params) self.assertIn("SignatureNonce", params) def test_submit_returns_processing_with_job_id(self): http_client = Mock() resp = Mock() resp.json.return_value = { "code": 0, "data": {"generateUuid": "uuid-abc"}, } resp.raise_for_status = Mock() http_client.post.return_value = resp provider = _make_provider(http_client) job = provider.submit_generation( prompt="a cat", size="1024x1024", style="watercolor" ) self.assertEqual(job["status"], "processing") self.assertEqual(job["job_id"], "uuid-abc") self.assertIsNone(job["image_url"]) call_kwargs = http_client.post.call_args body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") self.assertEqual(body["templateUuid"], "tpl-uuid") self.assertIn("a cat", body["generateParams"]["prompt"]) self.assertIn("watercolor", body["generateParams"]["prompt"].lower()) self.assertEqual(body["generateParams"]["aspectRatio"], "square") def test_submit_applies_style_when_prompt_does_not_include_it(self): http_client = Mock() resp = Mock() resp.json.return_value = { "code": 0, "data": {"generateUuid": "uuid-abc"}, } resp.raise_for_status = Mock() http_client.post.return_value = resp provider = _make_provider(http_client) provider.submit_generation( prompt="a cat under the rain", size="1024x1024", style="watercolor" ) call_kwargs = http_client.post.call_args body = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") self.assertIn("a cat under the rain", body["generateParams"]["prompt"]) self.assertIn("watercolor", body["generateParams"]["prompt"].lower()) def test_submit_raises_on_error_code(self): http_client = Mock() resp = Mock() resp.json.return_value = {"code": 100000, "msg": "param error"} resp.raise_for_status = Mock() http_client.post.return_value = resp provider = _make_provider(http_client) with self.assertRaises(RuntimeError): provider.submit_generation( prompt="a cat", size="1024x1024", style="watercolor" ) class PollUntilCompleteTest(unittest.TestCase): def test_returns_completed_on_status_5(self): http_client = Mock() pending_resp = Mock() pending_resp.json.return_value = { "code": 0, "data": {"generateStatus": 2, "images": []}, } pending_resp.raise_for_status = Mock() success_resp = Mock() success_resp.json.return_value = { "code": 0, "data": { "generateStatus": 5, "images": [ {"imageUrl": "https://cdn.example.com/1.png", "auditStatus": 3} ], }, } success_resp.raise_for_status = Mock() http_client.post.side_effect = [pending_resp, success_resp] provider = _make_provider(http_client) job = provider.poll_until_complete( {"status": "processing", "job_id": "uuid-abc"}, poll_interval_seconds=0, max_attempts=3, ) self.assertEqual(job["status"], "completed") self.assertEqual(job["image_url"], "https://cdn.example.com/1.png") self.assertEqual(job["job_id"], "uuid-abc") def test_raises_on_status_6_failure(self): http_client = Mock() resp = Mock() resp.json.return_value = { "code": 0, "data": {"generateStatus": 6, "generateMsg": "content violation"}, } resp.raise_for_status = Mock() http_client.post.return_value = resp provider = _make_provider(http_client) with self.assertRaises(RuntimeError, msg="content violation"): provider.poll_until_complete( {"status": "processing", "job_id": "uuid-abc"}, poll_interval_seconds=0, max_attempts=2, ) def test_raises_timeout_after_max_attempts(self): http_client = Mock() resp = Mock() resp.json.return_value = { "code": 0, "data": {"generateStatus": 2, "images": []}, } resp.raise_for_status = Mock() http_client.post.return_value = resp provider = _make_provider(http_client) with self.assertRaises(TimeoutError): provider.poll_until_complete( {"status": "processing", "job_id": "uuid-abc"}, poll_interval_seconds=0, max_attempts=2, ) class DownloadImageTest(unittest.TestCase): def test_download_allows_liblib_tmp_image_host_by_default(self): http_client = Mock() resp = Mock() resp.content = b"png-bytes" resp.raise_for_status = Mock() http_client.get.return_value = resp provider = _make_provider(http_client) payload = provider.download_image( {"image_url": "https://liblibai-tmp-image.liblib.cloud/img/demo.png"} ) self.assertEqual(payload, b"png-bytes") http_client.get.assert_called_once_with( "https://liblibai-tmp-image.liblib.cloud/img/demo.png" ) def test_download_fetches_binary_payload(self): http_client = Mock() resp = Mock() resp.content = b"png-bytes" resp.raise_for_status = Mock() http_client.get.return_value = resp provider = LiblibImageProvider( http_client=http_client, access_key="test-ak", secret_key="test-sk", base_url="https://openapi.liblibai.cloud", template_uuid="tpl-uuid", allowed_download_hosts=("cdn.example.com",), ) payload = provider.download_image( {"image_url": "https://cdn.example.com/1.png"} ) self.assertEqual(payload, b"png-bytes") def test_download_rejects_unapproved_host(self): http_client = Mock() provider = _make_provider(http_client) with self.assertRaises(ValueError): provider.download_image({"image_url": "https://evil.example.com/1.png"}) http_client.get.assert_not_called() class ProviderResourceManagementTest(unittest.TestCase): @patch("app.adapters.image_gen.liblib_provider.httpx.Client") def test_provider_closes_owned_http_client(self, httpx_client_cls): http_client = Mock() httpx_client_cls.return_value = http_client provider = LiblibImageProvider( access_key="test-ak", secret_key="test-sk", base_url="https://openapi.liblibai.cloud", template_uuid="tpl-uuid", ) provider.close() http_client.close.assert_called_once() def test_provider_does_not_close_injected_http_client(self): http_client = Mock() provider = LiblibImageProvider( http_client=http_client, access_key="test-ak", secret_key="test-sk", base_url="https://openapi.liblibai.cloud", template_uuid="tpl-uuid", ) provider.close() http_client.close.assert_not_called() class ProviderDefaultsTest(unittest.TestCase): @patch.dict(os.environ, {"LIBLIB_TEMPLATE_UUID": ""}, clear=False) def test_provider_uses_shared_template_uuid_default(self): provider = LiblibImageProvider( http_client=Mock(), access_key="test-ak", secret_key="test-sk", base_url="https://openapi.liblibai.cloud", ) self.assertEqual(provider.template_uuid, DEFAULT_LIBLIB_TEMPLATE_UUID)