feat(api): add liblib memoir image provider adapter
Made-with: Cursor
This commit is contained in:
43
api/services/memoir_images/provider.py
Normal file
43
api/services/memoir_images/provider.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class LiblibImageProvider:
|
||||
def __init__(
|
||||
self,
|
||||
http_client=None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
):
|
||||
self.http_client = http_client or httpx.Client(timeout=60)
|
||||
self.api_key = api_key or os.getenv("LIBLIB_API_KEY", "")
|
||||
self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "")).rstrip("/")
|
||||
|
||||
def submit_generation(self, prompt: str, size: str, style: str) -> dict:
|
||||
response = self.http_client.post(
|
||||
f"{self.base_url}/v1/images/generations",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={"prompt": prompt, "size": size, "style": style},
|
||||
)
|
||||
data = response.json()
|
||||
if data.get("image_url"):
|
||||
return {"status": "completed", "image_url": data["image_url"], "job_id": None}
|
||||
return {"status": "processing", "job_id": data.get("task_id"), "image_url": None}
|
||||
|
||||
def poll_until_complete(self, job: dict, poll_interval_seconds: int, max_attempts: int) -> dict:
|
||||
for _ in range(max_attempts):
|
||||
response = self.http_client.get(
|
||||
f"{self.base_url}/v1/images/generations/{job['job_id']}",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
)
|
||||
data = response.json()
|
||||
if data.get("image_url"):
|
||||
return {"status": "completed", "image_url": data["image_url"], "job_id": job["job_id"]}
|
||||
time.sleep(poll_interval_seconds)
|
||||
raise TimeoutError(f"Liblib image generation timed out for {job['job_id']}")
|
||||
|
||||
def download_image(self, job: dict) -> bytes:
|
||||
response = self.http_client.get(job["image_url"])
|
||||
return response.content
|
||||
58
api/tests/test_memoir_image_provider.py
Normal file
58
api/tests/test_memoir_image_provider.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from api.services.memoir_images.provider import LiblibImageProvider
|
||||
|
||||
|
||||
class MemoirImageProviderTest(unittest.TestCase):
|
||||
def test_submit_generation_handles_sync_provider_response(self):
|
||||
http_client = Mock()
|
||||
http_client.post.return_value.json.return_value = {
|
||||
"status": "succeeded",
|
||||
"image_url": "https://provider.example.com/1.png",
|
||||
}
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
|
||||
job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor")
|
||||
|
||||
self.assertEqual(job["status"], "completed")
|
||||
self.assertEqual(job["image_url"], "https://provider.example.com/1.png")
|
||||
|
||||
def test_submit_generation_handles_async_provider_response(self):
|
||||
http_client = Mock()
|
||||
http_client.post.return_value.json.return_value = {
|
||||
"task_id": "job-123",
|
||||
"status": "queued",
|
||||
}
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
|
||||
job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor")
|
||||
|
||||
self.assertEqual(job["status"], "processing")
|
||||
self.assertEqual(job["job_id"], "job-123")
|
||||
|
||||
def test_poll_until_complete_returns_completed_job(self):
|
||||
http_client = Mock()
|
||||
http_client.get.return_value.json.side_effect = [
|
||||
{"status": "queued"},
|
||||
{"status": "succeeded", "image_url": "https://provider.example.com/1.png"},
|
||||
]
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
|
||||
job = provider.poll_until_complete(
|
||||
{"status": "processing", "job_id": "job-123"},
|
||||
poll_interval_seconds=0,
|
||||
max_attempts=2,
|
||||
)
|
||||
|
||||
self.assertEqual(job["status"], "completed")
|
||||
self.assertEqual(job["image_url"], "https://provider.example.com/1.png")
|
||||
|
||||
def test_download_image_fetches_binary_payload(self):
|
||||
http_client = Mock()
|
||||
http_client.get.return_value.content = b"png-bytes"
|
||||
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||||
|
||||
payload = provider.download_image({"image_url": "https://provider.example.com/1.png"})
|
||||
|
||||
self.assertEqual(payload, b"png-bytes")
|
||||
Reference in New Issue
Block a user