Files
life-echo/api/services/memoir_images/provider.py
Kevin 201dedb84c fix: 完善 Liblib 图片尺寸兼容处理并补充安全排查记录
- 调整 Liblib 图片生成参数,优先使用官方 aspectRatio 预设,非预设尺寸回退为 imageSize 显式宽高。\n- 新增尺寸解析与边界钳制逻辑,并补充对 undocumented 状态 7 的防御性处理说明。\n- 新增密钥排查备忘,记录 .env 中腾讯 COS 凭证硬编码问题,便于后续安全整改。
2026-03-11 15:36:58 +08:00

288 lines
10 KiB
Python

import base64
import hmac
import logging
import os
import re
import time
import uuid
from hashlib import sha1
from urllib.parse import urlparse
import httpx
from .settings import DEFAULT_LIBLIB_TEMPLATE_UUID
logger = logging.getLogger(__name__)
_SENSITIVE_QUERY_PARAMS = ("AccessKey", "Signature", "Timestamp", "SignatureNonce")
_SENSITIVE_QUERY_RE = re.compile(
r"([?&])(" + "|".join(_SENSITIVE_QUERY_PARAMS) + r")=([^&\s]+)"
)
# LibLib Star-3 Alpha aspectRatio presets (official docs):
# square → 1:1, 1024x1024
# portrait → 3:4, 768x1024
# landscape→ 16:9, 1280x720
# To use exact pixel dimensions instead, pass imageSize (see submit_generation).
_SIZE_TO_ASPECT_RATIO = {
"1024x1024": "square",
"768x1024": "portrait",
"1280x720": "landscape",
}
_DEFAULT_WIDTH, _DEFAULT_HEIGHT = 1024, 1024
def _parse_size(size: str) -> tuple[int, int]:
"""Parse a 'WxH' string into (width, height), clamped to 512~2048."""
try:
w_str, h_str = size.lower().split("x", 1)
w = max(512, min(2048, int(w_str)))
h = max(512, min(2048, int(h_str)))
return w, h
except (ValueError, AttributeError):
return _DEFAULT_WIDTH, _DEFAULT_HEIGHT
class LiblibImageProvider:
"""Liblib (https://openapi.liblibai.cloud) image generation adapter."""
def __init__(
self,
http_client: httpx.Client | None = None,
access_key: str | None = None,
secret_key: str | None = None,
base_url: str | None = None,
template_uuid: str | None = None,
allowed_download_hosts: tuple[str, ...] | None = None,
):
_install_http_log_redaction()
self._owns_http_client = http_client is None
self.http_client = http_client or httpx.Client(timeout=120)
self.access_key = access_key or os.getenv("LIBLIB_ACCESS_KEY", "")
self.secret_key = secret_key or os.getenv("LIBLIB_SECRET_KEY", "")
self.base_url = (base_url or os.getenv("LIBLIB_BASE_URL", "https://openapi.liblibai.cloud")).rstrip("/")
self.template_uuid = template_uuid or os.getenv("LIBLIB_TEMPLATE_UUID") or DEFAULT_LIBLIB_TEMPLATE_UUID
self.allowed_download_hosts = _build_allowed_download_hosts(
self.base_url,
allowed_download_hosts=allowed_download_hosts,
)
# ------------------------------------------------------------------
# Signature helpers
# ------------------------------------------------------------------
def _build_url(self, uri: str) -> str:
return f"{self.base_url}{uri}"
def _sign(self, uri: str) -> dict[str, str]:
"""Build Liblib HMAC-SHA1 query-string auth params."""
timestamp = str(int(time.time() * 1000))
nonce = str(uuid.uuid4())
content = "&".join((uri, timestamp, nonce))
digest = hmac.new(self.secret_key.encode(), content.encode(), sha1).digest()
signature = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return {
"AccessKey": self.access_key,
"Signature": signature,
"Timestamp": timestamp,
"SignatureNonce": nonce,
}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def submit_generation(self, prompt: str, size: str, style: str) -> dict:
uri = "/api/generate/webui/text2img/ultra"
url = self._build_url(uri)
params = self._sign(uri)
styled_prompt = _apply_style_to_prompt(prompt, style)
aspect_ratio = _SIZE_TO_ASPECT_RATIO.get(size)
generate_params: dict = {
"prompt": styled_prompt,
"imgCount": 1,
"steps": 30,
}
if aspect_ratio:
generate_params["aspectRatio"] = aspect_ratio
else:
# Not a preset ratio — pass explicit imageSize (width/height 512~2048).
w, h = _parse_size(size)
generate_params["imageSize"] = {"width": w, "height": h}
body = {
"templateUuid": self.template_uuid,
"generateParams": generate_params,
}
response = self.http_client.post(
url,
params=params,
headers={"Content-Type": "application/json"},
json=body,
)
response.raise_for_status()
data = response.json()
if data.get("code") != 0:
raise RuntimeError(f"Liblib submit failed: {data.get('msg', data)}")
generate_uuid = data["data"]["generateUuid"]
return {"status": "processing", "job_id": generate_uuid, "image_url": None}
def poll_until_complete(
self, job: dict, poll_interval_seconds: int, max_attempts: int
) -> dict:
uri = "/api/generate/webui/status"
for attempt in range(max_attempts):
url = self._build_url(uri)
params = self._sign(uri)
response = self.http_client.post(
url,
params=params,
headers={"Content-Type": "application/json"},
json={"generateUuid": job["job_id"]},
)
response.raise_for_status()
data = response.json()
if data.get("code") != 0:
raise RuntimeError(f"Liblib status query failed: {data.get('msg', data)}")
result = data.get("data", {})
status = result.get("generateStatus")
if status == 5: # success
images = result.get("images") or []
if images:
return {
"status": "completed",
"image_url": images[0]["imageUrl"],
"job_id": job["job_id"],
}
raise RuntimeError(f"Liblib returned success but no images for {job['job_id']}")
if status == 6: # failed
raise RuntimeError(f"Liblib generation failed: {result.get('generateMsg', 'unknown')}")
# Status 7 is not listed in the official LibLib API docs (1-6 only).
# Treat any undocumented non-terminal status defensively as a failure.
if status == 7:
raise TimeoutError(f"Liblib returned undocumented status 7 for {job['job_id']}")
logger.debug(
"Liblib poll attempt %d/%d, status=%s, job=%s",
attempt + 1, max_attempts, status, job["job_id"],
)
time.sleep(poll_interval_seconds)
raise TimeoutError(f"Liblib image generation timed out after {max_attempts} attempts for {job['job_id']}")
def download_image(self, job: dict) -> bytes:
image_url = job["image_url"]
_validate_download_url(image_url, self.allowed_download_hosts)
response = self.http_client.get(image_url)
response.raise_for_status()
return response.content
def close(self) -> None:
if self._owns_http_client:
self.http_client.close()
def __enter__(self) -> "LiblibImageProvider":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()
class _LiblibAuthRedactionFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.msg = _redact_sensitive_query_values(record.msg)
if record.args:
if isinstance(record.args, dict):
record.args = {
key: _redact_sensitive_query_values(value)
for key, value in record.args.items()
}
else:
record.args = tuple(_redact_sensitive_query_values(value) for value in record.args)
return True
def _redact_sensitive_query_values(value):
if isinstance(value, str):
return _SENSITIVE_QUERY_RE.sub(r"\1\2=[REDACTED]", value)
return value
def _install_http_log_redaction() -> None:
for logger_name in (
"httpx",
"httpcore",
"httpcore.connection",
"httpcore.http11",
"httpcore.proxy",
):
target_logger = logging.getLogger(logger_name)
if getattr(target_logger, "_liblib_auth_redaction_installed", False):
continue
target_logger.addFilter(_LiblibAuthRedactionFilter())
target_logger._liblib_auth_redaction_installed = True
def _build_allowed_download_hosts(
base_url: str,
allowed_download_hosts: tuple[str, ...] | None = None,
) -> tuple[str, ...]:
configured_hosts = allowed_download_hosts
if configured_hosts is None:
configured_hosts = tuple(
host.strip().lower()
for host in os.getenv("MEMOIR_IMAGE_DOWNLOAD_HOSTS", "").split(",")
if host.strip()
)
base_hostname = (urlparse(base_url).hostname or "").lower()
default_hosts: set[str] = set()
if base_hostname:
default_hosts.add(base_hostname)
if base_hostname.endswith(".liblibai.cloud") or base_hostname == "liblibai.cloud":
default_hosts.add("liblibai.cloud")
# Liblib returns generated image downloads from *.liblib.cloud.
default_hosts.add("liblib.cloud")
return tuple(sorted(default_hosts.union(configured_hosts)))
def _validate_download_url(image_url: str, allowed_hosts: tuple[str, ...]) -> None:
parsed = urlparse(image_url)
hostname = (parsed.hostname or "").lower()
if parsed.scheme != "https" or not hostname:
raise ValueError(f"Unsupported image download URL: {image_url}")
if not any(_hostname_matches(hostname, allowed_host) for allowed_host in allowed_hosts):
raise ValueError(f"Image download host is not allowed: {hostname}")
def _hostname_matches(hostname: str, allowed_host: str) -> bool:
normalized_allowed = allowed_host.strip().lower()
if not normalized_allowed:
return False
return hostname == normalized_allowed or hostname.endswith(f".{normalized_allowed}")
def _apply_style_to_prompt(prompt: str, style: str) -> str:
cleaned_prompt = (prompt or "").strip()
cleaned_style = (style or "").strip()
if not cleaned_style:
return cleaned_prompt
if cleaned_style.lower() in cleaned_prompt.lower():
return cleaned_prompt
if not cleaned_prompt:
return cleaned_style
return f"{cleaned_style}, {cleaned_prompt}"