- Refactor app API and schemas; adjust surgery pipeline, repository, and session manager. - Improve consumption TSV logging and consumable vision integration; trim voice resolution. - Add Baidu Face 1:N search script, .env.example entries, and client API integration doc. - Update demo client, staging checklist, surgery interface doc, and related tests; add sample face image. Made-with: Cursor
421 lines
13 KiB
Python
421 lines
13 KiB
Python
"""手术室耗材视觉算法:可选手部检测 ROI + YOLO-cls(原离线双机位流水线核心逻辑)。
|
||
|
||
作为 FastAPI 内唯一的视频推理入口;撕扯动作分类已移除,由手部检测 + 耗材分类替代。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import math
|
||
import os
|
||
import sys
|
||
from collections import Counter
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from threading import Lock
|
||
|
||
import numpy as np
|
||
from loguru import logger
|
||
from openpyxl import load_workbook
|
||
from ultralytics import YOLO
|
||
|
||
from app.config import Settings, settings
|
||
|
||
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
||
|
||
|
||
def resolve_inference_device(explicit: str) -> str | None:
|
||
"""Ultralytics `device`;空则 macOS 优先 MPS,Linux/Windows 优先 CUDA。"""
|
||
configured = (explicit or "").strip()
|
||
if configured:
|
||
return configured
|
||
try:
|
||
import torch
|
||
except Exception:
|
||
return None
|
||
if sys.platform == "darwin":
|
||
if torch.backends.mps.is_available():
|
||
return "mps"
|
||
return None
|
||
if torch.cuda.is_available():
|
||
return "cuda:0"
|
||
return None
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class PredictionCandidate:
|
||
label: str
|
||
confidence: float
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class PredictionResult:
|
||
label: str
|
||
confidence: float
|
||
topk: list[PredictionCandidate]
|
||
|
||
|
||
class ModelNotConfiguredError(RuntimeError):
|
||
"""权重未配置或文件不存在。"""
|
||
|
||
|
||
class PredictionError(RuntimeError):
|
||
"""推理失败。"""
|
||
|
||
|
||
@dataclass
|
||
class ClsTop3:
|
||
t1_name: str
|
||
t1_conf: float
|
||
t2_name: str
|
||
t2_conf: float
|
||
t3_name: str
|
||
t3_conf: float
|
||
t1_pid: str
|
||
t2_pid: str
|
||
t3_pid: str
|
||
|
||
|
||
def _find_col_idx(headers: list[object], want: str) -> int | None:
|
||
want = want.strip()
|
||
for i, h in enumerate(headers):
|
||
if str(h).strip() == want:
|
||
return i
|
||
return None
|
||
|
||
|
||
def _cell_empty(value: object) -> bool:
|
||
if value is None:
|
||
return True
|
||
if isinstance(value, float) and math.isnan(value):
|
||
return True
|
||
return False
|
||
|
||
|
||
def _norm_product_name(name: str) -> str:
|
||
s = (name or "").strip()
|
||
if s == "一次性医用垫单":
|
||
return "一次性使用手术单(一次性医用垫单)"
|
||
return s
|
||
|
||
|
||
def load_name_to_product_code(xlsx: Path) -> dict[str, str]:
|
||
"""商品名称 -> 产品编码(白名单键为归一化后的名称)。"""
|
||
wb = load_workbook(filename=str(xlsx), read_only=True, data_only=True)
|
||
try:
|
||
ws = wb.worksheets[0]
|
||
rows = ws.iter_rows(values_only=True)
|
||
header = next(rows, None)
|
||
if header is None:
|
||
raise ValueError("Excel 为空")
|
||
headers = list(header)
|
||
i_code = _find_col_idx(headers, "产品编码")
|
||
i_name = _find_col_idx(headers, "商品名称")
|
||
if i_code is None or i_name is None:
|
||
raise ValueError("Excel 缺少「产品编码」或「商品名称」列")
|
||
|
||
m: dict[str, str] = {}
|
||
dups: set[str] = set()
|
||
for row in rows:
|
||
if not row:
|
||
continue
|
||
raw = row[i_name] if i_name < len(row) else None
|
||
if _cell_empty(raw):
|
||
continue
|
||
n = _norm_product_name(str(raw).strip())
|
||
if not n:
|
||
continue
|
||
code = row[i_code] if i_code < len(row) else None
|
||
if _cell_empty(code):
|
||
continue
|
||
sc = str(code).strip()
|
||
if n in m and m[n] != sc:
|
||
dups.add(n)
|
||
continue
|
||
if n not in m:
|
||
m[n] = sc
|
||
finally:
|
||
wb.close()
|
||
if dups:
|
||
logger.warning(
|
||
"Excel 中以下商品名称对应多组产品编码,已保留首次映射: {}",
|
||
";".join(sorted(dups)[:12]) + (" …" if len(dups) > 12 else ""),
|
||
)
|
||
return m
|
||
|
||
|
||
def collect_hand_boxes(model: YOLO, boxes) -> list[tuple[float, float, float, float]]:
|
||
if boxes is None or len(boxes) == 0:
|
||
return []
|
||
xyxy = boxes.xyxy.cpu().numpy()
|
||
cls_ids = boxes.cls.cpu().numpy().astype(int)
|
||
names = model.names
|
||
out: list[tuple[float, float, float, float]] = []
|
||
for i, c in enumerate(cls_ids):
|
||
label = str(names.get(int(c), "")).strip().lower()
|
||
if "hand" in label or label in {"手", "手部"}:
|
||
out.append(tuple(float(x) for x in xyxy[i]))
|
||
if not out and len(xyxy) > 0:
|
||
# 单类检测模型:无 hand 字样时保留全部框
|
||
for row in xyxy:
|
||
out.append(tuple(float(x) for x in row))
|
||
return out
|
||
|
||
|
||
def union_boxes(
|
||
boxes: list[tuple[float, float, float, float]],
|
||
) -> tuple[float, float, float, float]:
|
||
xs1, ys1, xs2, ys2 = zip(*boxes, strict=True)
|
||
return min(xs1), min(ys1), max(xs2), max(ys2)
|
||
|
||
|
||
def pad_box(
|
||
box: tuple[float, float, float, float],
|
||
w: int,
|
||
h: int,
|
||
pad_ratio: float,
|
||
) -> tuple[int, int, int, int]:
|
||
x1, y1, x2, y2 = box
|
||
bw, bh = x2 - x1, y2 - y1
|
||
pad_w, pad_h = bw * pad_ratio, bh * pad_ratio
|
||
nx1 = int(max(0, x1 - pad_w))
|
||
ny1 = int(max(0, y1 - pad_h))
|
||
nx2 = int(min(w, x2 + pad_w))
|
||
ny2 = int(min(h, y2 + pad_h))
|
||
return nx1, ny1, nx2, ny2
|
||
|
||
|
||
def cls_top3_from_result(
|
||
cls: YOLO, r, name_to_code: dict[str, str]
|
||
) -> ClsTop3 | None:
|
||
pr = r[0].probs
|
||
if pr is None or not hasattr(pr, "top5") or not pr.top5:
|
||
return None
|
||
t5i = list(pr.top5)
|
||
tc = pr.top5conf
|
||
if tc is None:
|
||
return None
|
||
|
||
def _ci(i: int) -> float:
|
||
if i < 0 or i >= len(tc):
|
||
return 0.0
|
||
try:
|
||
v = tc[i]
|
||
return float(v.item() if hasattr(v, "item") else v)
|
||
except (IndexError, ValueError, TypeError):
|
||
return 0.0
|
||
|
||
t1i = int(pr.top1)
|
||
c1 = _ci(0) if t5i and int(t5i[0]) == t1i else float(
|
||
pr.top1conf.item() if hasattr(pr.top1conf, "item") else pr.top1conf
|
||
)
|
||
n1 = str(cls.names.get(t1i, "")).strip()
|
||
|
||
n2 = n3 = ""
|
||
c2 = c3 = 0.0
|
||
if len(t5i) > 1:
|
||
n2 = str(cls.names.get(int(t5i[1]), "")).strip()
|
||
c2 = _ci(1)
|
||
if len(t5i) > 2:
|
||
n3 = str(cls.names.get(int(t5i[2]), "")).strip()
|
||
c3 = _ci(2)
|
||
|
||
def _pid(label: str) -> str:
|
||
lb = (label or "").strip()
|
||
if not lb:
|
||
return ""
|
||
norm = _norm_product_name(lb)
|
||
return (name_to_code.get(norm) or name_to_code.get(lb) or "").strip()
|
||
|
||
return ClsTop3(
|
||
t1_name=n1,
|
||
t1_conf=c1,
|
||
t2_name=n2,
|
||
t2_conf=c2,
|
||
t3_name=n3,
|
||
t3_conf=c3,
|
||
t1_pid=_pid(n1),
|
||
t2_pid=_pid(n2),
|
||
t3_pid=_pid(n3),
|
||
)
|
||
|
||
|
||
def cls_top3_to_prediction_result(snap: ClsTop3) -> PredictionResult:
|
||
topk: list[PredictionCandidate] = []
|
||
if snap.t1_name:
|
||
topk.append(PredictionCandidate(snap.t1_name, snap.t1_conf))
|
||
if snap.t2_name:
|
||
topk.append(PredictionCandidate(snap.t2_name, snap.t2_conf))
|
||
if snap.t3_name:
|
||
topk.append(PredictionCandidate(snap.t3_name, snap.t3_conf))
|
||
if not topk:
|
||
topk = [PredictionCandidate("", 0.0)]
|
||
return PredictionResult(
|
||
label=snap.t1_name,
|
||
confidence=snap.t1_conf,
|
||
topk=topk,
|
||
)
|
||
|
||
|
||
def _mode_lex(names: list[str]) -> str | None:
|
||
if not names:
|
||
return None
|
||
c = Counter(names)
|
||
best = max(c.values())
|
||
pool = [n for n, k in c.items() if k == best]
|
||
return min(pool)
|
||
|
||
|
||
def window_bucket_to_best_snap(
|
||
bucket_pts: list[tuple[str, ClsTop3]],
|
||
) -> ClsTop3 | None:
|
||
"""单个时间窗内:众数类名 + 该类下 top1 置信度最大的快照。"""
|
||
pick = _mode_lex([a for a, _ in bucket_pts])
|
||
if pick is None:
|
||
return None
|
||
best: ClsTop3 | None = None
|
||
for pname, sn in bucket_pts:
|
||
if pname == pick and (best is None or sn.t1_conf > best.t1_conf):
|
||
best = sn
|
||
return best
|
||
|
||
|
||
class ConsumableVisionAlgorithmService:
|
||
"""手部检测(可选)+ 耗材分类;供 CameraSessionManager 在视频线程中调用。"""
|
||
|
||
def __init__(self, app_settings: Settings | None = None) -> None:
|
||
self._s = app_settings or settings
|
||
self._det: YOLO | None = None
|
||
self._cls: YOLO | None = None
|
||
self._det_lock = Lock()
|
||
self._cls_lock = Lock()
|
||
|
||
def build_name_mapping(
|
||
self, candidate_consumables: list[str]
|
||
) -> dict[str, str]:
|
||
"""分类标签 -> 业务物品 id(Excel 产品编码;无表时用名称自身)。"""
|
||
stripped = [_norm_product_name(c.strip()) for c in candidate_consumables if c.strip()]
|
||
candidates_norm = {n: n for n in stripped}
|
||
xlsx_raw = (self._s.consumable_catalog_xlsx_path or "").strip()
|
||
if xlsx_raw:
|
||
path = Path(xlsx_raw).expanduser()
|
||
if path.is_file():
|
||
full = load_name_to_product_code(path)
|
||
out: dict[str, str] = {}
|
||
for norm in candidates_norm:
|
||
if norm in full:
|
||
out[norm] = full[norm]
|
||
return out
|
||
logger.warning("耗材目录 Excel 路径已配置但文件不存在: {}", path)
|
||
return {n: n for n in candidates_norm}
|
||
|
||
def _det_weights(self) -> Path | None:
|
||
raw = (self._s.hand_detection_weights or "").strip()
|
||
if not raw:
|
||
return None
|
||
p = Path(raw).expanduser()
|
||
return p if p.is_file() else None
|
||
|
||
def _cls_weights(self) -> Path:
|
||
raw = (self._s.consumable_classifier_weights or "").strip()
|
||
if not raw:
|
||
raise ModelNotConfiguredError(
|
||
"未配置耗材分类权重。请设置 CONSUMABLE_CLASSIFIER_WEIGHTS。"
|
||
)
|
||
p = Path(raw).expanduser().resolve()
|
||
if not p.is_file():
|
||
raise ModelNotConfiguredError(f"耗材分类权重不存在: {p}")
|
||
return p
|
||
|
||
def _get_det(self) -> YOLO | None:
|
||
path = self._det_weights()
|
||
if path is None:
|
||
return None
|
||
if self._det is None:
|
||
with self._det_lock:
|
||
if self._det is None:
|
||
logger.info("加载手部检测权重: {}", path)
|
||
self._det = YOLO(str(path))
|
||
return self._det
|
||
|
||
def _get_cls(self) -> YOLO:
|
||
if self._cls is None:
|
||
with self._cls_lock:
|
||
if self._cls is None:
|
||
path = self._cls_weights()
|
||
logger.info("加载耗材分类权重: {}", path)
|
||
self._cls = YOLO(str(path))
|
||
return self._cls
|
||
|
||
def hand_crop(
|
||
self,
|
||
frame: np.ndarray,
|
||
det_model: YOLO,
|
||
*,
|
||
det_conf: float,
|
||
pad_ratio: float,
|
||
min_crop_px: int,
|
||
imgsz_det: int,
|
||
) -> np.ndarray | None:
|
||
h, w = frame.shape[:2]
|
||
device = resolve_inference_device(self._s.hand_detection_device)
|
||
results = det_model.predict(
|
||
frame,
|
||
conf=det_conf,
|
||
imgsz=imgsz_det,
|
||
device=device,
|
||
verbose=False,
|
||
)
|
||
hand_xyxys = collect_hand_boxes(det_model, results[0].boxes)
|
||
if not hand_xyxys:
|
||
return None
|
||
merged = union_boxes(hand_xyxys)
|
||
cx1, cy1, cx2, cy2 = pad_box(merged, w, h, pad_ratio)
|
||
if (cx2 - cx1) < min_crop_px or (cy2 - cy1) < min_crop_px:
|
||
return None
|
||
return frame[cy1:cy2, cx1:cx2]
|
||
|
||
def infer_frame_bgr(
|
||
self,
|
||
frame: np.ndarray,
|
||
name_to_code: dict[str, str],
|
||
) -> ClsTop3 | None:
|
||
"""单帧 BGR;仅当 top1 通过置信度且落在白名单(name_to_code 键)时返回。"""
|
||
whitelist = set(name_to_code.keys())
|
||
det_model = self._get_det()
|
||
cls_model = self._get_cls()
|
||
|
||
if det_model is not None:
|
||
crop = self.hand_crop(
|
||
frame,
|
||
det_model,
|
||
det_conf=self._s.hand_detection_conf,
|
||
pad_ratio=self._s.hand_detection_pad_ratio,
|
||
min_crop_px=self._s.hand_detection_min_crop_px,
|
||
imgsz_det=self._s.hand_detection_imgsz,
|
||
)
|
||
if crop is None:
|
||
return None
|
||
else:
|
||
crop = frame
|
||
|
||
device = resolve_inference_device(self._s.consumable_classifier_device)
|
||
try:
|
||
r = cls_model.predict(
|
||
crop,
|
||
imgsz=self._s.consumable_classifier_imgsz,
|
||
device=device,
|
||
verbose=False,
|
||
)
|
||
except Exception as exc:
|
||
raise PredictionError(f"耗材分类推理失败: {exc}") from exc
|
||
|
||
snap = cls_top3_from_result(cls_model, r, name_to_code)
|
||
if snap is None:
|
||
return None
|
||
if snap.t1_conf < self._s.consumable_min_cls_confidence:
|
||
return None
|
||
pname = snap.t1_name
|
||
if not pname or pname not in whitelist:
|
||
return None
|
||
return snap
|