"""手术室耗材视觉算法:可选手部检测 ROI + YOLO-cls(原离线双机位流水线核心逻辑)。 作为 FastAPI 内唯一的视频推理入口;撕扯动作分类已移除,由手部检测 + 耗材分类替代。 """ from __future__ import annotations import os import sys from collections import Counter from dataclasses import dataclass from pathlib import Path from threading import Lock import numpy as np import pandas as pd from loguru import logger from ultralytics import YOLO from app.config import Settings, settings os.environ["YOLO_CONFIG_DIR"] = "/tmp" def resolve_inference_device(explicit: str) -> str | None: """Ultralytics `device`;空则 macOS 优先 MPS,Linux/Windows 优先 CUDA。""" configured = (explicit or "").strip() if configured: return configured try: import torch except Exception: return None if sys.platform == "darwin": if torch.backends.mps.is_available(): return "mps" return None if torch.cuda.is_available(): return "cuda:0" return None @dataclass(frozen=True) class PredictionCandidate: label: str confidence: float @dataclass(frozen=True) class PredictionResult: label: str confidence: float topk: list[PredictionCandidate] class ModelNotConfiguredError(RuntimeError): """权重未配置或文件不存在。""" class PredictionError(RuntimeError): """推理失败。""" @dataclass class ClsTop3: t1_name: str t1_conf: float t2_name: str t2_conf: float t3_name: str t3_conf: float t1_pid: str t2_pid: str t3_pid: str def _find_col(df: pd.DataFrame, want: str) -> str | None: want = want.strip() for c in df.columns: if str(c).strip() == want: return str(c) return None def _norm_product_name(name: str) -> str: s = (name or "").strip() if s == "一次性医用垫单": return "一次性使用手术单(一次性医用垫单)" return s def load_name_to_product_code(xlsx: Path) -> dict[str, str]: """商品名称 -> 产品编码(白名单键为归一化后的名称)。""" df = pd.read_excel(xlsx, sheet_name=0) c_code = _find_col(df, "产品编码") c_name = _find_col(df, "商品名称") if c_code is None or c_name is None: raise ValueError("Excel 缺少「产品编码」或「商品名称」列") m: dict[str, str] = {} dups: set[str] = set() for _, row in df.iterrows(): raw = row.get(c_name) if raw is None or (isinstance(raw, float) and pd.isna(raw)): continue n = _norm_product_name(str(raw).strip()) if not n: continue code = row.get(c_code) if code is None or (isinstance(code, float) and pd.isna(code)): continue sc = str(code).strip() if n in m and m[n] != sc: dups.add(n) continue if n not in m: m[n] = sc if dups: logger.warning( "Excel 中以下商品名称对应多组产品编码,已保留首次映射: {}", ";".join(sorted(dups)[:12]) + (" …" if len(dups) > 12 else ""), ) return m def collect_hand_boxes(model: YOLO, boxes) -> list[tuple[float, float, float, float]]: if boxes is None or len(boxes) == 0: return [] xyxy = boxes.xyxy.cpu().numpy() cls_ids = boxes.cls.cpu().numpy().astype(int) names = model.names out: list[tuple[float, float, float, float]] = [] for i, c in enumerate(cls_ids): label = str(names.get(int(c), "")).strip().lower() if "hand" in label or label in {"手", "手部"}: out.append(tuple(float(x) for x in xyxy[i])) if not out and len(xyxy) > 0: # 单类检测模型:无 hand 字样时保留全部框 for row in xyxy: out.append(tuple(float(x) for x in row)) return out def union_boxes( boxes: list[tuple[float, float, float, float]], ) -> tuple[float, float, float, float]: xs1, ys1, xs2, ys2 = zip(*boxes, strict=True) return min(xs1), min(ys1), max(xs2), max(ys2) def pad_box( box: tuple[float, float, float, float], w: int, h: int, pad_ratio: float, ) -> tuple[int, int, int, int]: x1, y1, x2, y2 = box bw, bh = x2 - x1, y2 - y1 pad_w, pad_h = bw * pad_ratio, bh * pad_ratio nx1 = int(max(0, x1 - pad_w)) ny1 = int(max(0, y1 - pad_h)) nx2 = int(min(w, x2 + pad_w)) ny2 = int(min(h, y2 + pad_h)) return nx1, ny1, nx2, ny2 def cls_top3_from_result( cls: YOLO, r, name_to_code: dict[str, str] ) -> ClsTop3 | None: pr = r[0].probs if pr is None or not hasattr(pr, "top5") or not pr.top5: return None t5i = list(pr.top5) tc = pr.top5conf if tc is None: return None def _ci(i: int) -> float: if i < 0 or i >= len(tc): return 0.0 try: v = tc[i] return float(v.item() if hasattr(v, "item") else v) except (IndexError, ValueError, TypeError): return 0.0 t1i = int(pr.top1) c1 = _ci(0) if t5i and int(t5i[0]) == t1i else float( pr.top1conf.item() if hasattr(pr.top1conf, "item") else pr.top1conf ) n1 = str(cls.names.get(t1i, "")).strip() n2 = n3 = "" c2 = c3 = 0.0 if len(t5i) > 1: n2 = str(cls.names.get(int(t5i[1]), "")).strip() c2 = _ci(1) if len(t5i) > 2: n3 = str(cls.names.get(int(t5i[2]), "")).strip() c3 = _ci(2) return ClsTop3( t1_name=n1, t1_conf=c1, t2_name=n2, t2_conf=c2, t3_name=n3, t3_conf=c3, t1_pid=name_to_code.get(n1, ""), t2_pid=name_to_code.get(n2, ""), t3_pid=name_to_code.get(n3, ""), ) def cls_top3_to_prediction_result(snap: ClsTop3) -> PredictionResult: topk: list[PredictionCandidate] = [] if snap.t1_name: topk.append(PredictionCandidate(snap.t1_name, snap.t1_conf)) if snap.t2_name: topk.append(PredictionCandidate(snap.t2_name, snap.t2_conf)) if snap.t3_name: topk.append(PredictionCandidate(snap.t3_name, snap.t3_conf)) if not topk: topk = [PredictionCandidate("", 0.0)] return PredictionResult( label=snap.t1_name, confidence=snap.t1_conf, topk=topk, ) def _mode_lex(names: list[str]) -> str | None: if not names: return None c = Counter(names) best = max(c.values()) pool = [n for n, k in c.items() if k == best] return min(pool) def window_bucket_to_best_snap( bucket_pts: list[tuple[str, ClsTop3]], ) -> ClsTop3 | None: """单个时间窗内:众数类名 + 该类下 top1 置信度最大的快照。""" pick = _mode_lex([a for a, _ in bucket_pts]) if pick is None: return None best: ClsTop3 | None = None for pname, sn in bucket_pts: if pname == pick and (best is None or sn.t1_conf > best.t1_conf): best = sn return best class ConsumableVisionAlgorithmService: """手部检测(可选)+ 耗材分类;供 CameraSessionManager 在视频线程中调用。""" def __init__(self, app_settings: Settings | None = None) -> None: self._s = app_settings or settings self._det: YOLO | None = None self._cls: YOLO | None = None self._det_lock = Lock() self._cls_lock = Lock() def build_name_mapping( self, candidate_consumables: list[str] ) -> dict[str, str]: """分类标签 -> 业务物品 id(Excel 产品编码;无表时用名称自身)。""" stripped = [_norm_product_name(c.strip()) for c in candidate_consumables if c.strip()] candidates_norm = {n: n for n in stripped} xlsx_raw = (self._s.consumable_catalog_xlsx_path or "").strip() if xlsx_raw: path = Path(xlsx_raw).expanduser() if path.is_file(): full = load_name_to_product_code(path) out: dict[str, str] = {} for norm in candidates_norm: if norm in full: out[norm] = full[norm] return out logger.warning("耗材目录 Excel 路径已配置但文件不存在: {}", path) return {n: n for n in candidates_norm} def _det_weights(self) -> Path | None: raw = (self._s.hand_detection_weights or "").strip() if not raw: return None p = Path(raw).expanduser() return p if p.is_file() else None def _cls_weights(self) -> Path: raw = (self._s.consumable_classifier_weights or "").strip() if not raw: raise ModelNotConfiguredError( "未配置耗材分类权重。请设置 CONSUMABLE_CLASSIFIER_WEIGHTS。" ) p = Path(raw).expanduser().resolve() if not p.is_file(): raise ModelNotConfiguredError(f"耗材分类权重不存在: {p}") return p def _get_det(self) -> YOLO | None: path = self._det_weights() if path is None: return None if self._det is None: with self._det_lock: if self._det is None: logger.info("加载手部检测权重: {}", path) self._det = YOLO(str(path)) return self._det def _get_cls(self) -> YOLO: if self._cls is None: with self._cls_lock: if self._cls is None: path = self._cls_weights() logger.info("加载耗材分类权重: {}", path) self._cls = YOLO(str(path)) return self._cls def hand_crop( self, frame: np.ndarray, det_model: YOLO, *, det_conf: float, pad_ratio: float, min_crop_px: int, imgsz_det: int, ) -> np.ndarray | None: h, w = frame.shape[:2] device = resolve_inference_device(self._s.hand_detection_device) results = det_model.predict( frame, conf=det_conf, imgsz=imgsz_det, device=device, verbose=False, ) hand_xyxys = collect_hand_boxes(det_model, results[0].boxes) if not hand_xyxys: return None merged = union_boxes(hand_xyxys) cx1, cy1, cx2, cy2 = pad_box(merged, w, h, pad_ratio) if (cx2 - cx1) < min_crop_px or (cy2 - cy1) < min_crop_px: return None return frame[cy1:cy2, cx1:cx2] def infer_frame_bgr( self, frame: np.ndarray, name_to_code: dict[str, str], ) -> ClsTop3 | None: """单帧 BGR;仅当 top1 通过置信度且落在白名单(name_to_code 键)时返回。""" whitelist = set(name_to_code.keys()) det_model = self._get_det() cls_model = self._get_cls() if det_model is not None: crop = self.hand_crop( frame, det_model, det_conf=self._s.hand_detection_conf, pad_ratio=self._s.hand_detection_pad_ratio, min_crop_px=self._s.hand_detection_min_crop_px, imgsz_det=self._s.hand_detection_imgsz, ) if crop is None: return None else: crop = frame device = resolve_inference_device(self._s.consumable_classifier_device) try: r = cls_model.predict( crop, imgsz=self._s.consumable_classifier_imgsz, device=device, verbose=False, ) except Exception as exc: raise PredictionError(f"耗材分类推理失败: {exc}") from exc snap = cls_top3_from_result(cls_model, r, name_to_code) if snap is None: return None if snap.t1_conf < self._s.consumable_min_cls_confidence: return None pname = snap.t1_name if not pname or pname not in whitelist: return None return snap