Files
life-echo/api/services/memoir_images/provider.py
2026-03-11 11:26:42 +08:00

261 lines
9.1 KiB
Python

import base64
import hmac
import logging
import os
import re
import time
import uuid
from hashlib import sha1
from urllib.parse import urlparse
import httpx
from .settings import DEFAULT_LIBLIB_TEMPLATE_UUID
logger = logging.getLogger(__name__)
_SENSITIVE_QUERY_PARAMS = ("AccessKey", "Signature", "Timestamp", "SignatureNonce")
_SENSITIVE_QUERY_RE = re.compile(
r"([?&])(" + "|".join(_SENSITIVE_QUERY_PARAMS) + r")=([^&\s]+)"
)
_SIZE_TO_ASPECT_RATIO = {
"1024x1024": "square",
"768x1024": "portrait",
"1024x768": "landscape",
"1280x720": "landscape",
"720x1280": "portrait",
}
class LiblibImageProvider:
"""Liblib (https://openapi.liblibai.cloud) image generation adapter."""
def __init__(
self,
http_client: httpx.Client | None = None,
access_key: str | None = None,
secret_key: str | None = None,
base_url: str | None = None,
template_uuid: str | None = None,
allowed_download_hosts: tuple[str, ...] | None = None,
):
_install_http_log_redaction()
self._owns_http_client = http_client is None
self.http_client = http_client or httpx.Client(timeout=120)
self.access_key = access_key or os.getenv("LIBLIB_ACCESS_KEY", "")
self.secret_key = secret_key or os.getenv("LIBLIB_SECRET_KEY", "")
self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "https://openapi.liblibai.cloud")).rstrip("/")
self.template_uuid = template_uuid or os.getenv("LIBLIB_TEMPLATE_UUID") or DEFAULT_LIBLIB_TEMPLATE_UUID
self.allowed_download_hosts = _build_allowed_download_hosts(
self.base_url,
allowed_download_hosts=allowed_download_hosts,
)
# ------------------------------------------------------------------
# Signature helpers
# ------------------------------------------------------------------
def _build_url(self, uri: str) -> str:
return f"{self.base_url}{uri}"
def _sign(self, uri: str) -> dict[str, str]:
"""Build Liblib HMAC-SHA1 query-string auth params."""
timestamp = str(int(time.time() * 1000))
nonce = str(uuid.uuid4())
content = "&".join((uri, timestamp, nonce))
digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest()
signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return {
"AccessKey": self.access_key,
"Signature": signature,
"Timestamp": timestamp,
"SignatureNonce": nonce,
}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def submit_generation(self, prompt: str, size: str, style: str) -> dict:
uri = "/api/generate/webui/text2img/ultra"
url = self._build_url(uri)
params = self._sign(uri)
styled_prompt = _apply_style_to_prompt(prompt, style)
aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size, "square")
body = {
"templateUuid": self.template_uuid,
"generateParams": {
"prompt": styled_prompt,
"aspectRatio": aspect_ratio,
"imgCount": 1,
"steps": 30,
},
}
response = self.http_client.post(
url,
params=params,
headers={"Content-Type": "application/json"},
json=body,
)
response.raise_for_status()
data = response.json()
if data.get("code") != 0:
raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}")
generate_uuid = data["data"]["generateUuid"]
return {"status": "processing", "job_id": generate_uuid, "image_url": None}
def poll_until_complete(
self, job: dict, poll_interval_seconds: int, max_attempts: int
) -> dict:
uri = "/api/generate/webui/status"
for attempt in range(max_attempts):
url = self._build_url(uri)
params = self._sign(uri)
response = self.http_client.post(
url,
params=params,
headers={"Content-Type": "application/json"},
json={"generateUuid": job["job_id"]},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 0:
raise RuntimeError(f"Liblib status query failed: {data.get('msg', data)}")
result = data.get("data", {})
status = result.get("generateStatus")
if status == 5: # success
images = result.get("images") or []
if images:
return {
"status": "completed",
"image_url": images[0]["imageUrl"],
"job_id": job["job_id"],
}
raise RuntimeError(f"Liblib returned success but no images for {job['job_id']}")
if status == 6: # failed
raise RuntimeError(f"Liblib generation failed: {result.get('generateMsg', 'unknown')}")
if status == 7: # timeout on Liblib side
raise TimeoutError(f"Liblib generation timed out on server side: {job['job_id']}")
logger.debug(
"Liblib poll attempt %d/%d, status=%s, job=%s",
attempt + 1, max_attempts, status, job["job_id"],
)
time.sleep(poll_interval_seconds)
raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}")
def download_image(self, job: dict) -> bytes:
image_url = job["image_url"]
_validate_download_url(image_url, self.allowed_download_hosts)
response = self.http_client.get(image_url)
response.raise_for_status()
return response.content
def close(self) -> None:
if self._owns_http_client:
self.http_client.close()
def __enter__(self) -> "LiblibImageProvider":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()
class _LiblibAuthRedactionFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.msg = _redact_sensitive_query_values(record.msg)
if record.args:
if isinstance(record.args, dict):
record.args = {
key: _redact_sensitive_query_values(value)
for key, value in record.args.items()
}
else:
record.args = tuple(_redact_sensitive_query_values(value) for value in record.args)
return True
def _redact_sensitive_query_values(value):
if isinstance(value, str):
return _SENSITIVE_QUERY_RE.sub(r"\1\2=[REDACTED]", value)
return value
def _install_http_log_redaction() -> None:
for logger_name in (
"httpx",
"httpcore",
"httpcore.connection",
"httpcore.http11",
"httpcore.proxy",
):
target_logger = logging.getLogger(logger_name)
if getattr(target_logger, "_liblib_auth_redaction_installed", False):
continue
target_logger.addFilter(_LiblibAuthRedactionFilter())
target_logger._liblib_auth_redaction_installed = True
def _build_allowed_download_hosts(
base_url: str,
allowed_download_hosts: tuple[str, ...] | None = None,
) -> tuple[str, ...]:
configured_hosts = allowed_download_hosts
if configured_hosts is None:
configured_hosts = tuple(
host.strip().lower()
for host in os.getenv("MEMOIR_IMAGE_DOWNLOAD_HOSTS", "").split(",")
if host.strip()
)
base_hostname = (urlparse(base_url).hostname or "").lower()
default_hosts: set[str] = set()
if base_hostname:
default_hosts.add(base_hostname)
if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud":
default_hosts.add("liblibai.cloud")
return tuple(sorted(default_hosts.union(configured_hosts)))
def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> None:
parsed = urlparse(image_url)
hostname = (parsed.hostname or "").lower()
if parsed.scheme != "https" or not hostname:
raise ValueError(f"Unsupported image download URL: {image_url}")
if not any(_hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts):
raise ValueError(f"Image download host is not allowed: {hostname}")
def _hostname_matches(hostname: str, allowed_host: str) -> bool:
normalized_allowed = allowed_host.strip().lower()
if not normalized_allowed:
return False
return hostname == normalized_allowed or hostname.endswith(f".{normalized_allowed}")
def _apply_style_to_prompt(prompt: str, style: str) -> str:
cleaned_prompt = (prompt or "").strip()
cleaned_style = (style or "").strip()
if not cleaned_style:
return cleaned_prompt
if cleaned_style.lower() in cleaned_prompt.lower():
return cleaned_prompt
if not cleaned_prompt:
return cleaned_style
return f"{cleaned_style}, {cleaned_prompt}"