70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
"""Liblib image generation adapter — implements ImageGenerator port."""
|
|
|
|
from app.ports.image_gen import ImageResult, TaskStatus
|
|
|
|
from .liblib_provider import LiblibImageProvider
|
|
|
|
|
|
class LiblibImageGenerator:
|
|
"""Wraps the existing LiblibImageProvider to implement the ImageGenerator port."""
|
|
|
|
def __init__(
|
|
self,
|
|
access_key: str,
|
|
secret_key: str,
|
|
base_url: str = "https://openapi.liblibai.cloud",
|
|
template_uuid: str = "",
|
|
poll_interval: int = 3,
|
|
max_attempts: int = 20,
|
|
):
|
|
self._provider = LiblibImageProvider(
|
|
access_key=access_key,
|
|
secret_key=secret_key,
|
|
base_url=base_url,
|
|
template_uuid=template_uuid or None,
|
|
)
|
|
self._poll_interval = poll_interval
|
|
self._max_attempts = max_attempts
|
|
|
|
def generate(self, prompt: str, size: str, style: str) -> ImageResult:
|
|
task_id = ""
|
|
try:
|
|
job = self._provider.submit_generation(prompt, size, style)
|
|
task_id = job.get("job_id", "")
|
|
result = self._provider.poll_until_complete(
|
|
job, self._poll_interval, self._max_attempts
|
|
)
|
|
return ImageResult(
|
|
status=TaskStatus.COMPLETED,
|
|
task_id=result.get("job_id", task_id),
|
|
image_url=result.get("image_url"),
|
|
)
|
|
except TimeoutError as e:
|
|
return ImageResult(status=TaskStatus.FAILED, task_id=task_id, error=str(e))
|
|
except Exception as e:
|
|
return ImageResult(status=TaskStatus.FAILED, task_id=task_id, error=str(e))
|
|
|
|
def check_status(self, task_id: str) -> ImageResult:
|
|
try:
|
|
job = {"job_id": task_id}
|
|
result = self._provider.poll_until_complete(job, self._poll_interval, 1)
|
|
status = (
|
|
TaskStatus.COMPLETED
|
|
if result.get("status") == "completed"
|
|
else TaskStatus.PROCESSING
|
|
)
|
|
return ImageResult(
|
|
status=status,
|
|
task_id=task_id,
|
|
image_url=result.get("image_url"),
|
|
)
|
|
except Exception as e:
|
|
return ImageResult(status=TaskStatus.FAILED, task_id=task_id, error=str(e))
|
|
|
|
def download_image(self, image_url: str) -> bytes:
|
|
"""Download image bytes from URL (port 契约)."""
|
|
return self._provider.download_image_from_url(image_url)
|
|
|
|
def close(self) -> None:
|
|
self._provider.close()
|