import base64 import hmac import logging import os import re import time import uuid from hashlib import sha1 from urllib.parse import urlparse import httpx from .settings import DEFAULT_LIBLIB_TEMPLATE_UUID logger = logging.getLogger(__name__) _SENSITIVE_QUERY_PARAMS = ("AccessKey", "Signature", "Timestamp", "SignatureNonce") _SENSITIVE_QUERY_RE = re.compile( r"([?&])(" + "|".join(_SENSITIVE_QUERY_PARAMS) + r")=([^&\s]+)" ) # LibLib Star-3 Alpha aspectRatio presets (official docs): # square → 1:1, 1024x1024 # portrait → 3:4, 768x1024 # landscape→ 16:9, 1280x720 # To use exact pixel dimensions instead, pass imageSize (see submit_generation). _SIZE_TO_ASPECT_RATIO = { "1024x1024": "square", "768x1024": "portrait", "1280x720": "landscape", } _DEFAULT_WIDTH, _DEFAULT_HEIGHT = 1024, 1024 def _parse_size(size: str) -> tuple[int, int]: """Parse a 'WxH' string into (width, height), clamped to 512~2048.""" try: w_str, h_str = size.lower().split("x", 1) w = max(512, min(2048, int(w_str))) h = max(512, min(2048, int(h_str))) return w, h except (ValueError, AttributeError): return _DEFAULT_WIDTH, _DEFAULT_HEIGHT class LiblibImageProvider: """Liblib (https://openapi.liblibai.cloud) image generation adapter.""" def __init__( self, http_client: httpx.Client | None = None, access_key: str | None = None, secret_key: str | None = None, base_url: str | None = None, template_uuid: str | None = None, allowed_download_hosts: tuple[str, ...] | None = None, ): _install_http_log_redaction() self._owns_http_client = http_client is None self.http_client = http_client or httpx.Client(timeout=120) self.access_key = access_key or os.getenv("LIBLIB_ACCESS_KEY", "") self.secret_key = secret_key or os.getenv("LIBLIB_SECRET_KEY", "") self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "https://openapi.liblibai.cloud")).rstrip("/") self.template_uuid = template_uuid or os.getenv("LIBLIB_TEMPLATE_UUID") or DEFAULT_LIBLIB_TEMPLATE_UUID self.allowed_download_hosts = _build_allowed_download_hosts( self.base_url, allowed_download_hosts=allowed_download_hosts, ) # ------------------------------------------------------------------ # Signature helpers # ------------------------------------------------------------------ def _build_url(self, uri: str) -> str: return f"{self.base_url}{uri}" def _sign(self, uri: str) -> dict[str, str]: """Build Liblib HMAC-SHA1 query-string auth params.""" timestamp = str(int(time.time() * 1000)) nonce = str(uuid.uuid4()) content = "&".join((uri, timestamp, nonce)) digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest() signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() return { "AccessKey": self.access_key, "Signature": signature, "Timestamp": timestamp, "SignatureNonce": nonce, } # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def submit_generation(self, prompt: str, size: str, style: str) -> dict: uri = "/api/generate/webui/text2img/ultra" url = self._build_url(uri) params = self._sign(uri) styled_prompt = _apply_style_to_prompt(prompt, style) aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size) generate_params: dict = { "prompt": styled_prompt, "imgCount": 1, "steps": 30, } if aspect_ratio: generate_params["aspectRatio"] = aspect_ratio else: # Not a preset ratio — pass explicit imageSize (width/height 512~2048). w, h = _parse_size(size) generate_params["imageSize"] = {"width": w, "height": h} body = { "templateUuid": self.template_uuid, "generateParams": generate_params, } response = self.http_client.post( url, params=params, headers={"Content-Type": "application/json"}, json=body, ) response.raise_for_status() data = response.json() if data.get("code") != 0: raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}") generate_uuid = data["data"]["generateUuid"] return {"status": "processing", "job_id": generate_uuid, "image_url": None} def poll_until_complete( self, job: dict, poll_interval_seconds: int, max_attempts: int ) -> dict: uri = "/api/generate/webui/status" for attempt in range(max_attempts): url = self._build_url(uri) params = self._sign(uri) response = self.http_client.post( url, params=params, headers={"Content-Type": "application/json"}, json={"generateUuid": job["job_id"]}, ) response.raise_for_status() data = response.json() if data.get("code") != 0: raise RuntimeError(f"Liblib status query failed: {data.get('msg', data)}") result = data.get("data", {}) status = result.get("generateStatus") if status == 5: # success images = result.get("images") or [] if images: return { "status": "completed", "image_url": images[0]["imageUrl"], "job_id": job["job_id"], } raise RuntimeError(f"Liblib returned success but no images for {job['job_id']}") if status == 6: # failed raise RuntimeError(f"Liblib generation failed: {result.get('generateMsg', 'unknown')}") # Status 7 is not listed in the official LibLib API docs (1-6 only). # Treat any undocumented non-terminal status defensively as a failure. if status == 7: raise TimeoutError(f"Liblib returned undocumented status 7 for {job['job_id']}") logger.debug( "Liblib poll attempt %d/%d, status=%s, job=%s", attempt + 1, max_attempts, status, job["job_id"], ) time.sleep(poll_interval_seconds) raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}") def download_image(self, job: dict) -> bytes: image_url = job["image_url"] _validate_download_url(image_url, self.allowed_download_hosts) response = self.http_client.get(image_url) response.raise_for_status() return response.content def close(self) -> None: if self._owns_http_client: self.http_client.close() def __enter__(self) -> "LiblibImageProvider": return self def __exit__(self, exc_type, exc, tb) -> None: self.close() class _LiblibAuthRedactionFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: record.msg = _redact_sensitive_query_values(record.msg) if record.args: if isinstance(record.args, dict): record.args = { key: _redact_sensitive_query_values(value) for key, value in record.args.items() } else: record.args = tuple(_redact_sensitive_query_values(value) for value in record.args) return True def _redact_sensitive_query_values(value): if isinstance(value, str): return _SENSITIVE_QUERY_RE.sub(r"\1\2=[REDACTED]", value) return value def _install_http_log_redaction() -> None: for logger_name in ( "httpx", "httpcore", "httpcore.connection", "httpcore.http11", "httpcore.proxy", ): target_logger = logging.getLogger(logger_name) if getattr(target_logger, "_liblib_auth_redaction_installed", False): continue target_logger.addFilter(_LiblibAuthRedactionFilter()) target_logger._liblib_auth_redaction_installed = True def _build_allowed_download_hosts( base_url: str, allowed_download_hosts: tuple[str, ...] | None = None, ) -> tuple[str, ...]: configured_hosts = allowed_download_hosts if configured_hosts is None: configured_hosts = tuple( host.strip().lower() for host in os.getenv("MEMOIR_IMAGE_DOWNLOAD_HOSTS", "").split(",") if host.strip() ) base_hostname = (urlparse(base_url).hostname or "").lower() default_hosts: set[str] = set() if base_hostname: default_hosts.add(base_hostname) if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud": default_hosts.add("liblibai.cloud") # Liblib returns generated image downloads from *.liblib.cloud. default_hosts.add("liblib.cloud") return tuple(sorted(default_hosts.union(configured_hosts))) def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> None: parsed = urlparse(image_url) hostname = (parsed.hostname or "").lower() if parsed.scheme != "https" or not hostname: raise ValueError(f"Unsupported image download URL: {image_url}") if not any(_hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts): raise ValueError(f"Image download host is not allowed: {hostname}") def _hostname_matches(hostname: str, allowed_host: str) -> bool: normalized_allowed = allowed_host.strip().lower() if not normalized_allowed: return False return hostname == normalized_allowed or hostname.endswith(f".{normalized_allowed}") def _apply_style_to_prompt(prompt: str, style: str) -> str: cleaned_prompt = (prompt or "").strip() cleaned_style = (style or "").strip() if not cleaned_style: return cleaned_prompt if cleaned_style.lower() in cleaned_prompt.lower(): return cleaned_prompt if not cleaned_prompt: return cleaned_style return f"{cleaned_style}, {cleaned_prompt}"