"""手术室耗材视觉算法:可选手部检测 ROI + YOLO-cls(原离线双机位流水线核心逻辑)。 作为 FastAPI 内唯一的视频推理入口;撕扯动作分类已移除,由手部检测 + 耗材分类替代。 """ from __future__ import annotations import functools import os from typing import Any import sys from collections import Counter from dataclasses import dataclass from pathlib import Path from threading import Lock import numpy as np import yaml from loguru import logger from ultralytics import YOLO from app.config import Settings, settings def _ensure_yolo_config_dir() -> None: """Ultralytics 需要可写 YOLO_CONFIG_DIR;仅在未设置时给一个安全默认,不覆盖用户配置。""" if not os.environ.get("YOLO_CONFIG_DIR"): os.environ["YOLO_CONFIG_DIR"] = "/tmp" def resolve_inference_device(explicit: str) -> str | None: """Ultralytics `device`;空则 macOS 优先 MPS,Linux/Windows 优先 CUDA。""" configured = (explicit or "").strip() if configured: return configured try: import torch except Exception: return None if sys.platform == "darwin": if torch.backends.mps.is_available(): return "mps" return None if torch.cuda.is_available(): return "cuda:0" return None @dataclass(frozen=True) class PredictionCandidate: label: str confidence: float @dataclass(frozen=True) class PredictionResult: label: str confidence: float topk: list[PredictionCandidate] class ModelNotConfiguredError(RuntimeError): """权重未配置或文件不存在。""" class PredictionError(RuntimeError): """推理失败。""" @dataclass class ClsTop3: t1_name: str t1_conf: float t2_name: str t2_conf: float t3_name: str t3_conf: float t1_pid: str t2_pid: str t3_pid: str def _norm_product_name(name: str) -> str: s = (name or "").strip() if s == "一次性医用垫单": return "一次性使用手术单(一次性医用垫单)" return s def load_name_to_label_id_from_yaml(path: Path) -> dict[str, str]: """从 ``consumable_classifier_labels.yaml`` 得到:归一化商品名 -> 业务 label_id(可与 ``names`` 下标一一对应;多规格为 ``a/b/...``)。""" try: raw = path.read_text(encoding="utf-8") except OSError as exc: logger.warning("无法读取耗材 label YAML {}: {}", path, exc) return {} try: data: Any data = yaml.safe_load(raw) except yaml.YAMLError as exc: logger.warning("解析耗材 label YAML 失败 {}: {}", path, exc) return {} if not isinstance(data, dict): return {} names_raw = data.get("names") label_raw = data.get("label_id") if not isinstance(names_raw, dict) or not isinstance(label_raw, dict): return {} out: dict[str, str] = {} for k, v in names_raw.items(): try: i = int(k) except (TypeError, ValueError): continue name = str(v).strip() if v is not None else "" if not name: continue lid: Any = None if i in label_raw: lid = label_raw[i] elif str(i) in label_raw: lid = label_raw[str(i)] if lid is None or (isinstance(lid, str) and not str(lid).strip()): continue id_str = str(lid).strip() out[_norm_product_name(name)] = id_str return out def load_index_to_label_id_from_yaml(path: Path) -> dict[int, str]: """与 ``label_id`` 段:类索引 -> 业务 id 字符串;类名与 YAML 略有不一致时仍可落盘到正确 id。""" try: raw = path.read_text(encoding="utf-8") except OSError: return {} try: data: Any = yaml.safe_load(raw) except yaml.YAMLError: return {} if not isinstance(data, dict): return {} label_raw = data.get("label_id") if not isinstance(label_raw, dict): return {} out: dict[int, str] = {} for k, v in label_raw.items(): try: i = int(k) except (TypeError, ValueError): continue if v is None or (isinstance(v, str) and not str(v).strip()): continue out[i] = str(v).strip() return out @functools.lru_cache(maxsize=8) def _cached_index_to_label_id(path_resolved: str, mtime_ns: int) -> dict[int, str]: return load_index_to_label_id_from_yaml(Path(path_resolved)) def list_sorted_class_names_from_yaml(path: Path) -> list[str]: """自 ``names`` 段按类索引升序取类名字符串(与训练/权重一致)。""" try: raw = path.read_text(encoding="utf-8") except OSError: return [] try: data: Any = yaml.safe_load(raw) except yaml.YAMLError: return [] if not isinstance(data, dict): return [] names_raw = data.get("names") if not isinstance(names_raw, dict): return [] items: list[tuple[int, str]] = [] for k, v in names_raw.items(): try: i = int(k) except (TypeError, ValueError): continue s = str(v).strip() if v is not None else "" if not s: continue items.append((i, _norm_product_name(s))) items.sort(key=lambda t: t[0]) return [n for _, n in items] def collect_hand_boxes(model: YOLO, boxes) -> list[tuple[float, float, float, float]]: if boxes is None or len(boxes) == 0: return [] xyxy = boxes.xyxy.cpu().numpy() cls_ids = boxes.cls.cpu().numpy().astype(int) names = model.names out: list[tuple[float, float, float, float]] = [] for i, c in enumerate(cls_ids): label = str(names.get(int(c), "")).strip().lower() if "hand" in label or label in {"手", "手部"}: out.append(tuple(float(x) for x in xyxy[i])) if not out and len(xyxy) > 0: # 单类检测模型:无 hand 字样时保留全部框 for row in xyxy: out.append(tuple(float(x) for x in row)) return out def union_boxes( boxes: list[tuple[float, float, float, float]], ) -> tuple[float, float, float, float]: xs1, ys1, xs2, ys2 = zip(*boxes, strict=True) return min(xs1), min(ys1), max(xs2), max(ys2) def pad_box( box: tuple[float, float, float, float], w: int, h: int, pad_ratio: float, ) -> tuple[int, int, int, int]: x1, y1, x2, y2 = box bw, bh = x2 - x1, y2 - y1 pad_w, pad_h = bw * pad_ratio, bh * pad_ratio nx1 = int(max(0, x1 - pad_w)) ny1 = int(max(0, y1 - pad_h)) nx2 = int(min(w, x2 + pad_w)) ny2 = int(min(h, y2 + pad_h)) return nx1, ny1, nx2, ny2 def _probs_data_to_numpy1d(raw) -> np.ndarray: """分类 logits/probs 向量 → 1D float64 NumPy 数组。 PyTorch 张量若在 ``cuda``、``mps`` 等设备上,**必须先** ``.cpu()`` 再转 NumPy: NumPy 只支持 CPU(主机)内存,没有 CUDA/MPS 后端;``np.asarray(cuda_tensor)`` / ``tensor.numpy()``(设备上)都会失败。``.cpu()`` 会做一次设备→主机的拷贝(已是 CPU 时开销很小),因此 CUDA 与 MPS 共用同一路径即可。 """ if raw is None: return np.zeros((0,), dtype=np.float64) x = raw if hasattr(x, "detach"): x = x.detach() if hasattr(x, "cpu"): x = x.cpu() if hasattr(x, "numpy"): # torch.Tensor / ultralytics BaseTensor 等 x = x.numpy() return np.asarray(x, dtype=np.float64).reshape(-1) def cls_top3_from_result( cls: YOLO, r, name_to_code: dict[str, str], *, index_to_label_id: dict[int, str] | None = None, ) -> ClsTop3 | None: pr = r[0].probs if pr is None: return None arr = _probs_data_to_numpy1d(pr.data) if arr.size == 0: return None order = np.argsort(-arr, kind="stable") t5i = [int(order[i]) for i in range(min(5, int(order.size)))] def _conf_for_idx(idx: int) -> float: if idx < 0 or idx >= arr.size: return 0.0 try: v = arr[idx] return float(v.item() if hasattr(v, "item") else v) except (IndexError, ValueError, TypeError): return 0.0 t1i = int(t5i[0]) c1 = _conf_for_idx(t1i) n1 = str(cls.names.get(t1i, "")).strip() n2 = n3 = "" c2 = c3 = 0.0 i2 = i3 = -1 if len(t5i) > 1: i2 = int(t5i[1]) n2 = str(cls.names.get(i2, "")).strip() c2 = _conf_for_idx(i2) if len(t5i) > 2: i3 = int(t5i[2]) n3 = str(cls.names.get(i3, "")).strip() c3 = _conf_for_idx(i3) idx_extras = index_to_label_id or {} def _pid(label: str, class_idx: int) -> str: lb = (label or "").strip() if not lb: return "" norm = _norm_product_name(lb) c = (name_to_code.get(norm) or name_to_code.get(lb) or "").strip() if c: return c if class_idx >= 0 and class_idx in idx_extras: return idx_extras[class_idx] return "" return ClsTop3( t1_name=n1, t1_conf=c1, t2_name=n2, t2_conf=c2, t3_name=n3, t3_conf=c3, t1_pid=_pid(n1, t1i), t2_pid=_pid(n2, i2), t3_pid=_pid(n3, i3), ) def cls_top3_to_prediction_result(snap: ClsTop3) -> PredictionResult: topk: list[PredictionCandidate] = [] if snap.t1_name: topk.append(PredictionCandidate(snap.t1_name, snap.t1_conf)) if snap.t2_name: topk.append(PredictionCandidate(snap.t2_name, snap.t2_conf)) if snap.t3_name: topk.append(PredictionCandidate(snap.t3_name, snap.t3_conf)) if not topk: topk = [PredictionCandidate("", 0.0)] return PredictionResult( label=snap.t1_name, confidence=snap.t1_conf, topk=topk, ) def _mode_lex(names: list[str]) -> str | None: if not names: return None c = Counter(names) best = max(c.values()) pool = [n for n, k in c.items() if k == best] return min(pool) def window_bucket_to_best_snap( bucket_pts: list[tuple[str, ClsTop3]], ) -> ClsTop3 | None: """单个时间窗内:众数类名 + 该类下 top1 置信度最大的快照。""" pick = _mode_lex([a for a, _ in bucket_pts]) if pick is None: return None best: ClsTop3 | None = None for pname, sn in bucket_pts: if pname == pick and (best is None or sn.t1_conf > best.t1_conf): best = sn return best class ConsumableVisionAlgorithmService: """手部检测(可选)+ 耗材分类;供 CameraSessionManager 在视频线程中调用。""" def __init__(self, app_settings: Settings | None = None) -> None: _ensure_yolo_config_dir() self._s = app_settings or settings self._det: YOLO | None = None self._cls: YOLO | None = None self._det_lock = Lock() self._cls_lock = Lock() def effective_candidate_consumables(self, requested: list[str]) -> list[str]: """请求体中的耗材子集;未提供(缺省或仅空白)时先用 ``consumable_classifier_labels.yaml`` 的 ``names``,无有效 YAML 则分类模型类名。""" out: list[str] = [] seen: set[str] = set() for c in requested: n = _norm_product_name((c or "").strip()) if not n or n in seen: continue seen.add(n) out.append(n) if out: return out yaml_path = Path(self._s.consumable_classifier_labels_yaml_path).expanduser() if yaml_path.is_file(): ylist = list_sorted_class_names_from_yaml(yaml_path) if ylist: return ylist logger.warning("耗材 label YAML 中无有效 names: {}", yaml_path) cls_model = self._get_cls() labels = sorted( {str(v).strip() for v in cls_model.names.values() if str(v).strip()} ) return labels def build_name_mapping( self, candidate_consumables: list[str] ) -> dict[str, str]: """分类类名(归一化) -> 业务 id:仅 ``consumable_classifier_labels.yaml`` 的 ``label_id``;无映射时用语义类名作 id。""" stripped = [_norm_product_name(c.strip()) for c in candidate_consumables if c.strip()] candidates_norm = {n: n for n in stripped} if not candidates_norm: return {} yaml_path = Path(self._s.consumable_classifier_labels_yaml_path).expanduser() yaml_map: dict[str, str] = {} if yaml_path.is_file(): try: yaml_map = load_name_to_label_id_from_yaml(yaml_path) except Exception as exc: # noqa: BLE001 logger.warning("加载耗材 label YAML 失败 {}: {}", yaml_path, exc) else: logger.debug("耗材 label YAML 不存在: {}", yaml_path) out: dict[str, str] = {} for norm in candidates_norm: out[norm] = yaml_map.get(norm) or norm return out def _det_weights(self) -> Path | None: raw = (self._s.hand_detection_weights or "").strip() if not raw: return None p = Path(raw).expanduser() return p if p.is_file() else None def _cls_weights(self) -> Path: raw = (self._s.consumable_classifier_weights or "").strip() if not raw: raise ModelNotConfiguredError( "未配置耗材分类权重。请设置 CONSUMABLE_CLASSIFIER_WEIGHTS。" ) p = Path(raw).expanduser().resolve() if not p.is_file(): raise ModelNotConfiguredError(f"耗材分类权重不存在: {p}") return p def _get_det(self) -> YOLO | None: path = self._det_weights() if path is None: return None if self._det is None: with self._det_lock: if self._det is None: logger.info("加载手部检测权重: {}", path) self._det = YOLO(str(path)) return self._det def _get_cls(self) -> YOLO: if self._cls is None: with self._cls_lock: if self._cls is None: path = self._cls_weights() logger.info("加载耗材分类权重: {}", path) self._cls = YOLO(str(path)) return self._cls def hand_crop( self, frame: np.ndarray, det_model: YOLO, *, det_conf: float, pad_ratio: float, min_crop_px: int, imgsz_det: int, ) -> np.ndarray | None: h, w = frame.shape[:2] device = resolve_inference_device(self._s.hand_detection_device) results = det_model.predict( frame, conf=det_conf, imgsz=imgsz_det, device=device, verbose=False, ) hand_xyxys = collect_hand_boxes(det_model, results[0].boxes) if not hand_xyxys: return None merged = union_boxes(hand_xyxys) cx1, cy1, cx2, cy2 = pad_box(merged, w, h, pad_ratio) if (cx2 - cx1) < min_crop_px or (cy2 - cy1) < min_crop_px: return None return frame[cy1:cy2, cx1:cx2] def infer_frame_bgr( self, frame: np.ndarray, name_to_code: dict[str, str], ) -> ClsTop3 | None: """单帧 BGR;仅当 top1 通过置信度且落在白名单(name_to_code 键)时返回。""" whitelist = set(name_to_code.keys()) det_model = self._get_det() cls_model = self._get_cls() if det_model is not None: crop = self.hand_crop( frame, det_model, det_conf=self._s.hand_detection_conf, pad_ratio=self._s.hand_detection_pad_ratio, min_crop_px=self._s.hand_detection_min_crop_px, imgsz_det=self._s.hand_detection_imgsz, ) if crop is None: return None else: crop = frame device = resolve_inference_device(self._s.consumable_classifier_device) try: r = cls_model.predict( crop, imgsz=self._s.consumable_classifier_imgsz, device=device, verbose=False, ) except Exception as exc: raise PredictionError(f"耗材分类推理失败: {exc}") from exc yp = Path(self._s.consumable_classifier_labels_yaml_path).expanduser() if yp.is_file(): st = yp.stat() index_to_label_id = _cached_index_to_label_id( str(yp.resolve()), st.st_mtime_ns ) else: index_to_label_id = {} snap = cls_top3_from_result( cls_model, r, name_to_code, index_to_label_id=index_to_label_id, ) if snap is None: return None if snap.t1_conf < self._s.consumable_min_cls_confidence: return None pname = snap.t1_name if not pname: return None pnorm = _norm_product_name(pname) if pnorm in whitelist or pname in whitelist: return snap if (snap.t1_pid or "").strip(): return snap return None