691 lines
24 KiB
Python
691 lines
24 KiB
Python
"""Tests for offline batch orchestration (app.algo_host)."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import shutil
|
||
import subprocess
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
import pytest
|
||
import yaml
|
||
from fastapi import FastAPI
|
||
from fastapi.testclient import TestClient
|
||
|
||
from app.algo_host import bundle as bundle_runtime
|
||
from app.algo_host.batch_service import BatchAlgorithmService, BatchRunResult
|
||
from app.algo_host.job_workspace import build_job_config
|
||
from app.algo_host.result_adapter import (
|
||
doctor_id_for_consumption_rows,
|
||
is_reference_result_complete,
|
||
parse_reference_doctor_info,
|
||
parse_reference_tsv,
|
||
)
|
||
from app.algo_host.subprocess_runner import (
|
||
build_batch_main_command,
|
||
build_visualization_command,
|
||
describe_batch_returncode,
|
||
format_batch_failure,
|
||
)
|
||
from app.algo_host.transcode import (
|
||
VISUALIZATION_MAX_WIDTH,
|
||
batch_input_needs_normalize,
|
||
browser_transcode_tmp_path,
|
||
ensure_batch_pipeline_input_video,
|
||
is_browser_compatible_mp4,
|
||
transcode_visualization_for_browser,
|
||
)
|
||
from app.api import router as api_router
|
||
from app.dependencies import get_surgery_pipeline
|
||
from app.domain.consumption import SurgeryConsumptionStored
|
||
from app.routers import recording_demo
|
||
from app.schemas import SurgeryConsumptionDetail
|
||
from app.services.video_batch_cleanup import VISUALIZATION_FILENAME, visualization_output_path
|
||
from tests.reference_bundle_fixtures import complete_result_tsv_body, write_minimal_reference_bundle
|
||
|
||
|
||
def test_build_job_config_does_not_keep_work_dir(tmp_path: Path) -> None:
|
||
bundle = tmp_path / "bundle"
|
||
write_minimal_reference_bundle(bundle)
|
||
cfg = build_job_config(
|
||
bundle_dir=bundle,
|
||
video_path=tmp_path / "input.mp4",
|
||
output_path=tmp_path / "out.tsv",
|
||
work_dir=tmp_path / "work",
|
||
excel_path=tmp_path / "catalog.xlsx",
|
||
whitelist_path=tmp_path / "whitelist.json",
|
||
)
|
||
assert cfg["runtime"]["keep_work_dir"] is False
|
||
|
||
|
||
def test_latest_visualization_path_uses_vis_directory(tmp_path: Path) -> None:
|
||
root = tmp_path / "batch"
|
||
runner = BatchAlgorithmService(root_dir=root)
|
||
assert runner.latest_visualization_path("100001") is None
|
||
|
||
vis_path = visualization_output_path(root, "100001")
|
||
vis_path.parent.mkdir(parents=True)
|
||
vis_path.write_bytes(b"not-really-mp4")
|
||
assert runner.latest_visualization_path("100001") is None
|
||
|
||
|
||
def test_is_reference_result_complete_requires_footer_and_rows(tmp_path: Path) -> None:
|
||
complete = tmp_path / "complete.tsv"
|
||
complete.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
partial = tmp_path / "partial.tsv"
|
||
partial.write_text(
|
||
"rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\n"
|
||
"1\t0\t1\tP1\t耗材1\t1.0\n",
|
||
encoding="utf-8",
|
||
)
|
||
|
||
assert is_reference_result_complete(complete) is True
|
||
assert is_reference_result_complete(partial) is False
|
||
|
||
|
||
def test_parse_reference_doctor_info_name_and_id(tmp_path: Path) -> None:
|
||
tsv = tmp_path / "result.tsv"
|
||
tsv.write_text(
|
||
"rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\n"
|
||
"1\t0\t1\tP1\t耗材1\t1.0\n"
|
||
"医生信息:付玉峰 (id=24503, conf=0.9969)\n",
|
||
encoding="utf-8",
|
||
)
|
||
|
||
doctor = parse_reference_doctor_info(tsv)
|
||
assert doctor is not None
|
||
assert doctor.doctor_name == "付玉峰"
|
||
assert doctor.doctor_id == "24503"
|
||
assert doctor.display == "付玉峰 (24503)"
|
||
assert doctor_id_for_consumption_rows(doctor) == "付玉峰 (24503)"
|
||
|
||
rows = parse_reference_tsv(tsv, doctor=doctor)
|
||
assert len(rows) == 1
|
||
assert rows[0].doctor_id == "付玉峰 (24503)"
|
||
|
||
|
||
def test_parse_reference_doctor_info_failure_falls_back_to_vision(tmp_path: Path) -> None:
|
||
tsv = tmp_path / "result.tsv"
|
||
tsv.write_text(
|
||
"rank\tstart_sec\tend_sec\tproduct_id_top1\ttop1_name\ttop1_conf\n"
|
||
"1\t0\t1\tP1\t耗材1\t1.0\n"
|
||
"医生信息:识别失败(No module named 'mediapipe')\n",
|
||
encoding="utf-8",
|
||
)
|
||
|
||
doctor = parse_reference_doctor_info(tsv)
|
||
assert doctor is not None
|
||
assert doctor.doctor_name is None
|
||
assert doctor.doctor_id == "vision"
|
||
assert "识别失败" in doctor.display
|
||
|
||
|
||
def test_parse_reference_tsv_top1_and_empty_rows(tmp_path: Path) -> None:
|
||
tsv = tmp_path / "result.tsv"
|
||
tsv.write_text(
|
||
"\t".join(
|
||
[
|
||
"rank",
|
||
"start_sec",
|
||
"end_sec",
|
||
"product_id_top1",
|
||
"top1_name",
|
||
"top1_conf",
|
||
"product_id_top2",
|
||
"top2_name",
|
||
"top2_conf",
|
||
"product_id_top3",
|
||
"top3_name",
|
||
"top3_conf",
|
||
]
|
||
)
|
||
+ "\n"
|
||
+ "1\t1.5\t3.0\tA/B\t一次性耗材\t1.0000\tC\t备选\t0.2\t\t\t\n"
|
||
+ "2\tbad\t4.0\t\t\t\t\t\t\t\t\t\n",
|
||
encoding="utf-8",
|
||
)
|
||
|
||
rows = parse_reference_tsv(
|
||
tsv,
|
||
base_timestamp=datetime(2026, 5, 8, tzinfo=timezone.utc),
|
||
)
|
||
|
||
assert len(rows) == 1
|
||
assert rows[0].item_id == "A/B"
|
||
assert rows[0].item_name == "一次性耗材"
|
||
assert rows[0].qty == 1
|
||
assert rows[0].doctor_id == "vision"
|
||
assert rows[0].source == "video_batch"
|
||
assert rows[0].timestamp.isoformat() == "2026-05-08T00:00:01.500000+00:00"
|
||
|
||
|
||
def test_browser_transcode_tmp_path_keeps_mp4_extension() -> None:
|
||
out = Path("/cache/196707/output/result_vis.mp4")
|
||
tmp = browser_transcode_tmp_path(out)
|
||
assert tmp.name == "result_vis.part.mp4"
|
||
assert tmp.suffix == ".mp4"
|
||
assert str(tmp).endswith(".mp4")
|
||
|
||
|
||
@pytest.mark.skipif(shutil.which("ffmpeg") is None, reason="ffmpeg not installed")
|
||
def test_ensure_batch_pipeline_input_video_normalizes_non_h264(tmp_path: Path) -> None:
|
||
ffmpeg = shutil.which("ffmpeg")
|
||
assert ffmpeg is not None
|
||
source = tmp_path / "upload.mp4"
|
||
dest = tmp_path / "input.mp4"
|
||
proc = subprocess.run(
|
||
[
|
||
ffmpeg,
|
||
"-y",
|
||
"-f",
|
||
"lavfi",
|
||
"-i",
|
||
"testsrc=size=640x360:rate=10",
|
||
"-t",
|
||
"0.5",
|
||
"-c:v",
|
||
"mpeg4",
|
||
"-pix_fmt",
|
||
"yuv420p",
|
||
str(source),
|
||
],
|
||
check=False,
|
||
capture_output=True,
|
||
text=True,
|
||
)
|
||
assert proc.returncode == 0, proc.stderr
|
||
assert batch_input_needs_normalize(source)
|
||
ensure_batch_pipeline_input_video(source_path=source, dest_path=dest)
|
||
assert dest.is_file()
|
||
assert is_browser_compatible_mp4(dest)
|
||
assert not batch_input_needs_normalize(dest)
|
||
|
||
|
||
@pytest.mark.skipif(shutil.which("ffmpeg") is None, reason="ffmpeg not installed")
|
||
def test_transcode_visualization_for_browser_writes_h264_mp4(tmp_path: Path) -> None:
|
||
ffmpeg = shutil.which("ffmpeg")
|
||
assert ffmpeg is not None
|
||
source = tmp_path / "result_vis_source.mp4"
|
||
output = tmp_path / "result_vis.mp4"
|
||
proc = subprocess.run(
|
||
[
|
||
ffmpeg,
|
||
"-y",
|
||
"-f",
|
||
"lavfi",
|
||
"-i",
|
||
"testsrc=size=640x360:rate=10",
|
||
"-t",
|
||
"0.5",
|
||
"-c:v",
|
||
"mpeg4",
|
||
"-pix_fmt",
|
||
"yuv420p",
|
||
str(source),
|
||
],
|
||
check=False,
|
||
capture_output=True,
|
||
text=True,
|
||
)
|
||
assert proc.returncode == 0, proc.stderr
|
||
assert transcode_visualization_for_browser(source, output)
|
||
assert output.is_file()
|
||
assert output.stat().st_size > 0
|
||
assert not browser_transcode_tmp_path(output).exists()
|
||
assert is_browser_compatible_mp4(output)
|
||
|
||
|
||
def test_build_visualization_command_uses_hand_model_and_result_tsv(
|
||
tmp_path: Path,
|
||
) -> None:
|
||
bundle = tmp_path / "bundle"
|
||
write_minimal_reference_bundle(bundle)
|
||
(bundle / "weights").mkdir()
|
||
(bundle / "weights" / "hand_detect.pt").write_bytes(b"fake")
|
||
(bundle / "visualize_result_video.py").write_text("# fake\n", encoding="utf-8")
|
||
cfg_path = bundle / "configs" / "default_config.yaml"
|
||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
||
cfg["weights"]["hand"] = "weights/hand_detect.pt"
|
||
cfg_path.write_text(yaml.safe_dump(cfg, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||
|
||
cmd = build_visualization_command(
|
||
bundle_dir=bundle,
|
||
video_path=tmp_path / "input.mp4",
|
||
result_path=tmp_path / "result.tsv",
|
||
output_video_path=tmp_path / "result_vis.mp4",
|
||
)
|
||
|
||
assert "--result-txt" in cmd
|
||
assert str(tmp_path / "result.tsv") in cmd
|
||
assert "--hand-model" in cmd
|
||
assert str(bundle / "weights" / "hand_detect.pt") in cmd
|
||
assert "--det-conf" not in cmd
|
||
assert "--max-width" in cmd
|
||
assert cmd[cmd.index("--max-width") + 1] == str(VISUALIZATION_MAX_WIDTH)
|
||
|
||
|
||
def test_build_batch_main_command_uses_5_15_main_py(tmp_path: Path) -> None:
|
||
cmd = build_batch_main_command(
|
||
bundle_dir=tmp_path / "algorithm_subprocesses" / "5.15",
|
||
config_path=tmp_path / "config.yaml",
|
||
)
|
||
|
||
assert cmd[:3] == ["uv", "run", "python"]
|
||
assert cmd[3:5] == ["-X", "faulthandler"]
|
||
assert cmd[5].endswith("algorithm_subprocesses/5.15/main.py")
|
||
assert cmd[6:] == ["--config", str(tmp_path / "config.yaml")]
|
||
|
||
|
||
def test_batch_service_respects_reference_bundle_relative_env(
|
||
tmp_path: Path,
|
||
monkeypatch,
|
||
) -> None:
|
||
bundle = tmp_path / "algorithm_subprocesses" / "custom"
|
||
write_minimal_reference_bundle(bundle)
|
||
video = tmp_path / "case.mp4"
|
||
video.write_bytes(b"same-video")
|
||
calls: list[list[str]] = []
|
||
|
||
class _Proc:
|
||
returncode = 0
|
||
stdout = ""
|
||
stderr = ""
|
||
|
||
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
|
||
calls.append(cmd)
|
||
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
|
||
output = Path(config["io"]["out"])
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
output.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
return _Proc()
|
||
|
||
monkeypatch.setenv("REFERENCE_BUNDLE_RELATIVE", "algorithm_subprocesses/custom")
|
||
monkeypatch.setattr(bundle_runtime, "REPO_ROOT", tmp_path)
|
||
monkeypatch.setattr("app.algo_host.subprocess_runner.subprocess.run", fake_run)
|
||
|
||
runner = BatchAlgorithmService(root_dir=tmp_path / "batch")
|
||
result = runner.run(
|
||
surgery_id="100001",
|
||
uploaded_video_path=video,
|
||
original_filename="case.mp4",
|
||
candidate_consumables=["耗材1"],
|
||
)
|
||
|
||
assert runner.bundle_dir == bundle.resolve()
|
||
assert calls[0][5] == str(bundle.resolve() / "main.py")
|
||
assert result.details[0].item_name == "耗材1"
|
||
|
||
|
||
def test_batch_service_reuses_cache_on_repeat_run(
|
||
tmp_path: Path,
|
||
monkeypatch,
|
||
) -> None:
|
||
bundle = tmp_path / "bundle"
|
||
bundle.mkdir()
|
||
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
|
||
(bundle / "code").mkdir()
|
||
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
|
||
cfg_dir = bundle / "configs"
|
||
cfg_dir.mkdir()
|
||
cfg_dir.joinpath("default_config.yaml").write_text(
|
||
yaml.safe_dump(
|
||
{
|
||
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
|
||
"weights": {
|
||
"actionformer": "weights/actionformer_epoch_045.pth.tar",
|
||
"hand": "weights/hand_detect.pt",
|
||
"goodbad": "weights/goodbad_frame.pt",
|
||
"haocai": "weights/haocai_classify.pt",
|
||
"tear": "weights/tear_classify.pt",
|
||
},
|
||
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
|
||
"device": {"type": "cuda", "half": False},
|
||
"phase1": {"af_min_score": 0.1, "af_min_seg_seconds": 5, "feat_batch_size": 1},
|
||
"phase2": {
|
||
"seek_margin_sec": 3.0,
|
||
"frame_stride": 1,
|
||
"det_conf": 0.5,
|
||
"pad_ratio": 0.3,
|
||
"imgsz_det": 640,
|
||
"merge_iou_gt": 0.0,
|
||
"merge_center_dist_max_px": None,
|
||
"merge_center_dist_max_frac_diag": None,
|
||
},
|
||
"classification": {
|
||
"imgsz_cls": 224,
|
||
"good_top1_conf_threshold": 0.9,
|
||
"good_top1_retry_threshold": 0.8,
|
||
"haocai_min_conf": 0.8,
|
||
"haocai_min_conf_retry": 0.7,
|
||
"empty_cache_every": 0,
|
||
},
|
||
"tear_merge": {
|
||
"merge_adjacent_tear": True,
|
||
"tear_merge_weights": None,
|
||
"tear_merge_class": "tearing",
|
||
"tear_merge_head_sec": 3.0,
|
||
"tear_merge_prob": 0.9,
|
||
"tear_merge_min_frames": 6,
|
||
"tear_merge_verbose": False,
|
||
"tear_merge_full_frame": False,
|
||
},
|
||
"output": {"legacy_12_col_only": True},
|
||
},
|
||
allow_unicode=True,
|
||
sort_keys=False,
|
||
),
|
||
encoding="utf-8",
|
||
)
|
||
video = tmp_path / "case.mp4"
|
||
video.write_bytes(b"same-video")
|
||
calls: list[list[str]] = []
|
||
|
||
class _Proc:
|
||
returncode = 0
|
||
stdout = ""
|
||
stderr = ""
|
||
|
||
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
|
||
calls.append(cmd)
|
||
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
|
||
output = Path(config["io"]["out"])
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
output.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
return _Proc()
|
||
|
||
monkeypatch.setattr("app.algo_host.subprocess_runner.subprocess.run", fake_run)
|
||
monkeypatch.setattr(
|
||
"app.algo_host.batch_service.BatchAlgorithmService._generate_visualization",
|
||
lambda *_a, **_k: None,
|
||
)
|
||
|
||
runner = BatchAlgorithmService(bundle_dir=bundle, root_dir=tmp_path / "batch")
|
||
first = runner.run(
|
||
surgery_id="100001",
|
||
uploaded_video_path=video,
|
||
original_filename="case.mp4",
|
||
candidate_consumables=["耗材1"],
|
||
)
|
||
second = runner.run(
|
||
surgery_id="100001",
|
||
uploaded_video_path=video,
|
||
original_filename="case.mp4",
|
||
candidate_consumables=["耗材1"],
|
||
)
|
||
|
||
assert len(calls) == 1
|
||
assert first.reused_cache is False
|
||
assert second.reused_cache is True
|
||
assert first.video_sha256 == second.video_sha256
|
||
assert first.details[0].item_name == "耗材1"
|
||
config_path = Path(calls[0][calls[0].index("--config") + 1])
|
||
config = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||
assert Path(config["io"]["video"]).is_file()
|
||
assert Path(config["io"]["excel"]).is_file()
|
||
whitelist = json.loads(Path(config["io"]["whitelist_json"]).read_text(encoding="utf-8"))
|
||
assert whitelist == {"allowed_names": ["耗材1"]}
|
||
|
||
|
||
def test_batch_service_shares_cache_across_surgeries_for_same_video(
|
||
tmp_path: Path,
|
||
monkeypatch,
|
||
) -> None:
|
||
bundle = tmp_path / "bundle"
|
||
bundle.mkdir()
|
||
(bundle / "main.py").write_text("# fake\n", encoding="utf-8")
|
||
(bundle / "code").mkdir()
|
||
(bundle / "code" / "repo_root.py").write_text("# fake\n", encoding="utf-8")
|
||
(bundle / "configs").mkdir()
|
||
(bundle / "configs" / "default_config.yaml").write_text(
|
||
yaml.safe_dump(
|
||
{
|
||
"io": {"video": "", "excel": "", "out": "", "whitelist_json": None},
|
||
"weights": {},
|
||
"runtime": {"work_dir": None, "keep_work_dir": False, "python": None},
|
||
"device": {},
|
||
"phase1": {},
|
||
"phase2": {},
|
||
"classification": {},
|
||
"tear_merge": {},
|
||
"output": {},
|
||
},
|
||
allow_unicode=True,
|
||
sort_keys=False,
|
||
),
|
||
encoding="utf-8",
|
||
)
|
||
video = tmp_path / "case.mp4"
|
||
video.write_bytes(b"same-video")
|
||
calls: list[list[str]] = []
|
||
|
||
class _Proc:
|
||
returncode = 0
|
||
stdout = ""
|
||
stderr = ""
|
||
|
||
def fake_run(cmd: list[str], **_kwargs: Any) -> _Proc:
|
||
calls.append(cmd)
|
||
config = yaml.safe_load(Path(cmd[cmd.index("--config") + 1]).read_text(encoding="utf-8"))
|
||
output = Path(config["io"]["out"])
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
output.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
return _Proc()
|
||
|
||
monkeypatch.setattr("app.algo_host.subprocess_runner.subprocess.run", fake_run)
|
||
monkeypatch.setattr(
|
||
"app.algo_host.batch_service.BatchAlgorithmService._generate_visualization",
|
||
lambda *_a, **_k: None,
|
||
)
|
||
|
||
runner = BatchAlgorithmService(bundle_dir=bundle, root_dir=tmp_path / "batch")
|
||
first = runner.run(surgery_id="100001", uploaded_video_path=video, original_filename="case.mp4", candidate_consumables=[])
|
||
second = runner.run(surgery_id="100002", uploaded_video_path=video, original_filename="case.mp4", candidate_consumables=[])
|
||
|
||
assert len(calls) == 1
|
||
assert first.reused_cache is False
|
||
assert second.reused_cache is True
|
||
assert first.video_sha256 == second.video_sha256
|
||
assert first.output_path == second.output_path
|
||
assert "/cache/" in str(first.output_path)
|
||
assert "100001" not in str(first.output_path)
|
||
assert "100002" not in str(second.output_path)
|
||
|
||
|
||
def test_batch_failure_message_keeps_stdout_stderr_and_decodes_245(tmp_path: Path) -> None:
|
||
assert describe_batch_returncode(245) == "exit=245 (possibly propagated -11/SIGSEGV)"
|
||
|
||
msg = format_batch_failure(
|
||
245,
|
||
stdout="[run.py] 打包根: /bundle",
|
||
stderr="Fatal Python error: Segmentation fault",
|
||
work_dir=tmp_path / "work",
|
||
output_path=tmp_path / "result.tsv",
|
||
)
|
||
|
||
assert "possibly propagated -11/SIGSEGV" in msg
|
||
assert "work_dir=" in msg
|
||
assert "stdout:" in msg
|
||
assert "stderr:" in msg
|
||
|
||
|
||
def test_demo_video_batch_endpoint_writes_queryable_result(
|
||
tmp_path: Path,
|
||
monkeypatch,
|
||
) -> None:
|
||
monkeypatch.setattr(recording_demo.settings, "demo_orchestrator_enabled", True)
|
||
|
||
detail = SurgeryConsumptionStored(
|
||
item_id="P1",
|
||
item_name="耗材1",
|
||
qty=1,
|
||
doctor_id="vision",
|
||
timestamp=datetime(2026, 5, 8, tzinfo=timezone.utc),
|
||
source="video_batch",
|
||
)
|
||
root_dir = tmp_path / "video_batch"
|
||
vis_calls: list[str] = []
|
||
|
||
class _FakeRunner:
|
||
def __init__(self) -> None:
|
||
self.root_dir = root_dir
|
||
|
||
def run(self, **kwargs: Any) -> BatchRunResult:
|
||
assert kwargs["surgery_id"] == "100001"
|
||
assert kwargs["uploaded_video_path"].is_file()
|
||
assert kwargs["candidate_consumables"] == ["耗材1"]
|
||
assert kwargs.get("include_visualization") is False
|
||
cache_dir = root_dir / "cache" / ("a" * 64) / "c1"
|
||
cache_input = cache_dir / "input" / "input.mp4"
|
||
cache_input.parent.mkdir(parents=True)
|
||
cache_input.write_bytes(b"pipeline-input")
|
||
output_path = cache_dir / "output" / "result.tsv"
|
||
output_path.parent.mkdir(parents=True)
|
||
output_path.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
return BatchRunResult(
|
||
video_sha256="a" * 64,
|
||
candidate_cache_key="c1",
|
||
input_path=root_dir / "100001" / "input" / "saved.mp4",
|
||
work_dir=cache_dir / "work",
|
||
output_path=output_path,
|
||
details=[detail],
|
||
reused_cache=False,
|
||
)
|
||
|
||
def finalize_visualization(self, *, surgery_id: str) -> None:
|
||
vis_calls.append(surgery_id)
|
||
|
||
class _FakePipeline:
|
||
def __init__(self) -> None:
|
||
self.rows: dict[str, list[SurgeryConsumptionStored]] = {}
|
||
|
||
async def save_video_batch_result(
|
||
self,
|
||
surgery_id: str,
|
||
details: list[SurgeryConsumptionStored],
|
||
) -> None:
|
||
self.rows[surgery_id] = list(details)
|
||
|
||
async def get_consumption_details_for_client(
|
||
self,
|
||
surgery_id: str,
|
||
) -> list[SurgeryConsumptionDetail] | None:
|
||
rows = self.rows.get(surgery_id)
|
||
if rows is None:
|
||
return None
|
||
return [
|
||
SurgeryConsumptionDetail(
|
||
item_id=r.item_id,
|
||
item_name=r.item_name,
|
||
qty=r.qty,
|
||
doctor_id=r.doctor_id,
|
||
timestamp=r.timestamp,
|
||
)
|
||
for r in rows
|
||
]
|
||
|
||
monkeypatch.setattr(recording_demo, "BatchAlgorithmService", _FakeRunner)
|
||
pipeline = _FakePipeline()
|
||
|
||
app = FastAPI()
|
||
app.include_router(api_router)
|
||
app.include_router(recording_demo.router)
|
||
app.dependency_overrides[get_surgery_pipeline] = lambda: pipeline
|
||
|
||
client = TestClient(app)
|
||
res = client.post(
|
||
"/internal/demo/offline-batch",
|
||
data={"surgery_id": "100001", "candidate_consumables_json": '["耗材1"]'},
|
||
files={"video1": ("case.mp4", b"video-bytes", "video/mp4")},
|
||
)
|
||
assert res.status_code == 200, res.text
|
||
body = res.json()
|
||
assert body["status"] == "accepted"
|
||
assert body["visualization_url"] is None
|
||
assert vis_calls == []
|
||
assert not (root_dir / "cache" / ("a" * 64)).exists()
|
||
assert not (root_dir / "100001").exists()
|
||
|
||
got = client.get("/client/surgeries/100001/result")
|
||
assert got.status_code == 200, got.text
|
||
result_body = got.json()
|
||
assert result_body["details"][0]["item_id"] == "P1"
|
||
assert result_body["summary"][0]["total_quantity"] == 1
|
||
|
||
|
||
def test_demo_video_batch_endpoint_stages_vis_and_purges_cache_when_requested(
|
||
tmp_path: Path,
|
||
monkeypatch,
|
||
) -> None:
|
||
monkeypatch.setattr(recording_demo.settings, "demo_orchestrator_enabled", True)
|
||
|
||
detail = SurgeryConsumptionStored(
|
||
item_id="P1",
|
||
item_name="耗材1",
|
||
qty=1,
|
||
doctor_id="vision",
|
||
timestamp=datetime(2026, 5, 8, tzinfo=timezone.utc),
|
||
source="video_batch",
|
||
)
|
||
root_dir = tmp_path / "video_batch"
|
||
vis_calls: list[str] = []
|
||
|
||
class _FakeRunner:
|
||
def __init__(self) -> None:
|
||
self.root_dir = root_dir
|
||
|
||
def run(self, **kwargs: Any) -> BatchRunResult:
|
||
cache_dir = root_dir / "cache" / ("b" * 64) / "c1"
|
||
cache_input = cache_dir / "input" / "input.mp4"
|
||
cache_input.parent.mkdir(parents=True)
|
||
cache_input.write_bytes(b"pipeline-input")
|
||
output_path = cache_dir / "output" / "result.tsv"
|
||
output_path.parent.mkdir(parents=True)
|
||
output_path.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
return BatchRunResult(
|
||
video_sha256="b" * 64,
|
||
candidate_cache_key="c1",
|
||
input_path=root_dir / "100001" / "input" / "saved.mp4",
|
||
work_dir=cache_dir / "work",
|
||
output_path=output_path,
|
||
details=[detail],
|
||
reused_cache=False,
|
||
)
|
||
|
||
def finalize_visualization(self, *, surgery_id: str) -> None:
|
||
vis_calls.append(surgery_id)
|
||
|
||
class _FakePipeline:
|
||
async def save_video_batch_result(
|
||
self,
|
||
surgery_id: str,
|
||
details: list[SurgeryConsumptionStored],
|
||
) -> None:
|
||
return None
|
||
|
||
monkeypatch.setattr(recording_demo, "BatchAlgorithmService", _FakeRunner)
|
||
app = FastAPI()
|
||
app.include_router(recording_demo.router)
|
||
app.dependency_overrides[get_surgery_pipeline] = lambda: _FakePipeline()
|
||
|
||
client = TestClient(app)
|
||
res = client.post(
|
||
"/internal/demo/offline-batch",
|
||
data={
|
||
"surgery_id": "100001",
|
||
"candidate_consumables_json": '["耗材1"]',
|
||
"include_visualization": "true",
|
||
},
|
||
files={"video1": ("case.mp4", b"video-bytes", "video/mp4")},
|
||
)
|
||
assert res.status_code == 200, res.text
|
||
body = res.json()
|
||
assert body["visualization_url"] == "/internal/demo/offline-batch/100001/visualization"
|
||
assert vis_calls == ["100001"]
|
||
assert not (root_dir / "cache" / ("b" * 64)).exists()
|
||
pending_input = root_dir / "vis_pending" / "100001" / "input.mp4"
|
||
pending_tsv = root_dir / "vis_pending" / "100001" / "result.tsv"
|
||
assert pending_input.read_bytes() == b"pipeline-input"
|
||
assert "医生信息" in pending_tsv.read_text(encoding="utf-8")
|