From cf460dd2a4e84b424e5c02532647fe683165216f Mon Sep 17 00:00:00 2001 From: Kevin Date: Tue, 10 Mar 2026 16:00:59 +0800 Subject: [PATCH] feat(api): add liblib memoir image provider adapter Made-with: Cursor --- api/services/memoir_images/provider.py | 43 ++++++++++++++++++ api/tests/test_memoir_image_provider.py | 58 +++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 api/services/memoir_images/provider.py create mode 100644 api/tests/test_memoir_image_provider.py diff --git a/api/services/memoir_images/provider.py b/api/services/memoir_images/provider.py new file mode 100644 index 0000000..dee32c0 --- /dev/null +++ b/api/services/memoir_images/provider.py @@ -0,0 +1,43 @@ +import os +import time + +import httpx + + +class LiblibImageProvider: + def __init__( + self, + http_client=None, + api_key: str | None = None, + base_url: str | None = None, + ): + self.http_client = http_client or httpx.Client(timeout=60) + self.api_key = api_key or os.getenv("LIBLIB_API_KEY", "") + self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "")).rstrip("/") + + def submit_generation(self, prompt: str, size: str, style: str) -> dict: + response = self.http_client.post( + f"{self.base_url}/v1/images/generations", + headers={"Authorization": f"Bearer {self.api_key}"}, + json={"prompt": prompt, "size": size, "style": style}, + ) + data = response.json() + if data.get("image_url"): + return {"status": "completed", "image_url": data["image_url"], "job_id": None} + return {"status": "processing", "job_id": data.get("task_id"), "image_url": None} + + def poll_until_complete(self, job: dict, poll_interval_seconds: int, max_attempts: int) -> dict: + for _ in range(max_attempts): + response = self.http_client.get( + f"{self.base_url}/v1/images/generations/{job['job_id']}", + headers={"Authorization": f"Bearer {self.api_key}"}, + ) + data = response.json() + if data.get("image_url"): + return {"status": "completed", "image_url": data["image_url"], "job_id": job["job_id"]} + time.sleep(poll_interval_seconds) + raise TimeoutError(f"Liblib image generation timed out for {job['job_id']}") + + def download_image(self, job: dict) -> bytes: + response = self.http_client.get(job["image_url"]) + return response.content diff --git a/api/tests/test_memoir_image_provider.py b/api/tests/test_memoir_image_provider.py new file mode 100644 index 0000000..e0c8de5 --- /dev/null +++ b/api/tests/test_memoir_image_provider.py @@ -0,0 +1,58 @@ +import unittest +from unittest.mock import Mock + +from api.services.memoir_images.provider import LiblibImageProvider + + +class MemoirImageProviderTest(unittest.TestCase): + def test_submit_generation_handles_sync_provider_response(self): + http_client = Mock() + http_client.post.return_value.json.return_value = { + "status": "succeeded", + "image_url": "https://provider.example.com/1.png", + } + provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + + job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor") + + self.assertEqual(job["status"], "completed") + self.assertEqual(job["image_url"], "https://provider.example.com/1.png") + + def test_submit_generation_handles_async_provider_response(self): + http_client = Mock() + http_client.post.return_value.json.return_value = { + "task_id": "job-123", + "status": "queued", + } + provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + + job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor") + + self.assertEqual(job["status"], "processing") + self.assertEqual(job["job_id"], "job-123") + + def test_poll_until_complete_returns_completed_job(self): + http_client = Mock() + http_client.get.return_value.json.side_effect = [ + {"status": "queued"}, + {"status": "succeeded", "image_url": "https://provider.example.com/1.png"}, + ] + provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + + job = provider.poll_until_complete( + {"status": "processing", "job_id": "job-123"}, + poll_interval_seconds=0, + max_attempts=2, + ) + + self.assertEqual(job["status"], "completed") + self.assertEqual(job["image_url"], "https://provider.example.com/1.png") + + def test_download_image_fetches_binary_payload(self): + http_client = Mock() + http_client.get.return_value.content = b"png-bytes" + provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com") + + payload = provider.download_image({"image_url": "https://provider.example.com/1.png"}) + + self.assertEqual(payload, b"png-bytes")