Files
operating-room-monitor-server/backend/scripts/bake_pretrained_weights.py
2026-05-27 16:22:55 +08:00

213 lines
6.8 KiB
Python

#!/usr/bin/env python3
"""Bake torchvision / torch.hub checkpoints into TORCH_HOME during Docker image build.
Runtime batch jobs and doctor identity must not download from pytorch.org when the
api container runs as a non-root compose user (read-only /app/.cache/torch).
Place optional offline copies under backend/weights/ before build:
- swin3d_t-7615ae03.pth (VideoSwin / Swin3D-T, ~110MB)
- resnet50-0676ba61.pth (doctor ReID backbone, ~98MB)
Override mirrors: PYTORCH_MODELS_MIRROR, PYTORCH_MODELS_URL (per-file URL not supported).
Legacy single-file: PYTORCH_MODELS_LOCAL_PATH (applies only to swin3d_t).
"""
from __future__ import annotations
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
OFFICIAL_PREFIX = "https://download.pytorch.org/models"
AGENTSMIRROR_PREFIX = "https://uv.agentsmirror.com/download.pytorch.org/models"
@dataclass(frozen=True)
class HubCheckpoint:
filename: str
min_bytes: int
label: str
# All torchvision hub files downloaded at runtime by production paths (as of 5.15 bundle).
HUB_CHECKPOINTS: tuple[HubCheckpoint, ...] = (
HubCheckpoint(
filename="swin3d_t-7615ae03.pth",
min_bytes=100_000_000,
label="VideoSwin Swin3D-T (Kinetics-400)",
),
HubCheckpoint(
filename="resnet50-0676ba61.pth",
min_bytes=90_000_000,
label="doctor ReID ResNet50 (ImageNet-1K)",
),
)
def _candidate_urls(filename: str) -> list[str]:
explicit = (os.environ.get("PYTORCH_MODELS_URL") or "").strip()
# Legacy single-URL override applied only to VideoSwin (other files use mirror list).
if explicit and filename == "swin3d_t-7615ae03.pth":
return [explicit]
raw = (os.environ.get("PYTORCH_MODELS_MIRROR") or "").strip().rstrip("/")
prefixes = [p for p in (raw, AGENTSMIRROR_PREFIX) if p]
urls: list[str] = []
for prefix in prefixes:
urls.append(f"{prefix}/models/{filename}")
urls.append(f"{OFFICIAL_PREFIX}/{filename}")
seen: set[str] = set()
ordered: list[str] = []
for url in urls:
if url in seen:
continue
seen.add(url)
ordered.append(url)
return ordered
def _local_source(filename: str) -> Path | None:
legacy = (os.environ.get("PYTORCH_MODELS_LOCAL_PATH") or "").strip()
if legacy and filename == "swin3d_t-7615ae03.pth":
path = Path(legacy)
if path.is_file():
return path
local_dir = (os.environ.get("PYTORCH_MODELS_LOCAL_DIR") or "").strip()
if local_dir:
path = Path(local_dir) / filename
if path.is_file():
return path
return None
def _copy_local(src: Path, dest: Path, *, min_bytes: int) -> None:
size = src.stat().st_size
if size < min_bytes:
raise OSError(f"local file too small ({size} bytes): {src}")
dest.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dest)
def _download_with_curl(url: str, dest: Path, *, min_bytes: int) -> None:
curl = shutil.which("curl")
if curl is None:
raise RuntimeError("curl not found")
tmp = dest.with_suffix(dest.suffix + ".part")
tmp.unlink(missing_ok=True)
proc = subprocess.run(
[
curl,
"-fL",
"--retry",
"5",
"--retry-all-errors",
"--retry-delay",
"3",
"--connect-timeout",
"30",
"--max-time",
"1800",
"-o",
str(tmp),
url,
],
check=False,
capture_output=True,
text=True,
)
if proc.returncode != 0:
tmp.unlink(missing_ok=True)
detail = (proc.stderr or proc.stdout or "").strip() or f"exit={proc.returncode}"
raise RuntimeError(detail)
size = tmp.stat().st_size
if size < min_bytes:
tmp.unlink(missing_ok=True)
raise OSError(f"download too small ({size} bytes)")
tmp.replace(dest)
def _download_with_torch(url: str, dest: Path, *, min_bytes: int) -> None:
import torch
dest.parent.mkdir(parents=True, exist_ok=True)
torch.hub.download_url_to_file(url, str(dest), progress=True)
size = dest.stat().st_size
if size < min_bytes:
dest.unlink(missing_ok=True)
raise OSError(f"download too small ({size} bytes)")
def bake_hub_checkpoint(
spec: HubCheckpoint,
*,
torch_home: Path,
) -> None:
dest = torch_home / "hub" / "checkpoints" / spec.filename
if dest.is_file() and dest.stat().st_size >= spec.min_bytes:
print(f"already baked [{spec.label}]: {dest} ({dest.stat().st_size} bytes)")
return
local = _local_source(spec.filename)
if local is not None:
try:
_copy_local(local, dest, min_bytes=spec.min_bytes)
print(f"baked [{spec.label}] {dest} ({dest.stat().st_size} bytes) from local {local}")
return
except OSError as exc:
print(f"local copy failed for {spec.filename}: {exc}", file=sys.stderr)
errors: list[str] = []
for url in _candidate_urls(spec.filename):
try:
print(f"downloading [{spec.label}] {url}")
_download_with_curl(url, dest, min_bytes=spec.min_bytes)
print(f"baked [{spec.label}] {dest} ({dest.stat().st_size} bytes) from {url}")
return
except (OSError, RuntimeError) as exc:
errors.append(f"curl {url}: {exc}")
dest.unlink(missing_ok=True)
official = f"{OFFICIAL_PREFIX}/{spec.filename}"
try:
print(f"torch.hub fallback [{spec.label}]: {official}")
_download_with_torch(official, dest, min_bytes=spec.min_bytes)
print(f"baked [{spec.label}] {dest} ({dest.stat().st_size} bytes) from {official}")
return
except (OSError, RuntimeError) as exc:
errors.append(f"torch {official}: {exc}")
dest.unlink(missing_ok=True)
print(f"failed to bake {spec.label} ({spec.filename}):", file=sys.stderr)
for line in errors:
print(f" - {line}", file=sys.stderr)
raise SystemExit(1)
def warm_torchvision_hub_models() -> None:
"""Load models so torchvision verifies hub checkpoints (no network if baked)."""
from torchvision.models import ResNet50_Weights, resnet50
from torchvision.models.video import Swin3D_T_Weights, swin3d_t
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
print("resnet50 IMAGENET1K_V1 ok")
swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1)
print("swin3d_t KINETICS400_V1 ok")
def main() -> int:
torch_home = Path(os.environ.get("TORCH_HOME", "/app/.cache/torch"))
for spec in HUB_CHECKPOINTS:
bake_hub_checkpoint(spec, torch_home=torch_home)
warm_torchvision_hub_models()
print("all pretrained hub weights baked")
return 0
if __name__ == "__main__":
raise SystemExit(main())