68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
|
|
def _load_module():
|
|
import sys
|
|
|
|
path = Path(__file__).resolve().parents[1] / "scripts" / "bake_pretrained_weights.py"
|
|
spec = importlib.util.spec_from_file_location("bake_pretrained_weights", path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
assert spec.loader is not None
|
|
sys.modules[spec.name] = mod
|
|
spec.loader.exec_module(mod)
|
|
return mod
|
|
|
|
|
|
def test_candidate_urls_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
mod = _load_module()
|
|
monkeypatch.delenv("PYTORCH_MODELS_URL", raising=False)
|
|
monkeypatch.delenv("PYTORCH_MODELS_MIRROR", raising=False)
|
|
urls = mod._candidate_urls("resnet50-0676ba61.pth")
|
|
assert urls[0].endswith("/models/resnet50-0676ba61.pth")
|
|
assert urls[-1] == f"{mod.OFFICIAL_PREFIX}/resnet50-0676ba61.pth"
|
|
|
|
|
|
def test_candidate_urls_explicit_override_applies_to_swin_only(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
mod = _load_module()
|
|
monkeypatch.setenv("PYTORCH_MODELS_URL", "https://example.com/swin3d_t-7615ae03.pth")
|
|
assert mod._candidate_urls("swin3d_t-7615ae03.pth") == [
|
|
"https://example.com/swin3d_t-7615ae03.pth"
|
|
]
|
|
assert "example.com" not in mod._candidate_urls("resnet50-0676ba61.pth")[0]
|
|
|
|
|
|
def test_local_source_prefers_legacy_swin_path(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
mod = _load_module()
|
|
swin = tmp_path / "swin3d_t-7615ae03.pth"
|
|
swin.write_bytes(b"x" * mod.HUB_CHECKPOINTS[0].min_bytes)
|
|
monkeypatch.setenv("PYTORCH_MODELS_LOCAL_PATH", str(swin))
|
|
assert mod._local_source("swin3d_t-7615ae03.pth") == swin
|
|
assert mod._local_source("resnet50-0676ba61.pth") is None
|
|
|
|
|
|
def test_local_source_uses_local_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
|
mod = _load_module()
|
|
weights = tmp_path / "weights"
|
|
weights.mkdir()
|
|
resnet = weights / "resnet50-0676ba61.pth"
|
|
resnet.write_bytes(b"x")
|
|
monkeypatch.setenv("PYTORCH_MODELS_LOCAL_DIR", str(weights))
|
|
assert mod._local_source("resnet50-0676ba61.pth") == resnet
|
|
|
|
|
|
def test_bake_skips_when_dest_already_valid(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
mod = _load_module()
|
|
monkeypatch.setenv("TORCH_HOME", str(tmp_path / "torch"))
|
|
spec = mod.HUB_CHECKPOINTS[1]
|
|
dest = tmp_path / "torch" / "hub" / "checkpoints" / spec.filename
|
|
dest.parent.mkdir(parents=True)
|
|
dest.write_bytes(b"x" * spec.min_bytes)
|
|
mod.bake_hub_checkpoint(spec, torch_home=tmp_path / "torch")
|