Files
operating-room-monitor-server/backend/tests/test_actionformer_segment_consumption.py

84 lines
2.7 KiB
Python
Raw Permalink Normal View History

2026-05-21 15:48:03 +08:00
"""actionformer_gated 公共助手 + 段稳定 dedupe + ``append_confirmed_detail`` 行为。
替换原 ``tests/test_tear_gated_segment_consumption.py``保留与 IPC / 业务层强相关
的测试移除与已废弃 tear_gated 内部算法merge_tear_segments / fishaction
postprocess / haocai_mean_topk耦合的部分
"""
from __future__ import annotations
import numpy as np
import pytest
from app.algorithm_runner.actionformer_gated.runner import (
ActionFormerSegmentRecord,
actionformer_segment_stable_dedupe_key,
)
from app.algorithm_runner.actionformer_gated.segment_helpers import mask_probs_whitelist
from app.algorithm_runner.actionformer_gated.whitelist_indices import (
candidate_class_indices_for_haocai_model,
)
def test_mask_probs_whitelist_keeps_allowed_only() -> None:
p = np.array([0.4, 0.6], dtype=np.float64)
out = mask_probs_whitelist(p, frozenset({1}), 2)
assert out is not None
assert abs(float(out[0])) < 1e-9
assert abs(float(out[1]) - 1.0) < 1e-9
def test_mask_probs_whitelist_empty_returns_none() -> None:
p = np.array([1.0, 0.0], dtype=np.float64)
assert mask_probs_whitelist(p, frozenset({5}), 2) is None
def test_candidate_class_indices_for_haocai_model() -> None:
haoc = {0: "手套A", 1: "纱布B"}
assert candidate_class_indices_for_haocai_model(["纱布B"], haoc) == frozenset({1})
assert candidate_class_indices_for_haocai_model(["未知"], haoc) is None
def test_actionformer_segment_stable_dedupe_key() -> None:
rec = ActionFormerSegmentRecord(
segment_index=9,
start_sec=1.006,
end_sec=12.004,
mid_stream_sec=6.0,
item_id="id",
item_name="手套",
top1_conf=0.88,
top2_name="",
top2_conf=0.0,
top3_name="",
top3_conf=0.0,
majority_ref="",
)
assert actionformer_segment_stable_dedupe_key(rec) == "1.01:12.0:手套"
@pytest.mark.asyncio
async def test_append_confirmed_detail_seg_cooldown_keys() -> None:
"""同 item_id 多段在独立 cooldown_key 下应都能写入。"""
from app.services.video.session_registry import SurgerySessionRegistry, SurgerySessionState
reg = SurgerySessionRegistry()
st = SurgerySessionState(candidate_consumables=["X"], name_to_code={"X": "id1"})
await reg.append_confirmed_detail(
state=st,
item_id="SAME",
item_name="A",
doctor_id="d",
source="tear_segment",
cooldown_key="s1:seg:1",
)
await reg.append_confirmed_detail(
state=st,
item_id="SAME",
item_name="A",
doctor_id="d",
source="tear_segment",
cooldown_key="s1:seg:2",
)
assert len(st.details) == 2