2026-04-22 16:31:12 +08:00
|
|
|
|
"""手术室耗材视觉算法:可选手部检测 ROI + YOLO-cls(原离线双机位流水线核心逻辑)。
|
|
|
|
|
|
|
|
|
|
|
|
作为 FastAPI 内唯一的视频推理入口;撕扯动作分类已移除,由手部检测 + 耗材分类替代。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
2026-04-24 14:27:56 +08:00
|
|
|
|
import functools
|
2026-04-22 16:31:12 +08:00
|
|
|
|
import os
|
2026-04-24 14:27:56 +08:00
|
|
|
|
from typing import Any
|
2026-04-22 16:31:12 +08:00
|
|
|
|
import sys
|
|
|
|
|
|
from collections import Counter
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from threading import Lock
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
2026-04-24 11:05:17 +08:00
|
|
|
|
import yaml
|
2026-04-22 16:31:12 +08:00
|
|
|
|
from loguru import logger
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
from app.baked import algorithm as ba
|
2026-04-22 16:31:12 +08:00
|
|
|
|
|
2026-04-23 20:42:21 +08:00
|
|
|
|
|
|
|
|
|
|
def _ensure_yolo_config_dir() -> None:
|
|
|
|
|
|
"""Ultralytics 需要可写 YOLO_CONFIG_DIR;仅在未设置时给一个安全默认,不覆盖用户配置。"""
|
|
|
|
|
|
if not os.environ.get("YOLO_CONFIG_DIR"):
|
|
|
|
|
|
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
2026-04-22 16:31:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_inference_device(explicit: str) -> str | None:
|
|
|
|
|
|
"""Ultralytics `device`;空则 macOS 优先 MPS,Linux/Windows 优先 CUDA。"""
|
|
|
|
|
|
configured = (explicit or "").strip()
|
|
|
|
|
|
if configured:
|
|
|
|
|
|
return configured
|
|
|
|
|
|
try:
|
|
|
|
|
|
import torch
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
return None
|
|
|
|
|
|
if sys.platform == "darwin":
|
|
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
|
|
|
return "mps"
|
|
|
|
|
|
return None
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
|
return "cuda:0"
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
|
class PredictionCandidate:
|
|
|
|
|
|
label: str
|
|
|
|
|
|
confidence: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
|
class PredictionResult:
|
|
|
|
|
|
label: str
|
|
|
|
|
|
confidence: float
|
|
|
|
|
|
topk: list[PredictionCandidate]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelNotConfiguredError(RuntimeError):
|
|
|
|
|
|
"""权重未配置或文件不存在。"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PredictionError(RuntimeError):
|
|
|
|
|
|
"""推理失败。"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class ClsTop3:
|
|
|
|
|
|
t1_name: str
|
|
|
|
|
|
t1_conf: float
|
|
|
|
|
|
t2_name: str
|
|
|
|
|
|
t2_conf: float
|
|
|
|
|
|
t3_name: str
|
|
|
|
|
|
t3_conf: float
|
|
|
|
|
|
t1_pid: str
|
|
|
|
|
|
t2_pid: str
|
|
|
|
|
|
t3_pid: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _norm_product_name(name: str) -> str:
|
|
|
|
|
|
s = (name or "").strip()
|
|
|
|
|
|
if s == "一次性医用垫单":
|
|
|
|
|
|
return "一次性使用手术单(一次性医用垫单)"
|
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-24 11:05:17 +08:00
|
|
|
|
def load_name_to_label_id_from_yaml(path: Path) -> dict[str, str]:
|
|
|
|
|
|
"""从 ``consumable_classifier_labels.yaml`` 得到:归一化商品名 -> 业务 label_id(可与 ``names`` 下标一一对应;多规格为 ``a/b/...``)。"""
|
2026-04-22 17:00:56 +08:00
|
|
|
|
try:
|
2026-04-24 11:05:17 +08:00
|
|
|
|
raw = path.read_text(encoding="utf-8")
|
|
|
|
|
|
except OSError as exc:
|
|
|
|
|
|
logger.warning("无法读取耗材 label YAML {}: {}", path, exc)
|
|
|
|
|
|
return {}
|
|
|
|
|
|
try:
|
|
|
|
|
|
data: Any
|
|
|
|
|
|
data = yaml.safe_load(raw)
|
|
|
|
|
|
except yaml.YAMLError as exc:
|
|
|
|
|
|
logger.warning("解析耗材 label YAML 失败 {}: {}", path, exc)
|
|
|
|
|
|
return {}
|
|
|
|
|
|
if not isinstance(data, dict):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
names_raw = data.get("names")
|
|
|
|
|
|
label_raw = data.get("label_id")
|
|
|
|
|
|
if not isinstance(names_raw, dict) or not isinstance(label_raw, dict):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
out: dict[str, str] = {}
|
|
|
|
|
|
for k, v in names_raw.items():
|
|
|
|
|
|
try:
|
|
|
|
|
|
i = int(k)
|
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
|
continue
|
|
|
|
|
|
name = str(v).strip() if v is not None else ""
|
|
|
|
|
|
if not name:
|
|
|
|
|
|
continue
|
|
|
|
|
|
lid: Any = None
|
|
|
|
|
|
if i in label_raw:
|
|
|
|
|
|
lid = label_raw[i]
|
|
|
|
|
|
elif str(i) in label_raw:
|
|
|
|
|
|
lid = label_raw[str(i)]
|
|
|
|
|
|
if lid is None or (isinstance(lid, str) and not str(lid).strip()):
|
|
|
|
|
|
continue
|
|
|
|
|
|
id_str = str(lid).strip()
|
|
|
|
|
|
out[_norm_product_name(name)] = id_str
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-24 14:27:56 +08:00
|
|
|
|
def load_index_to_label_id_from_yaml(path: Path) -> dict[int, str]:
|
|
|
|
|
|
"""与 ``label_id`` 段:类索引 -> 业务 id 字符串;类名与 YAML 略有不一致时仍可落盘到正确 id。"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
raw = path.read_text(encoding="utf-8")
|
|
|
|
|
|
except OSError:
|
|
|
|
|
|
return {}
|
|
|
|
|
|
try:
|
|
|
|
|
|
data: Any = yaml.safe_load(raw)
|
|
|
|
|
|
except yaml.YAMLError:
|
|
|
|
|
|
return {}
|
|
|
|
|
|
if not isinstance(data, dict):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
label_raw = data.get("label_id")
|
|
|
|
|
|
if not isinstance(label_raw, dict):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
out: dict[int, str] = {}
|
|
|
|
|
|
for k, v in label_raw.items():
|
|
|
|
|
|
try:
|
|
|
|
|
|
i = int(k)
|
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
|
continue
|
|
|
|
|
|
if v is None or (isinstance(v, str) and not str(v).strip()):
|
|
|
|
|
|
continue
|
|
|
|
|
|
out[i] = str(v).strip()
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.lru_cache(maxsize=8)
|
|
|
|
|
|
def _cached_index_to_label_id(path_resolved: str, mtime_ns: int) -> dict[int, str]:
|
|
|
|
|
|
return load_index_to_label_id_from_yaml(Path(path_resolved))
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-24 11:05:17 +08:00
|
|
|
|
def list_sorted_class_names_from_yaml(path: Path) -> list[str]:
|
|
|
|
|
|
"""自 ``names`` 段按类索引升序取类名字符串(与训练/权重一致)。"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
raw = path.read_text(encoding="utf-8")
|
|
|
|
|
|
except OSError:
|
|
|
|
|
|
return []
|
|
|
|
|
|
try:
|
|
|
|
|
|
data: Any = yaml.safe_load(raw)
|
|
|
|
|
|
except yaml.YAMLError:
|
|
|
|
|
|
return []
|
|
|
|
|
|
if not isinstance(data, dict):
|
|
|
|
|
|
return []
|
|
|
|
|
|
names_raw = data.get("names")
|
|
|
|
|
|
if not isinstance(names_raw, dict):
|
|
|
|
|
|
return []
|
|
|
|
|
|
items: list[tuple[int, str]] = []
|
|
|
|
|
|
for k, v in names_raw.items():
|
|
|
|
|
|
try:
|
|
|
|
|
|
i = int(k)
|
|
|
|
|
|
except (TypeError, ValueError):
|
|
|
|
|
|
continue
|
|
|
|
|
|
s = str(v).strip() if v is not None else ""
|
|
|
|
|
|
if not s:
|
|
|
|
|
|
continue
|
|
|
|
|
|
items.append((i, _norm_product_name(s)))
|
|
|
|
|
|
items.sort(key=lambda t: t[0])
|
|
|
|
|
|
return [n for _, n in items]
|
2026-04-22 16:31:12 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collect_hand_boxes(model: YOLO, boxes) -> list[tuple[float, float, float, float]]:
|
|
|
|
|
|
if boxes is None or len(boxes) == 0:
|
|
|
|
|
|
return []
|
|
|
|
|
|
xyxy = boxes.xyxy.cpu().numpy()
|
|
|
|
|
|
cls_ids = boxes.cls.cpu().numpy().astype(int)
|
|
|
|
|
|
names = model.names
|
|
|
|
|
|
out: list[tuple[float, float, float, float]] = []
|
|
|
|
|
|
for i, c in enumerate(cls_ids):
|
|
|
|
|
|
label = str(names.get(int(c), "")).strip().lower()
|
|
|
|
|
|
if "hand" in label or label in {"手", "手部"}:
|
|
|
|
|
|
out.append(tuple(float(x) for x in xyxy[i]))
|
|
|
|
|
|
if not out and len(xyxy) > 0:
|
|
|
|
|
|
# 单类检测模型:无 hand 字样时保留全部框
|
|
|
|
|
|
for row in xyxy:
|
|
|
|
|
|
out.append(tuple(float(x) for x in row))
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def union_boxes(
|
|
|
|
|
|
boxes: list[tuple[float, float, float, float]],
|
|
|
|
|
|
) -> tuple[float, float, float, float]:
|
|
|
|
|
|
xs1, ys1, xs2, ys2 = zip(*boxes, strict=True)
|
|
|
|
|
|
return min(xs1), min(ys1), max(xs2), max(ys2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_box(
|
|
|
|
|
|
box: tuple[float, float, float, float],
|
|
|
|
|
|
w: int,
|
|
|
|
|
|
h: int,
|
|
|
|
|
|
pad_ratio: float,
|
|
|
|
|
|
) -> tuple[int, int, int, int]:
|
|
|
|
|
|
x1, y1, x2, y2 = box
|
|
|
|
|
|
bw, bh = x2 - x1, y2 - y1
|
|
|
|
|
|
pad_w, pad_h = bw * pad_ratio, bh * pad_ratio
|
|
|
|
|
|
nx1 = int(max(0, x1 - pad_w))
|
|
|
|
|
|
ny1 = int(max(0, y1 - pad_h))
|
|
|
|
|
|
nx2 = int(min(w, x2 + pad_w))
|
|
|
|
|
|
ny2 = int(min(h, y2 + pad_h))
|
|
|
|
|
|
return nx1, ny1, nx2, ny2
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-23 20:42:21 +08:00
|
|
|
|
def _probs_data_to_numpy1d(raw) -> np.ndarray:
|
|
|
|
|
|
"""分类 logits/probs 向量 → 1D float64 NumPy 数组。
|
|
|
|
|
|
|
|
|
|
|
|
PyTorch 张量若在 ``cuda``、``mps`` 等设备上,**必须先** ``.cpu()`` 再转 NumPy:
|
|
|
|
|
|
NumPy 只支持 CPU(主机)内存,没有 CUDA/MPS 后端;``np.asarray(cuda_tensor)`` /
|
|
|
|
|
|
``tensor.numpy()``(设备上)都会失败。``.cpu()`` 会做一次设备→主机的拷贝(已是 CPU
|
|
|
|
|
|
时开销很小),因此 CUDA 与 MPS 共用同一路径即可。
|
|
|
|
|
|
"""
|
|
|
|
|
|
if raw is None:
|
|
|
|
|
|
return np.zeros((0,), dtype=np.float64)
|
|
|
|
|
|
x = raw
|
|
|
|
|
|
if hasattr(x, "detach"):
|
|
|
|
|
|
x = x.detach()
|
|
|
|
|
|
if hasattr(x, "cpu"):
|
|
|
|
|
|
x = x.cpu()
|
|
|
|
|
|
if hasattr(x, "numpy"):
|
|
|
|
|
|
# torch.Tensor / ultralytics BaseTensor 等
|
|
|
|
|
|
x = x.numpy()
|
|
|
|
|
|
return np.asarray(x, dtype=np.float64).reshape(-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-04-22 16:31:12 +08:00
|
|
|
|
def cls_top3_from_result(
|
2026-04-24 14:27:56 +08:00
|
|
|
|
cls: YOLO,
|
|
|
|
|
|
r,
|
|
|
|
|
|
name_to_code: dict[str, str],
|
|
|
|
|
|
*,
|
|
|
|
|
|
index_to_label_id: dict[int, str] | None = None,
|
2026-04-22 16:31:12 +08:00
|
|
|
|
) -> ClsTop3 | None:
|
|
|
|
|
|
pr = r[0].probs
|
2026-04-23 20:42:21 +08:00
|
|
|
|
if pr is None:
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return None
|
2026-04-23 20:42:21 +08:00
|
|
|
|
arr = _probs_data_to_numpy1d(pr.data)
|
|
|
|
|
|
if arr.size == 0:
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return None
|
2026-04-23 20:42:21 +08:00
|
|
|
|
order = np.argsort(-arr, kind="stable")
|
|
|
|
|
|
t5i = [int(order[i]) for i in range(min(5, int(order.size)))]
|
2026-04-22 16:31:12 +08:00
|
|
|
|
|
2026-04-23 20:42:21 +08:00
|
|
|
|
def _conf_for_idx(idx: int) -> float:
|
|
|
|
|
|
if idx < 0 or idx >= arr.size:
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return 0.0
|
|
|
|
|
|
try:
|
2026-04-23 20:42:21 +08:00
|
|
|
|
v = arr[idx]
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return float(v.item() if hasattr(v, "item") else v)
|
|
|
|
|
|
except (IndexError, ValueError, TypeError):
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
|
2026-04-23 20:42:21 +08:00
|
|
|
|
t1i = int(t5i[0])
|
|
|
|
|
|
c1 = _conf_for_idx(t1i)
|
2026-04-22 16:31:12 +08:00
|
|
|
|
n1 = str(cls.names.get(t1i, "")).strip()
|
|
|
|
|
|
|
|
|
|
|
|
n2 = n3 = ""
|
|
|
|
|
|
c2 = c3 = 0.0
|
2026-04-24 14:27:56 +08:00
|
|
|
|
i2 = i3 = -1
|
2026-04-22 16:31:12 +08:00
|
|
|
|
if len(t5i) > 1:
|
2026-04-23 20:42:21 +08:00
|
|
|
|
i2 = int(t5i[1])
|
|
|
|
|
|
n2 = str(cls.names.get(i2, "")).strip()
|
|
|
|
|
|
c2 = _conf_for_idx(i2)
|
2026-04-22 16:31:12 +08:00
|
|
|
|
if len(t5i) > 2:
|
2026-04-23 20:42:21 +08:00
|
|
|
|
i3 = int(t5i[2])
|
|
|
|
|
|
n3 = str(cls.names.get(i3, "")).strip()
|
|
|
|
|
|
c3 = _conf_for_idx(i3)
|
2026-04-22 16:31:12 +08:00
|
|
|
|
|
2026-04-24 14:27:56 +08:00
|
|
|
|
idx_extras = index_to_label_id or {}
|
|
|
|
|
|
|
|
|
|
|
|
def _pid(label: str, class_idx: int) -> str:
|
2026-04-23 16:09:20 +08:00
|
|
|
|
lb = (label or "").strip()
|
|
|
|
|
|
if not lb:
|
|
|
|
|
|
return ""
|
|
|
|
|
|
norm = _norm_product_name(lb)
|
2026-04-24 14:27:56 +08:00
|
|
|
|
c = (name_to_code.get(norm) or name_to_code.get(lb) or "").strip()
|
|
|
|
|
|
if c:
|
|
|
|
|
|
return c
|
|
|
|
|
|
if class_idx >= 0 and class_idx in idx_extras:
|
|
|
|
|
|
return idx_extras[class_idx]
|
|
|
|
|
|
return ""
|
2026-04-23 16:09:20 +08:00
|
|
|
|
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return ClsTop3(
|
|
|
|
|
|
t1_name=n1,
|
|
|
|
|
|
t1_conf=c1,
|
|
|
|
|
|
t2_name=n2,
|
|
|
|
|
|
t2_conf=c2,
|
|
|
|
|
|
t3_name=n3,
|
|
|
|
|
|
t3_conf=c3,
|
2026-04-24 14:27:56 +08:00
|
|
|
|
t1_pid=_pid(n1, t1i),
|
|
|
|
|
|
t2_pid=_pid(n2, i2),
|
|
|
|
|
|
t3_pid=_pid(n3, i3),
|
2026-04-22 16:31:12 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cls_top3_to_prediction_result(snap: ClsTop3) -> PredictionResult:
|
|
|
|
|
|
topk: list[PredictionCandidate] = []
|
|
|
|
|
|
if snap.t1_name:
|
|
|
|
|
|
topk.append(PredictionCandidate(snap.t1_name, snap.t1_conf))
|
|
|
|
|
|
if snap.t2_name:
|
|
|
|
|
|
topk.append(PredictionCandidate(snap.t2_name, snap.t2_conf))
|
|
|
|
|
|
if snap.t3_name:
|
|
|
|
|
|
topk.append(PredictionCandidate(snap.t3_name, snap.t3_conf))
|
|
|
|
|
|
if not topk:
|
|
|
|
|
|
topk = [PredictionCandidate("", 0.0)]
|
|
|
|
|
|
return PredictionResult(
|
|
|
|
|
|
label=snap.t1_name,
|
|
|
|
|
|
confidence=snap.t1_conf,
|
|
|
|
|
|
topk=topk,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mode_lex(names: list[str]) -> str | None:
|
|
|
|
|
|
if not names:
|
|
|
|
|
|
return None
|
|
|
|
|
|
c = Counter(names)
|
|
|
|
|
|
best = max(c.values())
|
|
|
|
|
|
pool = [n for n, k in c.items() if k == best]
|
|
|
|
|
|
return min(pool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def window_bucket_to_best_snap(
|
|
|
|
|
|
bucket_pts: list[tuple[str, ClsTop3]],
|
|
|
|
|
|
) -> ClsTop3 | None:
|
|
|
|
|
|
"""单个时间窗内:众数类名 + 该类下 top1 置信度最大的快照。"""
|
|
|
|
|
|
pick = _mode_lex([a for a, _ in bucket_pts])
|
|
|
|
|
|
if pick is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
best: ClsTop3 | None = None
|
|
|
|
|
|
for pname, sn in bucket_pts:
|
|
|
|
|
|
if pname == pick and (best is None or sn.t1_conf > best.t1_conf):
|
|
|
|
|
|
best = sn
|
|
|
|
|
|
return best
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConsumableVisionAlgorithmService:
|
|
|
|
|
|
"""手部检测(可选)+ 耗材分类;供 CameraSessionManager 在视频线程中调用。"""
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
def __init__(self, *, labels_yaml_path: str | None = None) -> None:
|
2026-04-23 20:42:21 +08:00
|
|
|
|
_ensure_yolo_config_dir()
|
2026-04-24 15:33:22 +08:00
|
|
|
|
self._labels_yaml_path = labels_yaml_path
|
2026-04-22 16:31:12 +08:00
|
|
|
|
self._det: YOLO | None = None
|
|
|
|
|
|
self._cls: YOLO | None = None
|
|
|
|
|
|
self._det_lock = Lock()
|
|
|
|
|
|
self._cls_lock = Lock()
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
def _labels_path(self) -> Path:
|
|
|
|
|
|
raw = self._labels_yaml_path
|
|
|
|
|
|
if raw is not None and str(raw).strip():
|
|
|
|
|
|
return Path(str(raw).strip()).expanduser()
|
|
|
|
|
|
return Path(ba.CONSUMABLE_CLASSIFIER_LABELS_YAML_PATH).expanduser()
|
|
|
|
|
|
|
2026-04-23 20:42:21 +08:00
|
|
|
|
def effective_candidate_consumables(self, requested: list[str]) -> list[str]:
|
2026-04-24 11:05:17 +08:00
|
|
|
|
"""请求体中的耗材子集;未提供(缺省或仅空白)时先用 ``consumable_classifier_labels.yaml`` 的 ``names``,无有效 YAML 则分类模型类名。"""
|
2026-04-23 20:42:21 +08:00
|
|
|
|
out: list[str] = []
|
|
|
|
|
|
seen: set[str] = set()
|
|
|
|
|
|
for c in requested:
|
|
|
|
|
|
n = _norm_product_name((c or "").strip())
|
|
|
|
|
|
if not n or n in seen:
|
|
|
|
|
|
continue
|
|
|
|
|
|
seen.add(n)
|
|
|
|
|
|
out.append(n)
|
|
|
|
|
|
if out:
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
yaml_path = self._labels_path()
|
2026-04-24 11:05:17 +08:00
|
|
|
|
if yaml_path.is_file():
|
|
|
|
|
|
ylist = list_sorted_class_names_from_yaml(yaml_path)
|
|
|
|
|
|
if ylist:
|
|
|
|
|
|
return ylist
|
|
|
|
|
|
logger.warning("耗材 label YAML 中无有效 names: {}", yaml_path)
|
2026-04-23 20:42:21 +08:00
|
|
|
|
|
|
|
|
|
|
cls_model = self._get_cls()
|
|
|
|
|
|
labels = sorted(
|
|
|
|
|
|
{str(v).strip() for v in cls_model.names.values() if str(v).strip()}
|
|
|
|
|
|
)
|
|
|
|
|
|
return labels
|
|
|
|
|
|
|
2026-04-22 16:31:12 +08:00
|
|
|
|
def build_name_mapping(
|
|
|
|
|
|
self, candidate_consumables: list[str]
|
|
|
|
|
|
) -> dict[str, str]:
|
2026-04-24 11:05:17 +08:00
|
|
|
|
"""分类类名(归一化) -> 业务 id:仅 ``consumable_classifier_labels.yaml`` 的 ``label_id``;无映射时用语义类名作 id。"""
|
2026-04-22 16:31:12 +08:00
|
|
|
|
stripped = [_norm_product_name(c.strip()) for c in candidate_consumables if c.strip()]
|
|
|
|
|
|
candidates_norm = {n: n for n in stripped}
|
2026-04-24 11:05:17 +08:00
|
|
|
|
if not candidates_norm:
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
yaml_path = self._labels_path()
|
2026-04-24 11:05:17 +08:00
|
|
|
|
yaml_map: dict[str, str] = {}
|
|
|
|
|
|
if yaml_path.is_file():
|
|
|
|
|
|
try:
|
|
|
|
|
|
yaml_map = load_name_to_label_id_from_yaml(yaml_path)
|
|
|
|
|
|
except Exception as exc: # noqa: BLE001
|
|
|
|
|
|
logger.warning("加载耗材 label YAML 失败 {}: {}", yaml_path, exc)
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.debug("耗材 label YAML 不存在: {}", yaml_path)
|
|
|
|
|
|
|
|
|
|
|
|
out: dict[str, str] = {}
|
|
|
|
|
|
for norm in candidates_norm:
|
|
|
|
|
|
out[norm] = yaml_map.get(norm) or norm
|
|
|
|
|
|
return out
|
2026-04-22 16:31:12 +08:00
|
|
|
|
|
|
|
|
|
|
def _det_weights(self) -> Path | None:
|
2026-04-24 15:33:22 +08:00
|
|
|
|
raw = (ba.HAND_DETECTION_WEIGHTS or "").strip()
|
2026-04-22 16:31:12 +08:00
|
|
|
|
if not raw:
|
|
|
|
|
|
return None
|
|
|
|
|
|
p = Path(raw).expanduser()
|
|
|
|
|
|
return p if p.is_file() else None
|
|
|
|
|
|
|
|
|
|
|
|
def _cls_weights(self) -> Path:
|
2026-04-24 15:33:22 +08:00
|
|
|
|
p = Path(ba.CONSUMABLE_CLASSIFIER_WEIGHTS).expanduser().resolve()
|
2026-04-22 16:31:12 +08:00
|
|
|
|
if not p.is_file():
|
|
|
|
|
|
raise ModelNotConfiguredError(f"耗材分类权重不存在: {p}")
|
|
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
|
|
def _get_det(self) -> YOLO | None:
|
|
|
|
|
|
path = self._det_weights()
|
|
|
|
|
|
if path is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
if self._det is None:
|
|
|
|
|
|
with self._det_lock:
|
|
|
|
|
|
if self._det is None:
|
|
|
|
|
|
logger.info("加载手部检测权重: {}", path)
|
|
|
|
|
|
self._det = YOLO(str(path))
|
|
|
|
|
|
return self._det
|
|
|
|
|
|
|
|
|
|
|
|
def _get_cls(self) -> YOLO:
|
|
|
|
|
|
if self._cls is None:
|
|
|
|
|
|
with self._cls_lock:
|
|
|
|
|
|
if self._cls is None:
|
|
|
|
|
|
path = self._cls_weights()
|
|
|
|
|
|
logger.info("加载耗材分类权重: {}", path)
|
|
|
|
|
|
self._cls = YOLO(str(path))
|
|
|
|
|
|
return self._cls
|
|
|
|
|
|
|
|
|
|
|
|
def hand_crop(
|
|
|
|
|
|
self,
|
|
|
|
|
|
frame: np.ndarray,
|
|
|
|
|
|
det_model: YOLO,
|
|
|
|
|
|
*,
|
|
|
|
|
|
det_conf: float,
|
|
|
|
|
|
pad_ratio: float,
|
|
|
|
|
|
min_crop_px: int,
|
|
|
|
|
|
imgsz_det: int,
|
|
|
|
|
|
) -> np.ndarray | None:
|
|
|
|
|
|
h, w = frame.shape[:2]
|
2026-04-24 15:33:22 +08:00
|
|
|
|
device = resolve_inference_device(ba.HAND_DETECTION_DEVICE)
|
2026-04-22 16:31:12 +08:00
|
|
|
|
results = det_model.predict(
|
|
|
|
|
|
frame,
|
|
|
|
|
|
conf=det_conf,
|
|
|
|
|
|
imgsz=imgsz_det,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
verbose=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
hand_xyxys = collect_hand_boxes(det_model, results[0].boxes)
|
|
|
|
|
|
if not hand_xyxys:
|
|
|
|
|
|
return None
|
|
|
|
|
|
merged = union_boxes(hand_xyxys)
|
|
|
|
|
|
cx1, cy1, cx2, cy2 = pad_box(merged, w, h, pad_ratio)
|
|
|
|
|
|
if (cx2 - cx1) < min_crop_px or (cy2 - cy1) < min_crop_px:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return frame[cy1:cy2, cx1:cx2]
|
|
|
|
|
|
|
|
|
|
|
|
def infer_frame_bgr(
|
|
|
|
|
|
self,
|
|
|
|
|
|
frame: np.ndarray,
|
|
|
|
|
|
name_to_code: dict[str, str],
|
|
|
|
|
|
) -> ClsTop3 | None:
|
|
|
|
|
|
"""单帧 BGR;仅当 top1 通过置信度且落在白名单(name_to_code 键)时返回。"""
|
|
|
|
|
|
whitelist = set(name_to_code.keys())
|
|
|
|
|
|
det_model = self._get_det()
|
|
|
|
|
|
cls_model = self._get_cls()
|
|
|
|
|
|
|
|
|
|
|
|
if det_model is not None:
|
|
|
|
|
|
crop = self.hand_crop(
|
|
|
|
|
|
frame,
|
|
|
|
|
|
det_model,
|
2026-04-24 15:33:22 +08:00
|
|
|
|
det_conf=ba.HAND_DETECTION_CONF,
|
|
|
|
|
|
pad_ratio=ba.HAND_DETECTION_PAD_RATIO,
|
|
|
|
|
|
min_crop_px=ba.HAND_DETECTION_MIN_CROP_PX,
|
|
|
|
|
|
imgsz_det=ba.HAND_DETECTION_IMGSZ,
|
2026-04-22 16:31:12 +08:00
|
|
|
|
)
|
|
|
|
|
|
if crop is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
else:
|
|
|
|
|
|
crop = frame
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
device = resolve_inference_device(ba.CONSUMABLE_CLASSIFIER_DEVICE)
|
2026-04-22 16:31:12 +08:00
|
|
|
|
try:
|
|
|
|
|
|
r = cls_model.predict(
|
|
|
|
|
|
crop,
|
2026-04-24 15:33:22 +08:00
|
|
|
|
imgsz=ba.CONSUMABLE_CLASSIFIER_IMGSZ,
|
2026-04-22 16:31:12 +08:00
|
|
|
|
device=device,
|
|
|
|
|
|
verbose=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
|
raise PredictionError(f"耗材分类推理失败: {exc}") from exc
|
|
|
|
|
|
|
2026-04-24 15:33:22 +08:00
|
|
|
|
yp = self._labels_path()
|
2026-04-24 14:27:56 +08:00
|
|
|
|
if yp.is_file():
|
|
|
|
|
|
st = yp.stat()
|
|
|
|
|
|
index_to_label_id = _cached_index_to_label_id(
|
|
|
|
|
|
str(yp.resolve()), st.st_mtime_ns
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
index_to_label_id = {}
|
|
|
|
|
|
|
|
|
|
|
|
snap = cls_top3_from_result(
|
|
|
|
|
|
cls_model,
|
|
|
|
|
|
r,
|
|
|
|
|
|
name_to_code,
|
|
|
|
|
|
index_to_label_id=index_to_label_id,
|
|
|
|
|
|
)
|
2026-04-22 16:31:12 +08:00
|
|
|
|
if snap is None:
|
|
|
|
|
|
return None
|
2026-04-24 15:33:22 +08:00
|
|
|
|
if snap.t1_conf < ba.CONSUMABLE_MIN_CLS_CONFIDENCE:
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return None
|
|
|
|
|
|
pname = snap.t1_name
|
2026-04-24 14:27:56 +08:00
|
|
|
|
if not pname:
|
2026-04-22 16:31:12 +08:00
|
|
|
|
return None
|
2026-04-24 14:27:56 +08:00
|
|
|
|
pnorm = _norm_product_name(pname)
|
|
|
|
|
|
if pnorm in whitelist or pname in whitelist:
|
|
|
|
|
|
return snap
|
|
|
|
|
|
if (snap.t1_pid or "").strip():
|
|
|
|
|
|
return snap
|
|
|
|
|
|
return None
|