59 lines
2.4 KiB
Python
59 lines
2.4 KiB
Python
|
|
import unittest
|
||
|
|
from unittest.mock import Mock
|
||
|
|
|
||
|
|
from api.services.memoir_images.provider import LiblibImageProvider
|
||
|
|
|
||
|
|
|
||
|
|
class MemoirImageProviderTest(unittest.TestCase):
|
||
|
|
def test_submit_generation_handles_sync_provider_response(self):
|
||
|
|
http_client = Mock()
|
||
|
|
http_client.post.return_value.json.return_value = {
|
||
|
|
"status": "succeeded",
|
||
|
|
"image_url": "https://provider.example.com/1.png",
|
||
|
|
}
|
||
|
|
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||
|
|
|
||
|
|
job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor")
|
||
|
|
|
||
|
|
self.assertEqual(job["status"], "completed")
|
||
|
|
self.assertEqual(job["image_url"], "https://provider.example.com/1.png")
|
||
|
|
|
||
|
|
def test_submit_generation_handles_async_provider_response(self):
|
||
|
|
http_client = Mock()
|
||
|
|
http_client.post.return_value.json.return_value = {
|
||
|
|
"task_id": "job-123",
|
||
|
|
"status": "queued",
|
||
|
|
}
|
||
|
|
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||
|
|
|
||
|
|
job = provider.submit_generation(prompt="foo", size="1024x1024", style="watercolor")
|
||
|
|
|
||
|
|
self.assertEqual(job["status"], "processing")
|
||
|
|
self.assertEqual(job["job_id"], "job-123")
|
||
|
|
|
||
|
|
def test_poll_until_complete_returns_completed_job(self):
|
||
|
|
http_client = Mock()
|
||
|
|
http_client.get.return_value.json.side_effect = [
|
||
|
|
{"status": "queued"},
|
||
|
|
{"status": "succeeded", "image_url": "https://provider.example.com/1.png"},
|
||
|
|
]
|
||
|
|
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||
|
|
|
||
|
|
job = provider.poll_until_complete(
|
||
|
|
{"status": "processing", "job_id": "job-123"},
|
||
|
|
poll_interval_seconds=0,
|
||
|
|
max_attempts=2,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.assertEqual(job["status"], "completed")
|
||
|
|
self.assertEqual(job["image_url"], "https://provider.example.com/1.png")
|
||
|
|
|
||
|
|
def test_download_image_fetches_binary_payload(self):
|
||
|
|
http_client = Mock()
|
||
|
|
http_client.get.return_value.content = b"png-bytes"
|
||
|
|
provider = LiblibImageProvider(http_client=http_client, api_key="test-key", base_url="https://example.com")
|
||
|
|
|
||
|
|
payload = provider.download_image({"image_url": "https://provider.example.com/1.png"})
|
||
|
|
|
||
|
|
self.assertEqual(payload, b"png-bytes")
|