Files
operating-room-monitor-server/backend/tests/test_algo_host_batch.py
2026-05-22 09:35:41 +08:00

691 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for offline batch orchestration (app.algo_host)."""
from __future__ import annotations
import json
import shutil
import subprocess
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import pytest
import yaml
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.algo_host import bundle as bundle_runtime
from app.algo_host.batch_service import BatchAlgorithmService, BatchRunResult
from app.algo_host.job_workspace import build_job_config
from app.algo_host.result_adapter import (
doctor_id_for_consumption_rows,
is_reference_result_complete,
parse_reference_doctor_info,
parse_reference_tsv,
)
from app.algo_host.subprocess_runner import (
build_batch_main_command,
build_visualization_command,
describe_batch_returncode,
format_batch_failure,
)
from app.algo_host.transcode import (
VISUALIZATION_MAX_WIDTH,
batch_input_needs_normalize,
browser_transcode_tmp_path,
ensure_batch_pipeline_input_video,
is_browser_compatible_mp4,
transcode_visualization_for_browser,
)
from app.api import router as api_router
from app.dependencies import get_surgery_pipeline
from app.domain.consumption import SurgeryConsumptionStored
from app.routers import recording_demo
from app.schemas import SurgeryConsumptionDetail
from app.services.video_batch_cleanup import VISUALIZATION_FILENAME, visualization_output_path
from tests.reference_bundle_fixtures import complete_result_tsv_body, write_minimal_reference_bundle
def test_build_job_config_does_not_keep_work_dir(tmp_path: Path) -> None:
bundle = tmp_path / "bundle"
write_minimal_reference_bundle(bundle)
cfg = build_job_config(
bundle_dir=bundle,
video_path=tmp_path / "input.mp4",
output_path=tmp_path / "out.tsv",
work_dir=tmp_path / "work",
excel_path=tmp_path / "catalog.xlsx",
whitelist_path=tmp_path / "whitelist.json",
)
assert cfg["runtime"]["keep_work_dir"] is False
def test_latest_visualization_path_uses_vis_directory(tmp_path: Path) -> None:
root = tmp_path / "batch"
runner = BatchAlgorithmService(root_dir=root)
assert runner.latest_visualization_path("100001") is None
vis_path = visualization_output_path(root, "100001")
vis_path.parent.mkdir(parents=True)
vis_path.write_bytes(b"not-really-mp4")
assert runner.latest_visualization_path("100001") is None
def test_is_reference_result_complete_requires_footer_and_rows(tmp_path: Path) -> None:
complete = tmp_path / "complete.tsv"
complete.write_text(complete_result_tsv_body(), encoding="utf-8")
partial = tmp_path / "partial.tsv"
partial.write_text(
"rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\n"
"1\t0\t1\tP1\t耗材1\t1.0\n",
encoding="utf-8",
)
assert is_reference_result_complete(complete) is True
assert is_reference_result_complete(partial) is False
def test_parse_reference_doctor_info_name_and_id(tmp_path: Path) -> None:
tsv = tmp_path / "result.tsv"
tsv.write_text(
"rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\n"
"1\t0\t1\tP1\t耗材1\t1.0\n"
"医生信息:付玉峰 (id=24503, conf=0.9969)\n",
encoding="utf-8",
)
doctor = parse_reference_doctor_info(tsv)
assert doctor is not None
assert doctor.doctor_name == "付玉峰"
assert doctor.doctor_id == "24503"
assert doctor.display == "付玉峰 (24503)"
assert doctor_id_for_consumption_rows(doctor) == "付玉峰 (24503)"
rows = parse_reference_tsv(tsv, doctor=doctor)
assert len(rows) == 1
assert rows[0].doctor_id == "付玉峰 (24503)"
def test_parse_reference_doctor_info_failure_falls_back_to_vision(tmp_path: Path) -> None:
tsv = tmp_path / "result.tsv"
tsv.write_text(
"rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\n"
"1\t0\t1\tP1\t耗材1\t1.0\n"
"医生信息识别失败No module named 'mediapipe'\n",
encoding="utf-8",
)
doctor = parse_reference_doctor_info(tsv)
assert doctor is not None
assert doctor.doctor_name is None
assert doctor.doctor_id == "vision"
assert "识别失败" in doctor.display
def test_parse_reference_tsv_top1_and_empty_rows(tmp_path: Path) -> None:
tsv = tmp_path / "result.tsv"
tsv.write_text(
"\t".join(
[
"rank",
"start_sec",
"end_sec",
"product_id_top1",
"top1_name",
"top1_conf",
"product_id_top2",
"top2_name",
"top2_conf",
"product_id_top3",
"top3_name",
"top3_conf",
]
)
+ "\n"
+ "1\t1.5\t3.0\tA/B\t一次性耗材\t1.0000\tC\t备选\t0.2\t\t\t\n"
+ "2\tbad\t4.0\t\t\t\t\t\t\t\t\t\n",
encoding="utf-8",
)
rows = parse_reference_tsv(
tsv,
base_timestamp=datetime(2026, 5, 8, tzinfo=timezone.utc),
)
assert len(rows) == 1
assert rows[0].item_id == "A/B"
assert rows[0].item_name == "一次性耗材"
assert rows[0].qty == 1
assert rows[0].doctor_id == "vision"
assert rows[0].source == "video_batch"
assert rows[0].timestamp.isoformat() == "2026-05-08T00:00:01.500000+00:00"
def test_browser_transcode_tmp_path_keeps_mp4_extension() -> None:
out = Path("/cache/196707/output/result_vis.mp4")
tmp = browser_transcode_tmp_path(out)
assert tmp.name == "result_vis.part.mp4"
assert tmp.suffix == ".mp4"
assert str(tmp).endswith(".mp4")
@pytest.mark.skipif(shutil.which("ffmpeg") is None, reason="ffmpeg not installed")
def test_ensure_batch_pipeline_input_video_normalizes_non_h264(tmp_path: Path) -> None:
ffmpeg = shutil.which("ffmpeg")
assert ffmpeg is not None
source = tmp_path / "upload.mp4"
dest = tmp_path / "input.mp4"
proc = subprocess.run(
[
ffmpeg,
"-y",
"-f",
"lavfi",
"-i",
"testsrc=size=640x360:rate=10",
"-t",
"0.5",
"-c:v",
"mpeg4",
"-pix_fmt",
"yuv420p",
str(source),
],
check=False,
capture_output=True,
text=True,
)
assert proc.returncode == 0, proc.stderr
assert batch_input_needs_normalize(source)
ensure_batch_pipeline_input_video(source_path=source, dest_path=dest)
assert dest.is_file()
assert is_browser_compatible_mp4(dest)
assert not batch_input_needs_normalize(dest)
@pytest.mark.skipif(shutil.which("ffmpeg") is None, reason="ffmpeg not installed")
def test_transcode_visualization_for_browser_writes_h264_mp4(tmp_path: Path) -> None:
ffmpeg = shutil.which("ffmpeg")
assert ffmpeg is not None
source = tmp_path / "result_vis_source.mp4"
output = tmp_path / "result_vis.mp4"
proc = subprocess.run(
[
ffmpeg,
"-y",
"-f",
"lavfi",
"-i",
"testsrc=size=640x360:rate=10",
"-t",
"0.5",
"-c:v",
"mpeg4",
"-pix_fmt",
"yuv420p",
str(source),
],
check=False,
capture_output=True,
text=True,
)
assert proc.returncode == 0, proc.stderr
assert transcode_visualization_for_browser(source, output)
assert output.is_file()
assert output.stat().st_size > 0
assert not browser_transcode_tmp_path(output).exists()
assert is_browser_compatible_mp4(output)
def test_build_visualization_command_uses_hand_model_and_result_tsv(
tmp_path: Path,
) -> None:
bundle = tmp_path / "bundle"
write_minimal_reference_bundle(bundle)
(bundle / "weights").mkdir()
(bundle / "weights" / "hand_detect.pt").write_bytes(b"fake")
(bundle / "visualize_result_video.py").write_text("# fake\n", encoding="utf-8")
cfg_path = bundle / "configs" / "default_config.yaml"
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
cfg["weights"]["hand"] = "weights/hand_detect.pt"
cfg_path.write_text(yaml.safe_dump(cfg, allow_unicode=True, sort_keys=False), encoding="utf-8")
cmd = build_visualization_command(
bundle_dir=bundle,
video_path=tmp_path / "input.mp4",
result_path=tmp_path / "result.tsv",
output_video_path=tmp_path / "result_vis.mp4",
)
assert "--result-txt" in cmd
assert str(tmp_path / "result.tsv") in cmd
assert "--hand-model" in cmd
assert str(bundle / "weights" / "hand_detect.pt") in cmd
assert "--det-conf" not in cmd
assert "--max-width" in cmd
assert cmd[cmd.index("--max-width") + 1] == str(VISUALIZATION_MAX_WIDTH)
def test_build_batch_main_command_uses_5_15_main_py(tmp_path: Path) -> None:
cmd = build_batch_main_command(
bundle_dir=tmp_path / "algorithm_subprocesses" / "5.15",
config_path=tmp_path / "config.yaml",
)
assert cmd[:3] == ["uv", "run", "python"]
assert cmd[3:5] == ["-X", "faulthandler"]
assert cmd[5].endswith("algorithm_subprocesses/5.15/main.py")
assert cmd[6:] == ["--config", str(tmp_path / "config.yaml")]
def test_batch_service_respects_reference_bundle_relative_env(
tmp_path: Path,
monkeypatch,
) -> None:
bundle = tmp_path / "algorithm_subprocesses" / "custom"
write_minimal_reference_bundle(bundle)
video = tmp_path / "case.mp4"
video.write_bytes(b"same-video")
calls: list[list[str]] = []
class _Proc:
returncode = 0
stdout = ""
stderr = ""
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
calls.append(cmd)
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
output = Path(config["io"]["out"])
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(complete_result_tsv_body(), encoding="utf-8")
return _Proc()
monkeypatch.setenv("REFERENCE_BUNDLE_RELATIVE", "algorithm_subprocesses/custom")
monkeypatch.setattr(bundle_runtime, "REPO_ROOT", tmp_path)
monkeypatch.setattr("app.algo_host.subprocess_runner.subprocess.run", fake_run)
runner = BatchAlgorithmService(root_dir=tmp_path / "batch")
result = runner.run(
surgery_id="100001",
uploaded_video_path=video,
original_filename="case.mp4",
candidate_consumables=["耗材1"],
)
assert runner.bundle_dir == bundle.resolve()
assert calls[0][5] == str(bundle.resolve() / "main.py")
assert result.details[0].item_name == "耗材1"
def test_batch_service_reuses_cache_on_repeat_run(
tmp_path: Path,
monkeypatch,
) -> None:
bundle = tmp_path / "bundle"
bundle.mkdir()
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
(bundle / "code").mkdir()
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
cfg_dir = bundle / "configs"
cfg_dir.mkdir()
cfg_dir.joinpath("default_config.yaml").write_text(
yaml.safe_dump(
{
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
"weights": {
"actionformer": "weights/actionformer_epoch_045.pth.tar",
"hand": "weights/hand_detect.pt",
"goodbad": "weights/goodbad_frame.pt",
"haocai": "weights/haocai_classify.pt",
"tear": "weights/tear_classify.pt",
},
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
"device": {"type": "cuda", "half": False},
"phase1": {"af_min_score": 0.1, "af_min_seg_seconds": 5, "feat_batch_size": 1},
"phase2": {
"seek_margin_sec": 3.0,
"frame_stride": 1,
"det_conf": 0.5,
"pad_ratio": 0.3,
"imgsz_det": 640,
"merge_iou_gt": 0.0,
"merge_center_dist_max_px": None,
"merge_center_dist_max_frac_diag": None,
},
"classification": {
"imgsz_cls": 224,
"good_top1_conf_threshold": 0.9,
"good_top1_retry_threshold": 0.8,
"haocai_min_conf": 0.8,
"haocai_min_conf_retry": 0.7,
"empty_cache_every": 0,
},
"tear_merge": {
"merge_adjacent_tear": True,
"tear_merge_weights": None,
"tear_merge_class": "tearing",
"tear_merge_head_sec": 3.0,
"tear_merge_prob": 0.9,
"tear_merge_min_frames": 6,
"tear_merge_verbose": False,
"tear_merge_full_frame": False,
},
"output": {"legacy_12_col_only": True},
},
allow_unicode=True,
sort_keys=False,
),
encoding="utf-8",
)
video = tmp_path / "case.mp4"
video.write_bytes(b"same-video")
calls: list[list[str]] = []
class _Proc:
returncode = 0
stdout = ""
stderr = ""
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
calls.append(cmd)
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
output = Path(config["io"]["out"])
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(complete_result_tsv_body(), encoding="utf-8")
return _Proc()
monkeypatch.setattr("app.algo_host.subprocess_runner.subprocess.run", fake_run)
monkeypatch.setattr(
"app.algo_host.batch_service.BatchAlgorithmService._generate_visualization",
lambda *_a, **_k: None,
)
runner = BatchAlgorithmService(bundle_dir=bundle, root_dir=tmp_path / "batch")
first = runner.run(
surgery_id="100001",
uploaded_video_path=video,
original_filename="case.mp4",
candidate_consumables=["耗材1"],
)
second = runner.run(
surgery_id="100001",
uploaded_video_path=video,
original_filename="case.mp4",
candidate_consumables=["耗材1"],
)
assert len(calls) == 1
assert first.reused_cache is False
assert second.reused_cache is True
assert first.video_sha256 == second.video_sha256
assert first.details[0].item_name == "耗材1"
config_path = Path(calls[0][calls[0].index("--config") + 1])
config = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert Path(config["io"]["video"]).is_file()
assert Path(config["io"]["excel"]).is_file()
whitelist = json.loads(Path(config["io"]["whitelist_json"]).read_text(encoding="utf-8"))
assert whitelist == {"allowed_names": ["耗材1"]}
def test_batch_service_shares_cache_across_surgeries_for_same_video(
tmp_path: Path,
monkeypatch,
) -> None:
bundle = tmp_path / "bundle"
bundle.mkdir()
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
(bundle / "code").mkdir()
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
(bundle / "configs").mkdir()
(bundle / "configs" / "default_config.yaml").write_text(
yaml.safe_dump(
{
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
"weights": {},
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
"device": {},
"phase1": {},
"phase2": {},
"classification": {},
"tear_merge": {},
"output": {},
},
allow_unicode=True,
sort_keys=False,
),
encoding="utf-8",
)
video = tmp_path / "case.mp4"
video.write_bytes(b"same-video")
calls: list[list[str]] = []
class _Proc:
returncode = 0
stdout = ""
stderr = ""
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
calls.append(cmd)
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
output = Path(config["io"]["out"])
output.parent.mkdir(parents=True, exist_ok=True)
output.write_text(complete_result_tsv_body(), encoding="utf-8")
return _Proc()
monkeypatch.setattr("app.algo_host.subprocess_runner.subprocess.run", fake_run)
monkeypatch.setattr(
"app.algo_host.batch_service.BatchAlgorithmService._generate_visualization",
lambda *_a, **_k: None,
)
runner = BatchAlgorithmService(bundle_dir=bundle, root_dir=tmp_path / "batch")
first = runner.run(surgery_id="100001", uploaded_video_path=video, original_filename="case.mp4", candidate_consumables=[])
second = runner.run(surgery_id="100002", uploaded_video_path=video, original_filename="case.mp4", candidate_consumables=[])
assert len(calls) == 1
assert first.reused_cache is False
assert second.reused_cache is True
assert first.video_sha256 == second.video_sha256
assert first.output_path == second.output_path
assert "/cache/" in str(first.output_path)
assert "100001" not in str(first.output_path)
assert "100002" not in str(second.output_path)
def test_batch_failure_message_keeps_stdout_stderr_and_decodes_245(tmp_path: Path) -> None:
assert describe_batch_returncode(245) == "exit=245 (possibly propagated -11/SIGSEGV)"
msg = format_batch_failure(
245,
stdout="[run.py] 打包根: /bundle",
stderr="Fatal Python error: Segmentation fault",
work_dir=tmp_path / "work",
output_path=tmp_path / "result.tsv",
)
assert "possibly propagated -11/SIGSEGV" in msg
assert "work_dir=" in msg
assert "stdout:" in msg
assert "stderr:" in msg
def test_demo_video_batch_endpoint_writes_queryable_result(
tmp_path: Path,
monkeypatch,
) -> None:
monkeypatch.setattr(recording_demo.settings, "demo_orchestrator_enabled", True)
detail = SurgeryConsumptionStored(
item_id="P1",
item_name="耗材1",
qty=1,
doctor_id="vision",
timestamp=datetime(2026, 5, 8, tzinfo=timezone.utc),
source="video_batch",
)
root_dir = tmp_path / "video_batch"
vis_calls: list[str] = []
class _FakeRunner:
def __init__(self) -> None:
self.root_dir = root_dir
def run(self, **kwargs: Any) -> BatchRunResult:
assert kwargs["surgery_id"] == "100001"
assert kwargs["uploaded_video_path"].is_file()
assert kwargs["candidate_consumables"] == ["耗材1"]
assert kwargs.get("include_visualization") is False
cache_dir = root_dir / "cache" / ("a" * 64) / "c1"
cache_input = cache_dir / "input" / "input.mp4"
cache_input.parent.mkdir(parents=True)
cache_input.write_bytes(b"pipeline-input")
output_path = cache_dir / "output" / "result.tsv"
output_path.parent.mkdir(parents=True)
output_path.write_text(complete_result_tsv_body(), encoding="utf-8")
return BatchRunResult(
video_sha256="a" * 64,
candidate_cache_key="c1",
input_path=root_dir / "100001" / "input" / "saved.mp4",
work_dir=cache_dir / "work",
output_path=output_path,
details=[detail],
reused_cache=False,
)
def finalize_visualization(self, *, surgery_id: str) -> None:
vis_calls.append(surgery_id)
class _FakePipeline:
def __init__(self) -> None:
self.rows: dict[str, list[SurgeryConsumptionStored]] = {}
async def save_video_batch_result(
self,
surgery_id: str,
details: list[SurgeryConsumptionStored],
) -> None:
self.rows[surgery_id] = list(details)
async def get_consumption_details_for_client(
self,
surgery_id: str,
) -> list[SurgeryConsumptionDetail] | None:
rows = self.rows.get(surgery_id)
if rows is None:
return None
return [
SurgeryConsumptionDetail(
item_id=r.item_id,
item_name=r.item_name,
qty=r.qty,
doctor_id=r.doctor_id,
timestamp=r.timestamp,
)
for r in rows
]
monkeypatch.setattr(recording_demo, "BatchAlgorithmService", _FakeRunner)
pipeline = _FakePipeline()
app = FastAPI()
app.include_router(api_router)
app.include_router(recording_demo.router)
app.dependency_overrides[get_surgery_pipeline] = lambda: pipeline
client = TestClient(app)
res = client.post(
"/internal/demo/offline-batch",
data={"surgery_id": "100001", "candidate_consumables_json": '["耗材1"]'},
files={"video1": ("case.mp4", b"video-bytes", "video/mp4")},
)
assert res.status_code == 200, res.text
body = res.json()
assert body["status"] == "accepted"
assert body["visualization_url"] is None
assert vis_calls == []
assert not (root_dir / "cache" / ("a" * 64)).exists()
assert not (root_dir / "100001").exists()
got = client.get("/client/surgeries/100001/result")
assert got.status_code == 200, got.text
result_body = got.json()
assert result_body["details"][0]["item_id"] == "P1"
assert result_body["summary"][0]["total_quantity"] == 1
def test_demo_video_batch_endpoint_stages_vis_and_purges_cache_when_requested(
tmp_path: Path,
monkeypatch,
) -> None:
monkeypatch.setattr(recording_demo.settings, "demo_orchestrator_enabled", True)
detail = SurgeryConsumptionStored(
item_id="P1",
item_name="耗材1",
qty=1,
doctor_id="vision",
timestamp=datetime(2026, 5, 8, tzinfo=timezone.utc),
source="video_batch",
)
root_dir = tmp_path / "video_batch"
vis_calls: list[str] = []
class _FakeRunner:
def __init__(self) -> None:
self.root_dir = root_dir
def run(self, **kwargs: Any) -> BatchRunResult:
cache_dir = root_dir / "cache" / ("b" * 64) / "c1"
cache_input = cache_dir / "input" / "input.mp4"
cache_input.parent.mkdir(parents=True)
cache_input.write_bytes(b"pipeline-input")
output_path = cache_dir / "output" / "result.tsv"
output_path.parent.mkdir(parents=True)
output_path.write_text(complete_result_tsv_body(), encoding="utf-8")
return BatchRunResult(
video_sha256="b" * 64,
candidate_cache_key="c1",
input_path=root_dir / "100001" / "input" / "saved.mp4",
work_dir=cache_dir / "work",
output_path=output_path,
details=[detail],
reused_cache=False,
)
def finalize_visualization(self, *, surgery_id: str) -> None:
vis_calls.append(surgery_id)
class _FakePipeline:
async def save_video_batch_result(
self,
surgery_id: str,
details: list[SurgeryConsumptionStored],
) -> None:
return None
monkeypatch.setattr(recording_demo, "BatchAlgorithmService", _FakeRunner)
app = FastAPI()
app.include_router(recording_demo.router)
app.dependency_overrides[get_surgery_pipeline] = lambda: _FakePipeline()
client = TestClient(app)
res = client.post(
"/internal/demo/offline-batch",
data={
"surgery_id": "100001",
"candidate_consumables_json": '["耗材1"]',
"include_visualization": "true",
},
files={"video1": ("case.mp4", b"video-bytes", "video/mp4")},
)
assert res.status_code == 200, res.text
body = res.json()
assert body["visualization_url"] == "/internal/demo/offline-batch/100001/visualization"
assert vis_calls == ["100001"]
assert not (root_dir / "cache" / ("b" * 64)).exists()
pending_input = root_dir / "vis_pending" / "100001" / "input.mp4"
pending_tsv = root_dir / "vis_pending" / "100001" / "result.tsv"
assert pending_input.read_bytes() == b"pipeline-input"
assert "医生信息" in pending_tsv.read_text(encoding="utf-8")