219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
"""FastAPI → 算法子进程调用链单元测试。
|
||
|
||
覆盖两条生产路径:
|
||
1. ``POST /internal/demo/offline-batch`` → ``BatchAlgorithmService`` → ``subprocess.run``(reference bundle ``main.py``)
|
||
2. ``POST /client/surgeries/start`` → ``CameraSessionManager`` → ``RtspSegmentRecorder`` + ``SliceBatchProcessor``
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
import yaml
|
||
from fastapi import FastAPI
|
||
from fastapi.testclient import TestClient
|
||
from httpx import ASGITransport, AsyncClient
|
||
|
||
from app.api import router as api_router
|
||
from app.config import Settings
|
||
from app.dependencies import build_container, get_surgery_pipeline, get_voice_terminal_hub
|
||
from app.routers import recording_demo
|
||
from app.services.video.rtsp_segment_recorder import RtspSegmentRecorder
|
||
from app.algo_host.batch_service import BatchAlgorithmService
|
||
from app.algo_host.subprocess_runner import build_batch_main_command
|
||
from tests.reference_bundle_fixtures import complete_result_tsv_body, write_minimal_reference_bundle
|
||
|
||
|
||
def _fake_reference_subprocess_run(captured: list[dict[str, Any]]):
|
||
class _Proc:
|
||
returncode = 0
|
||
stdout = ""
|
||
stderr = ""
|
||
|
||
def _run(cmd: list[str], **kwargs: Any) -> _Proc:
|
||
if len(cmd) >= 3 and cmd[1] == "-c" and "import numpy" in cmd[2]:
|
||
return _Proc()
|
||
config_path = Path(cmd[cmd.index("--config") + 1])
|
||
config = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||
captured.append(
|
||
{
|
||
"cmd": list(cmd),
|
||
"kwargs": dict(kwargs),
|
||
"config": config,
|
||
}
|
||
)
|
||
output = Path(config["io"]["out"])
|
||
output.parent.mkdir(parents=True, exist_ok=True)
|
||
output.write_text(complete_result_tsv_body(), encoding="utf-8")
|
||
return _Proc()
|
||
|
||
return _run
|
||
|
||
|
||
@pytest.fixture
|
||
def reference_bundle(tmp_path: Path) -> Path:
|
||
bundle = tmp_path / "algorithm_subprocesses" / "5.15"
|
||
write_minimal_reference_bundle(bundle)
|
||
return bundle
|
||
|
||
|
||
def test_video_batch_endpoint_invokes_reference_bundle_subprocess(
|
||
tmp_path: Path,
|
||
monkeypatch: pytest.MonkeyPatch,
|
||
reference_bundle: Path,
|
||
sqlite_session_factory,
|
||
) -> None:
|
||
monkeypatch.setattr(recording_demo.settings, "demo_orchestrator_enabled", True)
|
||
monkeypatch.setattr(
|
||
"app.algo_host.bundle.resolve_reference_bundle_dir",
|
||
lambda _override=None: reference_bundle.resolve(),
|
||
)
|
||
monkeypatch.setattr(
|
||
recording_demo,
|
||
"BatchAlgorithmService",
|
||
lambda: BatchAlgorithmService(
|
||
bundle_dir=reference_bundle,
|
||
root_dir=tmp_path / "video_batch",
|
||
),
|
||
)
|
||
|
||
captured: list[dict[str, Any]] = []
|
||
monkeypatch.setattr(
|
||
"app.algo_host.subprocess_runner.subprocess.run",
|
||
_fake_reference_subprocess_run(captured),
|
||
)
|
||
|
||
container = build_container(recording_demo.settings, session_factory=sqlite_session_factory)
|
||
app = FastAPI()
|
||
app.include_router(recording_demo.router)
|
||
app.dependency_overrides[get_surgery_pipeline] = lambda: container.surgery_pipeline
|
||
|
||
client = TestClient(app)
|
||
res = client.post(
|
||
"/internal/demo/offline-batch",
|
||
data={
|
||
"surgery_id": "100001",
|
||
"candidate_consumables_json": '["耗材1"]',
|
||
"include_visualization": "false",
|
||
},
|
||
files={"video1": ("case.mp4", b"fake-mp4-bytes", "video/mp4")},
|
||
)
|
||
assert res.status_code == 200, res.text
|
||
|
||
assert len(captured) == 1, "expected exactly one reference bundle subprocess invocation"
|
||
call = captured[0]
|
||
cmd: list[str] = call["cmd"]
|
||
kwargs: dict[str, Any] = call["kwargs"]
|
||
|
||
expected = build_batch_main_command(
|
||
bundle_dir=reference_bundle,
|
||
config_path=Path(cmd[cmd.index("--config") + 1]),
|
||
)
|
||
assert cmd == expected
|
||
assert kwargs.get("cwd") == str(reference_bundle.resolve())
|
||
assert kwargs.get("env", {}).get("PYTHONFAULTHANDLER") == "1"
|
||
|
||
config = call["config"]
|
||
assert Path(config["io"]["video"]).name == "pipeline.mp4"
|
||
assert "/input/" in config["io"]["video"]
|
||
assert str(config["io"]["excel"]).endswith("商品信息表.xlsx")
|
||
assert str(config["io"]["whitelist_json"]).endswith("whitelist.json")
|
||
assert config["runtime"]["keep_work_dir"] is False
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_start_surgery_endpoint_starts_rtsp_segment_recorders(
|
||
monkeypatch: pytest.MonkeyPatch,
|
||
sqlite_session_factory,
|
||
tmp_path: Path,
|
||
) -> None:
|
||
async def _check_db_ok() -> None:
|
||
return None
|
||
|
||
monkeypatch.setattr("app.api.check_database", _check_db_ok)
|
||
monkeypatch.setattr(
|
||
"app.services.video.session_manager.LOGS_DIR",
|
||
tmp_path / "logs",
|
||
)
|
||
|
||
settings = Settings(
|
||
video_rtsp_url_template="rtsp://lab/{camera_id}/live",
|
||
video_open_timeout_sec=5.0,
|
||
rtsp_primary_camera_id="or-cam-03",
|
||
)
|
||
container = build_container(settings, session_factory=sqlite_session_factory)
|
||
|
||
async def _fake_resolve_rtsp(
|
||
self,
|
||
*,
|
||
camera_id: str,
|
||
kind: Any,
|
||
) -> tuple[str, int | None, bool]:
|
||
return f"rtsp://unittest/{camera_id}/live", None, False
|
||
|
||
monkeypatch.setattr(
|
||
"app.services.video.session_manager.CameraSessionManager._resolve_rtsp_url",
|
||
_fake_resolve_rtsp,
|
||
)
|
||
|
||
recorder_starts: list[tuple[str, str]] = []
|
||
|
||
async def fake_recorder_run(self: RtspSegmentRecorder, stop_event: Any) -> None:
|
||
recorder_starts.append((self._surgery_id, self._camera_id))
|
||
if self._ready_event is not None:
|
||
self._ready_event.set()
|
||
await stop_event.wait()
|
||
|
||
monkeypatch.setattr(RtspSegmentRecorder, "run", fake_recorder_run)
|
||
|
||
container.camera_session_manager._slice_batch.drain = AsyncMock()
|
||
|
||
async def _instant_sleep(_delay: float) -> None:
|
||
return None
|
||
|
||
monkeypatch.setattr("app.services.video.session_manager.asyncio.sleep", _instant_sleep)
|
||
monkeypatch.setattr("app.api.asyncio.sleep", _instant_sleep)
|
||
|
||
async def _noop_voice_assign(*args: Any, **kwargs: Any) -> None:
|
||
return None
|
||
|
||
monkeypatch.setattr(
|
||
"app.services.recording_live.assign_voice_terminal_after_recording_started",
|
||
_noop_voice_assign,
|
||
)
|
||
|
||
app = FastAPI()
|
||
app.include_router(api_router)
|
||
app.dependency_overrides[get_surgery_pipeline] = lambda: container.surgery_pipeline
|
||
app.dependency_overrides[get_voice_terminal_hub] = lambda: container.voice_terminal_hub
|
||
|
||
surgery_id = "123456"
|
||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||
res = await client.post(
|
||
"/client/surgeries/start",
|
||
json={
|
||
"surgery_id": surgery_id,
|
||
"camera_ids": ["cam1", "or-cam-03"],
|
||
"candidate_consumables": ["纱布"],
|
||
},
|
||
)
|
||
assert res.status_code == 200, res.text
|
||
assert res.json()["status"] == "accepted"
|
||
|
||
end = await client.post("/client/surgeries/end", json={"surgery_id": surgery_id})
|
||
assert end.status_code == 200, end.text
|
||
|
||
assert len(recorder_starts) == 1
|
||
assert recorder_starts[0] == (surgery_id, "or-cam-03")
|
||
|
||
whitelist_path = tmp_path / "logs" / f"surgery_{surgery_id}_whitelist.json"
|
||
assert whitelist_path.is_file()
|
||
whitelist = json.loads(whitelist_path.read_text(encoding="utf-8"))
|
||
assert whitelist["candidate_consumables"] == ["纱布"]
|
||
|
||
container.camera_session_manager._slice_batch.drain.assert_awaited()
|