213 lines
6.8 KiB
Python
213 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
|
"""Bake torchvision / torch.hub checkpoints into TORCH_HOME during Docker image build.
|
|
|
|
Runtime batch jobs and doctor identity must not download from pytorch.org when the
|
|
api container runs as a non-root compose user (read-only /app/.cache/torch).
|
|
|
|
Place optional offline copies under backend/weights/ before build:
|
|
- swin3d_t-7615ae03.pth (VideoSwin / Swin3D-T, ~110MB)
|
|
- resnet50-0676ba61.pth (doctor ReID backbone, ~98MB)
|
|
|
|
Override mirrors: PYTORCH_MODELS_MIRROR, PYTORCH_MODELS_URL (per-file URL not supported).
|
|
Legacy single-file: PYTORCH_MODELS_LOCAL_PATH (applies only to swin3d_t).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
OFFICIAL_PREFIX = "https://download.pytorch.org/models"
|
|
AGENTSMIRROR_PREFIX = "https://uv.agentsmirror.com/download.pytorch.org/models"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HubCheckpoint:
|
|
filename: str
|
|
min_bytes: int
|
|
label: str
|
|
|
|
|
|
# All torchvision hub files downloaded at runtime by production paths (as of 5.15 bundle).
|
|
HUB_CHECKPOINTS: tuple[HubCheckpoint, ...] = (
|
|
HubCheckpoint(
|
|
filename="swin3d_t-7615ae03.pth",
|
|
min_bytes=100_000_000,
|
|
label="VideoSwin Swin3D-T (Kinetics-400)",
|
|
),
|
|
HubCheckpoint(
|
|
filename="resnet50-0676ba61.pth",
|
|
min_bytes=90_000_000,
|
|
label="doctor ReID ResNet50 (ImageNet-1K)",
|
|
),
|
|
)
|
|
|
|
|
|
def _candidate_urls(filename: str) -> list[str]:
|
|
explicit = (os.environ.get("PYTORCH_MODELS_URL") or "").strip()
|
|
# Legacy single-URL override applied only to VideoSwin (other files use mirror list).
|
|
if explicit and filename == "swin3d_t-7615ae03.pth":
|
|
return [explicit]
|
|
|
|
raw = (os.environ.get("PYTORCH_MODELS_MIRROR") or "").strip().rstrip("/")
|
|
prefixes = [p for p in (raw, AGENTSMIRROR_PREFIX) if p]
|
|
urls: list[str] = []
|
|
for prefix in prefixes:
|
|
urls.append(f"{prefix}/models/{filename}")
|
|
urls.append(f"{OFFICIAL_PREFIX}/{filename}")
|
|
|
|
seen: set[str] = set()
|
|
ordered: list[str] = []
|
|
for url in urls:
|
|
if url in seen:
|
|
continue
|
|
seen.add(url)
|
|
ordered.append(url)
|
|
return ordered
|
|
|
|
|
|
def _local_source(filename: str) -> Path | None:
|
|
legacy = (os.environ.get("PYTORCH_MODELS_LOCAL_PATH") or "").strip()
|
|
if legacy and filename == "swin3d_t-7615ae03.pth":
|
|
path = Path(legacy)
|
|
if path.is_file():
|
|
return path
|
|
|
|
local_dir = (os.environ.get("PYTORCH_MODELS_LOCAL_DIR") or "").strip()
|
|
if local_dir:
|
|
path = Path(local_dir) / filename
|
|
if path.is_file():
|
|
return path
|
|
return None
|
|
|
|
|
|
def _copy_local(src: Path, dest: Path, *, min_bytes: int) -> None:
|
|
size = src.stat().st_size
|
|
if size < min_bytes:
|
|
raise OSError(f"local file too small ({size} bytes): {src}")
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy2(src, dest)
|
|
|
|
|
|
def _download_with_curl(url: str, dest: Path, *, min_bytes: int) -> None:
|
|
curl = shutil.which("curl")
|
|
if curl is None:
|
|
raise RuntimeError("curl not found")
|
|
tmp = dest.with_suffix(dest.suffix + ".part")
|
|
tmp.unlink(missing_ok=True)
|
|
proc = subprocess.run(
|
|
[
|
|
curl,
|
|
"-fL",
|
|
"--retry",
|
|
"5",
|
|
"--retry-all-errors",
|
|
"--retry-delay",
|
|
"3",
|
|
"--connect-timeout",
|
|
"30",
|
|
"--max-time",
|
|
"1800",
|
|
"-o",
|
|
str(tmp),
|
|
url,
|
|
],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
if proc.returncode != 0:
|
|
tmp.unlink(missing_ok=True)
|
|
detail = (proc.stderr or proc.stdout or "").strip() or f"exit={proc.returncode}"
|
|
raise RuntimeError(detail)
|
|
size = tmp.stat().st_size
|
|
if size < min_bytes:
|
|
tmp.unlink(missing_ok=True)
|
|
raise OSError(f"download too small ({size} bytes)")
|
|
tmp.replace(dest)
|
|
|
|
|
|
def _download_with_torch(url: str, dest: Path, *, min_bytes: int) -> None:
|
|
import torch
|
|
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
torch.hub.download_url_to_file(url, str(dest), progress=True)
|
|
size = dest.stat().st_size
|
|
if size < min_bytes:
|
|
dest.unlink(missing_ok=True)
|
|
raise OSError(f"download too small ({size} bytes)")
|
|
|
|
|
|
def bake_hub_checkpoint(
|
|
spec: HubCheckpoint,
|
|
*,
|
|
torch_home: Path,
|
|
) -> None:
|
|
dest = torch_home / "hub" / "checkpoints" / spec.filename
|
|
if dest.is_file() and dest.stat().st_size >= spec.min_bytes:
|
|
print(f"already baked [{spec.label}]: {dest} ({dest.stat().st_size} bytes)")
|
|
return
|
|
|
|
local = _local_source(spec.filename)
|
|
if local is not None:
|
|
try:
|
|
_copy_local(local, dest, min_bytes=spec.min_bytes)
|
|
print(f"baked [{spec.label}] {dest} ({dest.stat().st_size} bytes) from local {local}")
|
|
return
|
|
except OSError as exc:
|
|
print(f"local copy failed for {spec.filename}: {exc}", file=sys.stderr)
|
|
|
|
errors: list[str] = []
|
|
for url in _candidate_urls(spec.filename):
|
|
try:
|
|
print(f"downloading [{spec.label}] {url}")
|
|
_download_with_curl(url, dest, min_bytes=spec.min_bytes)
|
|
print(f"baked [{spec.label}] {dest} ({dest.stat().st_size} bytes) from {url}")
|
|
return
|
|
except (OSError, RuntimeError) as exc:
|
|
errors.append(f"curl {url}: {exc}")
|
|
dest.unlink(missing_ok=True)
|
|
|
|
official = f"{OFFICIAL_PREFIX}/{spec.filename}"
|
|
try:
|
|
print(f"torch.hub fallback [{spec.label}]: {official}")
|
|
_download_with_torch(official, dest, min_bytes=spec.min_bytes)
|
|
print(f"baked [{spec.label}] {dest} ({dest.stat().st_size} bytes) from {official}")
|
|
return
|
|
except (OSError, RuntimeError) as exc:
|
|
errors.append(f"torch {official}: {exc}")
|
|
dest.unlink(missing_ok=True)
|
|
|
|
print(f"failed to bake {spec.label} ({spec.filename}):", file=sys.stderr)
|
|
for line in errors:
|
|
print(f" - {line}", file=sys.stderr)
|
|
raise SystemExit(1)
|
|
|
|
|
|
def warm_torchvision_hub_models() -> None:
|
|
"""Load models so torchvision verifies hub checkpoints (no network if baked)."""
|
|
from torchvision.models import ResNet50_Weights, resnet50
|
|
from torchvision.models.video import Swin3D_T_Weights, swin3d_t
|
|
|
|
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
|
|
print("resnet50 IMAGENET1K_V1 ok")
|
|
swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1)
|
|
print("swin3d_t KINETICS400_V1 ok")
|
|
|
|
|
|
def main() -> int:
|
|
torch_home = Path(os.environ.get("TORCH_HOME", "/app/.cache/torch"))
|
|
for spec in HUB_CHECKPOINTS:
|
|
bake_hub_checkpoint(spec, torch_home=torch_home)
|
|
warm_torchvision_hub_models()
|
|
print("all pretrained hub weights baked")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|