fix ffmpeg
This commit is contained in:
@@ -4,9 +4,11 @@
|
||||
#
|
||||
# 5-6 ActionFormer 实时算法(默认开启):
|
||||
# - app/resources/actionformer_epoch_045.pth.tar 必须存在(离线下发,~110MB,未入 git)。
|
||||
# - VideoSwin Swin3D-T 走 torchvision Kinetics-400 预训练,**首次运行需联网下载**;
|
||||
# 可设 TORCH_HOME 指向持久化目录(容器中已默认 /root/.cache/torch),或将权重
|
||||
# 提前预热到该目录避免开录时拉取。
|
||||
# - VideoSwin Swin3D-T 权重在 Docker 构建时预下载到 /app/.cache/torch(见 scripts/bake_torch_hub_checkpoint.py);
|
||||
# 运行时不再访问 pytorch.org。「首次运行」指 torch 缓存为空时才会联网下载;现已改为构建时烘焙进镜像。
|
||||
# 国内 PyPI 镜像(南大/清华/阿里)不同步 /models/*.pth,构建默认先试 uv.agentsmirror.com 再回退官方源。
|
||||
# 离线/弱网:先 wget 权重到 backend/weights/swin3d_t-7615ae03.pth,再 docker compose build api。
|
||||
# 或:docker compose build --build-arg PYTORCH_MODELS_URL=https://your-mirror/.../swin3d_t-7615ae03.pth api
|
||||
# - Linux GPU 机:镜像内 torch / torchvision / torchaudio 为 cu130 wheel;
|
||||
# 宿主机需 NVIDIA 驱动 + NVIDIA Container Toolkit;`api` 服务已配置 `gpus: all`。
|
||||
# 启动后可验证:docker compose exec api python -c "import torch; print(torch.cuda.is_available())"
|
||||
@@ -28,9 +30,9 @@ POSTGRES_PORT=45432
|
||||
HOST_UID=1000
|
||||
HOST_GID=1000
|
||||
DOCKER_GID=999
|
||||
# 非 root 运行时 uv/torch 缓存目录(compose 内已设为 /tmp/*,一般无需改)
|
||||
# 非 root 运行时 uv/torch 缓存目录(compose 内已设;TORCH_HOME 为镜像内预烘焙路径)
|
||||
# UV_CACHE_DIR=/tmp/uv-cache
|
||||
# TORCH_HOME=/tmp/torch-cache
|
||||
# TORCH_HOME=/app/.cache/torch
|
||||
|
||||
# --- HTTP(API 对外端口)---
|
||||
# 局域网语音确认终端 / Demo 客户端访问 API 时,填写
|
||||
|
||||
@@ -14,6 +14,8 @@ RUN sed -i \
|
||||
|
||||
# OpenCV / MediaPipe (doctor pose) need GLVND + Mesa GLES/EGL in slim images; omit X11/GUI stack.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
curl \
|
||||
docker.io \
|
||||
ffmpeg \
|
||||
fontconfig \
|
||||
@@ -42,13 +44,23 @@ WORKDIR /app
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
UV_HTTP_TIMEOUT=600 \
|
||||
UV_LINK_MODE=copy \
|
||||
TORCH_HOME=/root/.cache/torch
|
||||
TORCH_HOME=/app/.cache/torch \
|
||||
YOLO_CONFIG_DIR=/app/.cache/ultralytics
|
||||
|
||||
COPY pyproject.toml uv.lock main.py alembic.ini ./
|
||||
COPY scripts ./scripts/
|
||||
COPY app ./app/
|
||||
COPY alembic ./alembic/
|
||||
# 离线批处理 / demo 直调 algorithm_subprocesses/5.15/main.py(含 weights/)
|
||||
COPY algorithm_subprocesses ./algorithm_subprocesses/
|
||||
# Bake runtime patches/assets so non-root api never writes the read-only bundle tree.
|
||||
COPY app/algorithm_runner/actionformer_release/libs/utils/nms.py \
|
||||
algorithm_subprocesses/5.15/code/actionformer_release/libs/utils/nms.py
|
||||
RUN mkdir -p algorithm_subprocesses/5.15/doctor_identity_package/.mediapipe_models && \
|
||||
curl -fsSL --retry 3 \
|
||||
-o algorithm_subprocesses/5.15/doctor_identity_package/.mediapipe_models/pose_landmarker_lite.task \
|
||||
"https://storage.googleapis.com/mediapipe-models/pose_landmarker/pose_landmarker_lite/float16/1/pose_landmarker_lite.task" && \
|
||||
test -s algorithm_subprocesses/5.15/doctor_identity_package/.mediapipe_models/pose_landmarker_lite.task
|
||||
|
||||
# uv.lock pins uv.agentsmirror.com artifact URLs. Rewrite to mainland mirrors (same /packages/... paths).
|
||||
# PyPI: Tsinghua | PyTorch wheel index: 南大 (syncs download.pytorch.org / download-r2)
|
||||
@@ -61,11 +73,27 @@ RUN sed -i \
|
||||
|
||||
ENV UV_DEFAULT_INDEX=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# VideoSwin (Swin3D-T) hub weights: bake at build time so RTSP batch jobs never hit pytorch.org at runtime.
|
||||
# Domestic PyPI mirrors (NJU/Tsinghua/Aliyun) only sync pip wheels, not /models/*.pth; default tries
|
||||
# uv.agentsmirror.com (same ecosystem as uv.lock) then download.pytorch.org. Optional offline bake:
|
||||
# backend/weights/swin3d_t-7615ae03.pth (see weights/.gitkeep)
|
||||
# Override: --build-arg PYTORCH_MODELS_URL=... or PYTORCH_MODELS_MIRROR=...
|
||||
ARG PYTORCH_MODELS_MIRROR=
|
||||
ARG PYTORCH_MODELS_URL=
|
||||
ENV PYTORCH_MODELS_MIRROR=${PYTORCH_MODELS_MIRROR} \
|
||||
PYTORCH_MODELS_URL=${PYTORCH_MODELS_URL}
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=weights/swin3d_t-7615ae03.pth,target=/tmp/prebaked/swin3d_t-7615ae03.pth,readonly,required=false \
|
||||
uv sync --frozen --no-dev --no-compile --refresh-package numpy --refresh-package mediapipe && \
|
||||
.venv/bin/python -c "import alembic" && \
|
||||
.venv/bin/python -c "import numpy; import numpy.lib._index_tricks_impl" && \
|
||||
.venv/bin/python -c "import mediapipe as mp; print('mediapipe', mp.__version__)"
|
||||
.venv/bin/python -c "import mediapipe as mp; print('mediapipe', mp.__version__)" && \
|
||||
mkdir -p /app/.cache/ultralytics && \
|
||||
PYTORCH_MODELS_LOCAL_PATH=/tmp/prebaked/swin3d_t-7615ae03.pth \
|
||||
.venv/bin/python scripts/bake_torch_hub_checkpoint.py && \
|
||||
TORCH_HOME=/app/.cache/torch .venv/bin/python -c "from torchvision.models.video import Swin3D_T_Weights, swin3d_t; swin3d_t(weights=Swin3D_T_Weights.KINETICS400_V1); print('swin3d_t cached ok')" && \
|
||||
chmod -R a+rX /app/.venv /app/algorithm_subprocesses /app/.cache/torch /app/.cache/ultralytics
|
||||
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ def _rel(pack_root: Path, raw: str | None) -> Path | None:
|
||||
return None
|
||||
path = Path(raw)
|
||||
if path.is_absolute():
|
||||
return path.resolve()
|
||||
# Do not resolve(): symlinks to bundle weights must keep writable parent dirs.
|
||||
return path
|
||||
return (pack_root / path).resolve()
|
||||
|
||||
|
||||
|
||||
@@ -4,12 +4,18 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from app.algo_host.bundle import load_reference_default_config, resolve_reference_bundle_dir
|
||||
from app.algo_host.bundle import (
|
||||
load_reference_default_config,
|
||||
resolve_bundle_relative_path,
|
||||
resolve_reference_bundle_dir,
|
||||
)
|
||||
from app.consumable_catalog import build_name_mapping
|
||||
|
||||
|
||||
@@ -51,6 +57,37 @@ def write_reference_whitelist_json(path: Path, *, candidate_consumables: list[st
|
||||
)
|
||||
|
||||
|
||||
def stage_actionformer_checkpoint(*, bundle_dir: Path, work_dir: Path) -> Path:
|
||||
"""Place ActionFormer ckpt under writable work_dir so eval can write eval_results.pkl."""
|
||||
|
||||
cfg = load_reference_default_config(bundle_dir)
|
||||
weights = cfg.get("weights") if isinstance(cfg.get("weights"), dict) else {}
|
||||
raw = str((weights or {}).get("actionformer") or "").strip()
|
||||
if not raw:
|
||||
raise ValueError("reference bundle missing weights.actionformer")
|
||||
source = resolve_bundle_relative_path(bundle_dir, raw)
|
||||
if not source.is_file():
|
||||
raise FileNotFoundError(f"ActionFormer checkpoint not found: {source}")
|
||||
|
||||
staged = work_dir / source.name
|
||||
if staged.exists():
|
||||
if staged.is_symlink() and os.path.samefile(staged, source):
|
||||
return staged
|
||||
if staged.is_file() and not staged.is_symlink():
|
||||
if staged.stat().st_size == source.stat().st_size:
|
||||
return staged
|
||||
staged.unlink()
|
||||
try:
|
||||
staged.symlink_to(source, target_is_directory=False)
|
||||
except OSError:
|
||||
tmp = staged.with_suffix(staged.suffix + ".part")
|
||||
if tmp.exists():
|
||||
tmp.unlink()
|
||||
shutil.copy2(source, tmp)
|
||||
tmp.replace(staged)
|
||||
return staged
|
||||
|
||||
|
||||
def build_job_config(
|
||||
*,
|
||||
bundle_dir: Path,
|
||||
@@ -59,6 +96,7 @@ def build_job_config(
|
||||
work_dir: Path,
|
||||
excel_path: Path,
|
||||
whitelist_path: Path,
|
||||
actionformer_ckpt: Path | None = None,
|
||||
) -> dict:
|
||||
cfg = copy.deepcopy(load_reference_default_config(bundle_dir))
|
||||
cfg["io"]["video"] = str(video_path.resolve())
|
||||
@@ -67,6 +105,8 @@ def build_job_config(
|
||||
cfg["io"]["whitelist_json"] = str(whitelist_path.resolve())
|
||||
cfg["runtime"]["work_dir"] = str(work_dir.resolve())
|
||||
cfg["runtime"]["keep_work_dir"] = False
|
||||
if actionformer_ckpt is not None:
|
||||
cfg["weights"]["actionformer"] = str(actionformer_ckpt)
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -92,6 +132,7 @@ def prepare_batch_job(
|
||||
|
||||
write_reference_catalog_excel(excel_path, candidate_consumables=candidate_consumables)
|
||||
write_reference_whitelist_json(whitelist_path, candidate_consumables=candidate_consumables)
|
||||
staged_ckpt = stage_actionformer_checkpoint(bundle_dir=root, work_dir=cache_work_dir)
|
||||
config = build_job_config(
|
||||
bundle_dir=root,
|
||||
video_path=pipeline_video,
|
||||
@@ -99,6 +140,7 @@ def prepare_batch_job(
|
||||
work_dir=cache_work_dir.resolve(),
|
||||
excel_path=excel_path.resolve(),
|
||||
whitelist_path=whitelist_path.resolve(),
|
||||
actionformer_ckpt=staged_ckpt,
|
||||
)
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.algorithm_runner.reference_bundle_runtime import verify_reference_nms_for_subprocess
|
||||
from app.algo_host.bundle import load_reference_default_config, resolve_bundle_relative_path
|
||||
from app.algo_host.cjk_font import resolve_cjk_font_path
|
||||
from app.algo_host.transcode import VISUALIZATION_MAX_WIDTH
|
||||
@@ -19,12 +20,53 @@ def build_reference_env() -> dict[str, str]:
|
||||
env = os.environ.copy()
|
||||
env["PYTHONFAULTHANDLER"] = "1"
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
env.setdefault("PYTHONDONTWRITEBYTECODE", "1")
|
||||
# Headless containers: avoid OpenGL/EGL init noise from CV/MediaPipe defaults.
|
||||
env.setdefault("OPENCV_OPENCL_RUNTIME", "")
|
||||
env.setdefault("QT_QPA_PLATFORM", "offscreen")
|
||||
# Non-root compose user lacks passwd entries; keep caches and downloads under /tmp.
|
||||
env.setdefault("HOME", "/tmp")
|
||||
env.setdefault("USER", "app")
|
||||
env.setdefault("LOGNAME", "app")
|
||||
env.setdefault("TORCH_HOME", "/app/.cache/torch")
|
||||
env.setdefault("TORCHINDUCTOR_CACHE_DIR", "/tmp/torchinductor-cache")
|
||||
env.setdefault("YOLO_CONFIG_DIR", "/app/.cache/ultralytics")
|
||||
# Reduce fragile torch/CUDA init paths in short-lived batch subprocesses.
|
||||
env.setdefault("TORCHDYNAMO_DISABLE", "1")
|
||||
env.setdefault("TORCH_COMPILE_DISABLE", "1")
|
||||
env.setdefault("CUDA_MODULE_LOADING", "LAZY")
|
||||
env.setdefault("OMP_NUM_THREADS", "1")
|
||||
env.setdefault("MKL_NUM_THREADS", "1")
|
||||
env.setdefault("OPENBLAS_NUM_THREADS", "1")
|
||||
env.setdefault("MALLOC_ARENA_MAX", "2")
|
||||
return env
|
||||
|
||||
|
||||
def preflight_reference_runtime(*, python: str, env: dict[str, str]) -> None:
|
||||
"""Fail fast when the venv cannot import core native deps (ld.so/torch issues)."""
|
||||
|
||||
probe = [
|
||||
python,
|
||||
"-c",
|
||||
"import numpy; import torch; print('reference_runtime_ok')",
|
||||
]
|
||||
proc = subprocess.run(
|
||||
probe,
|
||||
env=env,
|
||||
cwd="/tmp",
|
||||
check=False,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=120,
|
||||
)
|
||||
if proc.returncode == 0:
|
||||
return
|
||||
stderr = (proc.stderr or "").strip()
|
||||
stdout = (proc.stdout or "").strip()
|
||||
detail = stderr or stdout or f"exit={proc.returncode}"
|
||||
raise RuntimeError(f"reference runtime preflight failed: {detail}")
|
||||
|
||||
|
||||
def build_batch_main_command(*, bundle_dir: Path, config_path: Path) -> list[str]:
|
||||
# Use the image venv interpreter directly. ``uv run`` would try to update /app/uv.lock,
|
||||
# which is root-owned in the image and fails under compose ``user: HOST_UID``.
|
||||
@@ -133,13 +175,15 @@ def run_subprocess(
|
||||
output_path: Path,
|
||||
log_label: str,
|
||||
) -> None:
|
||||
env = build_reference_env()
|
||||
preflight_reference_runtime(python=cmd[0], env=env)
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(cwd),
|
||||
check=False,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
env=build_reference_env(),
|
||||
env=env,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
msg = format_batch_failure(
|
||||
@@ -154,6 +198,7 @@ def run_subprocess(
|
||||
|
||||
|
||||
def run_batch_main(*, bundle_dir: Path, config_path: Path, work_dir: Path, output_path: Path) -> None:
|
||||
verify_reference_nms_for_subprocess(bundle_dir)
|
||||
cmd = build_batch_main_command(bundle_dir=bundle_dir, config_path=config_path)
|
||||
logger.info("reference batch starting: {}", " ".join(cmd))
|
||||
run_subprocess(
|
||||
@@ -172,6 +217,7 @@ def run_visualization_script(
|
||||
result_path: Path,
|
||||
raw_output_video_path: Path,
|
||||
) -> None:
|
||||
verify_reference_nms_for_subprocess(bundle_dir)
|
||||
cmd = build_visualization_command(
|
||||
bundle_dir=bundle_dir,
|
||||
video_path=video_path,
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
from app.algo_host.bundle import (
|
||||
@@ -26,8 +27,49 @@ __all__ = [
|
||||
"load_reference_default_config",
|
||||
"reference_weight_path",
|
||||
"resolve_reference_bundle_dir",
|
||||
"verify_reference_nms_for_subprocess",
|
||||
]
|
||||
|
||||
_NMS_OVERRIDE_MODULE = "_orm_reference_actionformer_nms"
|
||||
|
||||
|
||||
def _reference_nms_source() -> Path:
|
||||
source = REPO_ROOT / "app" / "algorithm_runner" / "actionformer_release" / "libs" / "utils" / "nms.py"
|
||||
if not source.is_file():
|
||||
raise FileNotFoundError(f"vendored ActionFormer nms.py not found: {source}")
|
||||
return source
|
||||
|
||||
|
||||
def _reference_nms_target(bundle_root: Path) -> Path | None:
|
||||
targets = [
|
||||
bundle_root / "code" / "actionformer_release" / "libs" / "utils" / "nms.py",
|
||||
bundle_root / "actionformer_release" / "libs" / "utils" / "nms.py",
|
||||
]
|
||||
return next((p for p in targets if p.is_file()), None)
|
||||
|
||||
|
||||
def _load_nms_override_module(source: Path) -> types.ModuleType:
|
||||
spec = importlib.util.spec_from_file_location(_NMS_OVERRIDE_MODULE, source)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"cannot load ActionFormer NMS override from {source}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[_NMS_OVERRIDE_MODULE] = module
|
||||
spec.loader.exec_module(module)
|
||||
if not hasattr(module, "batched_nms"):
|
||||
raise ImportError(f"ActionFormer NMS override missing batched_nms: {source}")
|
||||
return module
|
||||
|
||||
|
||||
def _inject_nms_override(source: Path) -> None:
|
||||
"""Register backend pure-PyTorch NMS before ActionFormer imports libs.utils.nms."""
|
||||
|
||||
module = _load_nms_override_module(source)
|
||||
sys.modules["libs.utils.nms"] = module
|
||||
utils_pkg = sys.modules.get("libs.utils")
|
||||
if utils_pkg is not None:
|
||||
utils_pkg.nms = module
|
||||
utils_pkg.batched_nms = module.batched_nms
|
||||
|
||||
|
||||
def ensure_reference_bundle_on_path(bundle_dir: Path | None = None) -> Path:
|
||||
"""Make the configured reference bundle importable in the current process."""
|
||||
@@ -51,25 +93,36 @@ def ensure_reference_bundle_on_path(bundle_dir: Path | None = None) -> Path:
|
||||
|
||||
|
||||
def ensure_reference_nms_patch(bundle_dir: Path | None = None) -> bool:
|
||||
"""Patch the reference ActionFormer to use the backend's pure-PyTorch NMS (realtime only)."""
|
||||
"""Ensure ActionFormer uses backend pure-PyTorch NMS without writing the bundle."""
|
||||
|
||||
root = resolve_reference_bundle_dir(bundle_dir)
|
||||
source = REPO_ROOT / "app" / "algorithm_runner" / "actionformer_release" / "libs" / "utils" / "nms.py"
|
||||
targets = [
|
||||
root / "code" / "actionformer_release" / "libs" / "utils" / "nms.py",
|
||||
root / "actionformer_release" / "libs" / "utils" / "nms.py",
|
||||
]
|
||||
target = next((p for p in targets if p.is_file()), None)
|
||||
source = _reference_nms_source()
|
||||
target = _reference_nms_target(root)
|
||||
if target is None:
|
||||
return False
|
||||
if not source.is_file():
|
||||
raise FileNotFoundError(f"vendored ActionFormer nms.py not found: {source}")
|
||||
if target.read_bytes() == source.read_bytes():
|
||||
return False
|
||||
shutil.copy2(source, target)
|
||||
_inject_nms_override(source)
|
||||
return True
|
||||
|
||||
|
||||
def verify_reference_nms_for_subprocess(bundle_dir: Path | None = None) -> None:
|
||||
"""Fail fast before spawning bundle subprocesses that cannot receive in-memory NMS overrides."""
|
||||
|
||||
root = resolve_reference_bundle_dir(bundle_dir)
|
||||
source = _reference_nms_source()
|
||||
target = _reference_nms_target(root)
|
||||
if target is None:
|
||||
return
|
||||
if target.read_bytes() == source.read_bytes():
|
||||
return
|
||||
raise RuntimeError(
|
||||
"ActionFormer NMS patch is missing in the read-only reference bundle. "
|
||||
f"Expected {target} to match backend override {source}. "
|
||||
"Rebuild the api image (Dockerfile bakes this file) or update the bundle copy."
|
||||
)
|
||||
|
||||
|
||||
def reference_weight_path(key: str, bundle_dir: Path | None = None) -> Path:
|
||||
cfg = load_reference_default_config(bundle_dir)
|
||||
raw = (((cfg.get("weights") or {}) if isinstance(cfg, dict) else {}).get(key) or "").strip()
|
||||
|
||||
@@ -86,7 +86,10 @@ services:
|
||||
HOME: /tmp
|
||||
XDG_CACHE_HOME: /tmp
|
||||
UV_CACHE_DIR: /tmp/uv-cache
|
||||
TORCH_HOME: /tmp/torch-cache
|
||||
TORCH_HOME: /app/.cache/torch
|
||||
YOLO_CONFIG_DIR: /app/.cache/ultralytics
|
||||
# Numeric UID in compose has no passwd entry; PyTorch inductor cache must not call getpass.getuser().
|
||||
TORCHINDUCTOR_CACHE_DIR: /tmp/torchinductor-cache
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-operation_room}
|
||||
@@ -138,7 +141,7 @@ services:
|
||||
DEMO_HLS_PREVIEW_CONTAINER_NAME: ${DEMO_HLS_PREVIEW_CONTAINER_NAME:-orm-mediamtx-hls}
|
||||
MEDIAMTX_DOCKER_IMAGE: ${MEDIAMTX_DOCKER_IMAGE:-m.daocloud.io/docker.io/bluenviron/mediamtx:latest}
|
||||
command: >
|
||||
sh -c "mkdir -p /tmp/uv-cache /tmp/torch-cache &&
|
||||
sh -c "mkdir -p /tmp/uv-cache /tmp/torchinductor-cache &&
|
||||
uv run --no-sync alembic upgrade head &&
|
||||
uv run --no-sync uvicorn main:app --host 0.0.0.0 --port 8000"
|
||||
ports:
|
||||
|
||||
150
backend/scripts/bake_torch_hub_checkpoint.py
Normal file
150
backend/scripts/bake_torch_hub_checkpoint.py
Normal file
@@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Bake torchvision hub checkpoints into TORCH_HOME during Docker image build."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
CHECKPOINT = "swin3d_t-7615ae03.pth"
|
||||
OFFICIAL_URL = f"https://download.pytorch.org/models/{CHECKPOINT}"
|
||||
AGENTSMIRROR_URL = f"https://uv.agentsmirror.com/download.pytorch.org/models/{CHECKPOINT}"
|
||||
MIN_BYTES = 100_000_000
|
||||
|
||||
|
||||
def _candidate_urls() -> list[str]:
|
||||
explicit = (os.environ.get("PYTORCH_MODELS_URL") or "").strip()
|
||||
if explicit:
|
||||
return [explicit]
|
||||
|
||||
raw = (os.environ.get("PYTORCH_MODELS_MIRROR") or "").strip().rstrip("/")
|
||||
prefixes = [p for p in (raw, "https://uv.agentsmirror.com/download.pytorch.org") if p]
|
||||
urls: list[str] = []
|
||||
for prefix in prefixes:
|
||||
urls.append(f"{prefix}/models/{CHECKPOINT}")
|
||||
urls.extend([AGENTSMIRROR_URL, OFFICIAL_URL])
|
||||
|
||||
seen: set[str] = set()
|
||||
ordered: list[str] = []
|
||||
for url in urls:
|
||||
if url in seen:
|
||||
continue
|
||||
seen.add(url)
|
||||
ordered.append(url)
|
||||
return ordered
|
||||
|
||||
|
||||
def _copy_local(src: Path, dest: Path) -> None:
|
||||
size = src.stat().st_size
|
||||
if size < MIN_BYTES:
|
||||
raise OSError(f"local file too small ({size} bytes): {src}")
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src, dest)
|
||||
|
||||
|
||||
def _download_with_curl(url: str, dest: Path) -> None:
|
||||
curl = shutil.which("curl")
|
||||
if curl is None:
|
||||
raise RuntimeError("curl not found")
|
||||
tmp = dest.with_suffix(dest.suffix + ".part")
|
||||
tmp.unlink(missing_ok=True)
|
||||
proc = subprocess.run(
|
||||
[
|
||||
curl,
|
||||
"-fL",
|
||||
"--retry",
|
||||
"5",
|
||||
"--retry-all-errors",
|
||||
"--retry-delay",
|
||||
"3",
|
||||
"--connect-timeout",
|
||||
"30",
|
||||
"--max-time",
|
||||
"1800",
|
||||
"-o",
|
||||
str(tmp),
|
||||
url,
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
tmp.unlink(missing_ok=True)
|
||||
detail = (proc.stderr or proc.stdout or "").strip() or f"exit={proc.returncode}"
|
||||
raise RuntimeError(detail)
|
||||
size = tmp.stat().st_size
|
||||
if size < MIN_BYTES:
|
||||
tmp.unlink(missing_ok=True)
|
||||
raise OSError(f"download too small ({size} bytes)")
|
||||
tmp.replace(dest)
|
||||
|
||||
|
||||
def _download_with_torch(url: str, dest: Path) -> None:
|
||||
import torch
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.hub.download_url_to_file(url, str(dest), progress=True)
|
||||
size = dest.stat().st_size
|
||||
if size < MIN_BYTES:
|
||||
dest.unlink(missing_ok=True)
|
||||
raise OSError(f"download too small ({size} bytes)")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
torch_home = Path(os.environ.get("TORCH_HOME", "/app/.cache/torch"))
|
||||
dest = torch_home / "hub" / "checkpoints" / CHECKPOINT
|
||||
if dest.is_file() and dest.stat().st_size >= MIN_BYTES:
|
||||
print(f"already baked: {dest} ({dest.stat().st_size} bytes)")
|
||||
return 0
|
||||
|
||||
local_raw = (os.environ.get("PYTORCH_MODELS_LOCAL_PATH") or "").strip()
|
||||
if local_raw:
|
||||
local = Path(local_raw)
|
||||
if local.is_file():
|
||||
try:
|
||||
_copy_local(local, dest)
|
||||
print(f"baked {dest} ({dest.stat().st_size} bytes) from local {local}")
|
||||
return 0
|
||||
except OSError as exc:
|
||||
print(f"local copy failed: {exc}", file=sys.stderr)
|
||||
|
||||
errors: list[str] = []
|
||||
for url in _candidate_urls():
|
||||
try:
|
||||
print(f"downloading {url}")
|
||||
_download_with_curl(url, dest)
|
||||
print(f"baked {dest} ({dest.stat().st_size} bytes) from {url}")
|
||||
return 0
|
||||
except (OSError, RuntimeError) as exc:
|
||||
errors.append(f"curl {url}: {exc}")
|
||||
dest.unlink(missing_ok=True)
|
||||
|
||||
for url in (OFFICIAL_URL,):
|
||||
try:
|
||||
print(f"torch.hub fallback: {url}")
|
||||
_download_with_torch(url, dest)
|
||||
print(f"baked {dest} ({dest.stat().st_size} bytes) from {url}")
|
||||
return 0
|
||||
except (OSError, RuntimeError) as exc:
|
||||
errors.append(f"torch {url}: {exc}")
|
||||
dest.unlink(missing_ok=True)
|
||||
|
||||
print("failed to bake VideoSwin checkpoint:", file=sys.stderr)
|
||||
for line in errors:
|
||||
print(f" - {line}", file=sys.stderr)
|
||||
print(
|
||||
"hint: domestic PyPI mirrors (NJU/Tsinghua/Aliyun) do not sync /models/*.pth; "
|
||||
"pre-download once and set PYTORCH_MODELS_LOCAL_PATH, or pass "
|
||||
"PYTORCH_MODELS_URL / --build-arg PYTORCH_MODELS_MIRROR to a mirror that hosts "
|
||||
f"/models/{CHECKPOINT}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -6,6 +6,11 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
REFERENCE_NMS_SOURCE = (
|
||||
REPO_ROOT / "app" / "algorithm_runner" / "actionformer_release" / "libs" / "utils" / "nms.py"
|
||||
)
|
||||
|
||||
|
||||
def complete_result_tsv_body() -> str:
|
||||
return (
|
||||
@@ -20,12 +25,21 @@ def write_minimal_reference_bundle(bundle: Path) -> None:
|
||||
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
|
||||
(bundle / "code").mkdir()
|
||||
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
|
||||
nms_target = bundle / "code" / "actionformer_release" / "libs" / "utils"
|
||||
nms_target.mkdir(parents=True, exist_ok=True)
|
||||
if REFERENCE_NMS_SOURCE.is_file():
|
||||
(nms_target / "nms.py").write_bytes(REFERENCE_NMS_SOURCE.read_bytes())
|
||||
else:
|
||||
(nms_target / "nms.py").write_text("# fake nms\n", encoding="utf-8")
|
||||
weights_dir = bundle / "weights"
|
||||
weights_dir.mkdir()
|
||||
(weights_dir / "actionformer_epoch_045.pth.tar").write_bytes(b"fake-ckpt")
|
||||
(bundle / "configs").mkdir()
|
||||
(bundle / "configs" / "default_config.yaml").write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
|
||||
"weights": {},
|
||||
"weights": {"actionformer": "weights/actionformer_epoch_045.pth.tar"},
|
||||
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
|
||||
"device": {},
|
||||
"phase1": {},
|
||||
|
||||
@@ -16,7 +16,7 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from app.algo_host import bundle as bundle_runtime
|
||||
from app.algo_host.batch_service import BatchAlgorithmService, BatchRunResult
|
||||
from app.algo_host.job_workspace import build_job_config
|
||||
from app.algo_host.job_workspace import build_job_config, stage_actionformer_checkpoint
|
||||
from app.algo_host.result_adapter import (
|
||||
doctor_id_for_consumption_rows,
|
||||
is_reference_result_complete,
|
||||
@@ -25,6 +25,7 @@ from app.algo_host.result_adapter import (
|
||||
)
|
||||
from app.algo_host.subprocess_runner import (
|
||||
build_batch_main_command,
|
||||
build_reference_env,
|
||||
build_visualization_command,
|
||||
describe_batch_returncode,
|
||||
format_batch_failure,
|
||||
@@ -46,6 +47,10 @@ from app.services.video_batch_cleanup import VISUALIZATION_FILENAME, visualizati
|
||||
from tests.reference_bundle_fixtures import complete_result_tsv_body, write_minimal_reference_bundle
|
||||
|
||||
|
||||
def _is_runtime_preflight_cmd(cmd: list[str]) -> bool:
|
||||
return len(cmd) >= 3 and cmd[1] == "-c" and "import numpy" in cmd[2]
|
||||
|
||||
|
||||
def test_build_job_config_does_not_keep_work_dir(tmp_path: Path) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
write_minimal_reference_bundle(bundle)
|
||||
@@ -60,6 +65,56 @@ def test_build_job_config_does_not_keep_work_dir(tmp_path: Path) -> None:
|
||||
assert cfg["runtime"]["keep_work_dir"] is False
|
||||
|
||||
|
||||
def test_stage_actionformer_checkpoint_uses_writable_work_dir(tmp_path: Path) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
write_minimal_reference_bundle(bundle)
|
||||
ckpt = bundle / "weights" / "actionformer_epoch_045.pth.tar"
|
||||
|
||||
work_dir = tmp_path / "work"
|
||||
work_dir.mkdir()
|
||||
staged = stage_actionformer_checkpoint(bundle_dir=bundle, work_dir=work_dir)
|
||||
assert staged.parent == work_dir
|
||||
assert staged.name == ckpt.name
|
||||
assert staged.is_file()
|
||||
|
||||
job_cfg = build_job_config(
|
||||
bundle_dir=bundle,
|
||||
video_path=tmp_path / "input.mp4",
|
||||
output_path=tmp_path / "out.tsv",
|
||||
work_dir=work_dir,
|
||||
excel_path=tmp_path / "catalog.xlsx",
|
||||
whitelist_path=tmp_path / "whitelist.json",
|
||||
actionformer_ckpt=staged,
|
||||
)
|
||||
assert job_cfg["weights"]["actionformer"] == str(staged)
|
||||
|
||||
|
||||
def test_stage_actionformer_checkpoint_survives_bundle_config_resolve(tmp_path: Path) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
write_minimal_reference_bundle(bundle)
|
||||
work_dir = tmp_path / "work"
|
||||
work_dir.mkdir()
|
||||
source = bundle / "weights" / "actionformer_epoch_045.pth.tar"
|
||||
staged = work_dir / source.name
|
||||
staged.symlink_to(source)
|
||||
|
||||
# Mirrors algorithm_subprocesses/5.15/src/config.py::_rel for absolute paths.
|
||||
loaded = Path(str(staged))
|
||||
assert loaded.parent == work_dir
|
||||
assert loaded.name == source.name
|
||||
|
||||
|
||||
def test_build_reference_env_sets_container_safe_defaults(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("TORCHINDUCTOR_CACHE_DIR", raising=False)
|
||||
monkeypatch.delenv("USER", raising=False)
|
||||
env = build_reference_env()
|
||||
assert env["TORCHINDUCTOR_CACHE_DIR"] == "/tmp/torchinductor-cache"
|
||||
assert env["USER"] == "app"
|
||||
assert env["PYTHONDONTWRITEBYTECODE"] == "1"
|
||||
assert env["TORCHDYNAMO_DISABLE"] == "1"
|
||||
assert env["CUDA_MODULE_LOADING"] == "LAZY"
|
||||
|
||||
|
||||
def test_latest_visualization_path_uses_vis_directory(tmp_path: Path) -> None:
|
||||
root = tmp_path / "batch"
|
||||
runner = BatchAlgorithmService(root_dir=root)
|
||||
@@ -276,7 +331,6 @@ def test_build_visualization_command_uses_hand_model_and_result_tsv(
|
||||
) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
write_minimal_reference_bundle(bundle)
|
||||
(bundle / "weights").mkdir()
|
||||
(bundle / "weights" / "hand_detect.pt").write_bytes(b"fake")
|
||||
(bundle / "visualize_result_video.py").write_text("# fake\n", encoding="utf-8")
|
||||
cfg_path = bundle / "configs" / "default_config.yaml"
|
||||
@@ -306,7 +360,6 @@ def test_build_visualization_command_passes_font_path_when_available(
|
||||
) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
write_minimal_reference_bundle(bundle)
|
||||
(bundle / "weights").mkdir()
|
||||
(bundle / "weights" / "hand_detect.pt").write_bytes(b"fake")
|
||||
(bundle / "visualize_result_video.py").write_text("# fake\n", encoding="utf-8")
|
||||
font = tmp_path / "NotoSansCJK-Regular.ttc"
|
||||
@@ -354,6 +407,8 @@ def test_batch_service_respects_reference_bundle_relative_env(
|
||||
stderr = ""
|
||||
|
||||
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
|
||||
if _is_runtime_preflight_cmd(cmd):
|
||||
return _Proc()
|
||||
calls.append(cmd)
|
||||
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
|
||||
output = Path(config["io"]["out"])
|
||||
@@ -374,7 +429,7 @@ def test_batch_service_respects_reference_bundle_relative_env(
|
||||
)
|
||||
|
||||
assert runner.bundle_dir == bundle.resolve()
|
||||
assert calls[0][5] == str(bundle.resolve() / "main.py")
|
||||
assert calls[0][3] == str(bundle.resolve() / "main.py")
|
||||
assert result.details[0].item_name == "耗材1"
|
||||
config = yaml.safe_load(Path(calls[0][calls[0].index("--config") + 1]).read_text(encoding="utf-8"))
|
||||
assert Path(config["io"]["video"]).name == "pipeline.mp4"
|
||||
@@ -387,61 +442,7 @@ def test_batch_service_reuses_cache_on_repeat_run(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
bundle.mkdir()
|
||||
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
|
||||
(bundle / "code").mkdir()
|
||||
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
|
||||
cfg_dir = bundle / "configs"
|
||||
cfg_dir.mkdir()
|
||||
cfg_dir.joinpath("default_config.yaml").write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
|
||||
"weights": {
|
||||
"actionformer": "weights/actionformer_epoch_045.pth.tar",
|
||||
"hand": "weights/hand_detect.pt",
|
||||
"goodbad": "weights/goodbad_frame.pt",
|
||||
"haocai": "weights/haocai_classify.pt",
|
||||
"tear": "weights/tear_classify.pt",
|
||||
},
|
||||
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
|
||||
"device": {"type": "cuda", "half": False},
|
||||
"phase1": {"af_min_score": 0.1, "af_min_seg_seconds": 5, "feat_batch_size": 1},
|
||||
"phase2": {
|
||||
"seek_margin_sec": 3.0,
|
||||
"frame_stride": 1,
|
||||
"det_conf": 0.5,
|
||||
"pad_ratio": 0.3,
|
||||
"imgsz_det": 640,
|
||||
"merge_iou_gt": 0.0,
|
||||
"merge_center_dist_max_px": None,
|
||||
"merge_center_dist_max_frac_diag": None,
|
||||
},
|
||||
"classification": {
|
||||
"imgsz_cls": 224,
|
||||
"good_top1_conf_threshold": 0.9,
|
||||
"good_top1_retry_threshold": 0.8,
|
||||
"haocai_min_conf": 0.8,
|
||||
"haocai_min_conf_retry": 0.7,
|
||||
"empty_cache_every": 0,
|
||||
},
|
||||
"tear_merge": {
|
||||
"merge_adjacent_tear": True,
|
||||
"tear_merge_weights": None,
|
||||
"tear_merge_class": "tearing",
|
||||
"tear_merge_head_sec": 3.0,
|
||||
"tear_merge_prob": 0.9,
|
||||
"tear_merge_min_frames": 6,
|
||||
"tear_merge_verbose": False,
|
||||
"tear_merge_full_frame": False,
|
||||
},
|
||||
"output": {"legacy_12_col_only": True},
|
||||
},
|
||||
allow_unicode=True,
|
||||
sort_keys=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
write_minimal_reference_bundle(bundle)
|
||||
video = tmp_path / "case.mp4"
|
||||
video.write_bytes(b"same-video")
|
||||
calls: list[list[str]] = []
|
||||
@@ -452,6 +453,8 @@ def test_batch_service_reuses_cache_on_repeat_run(
|
||||
stderr = ""
|
||||
|
||||
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
|
||||
if _is_runtime_preflight_cmd(cmd):
|
||||
return _Proc()
|
||||
calls.append(cmd)
|
||||
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
|
||||
output = Path(config["io"]["out"])
|
||||
@@ -499,29 +502,7 @@ def test_batch_service_shares_cache_across_surgeries_for_same_video(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
bundle.mkdir()
|
||||
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
|
||||
(bundle / "code").mkdir()
|
||||
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
|
||||
(bundle / "configs").mkdir()
|
||||
(bundle / "configs" / "default_config.yaml").write_text(
|
||||
yaml.safe_dump(
|
||||
{
|
||||
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
|
||||
"weights": {},
|
||||
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
|
||||
"device": {},
|
||||
"phase1": {},
|
||||
"phase2": {},
|
||||
"classification": {},
|
||||
"tear_merge": {},
|
||||
"output": {},
|
||||
},
|
||||
allow_unicode=True,
|
||||
sort_keys=False,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
write_minimal_reference_bundle(bundle)
|
||||
video = tmp_path / "case.mp4"
|
||||
video.write_bytes(b"same-video")
|
||||
calls: list[list[str]] = []
|
||||
@@ -532,6 +513,8 @@ def test_batch_service_shares_cache_across_surgeries_for_same_video(
|
||||
stderr = ""
|
||||
|
||||
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
|
||||
if _is_runtime_preflight_cmd(cmd):
|
||||
return _Proc()
|
||||
calls.append(cmd)
|
||||
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
|
||||
output = Path(config["io"]["out"])
|
||||
|
||||
28
backend/tests/test_bake_torch_hub_checkpoint.py
Normal file
28
backend/tests/test_bake_torch_hub_checkpoint.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _load_module():
|
||||
path = Path(__file__).resolve().parents[1] / "scripts" / "bake_torch_hub_checkpoint.py"
|
||||
spec = importlib.util.spec_from_file_location("bake_torch_hub_checkpoint", path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
def test_candidate_urls_default(monkeypatch):
|
||||
mod = _load_module()
|
||||
monkeypatch.delenv("PYTORCH_MODELS_URL", raising=False)
|
||||
monkeypatch.delenv("PYTORCH_MODELS_MIRROR", raising=False)
|
||||
urls = mod._candidate_urls()
|
||||
assert urls[0].startswith("https://uv.agentsmirror.com/download.pytorch.org/models/")
|
||||
assert urls[-1] == mod.OFFICIAL_URL
|
||||
|
||||
|
||||
def test_candidate_urls_explicit_override(monkeypatch):
|
||||
mod = _load_module()
|
||||
monkeypatch.setenv("PYTORCH_MODELS_URL", "https://example.com/swin3d_t-7615ae03.pth")
|
||||
assert mod._candidate_urls() == ["https://example.com/swin3d_t-7615ae03.pth"]
|
||||
@@ -35,6 +35,8 @@ def _fake_reference_subprocess_run(captured: list[dict[str, Any]]):
|
||||
stderr = ""
|
||||
|
||||
def _run(cmd: list[str], **kwargs: Any) -> _Proc:
|
||||
if len(cmd) >= 3 and cmd[1] == "-c" and "import numpy" in cmd[2]:
|
||||
return _Proc()
|
||||
config_path = Path(cmd[cmd.index("--config") + 1])
|
||||
config = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
captured.append(
|
||||
|
||||
54
backend/tests/test_reference_bundle_runtime.py
Normal file
54
backend/tests/test_reference_bundle_runtime.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Tests for reference bundle runtime helpers (NMS patch without writing bundle)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.algorithm_runner.reference_bundle_runtime import (
|
||||
_reference_nms_source,
|
||||
ensure_reference_nms_patch,
|
||||
verify_reference_nms_for_subprocess,
|
||||
)
|
||||
from tests.reference_bundle_fixtures import write_minimal_reference_bundle
|
||||
|
||||
|
||||
def _write_bundle_with_nms(bundle: Path, *, nms_bytes: bytes) -> Path:
|
||||
write_minimal_reference_bundle(bundle)
|
||||
target = bundle / "code" / "actionformer_release" / "libs" / "utils"
|
||||
target.mkdir(parents=True, exist_ok=True)
|
||||
nms_path = target / "nms.py"
|
||||
nms_path.write_bytes(nms_bytes)
|
||||
return nms_path
|
||||
|
||||
|
||||
def test_ensure_reference_nms_patch_injects_without_writing_bundle(tmp_path: Path) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
nms_path = _write_bundle_with_nms(bundle, nms_bytes=b"# stale vendor nms\n")
|
||||
|
||||
source = _reference_nms_source()
|
||||
patched = ensure_reference_nms_patch(bundle)
|
||||
assert patched is True
|
||||
assert nms_path.read_text(encoding="utf-8") == "# stale vendor nms\n"
|
||||
|
||||
injected = sys.modules["libs.utils.nms"]
|
||||
assert injected.batched_nms is not None
|
||||
assert injected.__file__ == str(source)
|
||||
|
||||
|
||||
def test_verify_reference_nms_for_subprocess_raises_when_bundle_stale(tmp_path: Path) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
_write_bundle_with_nms(bundle, nms_bytes=b"# stale vendor nms\n")
|
||||
|
||||
with pytest.raises(RuntimeError, match="NMS patch is missing"):
|
||||
verify_reference_nms_for_subprocess(bundle)
|
||||
|
||||
|
||||
def test_verify_reference_nms_for_subprocess_passes_when_baked(tmp_path: Path) -> None:
|
||||
bundle = tmp_path / "bundle"
|
||||
source = _reference_nms_source()
|
||||
_write_bundle_with_nms(bundle, nms_bytes=source.read_bytes())
|
||||
|
||||
verify_reference_nms_for_subprocess(bundle)
|
||||
0
backend/weights/.gitkeep
Normal file
0
backend/weights/.gitkeep
Normal file
BIN
backend/weights/swin3d_t-7615ae03.pth
Normal file
BIN
backend/weights/swin3d_t-7615ae03.pth
Normal file
Binary file not shown.
Reference in New Issue
Block a user