""" Liblib 图生 SDK 封装,位于 adapters 层;实现细节不暴露给 feature。 Feature 通过 port ImageGenerator 使用,本模块仅被 app.adapters.image_gen.liblib 使用。 """ import base64 import hmac import logging import re import time import uuid from hashlib import sha1 from urllib.parse import urlparse import httpx from app.core.config import settings from app.core.logging import get_logger logger = get_logger(__name__) DEFAULT_LIBLIB_TEMPLATE_UUID = "5d7e67009b344550bc1aa6ccbfa1d7f4" _SENSITIVE_QUERY_PARAMS = ("AccessKey", "Signature", "Timestamp", "SignatureNonce") _SENSITIVE_QUERY_RE = re.compile( r"([?&])(" + "|".join(_SENSITIVE_QUERY_PARAMS) + r")=([^&\s]+)" ) _SIZE_TO_ASPECT_RATIO = { "1024x1024": "square", "768x1024": "portrait", "1280x720": "landscape", } _DEFAULT_WIDTH, _DEFAULT_HEIGHT = 1024, 1024 def _parse_size(size: str) -> tuple[int, int]: try: w_str, h_str = size.lower().split("x", 1) w = max(512, min(2048, int(w_str))) h = max(512, min(2048, int(h_str))) return w, h except (ValueError, AttributeError): return _DEFAULT_WIDTH, _DEFAULT_HEIGHT class LiblibImageProvider: """Liblib (https://openapi.liblibai.cloud) image generation — adapter 层实现。""" def __init__( self, http_client: httpx.Client | None = None, access_key: str | None = None, secret_key: str | None = None, base_url: str | None = None, template_uuid: str | None = None, allowed_download_hosts: tuple[str, ...] | None = None, ): _install_http_log_redaction() self._owns_http_client = http_client is None self.http_client = http_client or httpx.Client(timeout=120) self.access_key = access_key or (settings.liblib_access_key or "") self.secret_key = secret_key or (settings.liblib_secret_key or "") self.base_url = ( base_url or settings.liblib_base_url or "https://openapi.liblibai.cloud" ).rstrip("/") self.template_uuid = template_uuid or ( settings.liblib_template_uuid or DEFAULT_LIBLIB_TEMPLATE_UUID ) self.allowed_download_hosts = _build_allowed_download_hosts( self.base_url, allowed_download_hosts=allowed_download_hosts, ) def _build_url(self, uri: str) -> str: return f"{self.base_url}{uri}" def _sign(self, uri: str) -> dict[str, str]: timestamp = str(int(time.time() * 1000)) nonce = str(uuid.uuid4()) content = "&".join((uri, timestamp, nonce)) digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest() signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() return { "AccessKey": self.access_key, "Signature": signature, "Timestamp": timestamp, "SignatureNonce": nonce, } def submit_generation(self, prompt: str, size: str, style: str) -> dict: uri = "/api/generate/webui/text2img/ultra" url = self._build_url(uri) params = self._sign(uri) styled_prompt = _apply_style_to_prompt(prompt, style) aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size) generate_params: dict = { "prompt": styled_prompt, "imgCount": 1, "steps": 30, } if aspect_ratio: generate_params["aspectRatio"] = aspect_ratio else: w, h = _parse_size(size) generate_params["imageSize"] = {"width": w, "height": h} body = { "templateUuid": self.template_uuid, "generateParams": generate_params, } response = self.http_client.post( url, params=params, headers={"Content-Type": "application/json"}, json=body, ) response.raise_for_status() data = response.json() if data.get("code") != 0: raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}") generate_uuid = data["data"]["generateUuid"] return {"status": "processing", "job_id": generate_uuid, "image_url": None} def poll_until_complete( self, job: dict, poll_interval_seconds: int, max_attempts: int ) -> dict: uri = "/api/generate/webui/status" for attempt in range(max_attempts): url = self._build_url(uri) params = self._sign(uri) response = self.http_client.post( url, params=params, headers={"Content-Type": "application/json"}, json={"generateUuid": job["job_id"]}, ) response.raise_for_status() data = response.json() if data.get("code") != 0: raise RuntimeError( f"Liblib status query failed: {data.get('msg', data)}" ) result = data.get("data", {}) status = result.get("generateStatus") if status == 5: images = result.get("images") or [] if images: return { "status": "completed", "image_url": images[0]["imageUrl"], "job_id": job["job_id"], } raise RuntimeError( f"Liblib returned success but no images for {job['job_id']}" ) if status == 6: raise RuntimeError( f"Liblib generation failed: {result.get('generateMsg', 'unknown')}" ) if status == 7: raise TimeoutError( f"Liblib returned undocumented status 7 for {job['job_id']}" ) logger.debug( "Liblib poll attempt {}/{}, status={}, job={}", attempt + 1, max_attempts, status, job["job_id"], ) time.sleep(poll_interval_seconds) raise TimeoutError( f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}" ) def download_image(self, job: dict) -> bytes: image_url = job["image_url"] _validate_download_url(image_url, self.allowed_download_hosts) response = self.http_client.get(image_url) response.raise_for_status() return response.content def download_image_from_url(self, image_url: str) -> bytes: """按 URL 下载图片(用于 port 的 download_image)。""" return self.download_image({"image_url": image_url}) def close(self) -> None: if self._owns_http_client: self.http_client.close() def __enter__(self) -> "LiblibImageProvider": return self def __exit__(self, exc_type, exc, tb) -> None: self.close() class _LiblibAuthRedactionFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: record.msg = _redact_sensitive_query_values(record.msg) if record.args: if isinstance(record.args, dict): record.args = { k: _redact_sensitive_query_values(v) for k, v in record.args.items() } else: record.args = tuple( _redact_sensitive_query_values(v) for v in record.args ) return True def _redact_sensitive_query_values(value): if isinstance(value, str): return _SENSITIVE_QUERY_RE.sub(r"\1\2=[REDACTED]", value) return value def _install_http_log_redaction() -> None: """对 httpx/httpcore 的标准库 Logger 挂 Filter(非 loguru get_logger)。""" for logger_name in ( "httpx", "httpcore", "httpcore.connection", "httpcore.http11", "httpcore.proxy", ): target_logger = logging.getLogger(logger_name) if getattr(target_logger, "_liblib_auth_redaction_installed", False): continue target_logger.addFilter(_LiblibAuthRedactionFilter()) target_logger._liblib_auth_redaction_installed = True def _build_allowed_download_hosts( base_url: str, allowed_download_hosts: tuple[str, ...] | None = None, ) -> tuple[str, ...]: configured_hosts = allowed_download_hosts if configured_hosts is None: configured_hosts = tuple( host.strip().lower() for host in (settings.memoir_image_download_hosts or "").split(",") if host.strip() ) base_hostname = (urlparse(base_url).hostname or "").lower() default_hosts: set[str] = set() if base_hostname: default_hosts.add(base_hostname) if ( base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud" ): default_hosts.add("liblibai.cloud") default_hosts.add("liblib.cloud") return tuple(sorted(default_hosts.union(configured_hosts))) def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> None: parsed = urlparse(image_url) hostname = (parsed.hostname or "").lower() if parsed.scheme != "https" or not hostname: raise ValueError(f"Unsupported image download URL: {image_url}") if not any( _hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts ): raise ValueError(f"Image download host is not allowed: {hostname}") def _hostname_matches(hostname: str, allowed_host: str) -> bool: normalized_allowed = allowed_host.strip().lower() if not normalized_allowed: return False return hostname == normalized_allowed or hostname.endswith(f".{normalized_allowed}") def _apply_style_to_prompt(prompt: str, style: str) -> str: cleaned_prompt = (prompt or "").strip() cleaned_style = (style or "").strip() if not cleaned_style: return cleaned_prompt if cleaned_style.lower() in cleaned_prompt.lower(): return cleaned_prompt if not cleaned_prompt: return cleaned_style return f"{cleaned_style}, {cleaned_prompt}"