44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
import os
|
|
import time
|
|
|
|
import httpx
|
|
|
|
|
|
class LiblibImageProvider:
|
|
def __init__(
|
|
self,
|
|
http_client=None,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
):
|
|
self.http_client = http_client or httpx.Client(timeout=60)
|
|
self.api_key = api_key or os.getenv("LIBLIB_API_KEY", "")
|
|
self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "")).rstrip("/")
|
|
|
|
def submit_generation(self, prompt: str, size: str, style: str) -> dict:
|
|
response = self.http_client.post(
|
|
f"{self.base_url}/v1/images/generations",
|
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
json={"prompt": prompt, "size": size, "style": style},
|
|
)
|
|
data = response.json()
|
|
if data.get("image_url"):
|
|
return {"status": "completed", "image_url": data["image_url"], "job_id": None}
|
|
return {"status": "processing", "job_id": data.get("task_id"), "image_url": None}
|
|
|
|
def poll_until_complete(self, job: dict, poll_interval_seconds: int, max_attempts: int) -> dict:
|
|
for _ in range(max_attempts):
|
|
response = self.http_client.get(
|
|
f"{self.base_url}/v1/images/generations/{job['job_id']}",
|
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
)
|
|
data = response.json()
|
|
if data.get("image_url"):
|
|
return {"status": "completed", "image_url": data["image_url"], "job_id": job["job_id"]}
|
|
time.sleep(poll_interval_seconds)
|
|
raise TimeoutError(f"Liblib image generation timed out for {job['job_id']}")
|
|
|
|
def download_image(self, job: dict) -> bytes:
|
|
response = self.http_client.get(job["image_url"])
|
|
return response.content
|