Files
operating-room-monitor-server/backend/tests/test_bake_pretrained_weights.py
2026-05-27 16:22:55 +08:00

68 lines
2.5 KiB
Python

from __future__ import annotations
import importlib.util
from pathlib import Path
import pytest
def _load_module():
import sys
path = Path(__file__).resolve().parents[1] / "scripts" / "bake_pretrained_weights.py"
spec = importlib.util.spec_from_file_location("bake_pretrained_weights", path)
mod = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = mod
spec.loader.exec_module(mod)
return mod
def test_candidate_urls_default(monkeypatch: pytest.MonkeyPatch) -> None:
mod = _load_module()
monkeypatch.delenv("PYTORCH_MODELS_URL", raising=False)
monkeypatch.delenv("PYTORCH_MODELS_MIRROR", raising=False)
urls = mod._candidate_urls("resnet50-0676ba61.pth")
assert urls[0].endswith("/models/resnet50-0676ba61.pth")
assert urls[-1] == f"{mod.OFFICIAL_PREFIX}/resnet50-0676ba61.pth"
def test_candidate_urls_explicit_override_applies_to_swin_only(
monkeypatch: pytest.MonkeyPatch,
) -> None:
mod = _load_module()
monkeypatch.setenv("PYTORCH_MODELS_URL", "https://example.com/swin3d_t-7615ae03.pth")
assert mod._candidate_urls("swin3d_t-7615ae03.pth") == [
"https://example.com/swin3d_t-7615ae03.pth"
]
assert "example.com" not in mod._candidate_urls("resnet50-0676ba61.pth")[0]
def test_local_source_prefers_legacy_swin_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
mod = _load_module()
swin = tmp_path / "swin3d_t-7615ae03.pth"
swin.write_bytes(b"x" * mod.HUB_CHECKPOINTS[0].min_bytes)
monkeypatch.setenv("PYTORCH_MODELS_LOCAL_PATH", str(swin))
assert mod._local_source("swin3d_t-7615ae03.pth") == swin
assert mod._local_source("resnet50-0676ba61.pth") is None
def test_local_source_uses_local_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
mod = _load_module()
weights = tmp_path / "weights"
weights.mkdir()
resnet = weights / "resnet50-0676ba61.pth"
resnet.write_bytes(b"x")
monkeypatch.setenv("PYTORCH_MODELS_LOCAL_DIR", str(weights))
assert mod._local_source("resnet50-0676ba61.pth") == resnet
def test_bake_skips_when_dest_already_valid(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
mod = _load_module()
monkeypatch.setenv("TORCH_HOME", str(tmp_path / "torch"))
spec = mod.HUB_CHECKPOINTS[1]
dest = tmp_path / "torch" / "hub" / "checkpoints" / spec.filename
dest.parent.mkdir(parents=True)
dest.write_bytes(b"x" * spec.min_bytes)
mod.bake_hub_checkpoint(spec, torch_home=tmp_path / "torch")