""" Liblib 图生 SDK 封装,位于 adapters 层;实现细节不暴露给 feature。 Feature 通过 port ImageGenerator 使用,本模块仅被 app.adapters.image_gen.liblib 使用。 """ import base64 import hmac import logging from app.core.logging import get_logger import re import time import uuid from hashlib import sha1 from urllib.parse import urlparse import httpx from app.core.config import settings logger = get_logger(__name__) DEFAULT_LIBLIB_TEMPLATE_UUID = "5d7e67009b344550bc1aa6ccbfa1d7f4" _SENSITIVE_QUERY_PARAMS = ("AccessKey", "Signature", "Timestamp", "SignatureNonce") _SENSITIVE_QUERY_RE = re.compile( r"([?&])(" + "|".join(_SENSITIVE_QUERY_PARAMS) + r")=([^&\s]+)" ) _SIZE_TO_ASPECT_RATIO = { "1024x1024": "square", "768x1024": "portrait", "1280x720": "landscape", } _DEFAULT_WIDTH, _DEFAULT_HEIGHT = 1024, 1024 def _parse_size(size: str) -> tuple[int, int]: try: w_str, h_str = size.lower().split("x", 1) w = max(512, min(2048, int(w_str))) h = max(512, min(2048, int(h_str))) return w, h except (ValueError, AttributeError): return _DEFAULT_WIDTH, _DEFAULT_HEIGHT class LiblibImageProvider: """Liblib (https://openapi.liblibai.cloud) image generation — adapter 层实现。""" def __init__( self, http_client: httpx.Client | None = None, access_key: str | None = None, secret_key: str | None = None, base_url: str | None = None, template_uuid: str | None = None, allowed_download_hosts: tuple[str, ...] | None = None, ): _install_http_log_redaction() self._owns_http_client = http_client is None self.http_client = http_client or httpx.Client(timeout=120) self.access_key = access_key or (settings.liblib_access_key or "") self.secret_key = secret_key or (settings.liblib_secret_key or "") self.base_url = ( base_url or settings.liblib_base_url or "https://openapi.liblibai.cloud" ).rstrip("/") self.template_uuid = template_uuid or ( settings.liblib_template_uuid or DEFAULT_LIBLIB_TEMPLATE_UUID ) self.allowed_download_hosts = _build_allowed_download_hosts( self.base_url, allowed_download_hosts=allowed_download_hosts, ) def _build_url(self, uri: str) -> str: return f"{self.base_url}{uri}" def _sign(self, uri: str) -> dict[str, str]: timestamp = str(int(time.time() * 1000)) nonce = str(uuid.uuid4()) content = "&".join((uri, timestamp, nonce)) digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest() signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() return { "AccessKey": self.access_key, "Signature": signature, "Timestamp": timestamp, "SignatureNonce": nonce, } def submit_generation(self, prompt: str, size: str, style: str) -> dict: uri = "/api/generate/webui/text2img/ultra" url = self._build_url(uri) params = self._sign(uri) styled_prompt = _apply_style_to_prompt(prompt, style) aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size) generate_params: dict = { "prompt": styled_prompt, "imgCount": 1, "steps": 30, } if aspect_ratio: generate_params["aspectRatio"] = aspect_ratio else: w, h = _parse_size(size) generate_params["imageSize"] = {"width": w, "height": h} body = { "templateUuid": self.template_uuid, "generateParams": generate_params, } response = self.http_client.post( url, params=params, headers={"Content-Type": "application/json"}, json=body, ) response.raise_for_status() data = response.json() if data.get("code") != 0: raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}") generate_uuid = data["data"]["generateUuid"] return {"status": "processing", "job_id": generate_uuid, "image_url": None} def poll_until_complete( self, job: dict, poll_interval_seconds: int, max_attempts: int ) -> dict: uri = "/api/generate/webui/status" for attempt in range(max_attempts): url = self._build_url(uri) params = self._sign(uri) response = self.http_client.post( url, params=params, headers={"Content-Type": "application/json"}, json={"generateUuid": job["job_id"]}, ) response.raise_for_status() data = response.json() if data.get("code") != 0: raise RuntimeError( f"Liblib status query failed: {data.get('msg', data)}" ) result = data.get("data", {}) status = result.get("generateStatus") if status == 5: images = result.get("images") or [] if images: return { "status": "completed", "image_url": images[0]["imageUrl"], "job_id": job["job_id"], } raise RuntimeError( f"Liblib returned success but no images for {job['job_id']}" ) if status == 6: raise RuntimeError( f"Liblib generation failed: {result.get('generateMsg', 'unknown')}" ) if status == 7: raise TimeoutError( f"Liblib returned undocumented status 7 for {job['job_id']}" ) logger.debug( "Liblib poll attempt %d/%d, status=%s, job=%s", attempt + 1, max_attempts, status, job["job_id"], ) time.sleep(poll_interval_seconds) raise TimeoutError( f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}" ) def download_image(self, job: dict) -> bytes: image_url = job["image_url"] _validate_download_url(image_url, self.allowed_download_hosts) response = self.http_client.get(image_url) response.raise_for_status() return response.content def download_image_from_url(self, image_url: str) -> bytes: """按 URL 下载图片(用于 port 的 download_image)。""" return self.download_image({"image_url": image_url}) def close(self) -> None: if self._owns_http_client: self.http_client.close() def __enter__(self) -> "LiblibImageProvider": return self def __exit__(self, exc_type, exc, tb) -> None: self.close() class _LiblibAuthRedactionFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: record.msg = _redact_sensitive_query_values(record.msg) if record.args: if isinstance(record.args, dict): record.args = { k: _redact_sensitive_query_values(v) for k, v in record.args.items() } else: record.args = tuple( _redact_sensitive_query_values(v) for v in record.args ) return True def _redact_sensitive_query_values(value): if isinstance(value, str): return _SENSITIVE_QUERY_RE.sub(r"\1\2=[REDACTED]", value) return value def _install_http_log_redaction() -> None: for logger_name in ( "httpx", "httpcore", "httpcore.connection", "httpcore.http11", "httpcore.proxy", ): target_logger = get_logger(logger_name) if getattr(target_logger, "_liblib_auth_redaction_installed", False): continue target_logger.addFilter(_LiblibAuthRedactionFilter()) target_logger._liblib_auth_redaction_installed = True def _build_allowed_download_hosts( base_url: str, allowed_download_hosts: tuple[str, ...] | None = None, ) -> tuple[str, ...]: configured_hosts = allowed_download_hosts if configured_hosts is None: configured_hosts = tuple( host.strip().lower() for host in (settings.memoir_image_download_hosts or "").split(",") if host.strip() ) base_hostname = (urlparse(base_url).hostname or "").lower() default_hosts: set[str] = set() if base_hostname: default_hosts.add(base_hostname) if ( base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud" ): default_hosts.add("liblibai.cloud") default_hosts.add("liblib.cloud") return tuple(sorted(default_hosts.union(configured_hosts))) def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> None: parsed = urlparse(image_url) hostname = (parsed.hostname or "").lower() if parsed.scheme != "https" or not hostname: raise ValueError(f"Unsupported image download URL: {image_url}") if not any( _hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts ): raise ValueError(f"Image download host is not allowed: {hostname}") def _hostname_matches(hostname: str, allowed_host: str) -> bool: normalized_allowed = allowed_host.strip().lower() if not normalized_allowed: return False return hostname == normalized_allowed or hostname.endswith(f".{normalized_allowed}") def _apply_style_to_prompt(prompt: str, style: str) -> str: cleaned_prompt = (prompt or "").strip() cleaned_style = (style or "").strip() if not cleaned_style: return cleaned_prompt if cleaned_style.lower() in cleaned_prompt.lower(): return cleaned_prompt if not cleaned_prompt: return cleaned_style return f"{cleaned_style}, {cleaned_prompt}"