"""有状态逐帧处理 + 停录时段级汇总(与 haocai_consumption demo main.run 同构)。""" from __future__ import annotations import time from collections import Counter from dataclasses import dataclass from pathlib import Path from threading import Lock import numpy as np from loguru import logger from ultralytics import YOLO from app.baked import algorithm as ba from app.services.consumable_vision_algorithm import ( _norm_product_name, resolve_inference_device, ) from app.services.tear_gated_segment_consumption.geometry import ( collect_hand_boxes, find_tearing_pair, pad_box, prob_tearing, union_boxes, ) from app.services.tear_gated_segment_consumption.segments import merge_tear_segments @dataclass(frozen=True) class TearGatedSegmentRecord: """单段输出,与离线 txt 一行语义一致。""" segment_index: int start_sec: float end_sec: float mid_stream_sec: float item_id: str item_name: str top1_conf: float top2_name: str top2_conf: float top3_name: str top3_conf: float majority_ref: str def is_good_frame( gb_model: YOLO, crop: np.ndarray, gb_names: dict, imgsz: int, device: str | None ) -> bool: if crop.size == 0: return False r = gb_model.predict(crop, imgsz=imgsz, verbose=False, device=device)[0] if r.probs is None: return False tid = int(r.probs.top1) label = str(gb_names.get(tid, "")) return label == "good" def haocai_mean_topk( probs_list: list[np.ndarray], names: dict, ) -> tuple[str, float, str, float, str, float]: if not probs_list: return "(无有效帧)", 0.0, "", 0.0, "", 0.0 p = np.mean(np.stack(probs_list, axis=0), axis=0) order = np.argsort(-p) t1, t2, t3 = (int(order[0]), int(order[1]), int(order[2])) return ( str(names.get(t1, str(t1))), float(p[t1]), str(names.get(t2, str(t2))), float(p[t2]), str(names.get(t3, str(t3))), float(p[t3]), ) class TearGatedSegmentRunner: """从首帧到 finalize:累积 timeline,停录时合并撕段并生成记录。""" def __init__( self, *, det_m: YOLO, tear_m: YOLO, gb_m: YOLO, haoc_m: YOLO, name_to_id: dict[str, str], ) -> None: self._det_m = det_m self._tear_m = tear_m self._gb_m = gb_m self._haoc_m = haoc_m self._tear_names = tear_m.names self._gb_names = gb_m.names self._haoc_names = haoc_m.names self._n_h = len(self._haoc_names) if isinstance(self._haoc_names, dict) else 41 self._name_to_id = name_to_id self._lock = Lock() self._frame_idx = 0 self._tear_buf: list[float] = [] self._timeline: list[tuple[int, float, bool, str]] = [] self._frame_probs: list[np.ndarray | None] = [] self._wall_t0: float | None = None self._start_seconds = 0.0 def _effective_fps(self) -> float: raw = float(ba.TEAR_SEGMENT_ASSUMED_FPS) return raw if raw > 0 else 25.0 def process_frame_bgr(self, frame: np.ndarray) -> None: """处理单帧 BGR(与 demo run() 主循环体一致)。""" with self._lock: if self._wall_t0 is None: self._wall_t0 = time.time() det_conf = ba.TEAR_SEGMENT_DET_CONF pad_ratio = ba.TEAR_SEGMENT_PAD_RATIO tear_conf = ba.TEAR_SEGMENT_TEAR_CONF tear_smooth = ba.TEAR_SEGMENT_TEAR_SMOOTH gap_ratio = ba.TEAR_SEGMENT_GAP_RATIO fps = self._effective_fps() w = int(frame.shape[1]) h = int(frame.shape[0]) start_seconds = self._start_seconds fidx = self._frame_idx t_abs = start_seconds + fidx / fps r = self._det_m.predict( frame, conf=det_conf, imgsz=ba.TEAR_SEGMENT_DET_IMGSZ, device=resolve_inference_device(ba.HAND_DETECTION_DEVICE), verbose=False, ) hand_xyxys = collect_hand_boxes(self._det_m, r[0].boxes) geom = ( len(hand_xyxys) >= 2 and find_tearing_pair(hand_xyxys, gap_ratio=gap_ratio) is not None ) f_probs: np.ndarray | None = None rec_label = "" is_tear = False max_p = 0.0 if len(hand_xyxys) >= 1: merged = union_boxes(hand_xyxys) cx1, cy1, cx2, cy2 = pad_box(merged, w, h, pad_ratio) gb_dev = resolve_inference_device(ba.TEAR_SEGMENT_GOODBAD_DEVICE) haoc_dev = resolve_inference_device(ba.TEAR_SEGMENT_HAOCAI_DEVICE) for hbox in hand_xyxys: hx1, hy1, hx2, hy2 = pad_box(hbox, w, h, pad_ratio) hc = frame[hy1:hy2, hx1:hx2] if hc.size > 0: tr = self._tear_m.predict( hc, imgsz=ba.TEAR_SEGMENT_TEAR_IMGSZ, verbose=False, device=resolve_inference_device( ba.TEAR_SEGMENT_TEAR_DEVICE ), ) max_p = max( max_p, prob_tearing(tr[0].probs, self._tear_names) ) if tear_smooth > 0: self._tear_buf.append(max_p) if len(self._tear_buf) > tear_smooth: self._tear_buf.pop(0) p_eff = sum(self._tear_buf) / len(self._tear_buf) else: p_eff = max_p eff = tear_conf * 0.55 if geom else tear_conf is_tear = p_eff >= eff if is_tear: cls_c = frame[cy1:cy2, cx1:cx2] if cls_c.size > 0 and is_good_frame( self._gb_m, cls_c, self._gb_names, ba.TEAR_SEGMENT_GOODBAD_IMGSZ, gb_dev, ): h_r = self._haoc_m.predict( cls_c, imgsz=ba.TEAR_SEGMENT_HAOCAI_IMGSZ, verbose=False, device=haoc_dev, )[0] pr = h_r.probs if pr is not None and pr.data is not None: v = pr.data.detach().float().cpu().numpy().ravel() n_exp = self._n_h if v.size < n_exp: v = np.resize(v, n_exp) v = v[:n_exp] s = v.sum() f_probs = (v / s) if s > 0 else v tid = int(np.argmax(f_probs)) rec_label = str(self._haoc_names.get(tid, str(tid))) else: self._tear_buf.clear() self._timeline.append((fidx, t_abs, is_tear, rec_label)) self._frame_probs.append(f_probs) self._frame_idx += 1 if self._frame_idx % 200 == 0: logger.info( "tear_segment: processed {} frames (surgery stream)", self._frame_idx, ) def finalize(self) -> list[TearGatedSegmentRecord]: """段合并 + 段内 topK 与 YAML 类名→label_id 映射;RTSP 无片尾,以停录为界。""" with self._lock: timeline = self._timeline frame_probs = self._frame_probs haoc_names = self._haoc_names name_to_id = self._name_to_id if not timeline: return [] segs = merge_tear_segments( timeline, min_tear_sec=ba.TEAR_SEGMENT_MIN_TEAR_SEC, min_gap_sec=ba.TEAR_SEGMENT_MIN_GAP_SEC, ) out: list[TearGatedSegmentRecord] = [] for s in segs: f0, f1 = s["start_frame"], s["end_frame"] probs_ok: list[np.ndarray] = [] lbs: list[str] = [] for fi in range(f0, f1 + 1): if 0 <= fi < len(frame_probs) and frame_probs[fi] is not None: probs_ok.append(frame_probs[fi]) for fi in range(f0, f1 + 1): if 0 <= fi < len(timeline): _, __, it, lab = timeline[fi] if it and lab: lbs.append(lab) if lbs: majority = Counter(lbs).most_common(1)[0][0] else: majority = "(本段无好帧+耗材)" t1, c1, t2, c2, t3, c3 = haocai_mean_topk(probs_ok, haoc_names) use_name = t1 if use_name in ("", "(无有效帧)"): use_name = majority if use_name.startswith("(") or use_name == "(本段无好帧+耗材)": item_id = "(无)" else: key = _norm_product_name(use_name) item_id = name_to_id.get(key, "(无匹配编码)") t_mid = 0.5 * (s["start_sec"] + s["end_sec"]) out.append( TearGatedSegmentRecord( segment_index=s["index"], start_sec=s["start_sec"], end_sec=s["end_sec"], mid_stream_sec=t_mid, item_id=item_id, item_name=use_name, top1_conf=c1, top2_name=t2, top2_conf=c2, top3_name=t3, top3_conf=c3, majority_ref=majority, ) ) return out def wall_time_for_record(self, rec: TearGatedSegmentRecord) -> float: """段中点对应的 Unix 时间(秒),用于落库时间戳。""" with self._lock: t0w = self._wall_t0 if t0w is None: return time.time() return t0w + rec.mid_stream_sec class TearGatedSegmentModelBundle: """四模型只加载一次,供多例 Runner 复用。""" def __init__(self) -> None: self._lock = Lock() self._det: YOLO | None = None self._tear: YOLO | None = None self._gb: YOLO | None = None self._haoc: YOLO | None = None def _p(self, key: str) -> Path: return Path((key or "").strip()).expanduser().resolve() def _load(self) -> None: with self._lock: if self._det is not None: return dp = self._p(ba.TEAR_SEGMENT_HAND_DET_WEIGHTS) tp = self._p(ba.TEAR_SEGMENT_TEAR_WEIGHTS) gp = self._p(ba.TEAR_SEGMENT_GOODBAD_WEIGHTS) hp = self._p(ba.TEAR_SEGMENT_HAOCAI_WEIGHTS) for p, label in ( (dp, "hand det"), (tp, "tear"), (gp, "good/bad"), (hp, "haocai 41"), ): if not p.is_file(): raise FileNotFoundError(f"tear_segment {label} 权重不存在: {p}") logger.info("加载撕段四模型: {} {} {} {}", dp, tp, gp, hp) self._det = YOLO(str(dp)) self._tear = YOLO(str(tp)) self._gb = YOLO(str(gp)) self._haoc = YOLO(str(hp)) def ensure_loaded(self) -> None: self._load() def create_runner(self, name_to_id: dict[str, str]) -> TearGatedSegmentRunner: self.ensure_loaded() assert self._det is not None assert self._tear is not None assert self._gb is not None assert self._haoc is not None return TearGatedSegmentRunner( det_m=self._det, tear_m=self._tear, gb_m=self._gb, haoc_m=self._haoc, name_to_id=name_to_id, )