84 lines
2.7 KiB
Python
84 lines
2.7 KiB
Python
|
|
"""actionformer_gated 公共助手 + 段稳定 dedupe + ``append_confirmed_detail`` 行为。
|
|||
|
|
|
|||
|
|
替换原 ``tests/test_tear_gated_segment_consumption.py``:保留与 IPC / 业务层强相关
|
|||
|
|
的测试,移除与已废弃 tear_gated 内部算法(merge_tear_segments / fishaction
|
|||
|
|
postprocess / haocai_mean_topk)耦合的部分。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
from app.algorithm_runner.actionformer_gated.runner import (
|
|||
|
|
ActionFormerSegmentRecord,
|
|||
|
|
actionformer_segment_stable_dedupe_key,
|
|||
|
|
)
|
|||
|
|
from app.algorithm_runner.actionformer_gated.segment_helpers import mask_probs_whitelist
|
|||
|
|
from app.algorithm_runner.actionformer_gated.whitelist_indices import (
|
|||
|
|
candidate_class_indices_for_haocai_model,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_mask_probs_whitelist_keeps_allowed_only() -> None:
|
|||
|
|
p = np.array([0.4, 0.6], dtype=np.float64)
|
|||
|
|
out = mask_probs_whitelist(p, frozenset({1}), 2)
|
|||
|
|
assert out is not None
|
|||
|
|
assert abs(float(out[0])) < 1e-9
|
|||
|
|
assert abs(float(out[1]) - 1.0) < 1e-9
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_mask_probs_whitelist_empty_returns_none() -> None:
|
|||
|
|
p = np.array([1.0, 0.0], dtype=np.float64)
|
|||
|
|
assert mask_probs_whitelist(p, frozenset({5}), 2) is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_candidate_class_indices_for_haocai_model() -> None:
|
|||
|
|
haoc = {0: "手套A", 1: "纱布B"}
|
|||
|
|
assert candidate_class_indices_for_haocai_model(["纱布B"], haoc) == frozenset({1})
|
|||
|
|
assert candidate_class_indices_for_haocai_model(["未知"], haoc) is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_actionformer_segment_stable_dedupe_key() -> None:
|
|||
|
|
rec = ActionFormerSegmentRecord(
|
|||
|
|
segment_index=9,
|
|||
|
|
start_sec=1.006,
|
|||
|
|
end_sec=12.004,
|
|||
|
|
mid_stream_sec=6.0,
|
|||
|
|
item_id="id",
|
|||
|
|
item_name="手套",
|
|||
|
|
top1_conf=0.88,
|
|||
|
|
top2_name="",
|
|||
|
|
top2_conf=0.0,
|
|||
|
|
top3_name="",
|
|||
|
|
top3_conf=0.0,
|
|||
|
|
majority_ref="",
|
|||
|
|
)
|
|||
|
|
assert actionformer_segment_stable_dedupe_key(rec) == "1.01:12.0:手套"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_append_confirmed_detail_seg_cooldown_keys() -> None:
|
|||
|
|
"""同 item_id 多段在独立 cooldown_key 下应都能写入。"""
|
|||
|
|
from app.services.video.session_registry import SurgerySessionRegistry, SurgerySessionState
|
|||
|
|
|
|||
|
|
reg = SurgerySessionRegistry()
|
|||
|
|
st = SurgerySessionState(candidate_consumables=["X"], name_to_code={"X": "id1"})
|
|||
|
|
await reg.append_confirmed_detail(
|
|||
|
|
state=st,
|
|||
|
|
item_id="SAME",
|
|||
|
|
item_name="A",
|
|||
|
|
doctor_id="d",
|
|||
|
|
source="tear_segment",
|
|||
|
|
cooldown_key="s1:seg:1",
|
|||
|
|
)
|
|||
|
|
await reg.append_confirmed_detail(
|
|||
|
|
state=st,
|
|||
|
|
item_id="SAME",
|
|||
|
|
item_name="A",
|
|||
|
|
doctor_id="d",
|
|||
|
|
source="tear_segment",
|
|||
|
|
cooldown_key="s1:seg:2",
|
|||
|
|
)
|
|||
|
|
assert len(st.details) == 2
|