refactor: 统一耗材视觉算法并扩展语音确认至全量候选清单
- 以 ConsumableVisionAlgorithmService 替代 consumable_classifier 与 tear_action; 可选手部检测权重,未配置时全帧分类;时间窗众数与 Excel 白名单配置。 - 语音待确认:ASR 先匹配 pending topk,再匹配本台 candidate_consumables; 记账 item_id 与 vision 一致使用 name_to_code。 - 更新 config、Compose、.env.example、依赖(pandas/openpyxl)与测试。 Made-with: Cursor
This commit is contained in:
@@ -14,11 +14,6 @@ def _default_consumable_classifier_weights() -> str:
|
||||
return str(_PACKAGE_DIR / "resources" / "consumable_classifier.pt")
|
||||
|
||||
|
||||
def _default_tear_action_weights() -> str:
|
||||
"""撕扯耗材动作识别:`app/resources/tear_action.pt`。"""
|
||||
return str(_PACKAGE_DIR / "resources" / "tear_action.pt")
|
||||
|
||||
|
||||
def _default_camera_rtsp_urls_sample_path() -> str:
|
||||
"""示例映射路径(可复制为自有 `camera_rtsp_urls.json` 后在环境变量中引用)。"""
|
||||
return str(_PACKAGE_DIR / "resources" / "camera_rtsp_urls.sample.json")
|
||||
@@ -38,10 +33,19 @@ class Settings(BaseSettings):
|
||||
#: Explicit Ultralytics device (e.g. cpu, mps, cuda:0). Empty -> macOS prefers MPS; Linux prefers CUDA if available.
|
||||
consumable_classifier_device: str = ""
|
||||
consumable_classifier_topk: int = 5
|
||||
tear_action_weights: str | None = None
|
||||
tear_action_imgsz: int = 224
|
||||
tear_action_device: str = ""
|
||||
tear_action_topk: int = 5
|
||||
#: 耗材分类 top1 最低置信度(手部 ROI 或全帧送入分类器后的门槛)。
|
||||
consumable_min_cls_confidence: float = Field(default=0.5, ge=0.0, le=1.0)
|
||||
#: 可选:`视频中的商品信息表.xlsx`(含「商品名称」「产品编码」);空则物品 id 用名称本身。
|
||||
consumable_catalog_xlsx_path: str = ""
|
||||
#: 与离线脚本一致的时间窗(秒);窗内多次推理取众数后再走自动记账 / 语音追问逻辑。
|
||||
consumable_vision_window_sec: float = Field(default=15.0, ge=0.5, le=600.0)
|
||||
#: 手部检测 YOLO 权重;空或文件不存在时退化为「全帧送分类器」(兼容仅有关分类权重的环境)。
|
||||
hand_detection_weights: str = ""
|
||||
hand_detection_imgsz: int = Field(default=640, ge=32, le=4096)
|
||||
hand_detection_conf: float = Field(default=0.25, ge=0.0, le=1.0)
|
||||
hand_detection_pad_ratio: float = Field(default=0.30, ge=0.0, le=2.0)
|
||||
hand_detection_min_crop_px: int = Field(default=64, ge=8, le=4096)
|
||||
hand_detection_device: str = ""
|
||||
#: 开始/结束手术时调用录制流水线的最大尝试次数(含首次)。
|
||||
surgery_recording_max_attempts: int = Field(default=3, ge=1, le=20)
|
||||
#: 两次尝试之间的等待秒数。
|
||||
@@ -138,13 +142,6 @@ class Settings(BaseSettings):
|
||||
return _default_consumable_classifier_weights()
|
||||
return str(value)
|
||||
|
||||
@field_validator("tear_action_weights", mode="before")
|
||||
@classmethod
|
||||
def tear_action_weights_default(cls, value: object) -> str:
|
||||
if value is None or value == "":
|
||||
return _default_tear_action_weights()
|
||||
return str(value)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
|
||||
@@ -3,17 +3,15 @@ from loguru import logger
|
||||
from app.config import settings
|
||||
from app.repositories.surgery_results import SurgeryResultRepository
|
||||
from app.repositories.voice_audits import VoiceAuditRepository
|
||||
from app.services.consumable_classifier import ConsumableClassifierService
|
||||
from app.services.baidu_speech import BaiduSpeechService
|
||||
from app.services.consumable_vision_algorithm import ConsumableVisionAlgorithmService
|
||||
from app.services.minio_audio_storage import MinioAudioStorageService
|
||||
from app.services.surgery_pipeline import SurgeryPipeline
|
||||
from app.services.voice_resolution import VoiceConfirmationService
|
||||
from app.services.tear_action import TearActionService
|
||||
from app.services.video.hikvision_runtime import HikvisionRuntime
|
||||
from app.services.video.session_manager import CameraSessionManager
|
||||
|
||||
consumable_classifier_service = ConsumableClassifierService()
|
||||
tear_action_service = TearActionService()
|
||||
consumable_vision_algorithm_service = ConsumableVisionAlgorithmService()
|
||||
|
||||
hikvision_runtime = HikvisionRuntime.try_load(settings.hikvision_lib_dir)
|
||||
if settings.hikvision_sdk_enabled and hikvision_runtime is None:
|
||||
@@ -29,8 +27,7 @@ minio_audio_storage_service = MinioAudioStorageService(settings)
|
||||
|
||||
camera_session_manager = CameraSessionManager(
|
||||
settings=settings,
|
||||
consumable_classifier=consumable_classifier_service,
|
||||
tear_action=tear_action_service,
|
||||
vision_algorithm=consumable_vision_algorithm_service,
|
||||
hikvision_runtime=hikvision_runtime,
|
||||
result_repository=surgery_result_repository,
|
||||
)
|
||||
@@ -49,12 +46,8 @@ surgery_pipeline = SurgeryPipeline(
|
||||
)
|
||||
|
||||
|
||||
def get_consumable_classifier_service() -> ConsumableClassifierService:
|
||||
return consumable_classifier_service
|
||||
|
||||
|
||||
def get_tear_action_service() -> TearActionService:
|
||||
return tear_action_service
|
||||
def get_consumable_vision_algorithm_service() -> ConsumableVisionAlgorithmService:
|
||||
return consumable_vision_algorithm_service
|
||||
|
||||
|
||||
def get_surgery_pipeline() -> SurgeryPipeline:
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import numpy as np
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from loguru import logger
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def resolve_classifier_inference_device(explicit: str) -> str | None:
|
||||
"""Ultralytics `device` string. If unset: macOS prefers MPS; Linux/Windows prefer CUDA when available."""
|
||||
configured = (explicit or "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
if sys.platform == "darwin":
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return None
|
||||
if torch.cuda.is_available():
|
||||
return "cuda:0"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredictionCandidate:
|
||||
label: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredictionResult:
|
||||
label: str
|
||||
confidence: float
|
||||
topk: list[PredictionCandidate]
|
||||
|
||||
|
||||
class ModelNotConfiguredError(RuntimeError):
|
||||
"""Raised when the model weights are not configured or missing."""
|
||||
|
||||
|
||||
class InvalidImageError(ValueError):
|
||||
"""Raised when uploaded bytes cannot be decoded as an image."""
|
||||
|
||||
|
||||
class PredictionError(RuntimeError):
|
||||
"""Raised when the model cannot produce a prediction."""
|
||||
|
||||
|
||||
class ConsumableClassifierService:
|
||||
"""耗材识别与分类(YOLO-cls):判断画面中的耗材类别;与撕扯动作模型 `TearActionService` 分离。内部流水线调用,不对外 HTTP。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._model: YOLO | None = None
|
||||
self._model_lock = Lock()
|
||||
|
||||
@property
|
||||
def weights_path(self) -> Path | None:
|
||||
if not settings.consumable_classifier_weights:
|
||||
return None
|
||||
return Path(settings.consumable_classifier_weights).expanduser()
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return self.weights_path is not None
|
||||
|
||||
@property
|
||||
def weights_found(self) -> bool:
|
||||
path = self.weights_path
|
||||
return path is not None and path.is_file()
|
||||
|
||||
@property
|
||||
def model_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
async def predict_image_bytes(
|
||||
self,
|
||||
payload: bytes,
|
||||
*,
|
||||
topk: int | None = None,
|
||||
) -> PredictionResult:
|
||||
return await run_in_threadpool(self._predict_image_bytes, payload, topk)
|
||||
|
||||
def _predict_image_bytes(
|
||||
self,
|
||||
payload: bytes,
|
||||
topk: int | None,
|
||||
) -> PredictionResult:
|
||||
model = self._get_model()
|
||||
image = self._decode_image(payload)
|
||||
|
||||
try:
|
||||
result = model.predict(
|
||||
image,
|
||||
imgsz=settings.consumable_classifier_imgsz,
|
||||
device=resolve_classifier_inference_device(
|
||||
settings.consumable_classifier_device
|
||||
),
|
||||
verbose=False,
|
||||
)[0]
|
||||
except Exception as exc: # pragma: no cover - ultralytics runtime errors vary.
|
||||
raise PredictionError(
|
||||
f"Failed to run consumable classifier inference: {exc}"
|
||||
) from exc
|
||||
|
||||
return self._build_prediction_result(result, model, topk=topk)
|
||||
|
||||
def _get_model(self) -> YOLO:
|
||||
path = self.weights_path
|
||||
if path is None:
|
||||
raise ModelNotConfiguredError(
|
||||
"Consumable classifier weights are not configured. "
|
||||
"Set CONSUMABLE_CLASSIFIER_WEIGHTS."
|
||||
)
|
||||
|
||||
path = path.resolve()
|
||||
if not path.is_file():
|
||||
raise ModelNotConfiguredError(
|
||||
f"Consumable classifier weights not found: {path}"
|
||||
)
|
||||
|
||||
if self._model is None:
|
||||
with self._model_lock:
|
||||
if self._model is None:
|
||||
logger.info("Loading consumable classifier weights from {}", path)
|
||||
self._model = YOLO(str(path))
|
||||
|
||||
return self._model
|
||||
|
||||
def _decode_image(self, payload: bytes) -> np.ndarray:
|
||||
if not payload:
|
||||
raise InvalidImageError("Uploaded image is empty.")
|
||||
|
||||
try:
|
||||
with Image.open(BytesIO(payload)) as image:
|
||||
return np.asarray(image.convert("RGB"))
|
||||
except (UnidentifiedImageError, OSError) as exc:
|
||||
raise InvalidImageError("Uploaded file is not a valid image.") from exc
|
||||
|
||||
def _build_prediction_result(
|
||||
self,
|
||||
result: object,
|
||||
model: YOLO,
|
||||
*,
|
||||
topk: int | None,
|
||||
) -> PredictionResult:
|
||||
probs = getattr(result, "probs", None)
|
||||
data = getattr(probs, "data", None)
|
||||
if probs is None or data is None:
|
||||
raise PredictionError("Model did not return classification probabilities.")
|
||||
|
||||
scores = data.tolist()
|
||||
if not isinstance(scores, list):
|
||||
scores = [float(scores)]
|
||||
|
||||
names = self._names(model)
|
||||
limit = max(1, topk or settings.consumable_classifier_topk)
|
||||
ranked = sorted(
|
||||
((index, float(score)) for index, score in enumerate(scores)),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
if not ranked:
|
||||
raise PredictionError("Model returned an empty prediction result.")
|
||||
|
||||
candidates = [
|
||||
PredictionCandidate(
|
||||
label=names.get(index, str(index)),
|
||||
confidence=confidence,
|
||||
)
|
||||
for index, confidence in ranked
|
||||
]
|
||||
return PredictionResult(
|
||||
label=candidates[0].label,
|
||||
confidence=candidates[0].confidence,
|
||||
topk=candidates,
|
||||
)
|
||||
|
||||
def _names(self, model: YOLO) -> dict[int, str]:
|
||||
raw = getattr(model.model, "names", None) or {}
|
||||
return {int(key): str(value) for key, value in raw.items()}
|
||||
392
app/services/consumable_vision_algorithm.py
Normal file
392
app/services/consumable_vision_algorithm.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""手术室耗材视觉算法:可选手部检测 ROI + YOLO-cls(原离线双机位流水线核心逻辑)。
|
||||
|
||||
作为 FastAPI 内唯一的视频推理入口;撕扯动作分类已移除,由手部检测 + 耗材分类替代。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from ultralytics import YOLO
|
||||
|
||||
from app.config import Settings, settings
|
||||
|
||||
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
||||
|
||||
|
||||
def resolve_inference_device(explicit: str) -> str | None:
|
||||
"""Ultralytics `device`;空则 macOS 优先 MPS,Linux/Windows 优先 CUDA。"""
|
||||
configured = (explicit or "").strip()
|
||||
if configured:
|
||||
return configured
|
||||
try:
|
||||
import torch
|
||||
except Exception:
|
||||
return None
|
||||
if sys.platform == "darwin":
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
return None
|
||||
if torch.cuda.is_available():
|
||||
return "cuda:0"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredictionCandidate:
|
||||
label: str
|
||||
confidence: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PredictionResult:
|
||||
label: str
|
||||
confidence: float
|
||||
topk: list[PredictionCandidate]
|
||||
|
||||
|
||||
class ModelNotConfiguredError(RuntimeError):
|
||||
"""权重未配置或文件不存在。"""
|
||||
|
||||
|
||||
class PredictionError(RuntimeError):
|
||||
"""推理失败。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClsTop3:
|
||||
t1_name: str
|
||||
t1_conf: float
|
||||
t2_name: str
|
||||
t2_conf: float
|
||||
t3_name: str
|
||||
t3_conf: float
|
||||
t1_pid: str
|
||||
t2_pid: str
|
||||
t3_pid: str
|
||||
|
||||
|
||||
def _find_col(df: pd.DataFrame, want: str) -> str | None:
|
||||
want = want.strip()
|
||||
for c in df.columns:
|
||||
if str(c).strip() == want:
|
||||
return str(c)
|
||||
return None
|
||||
|
||||
|
||||
def _norm_product_name(name: str) -> str:
|
||||
s = (name or "").strip()
|
||||
if s == "一次性医用垫单":
|
||||
return "一次性使用手术单(一次性医用垫单)"
|
||||
return s
|
||||
|
||||
|
||||
def load_name_to_product_code(xlsx: Path) -> dict[str, str]:
|
||||
"""商品名称 -> 产品编码(白名单键为归一化后的名称)。"""
|
||||
df = pd.read_excel(xlsx, sheet_name=0)
|
||||
c_code = _find_col(df, "产品编码")
|
||||
c_name = _find_col(df, "商品名称")
|
||||
if c_code is None or c_name is None:
|
||||
raise ValueError("Excel 缺少「产品编码」或「商品名称」列")
|
||||
m: dict[str, str] = {}
|
||||
dups: set[str] = set()
|
||||
for _, row in df.iterrows():
|
||||
raw = row.get(c_name)
|
||||
if raw is None or (isinstance(raw, float) and pd.isna(raw)):
|
||||
continue
|
||||
n = _norm_product_name(str(raw).strip())
|
||||
if not n:
|
||||
continue
|
||||
code = row.get(c_code)
|
||||
if code is None or (isinstance(code, float) and pd.isna(code)):
|
||||
continue
|
||||
sc = str(code).strip()
|
||||
if n in m and m[n] != sc:
|
||||
dups.add(n)
|
||||
continue
|
||||
if n not in m:
|
||||
m[n] = sc
|
||||
if dups:
|
||||
logger.warning(
|
||||
"Excel 中以下商品名称对应多组产品编码,已保留首次映射: {}",
|
||||
";".join(sorted(dups)[:12]) + (" …" if len(dups) > 12 else ""),
|
||||
)
|
||||
return m
|
||||
|
||||
|
||||
def collect_hand_boxes(model: YOLO, boxes) -> list[tuple[float, float, float, float]]:
|
||||
if boxes is None or len(boxes) == 0:
|
||||
return []
|
||||
xyxy = boxes.xyxy.cpu().numpy()
|
||||
cls_ids = boxes.cls.cpu().numpy().astype(int)
|
||||
names = model.names
|
||||
out: list[tuple[float, float, float, float]] = []
|
||||
for i, c in enumerate(cls_ids):
|
||||
label = str(names.get(int(c), "")).strip().lower()
|
||||
if "hand" in label or label in {"手", "手部"}:
|
||||
out.append(tuple(float(x) for x in xyxy[i]))
|
||||
if not out and len(xyxy) > 0:
|
||||
# 单类检测模型:无 hand 字样时保留全部框
|
||||
for row in xyxy:
|
||||
out.append(tuple(float(x) for x in row))
|
||||
return out
|
||||
|
||||
|
||||
def union_boxes(
|
||||
boxes: list[tuple[float, float, float, float]],
|
||||
) -> tuple[float, float, float, float]:
|
||||
xs1, ys1, xs2, ys2 = zip(*boxes, strict=True)
|
||||
return min(xs1), min(ys1), max(xs2), max(ys2)
|
||||
|
||||
|
||||
def pad_box(
|
||||
box: tuple[float, float, float, float],
|
||||
w: int,
|
||||
h: int,
|
||||
pad_ratio: float,
|
||||
) -> tuple[int, int, int, int]:
|
||||
x1, y1, x2, y2 = box
|
||||
bw, bh = x2 - x1, y2 - y1
|
||||
pad_w, pad_h = bw * pad_ratio, bh * pad_ratio
|
||||
nx1 = int(max(0, x1 - pad_w))
|
||||
ny1 = int(max(0, y1 - pad_h))
|
||||
nx2 = int(min(w, x2 + pad_w))
|
||||
ny2 = int(min(h, y2 + pad_h))
|
||||
return nx1, ny1, nx2, ny2
|
||||
|
||||
|
||||
def cls_top3_from_result(
|
||||
cls: YOLO, r, name_to_code: dict[str, str]
|
||||
) -> ClsTop3 | None:
|
||||
pr = r[0].probs
|
||||
if pr is None or not hasattr(pr, "top5") or not pr.top5:
|
||||
return None
|
||||
t5i = list(pr.top5)
|
||||
tc = pr.top5conf
|
||||
if tc is None:
|
||||
return None
|
||||
|
||||
def _ci(i: int) -> float:
|
||||
if i < 0 or i >= len(tc):
|
||||
return 0.0
|
||||
try:
|
||||
v = tc[i]
|
||||
return float(v.item() if hasattr(v, "item") else v)
|
||||
except (IndexError, ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
t1i = int(pr.top1)
|
||||
c1 = _ci(0) if t5i and int(t5i[0]) == t1i else float(
|
||||
pr.top1conf.item() if hasattr(pr.top1conf, "item") else pr.top1conf
|
||||
)
|
||||
n1 = str(cls.names.get(t1i, "")).strip()
|
||||
|
||||
n2 = n3 = ""
|
||||
c2 = c3 = 0.0
|
||||
if len(t5i) > 1:
|
||||
n2 = str(cls.names.get(int(t5i[1]), "")).strip()
|
||||
c2 = _ci(1)
|
||||
if len(t5i) > 2:
|
||||
n3 = str(cls.names.get(int(t5i[2]), "")).strip()
|
||||
c3 = _ci(2)
|
||||
|
||||
return ClsTop3(
|
||||
t1_name=n1,
|
||||
t1_conf=c1,
|
||||
t2_name=n2,
|
||||
t2_conf=c2,
|
||||
t3_name=n3,
|
||||
t3_conf=c3,
|
||||
t1_pid=name_to_code.get(n1, ""),
|
||||
t2_pid=name_to_code.get(n2, ""),
|
||||
t3_pid=name_to_code.get(n3, ""),
|
||||
)
|
||||
|
||||
|
||||
def cls_top3_to_prediction_result(snap: ClsTop3) -> PredictionResult:
|
||||
topk: list[PredictionCandidate] = []
|
||||
if snap.t1_name:
|
||||
topk.append(PredictionCandidate(snap.t1_name, snap.t1_conf))
|
||||
if snap.t2_name:
|
||||
topk.append(PredictionCandidate(snap.t2_name, snap.t2_conf))
|
||||
if snap.t3_name:
|
||||
topk.append(PredictionCandidate(snap.t3_name, snap.t3_conf))
|
||||
if not topk:
|
||||
topk = [PredictionCandidate("", 0.0)]
|
||||
return PredictionResult(
|
||||
label=snap.t1_name,
|
||||
confidence=snap.t1_conf,
|
||||
topk=topk,
|
||||
)
|
||||
|
||||
|
||||
def _mode_lex(names: list[str]) -> str | None:
|
||||
if not names:
|
||||
return None
|
||||
c = Counter(names)
|
||||
best = max(c.values())
|
||||
pool = [n for n, k in c.items() if k == best]
|
||||
return min(pool)
|
||||
|
||||
|
||||
def window_bucket_to_best_snap(
|
||||
bucket_pts: list[tuple[str, ClsTop3]],
|
||||
) -> ClsTop3 | None:
|
||||
"""单个时间窗内:众数类名 + 该类下 top1 置信度最大的快照。"""
|
||||
pick = _mode_lex([a for a, _ in bucket_pts])
|
||||
if pick is None:
|
||||
return None
|
||||
best: ClsTop3 | None = None
|
||||
for pname, sn in bucket_pts:
|
||||
if pname == pick and (best is None or sn.t1_conf > best.t1_conf):
|
||||
best = sn
|
||||
return best
|
||||
|
||||
|
||||
class ConsumableVisionAlgorithmService:
|
||||
"""手部检测(可选)+ 耗材分类;供 CameraSessionManager 在视频线程中调用。"""
|
||||
|
||||
def __init__(self, app_settings: Settings | None = None) -> None:
|
||||
self._s = app_settings or settings
|
||||
self._det: YOLO | None = None
|
||||
self._cls: YOLO | None = None
|
||||
self._det_lock = Lock()
|
||||
self._cls_lock = Lock()
|
||||
|
||||
def build_name_mapping(
|
||||
self, candidate_consumables: list[str]
|
||||
) -> dict[str, str]:
|
||||
"""分类标签 -> 业务物品 id(Excel 产品编码;无表时用名称自身)。"""
|
||||
stripped = [_norm_product_name(c.strip()) for c in candidate_consumables if c.strip()]
|
||||
candidates_norm = {n: n for n in stripped}
|
||||
xlsx_raw = (self._s.consumable_catalog_xlsx_path or "").strip()
|
||||
if xlsx_raw:
|
||||
path = Path(xlsx_raw).expanduser()
|
||||
if path.is_file():
|
||||
full = load_name_to_product_code(path)
|
||||
out: dict[str, str] = {}
|
||||
for norm in candidates_norm:
|
||||
if norm in full:
|
||||
out[norm] = full[norm]
|
||||
return out
|
||||
logger.warning("耗材目录 Excel 路径已配置但文件不存在: {}", path)
|
||||
return {n: n for n in candidates_norm}
|
||||
|
||||
def _det_weights(self) -> Path | None:
|
||||
raw = (self._s.hand_detection_weights or "").strip()
|
||||
if not raw:
|
||||
return None
|
||||
p = Path(raw).expanduser()
|
||||
return p if p.is_file() else None
|
||||
|
||||
def _cls_weights(self) -> Path:
|
||||
raw = (self._s.consumable_classifier_weights or "").strip()
|
||||
if not raw:
|
||||
raise ModelNotConfiguredError(
|
||||
"未配置耗材分类权重。请设置 CONSUMABLE_CLASSIFIER_WEIGHTS。"
|
||||
)
|
||||
p = Path(raw).expanduser().resolve()
|
||||
if not p.is_file():
|
||||
raise ModelNotConfiguredError(f"耗材分类权重不存在: {p}")
|
||||
return p
|
||||
|
||||
def _get_det(self) -> YOLO | None:
|
||||
path = self._det_weights()
|
||||
if path is None:
|
||||
return None
|
||||
if self._det is None:
|
||||
with self._det_lock:
|
||||
if self._det is None:
|
||||
logger.info("加载手部检测权重: {}", path)
|
||||
self._det = YOLO(str(path))
|
||||
return self._det
|
||||
|
||||
def _get_cls(self) -> YOLO:
|
||||
if self._cls is None:
|
||||
with self._cls_lock:
|
||||
if self._cls is None:
|
||||
path = self._cls_weights()
|
||||
logger.info("加载耗材分类权重: {}", path)
|
||||
self._cls = YOLO(str(path))
|
||||
return self._cls
|
||||
|
||||
def hand_crop(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
det_model: YOLO,
|
||||
*,
|
||||
det_conf: float,
|
||||
pad_ratio: float,
|
||||
min_crop_px: int,
|
||||
imgsz_det: int,
|
||||
) -> np.ndarray | None:
|
||||
h, w = frame.shape[:2]
|
||||
device = resolve_inference_device(self._s.hand_detection_device)
|
||||
results = det_model.predict(
|
||||
frame,
|
||||
conf=det_conf,
|
||||
imgsz=imgsz_det,
|
||||
device=device,
|
||||
verbose=False,
|
||||
)
|
||||
hand_xyxys = collect_hand_boxes(det_model, results[0].boxes)
|
||||
if not hand_xyxys:
|
||||
return None
|
||||
merged = union_boxes(hand_xyxys)
|
||||
cx1, cy1, cx2, cy2 = pad_box(merged, w, h, pad_ratio)
|
||||
if (cx2 - cx1) < min_crop_px or (cy2 - cy1) < min_crop_px:
|
||||
return None
|
||||
return frame[cy1:cy2, cx1:cx2]
|
||||
|
||||
def infer_frame_bgr(
|
||||
self,
|
||||
frame: np.ndarray,
|
||||
name_to_code: dict[str, str],
|
||||
) -> ClsTop3 | None:
|
||||
"""单帧 BGR;仅当 top1 通过置信度且落在白名单(name_to_code 键)时返回。"""
|
||||
whitelist = set(name_to_code.keys())
|
||||
det_model = self._get_det()
|
||||
cls_model = self._get_cls()
|
||||
|
||||
if det_model is not None:
|
||||
crop = self.hand_crop(
|
||||
frame,
|
||||
det_model,
|
||||
det_conf=self._s.hand_detection_conf,
|
||||
pad_ratio=self._s.hand_detection_pad_ratio,
|
||||
min_crop_px=self._s.hand_detection_min_crop_px,
|
||||
imgsz_det=self._s.hand_detection_imgsz,
|
||||
)
|
||||
if crop is None:
|
||||
return None
|
||||
else:
|
||||
crop = frame
|
||||
|
||||
device = resolve_inference_device(self._s.consumable_classifier_device)
|
||||
try:
|
||||
r = cls_model.predict(
|
||||
crop,
|
||||
imgsz=self._s.consumable_classifier_imgsz,
|
||||
device=device,
|
||||
verbose=False,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise PredictionError(f"耗材分类推理失败: {exc}") from exc
|
||||
|
||||
snap = cls_top3_from_result(cls_model, r, name_to_code)
|
||||
if snap is None:
|
||||
return None
|
||||
if snap.t1_conf < self._s.consumable_min_cls_confidence:
|
||||
return None
|
||||
pname = snap.t1_name
|
||||
if not pname or pname not in whitelist:
|
||||
return None
|
||||
return snap
|
||||
@@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import os
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import numpy as np
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from loguru import logger
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
|
||||
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
from app.config import settings
|
||||
from app.services.consumable_classifier import (
|
||||
InvalidImageError,
|
||||
ModelNotConfiguredError,
|
||||
PredictionCandidate,
|
||||
PredictionError,
|
||||
PredictionResult,
|
||||
resolve_classifier_inference_device,
|
||||
)
|
||||
|
||||
|
||||
class TearActionService:
|
||||
"""撕扯耗材动作识别(独立权重):判断是否存在/如何撕扯耗材等行为;与耗材分类 `ConsumableClassifierService` 分离。内部流水线调用,不对外 HTTP。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._model: YOLO | None = None
|
||||
self._model_lock = Lock()
|
||||
|
||||
@property
|
||||
def weights_path(self) -> Path | None:
|
||||
if not settings.tear_action_weights:
|
||||
return None
|
||||
return Path(settings.tear_action_weights).expanduser()
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return self.weights_path is not None
|
||||
|
||||
@property
|
||||
def weights_found(self) -> bool:
|
||||
path = self.weights_path
|
||||
return path is not None and path.is_file()
|
||||
|
||||
@property
|
||||
def model_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
async def predict_image_bytes(
|
||||
self,
|
||||
payload: bytes,
|
||||
*,
|
||||
topk: int | None = None,
|
||||
) -> PredictionResult:
|
||||
return await run_in_threadpool(self._predict_image_bytes, payload, topk)
|
||||
|
||||
def _predict_image_bytes(
|
||||
self,
|
||||
payload: bytes,
|
||||
topk: int | None,
|
||||
) -> PredictionResult:
|
||||
model = self._get_model()
|
||||
image = self._decode_image(payload)
|
||||
|
||||
try:
|
||||
result = model.predict(
|
||||
image,
|
||||
imgsz=settings.tear_action_imgsz,
|
||||
device=resolve_classifier_inference_device(settings.tear_action_device),
|
||||
verbose=False,
|
||||
)[0]
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise PredictionError(
|
||||
f"Failed to run tear-action inference: {exc}"
|
||||
) from exc
|
||||
|
||||
return self._build_prediction_result(result, model, topk=topk)
|
||||
|
||||
def _get_model(self) -> YOLO:
|
||||
path = self.weights_path
|
||||
if path is None:
|
||||
raise ModelNotConfiguredError(
|
||||
"Tear-action weights are not configured. Set TEAR_ACTION_WEIGHTS."
|
||||
)
|
||||
|
||||
path = path.resolve()
|
||||
if not path.is_file():
|
||||
raise ModelNotConfiguredError(f"Tear-action weights not found: {path}")
|
||||
|
||||
if self._model is None:
|
||||
with self._model_lock:
|
||||
if self._model is None:
|
||||
logger.info("Loading tear-action weights from {}", path)
|
||||
self._model = YOLO(str(path))
|
||||
|
||||
return self._model
|
||||
|
||||
def _decode_image(self, payload: bytes) -> np.ndarray:
|
||||
if not payload:
|
||||
raise InvalidImageError("Uploaded image is empty.")
|
||||
|
||||
try:
|
||||
with Image.open(BytesIO(payload)) as image:
|
||||
return np.asarray(image.convert("RGB"))
|
||||
except (UnidentifiedImageError, OSError) as exc:
|
||||
raise InvalidImageError("Uploaded file is not a valid image.") from exc
|
||||
|
||||
def _build_prediction_result(
|
||||
self,
|
||||
result: object,
|
||||
model: YOLO,
|
||||
*,
|
||||
topk: int | None,
|
||||
) -> PredictionResult:
|
||||
probs = getattr(result, "probs", None)
|
||||
data = getattr(probs, "data", None)
|
||||
if probs is None or data is None:
|
||||
raise PredictionError("Model did not return classification probabilities.")
|
||||
|
||||
scores = data.tolist()
|
||||
if not isinstance(scores, list):
|
||||
scores = [float(scores)]
|
||||
|
||||
names = self._names(model)
|
||||
limit = max(1, topk or settings.tear_action_topk)
|
||||
ranked = sorted(
|
||||
((index, float(score)) for index, score in enumerate(scores)),
|
||||
key=lambda item: item[1],
|
||||
reverse=True,
|
||||
)[:limit]
|
||||
|
||||
if not ranked:
|
||||
raise PredictionError("Model returned an empty prediction result.")
|
||||
|
||||
candidates = [
|
||||
PredictionCandidate(
|
||||
label=names.get(index, str(index)),
|
||||
confidence=confidence,
|
||||
)
|
||||
for index, confidence in ranked
|
||||
]
|
||||
return PredictionResult(
|
||||
label=candidates[0].label,
|
||||
confidence=candidates[0].confidence,
|
||||
topk=candidates,
|
||||
)
|
||||
|
||||
def _names(self, model: YOLO) -> dict[int, str]:
|
||||
raw = getattr(model.model, "names", None) or {}
|
||||
return {int(key): str(value) for key, value in raw.items()}
|
||||
@@ -13,14 +13,16 @@ from app.config import Settings
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.repositories.surgery_results import SurgeryResultRepository
|
||||
from app.schemas import SurgeryConsumptionDetail
|
||||
from app.services.consumable_classifier import (
|
||||
ConsumableClassifierService,
|
||||
from app.services.consumable_vision_algorithm import (
|
||||
ClsTop3,
|
||||
ConsumableVisionAlgorithmService,
|
||||
PredictionCandidate,
|
||||
PredictionResult,
|
||||
_norm_product_name,
|
||||
cls_top3_to_prediction_result,
|
||||
window_bucket_to_best_snap,
|
||||
)
|
||||
from app.services.tear_action import TearActionService
|
||||
from app.services.video.backend_resolver import BackendResolver
|
||||
from app.services.video.frame_encode import frame_to_jpeg_bytes
|
||||
from app.services.video.hikvision_runtime import HikvisionInitRefCount, HikvisionRuntime
|
||||
from app.services.video.rtsp_capture import RtspCapture
|
||||
from app.services.video.types import VideoBackendKind
|
||||
@@ -41,9 +43,21 @@ class PendingConsumableConfirmation:
|
||||
model_top1_confidence: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class CameraStreamInferState:
|
||||
"""单路视频上的时间窗投票(与离线算法一致)。"""
|
||||
|
||||
votes: list[tuple[float, str, ClsTop3]] = field(default_factory=list)
|
||||
stream_t0: float | None = None
|
||||
next_bucket: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SurgerySessionState:
|
||||
candidate_consumables: list[str]
|
||||
#: 分类类名(归一化) -> 业务物品 id(Excel 产品编码或名称)。
|
||||
name_to_code: dict[str, str] = field(default_factory=dict)
|
||||
camera_infer: dict[str, CameraStreamInferState] = field(default_factory=dict)
|
||||
details: list[SurgeryConsumptionDetail] = field(default_factory=list)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
ready: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
@@ -94,14 +108,12 @@ class CameraSessionManager:
|
||||
self,
|
||||
*,
|
||||
settings: Settings,
|
||||
consumable_classifier: ConsumableClassifierService,
|
||||
tear_action: TearActionService,
|
||||
vision_algorithm: ConsumableVisionAlgorithmService,
|
||||
hikvision_runtime: HikvisionRuntime | None,
|
||||
result_repository: SurgeryResultRepository | None = None,
|
||||
) -> None:
|
||||
self._s = settings
|
||||
self._classifier = consumable_classifier
|
||||
self._tear = tear_action
|
||||
self._vision = vision_algorithm
|
||||
self._hik = hikvision_runtime
|
||||
self._repo = result_repository
|
||||
self._resolver = BackendResolver(settings, hikvision_runtime=hikvision_runtime)
|
||||
@@ -221,8 +233,10 @@ class CameraSessionManager:
|
||||
"该手术号存在尚未写入数据库的历史结果,请修复数据库或等待自动重试成功后再开始。",
|
||||
)
|
||||
|
||||
name_to_code = self._vision.build_name_mapping(candidate_consumables)
|
||||
state = SurgerySessionState(
|
||||
candidate_consumables=list(candidate_consumables),
|
||||
name_to_code=name_to_code,
|
||||
)
|
||||
stop_event = asyncio.Event()
|
||||
readies = [asyncio.Event() for _ in camera_ids]
|
||||
@@ -388,6 +402,12 @@ class CameraSessionManager:
|
||||
return None
|
||||
return p
|
||||
|
||||
def get_surgery_candidate_consumables(self, surgery_id: str) -> list[str]:
|
||||
"""本台手术开始手术时传入的耗材候选清单(语音可任选其中一项,不限于模型 topk)。"""
|
||||
if surgery_id not in self._active:
|
||||
return []
|
||||
return list(self._active[surgery_id].state.candidate_consumables)
|
||||
|
||||
def next_pending_confirmation(
|
||||
self, surgery_id: str
|
||||
) -> PendingConsumableConfirmation | None:
|
||||
@@ -436,20 +456,23 @@ class CameraSessionManager:
|
||||
"CONFIRMATION_INVALID",
|
||||
"请提供 chosen_label 或设置 rejected=true。",
|
||||
)
|
||||
allowed = {lbl.strip() for lbl, _ in pending.options if lbl.strip()}
|
||||
allowed_pending = {lbl.strip() for lbl, _ in pending.options if lbl.strip()}
|
||||
allowed_surgery = {c.strip() for c in st.candidate_consumables if c.strip()}
|
||||
if rejected:
|
||||
pending.status = "rejected"
|
||||
else:
|
||||
label = chosen_label.strip() if chosen_label else ""
|
||||
if label not in allowed:
|
||||
if label not in allowed_pending and label not in allowed_surgery:
|
||||
raise SurgeryPipelineError(
|
||||
"CONFIRMATION_INVALID",
|
||||
f"所选耗材不在候选列表中:{chosen_label!r}",
|
||||
f"所选耗材不在本台手术候选清单或本次追问选项中:{chosen_label!r}",
|
||||
)
|
||||
pending.status = "confirmed"
|
||||
norm = _norm_product_name(label)
|
||||
item_id = st.name_to_code.get(norm, label)
|
||||
self._append_confirmed_detail_locked(
|
||||
state=st,
|
||||
item_id=label,
|
||||
item_id=item_id,
|
||||
item_name=label,
|
||||
doctor_id=self._s.video_voice_confirm_doctor_id,
|
||||
source="voice",
|
||||
@@ -582,13 +605,11 @@ class CameraSessionManager:
|
||||
continue
|
||||
last_infer = now
|
||||
try:
|
||||
jpeg = await asyncio.to_thread(
|
||||
frame_to_jpeg_bytes,
|
||||
snap = await asyncio.to_thread(
|
||||
self._vision.infer_frame_bgr,
|
||||
frame,
|
||||
quality=self._s.video_jpeg_quality,
|
||||
state.name_to_code,
|
||||
)
|
||||
cls_res = await self._classifier.predict_image_bytes(jpeg)
|
||||
tear_res = await self._tear.predict_image_bytes(jpeg)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Inference skip camera={} surgery={}: {}",
|
||||
@@ -598,11 +619,45 @@ class CameraSessionManager:
|
||||
)
|
||||
continue
|
||||
|
||||
await self._handle_classification_result(
|
||||
state=state,
|
||||
cls_res=cls_res,
|
||||
tear_label=tear_res.label,
|
||||
)
|
||||
if snap is None:
|
||||
continue
|
||||
|
||||
wsec = self._s.consumable_vision_window_sec
|
||||
pending_preds: list[PredictionResult] = []
|
||||
async with state.lock:
|
||||
cis = state.camera_infer.setdefault(
|
||||
camera_id, CameraStreamInferState()
|
||||
)
|
||||
if cis.stream_t0 is None:
|
||||
cis.stream_t0 = time.monotonic()
|
||||
t_rel = time.monotonic() - cis.stream_t0
|
||||
cis.votes.append((t_rel, snap.t1_name, snap))
|
||||
current_b = int(t_rel // wsec)
|
||||
while cis.next_bucket < current_b:
|
||||
b = cis.next_bucket
|
||||
cis.next_bucket += 1
|
||||
lo, hi = b * wsec, (b + 1) * wsec
|
||||
bucket_pts = [
|
||||
(p, sn) for (t, p, sn) in cis.votes if lo <= t < hi
|
||||
]
|
||||
cis.votes = [
|
||||
(t, p, sn)
|
||||
for (t, p, sn) in cis.votes
|
||||
if not (lo <= t < hi)
|
||||
]
|
||||
if not bucket_pts:
|
||||
continue
|
||||
best = window_bucket_to_best_snap(bucket_pts)
|
||||
if best is not None:
|
||||
pending_preds.append(
|
||||
cls_top3_to_prediction_result(best)
|
||||
)
|
||||
|
||||
for cls_res in pending_preds:
|
||||
await self._handle_classification_result(
|
||||
state=state,
|
||||
cls_res=cls_res,
|
||||
)
|
||||
finally:
|
||||
if cap is not None:
|
||||
await asyncio.to_thread(cap.release)
|
||||
@@ -616,11 +671,10 @@ class CameraSessionManager:
|
||||
*,
|
||||
state: SurgerySessionState,
|
||||
cls_res: PredictionResult,
|
||||
tear_label: str,
|
||||
) -> None:
|
||||
_ = tear_label
|
||||
conf = cls_res.confidence
|
||||
label = (cls_res.label or "").strip()
|
||||
item_id = state.name_to_code.get(label, label)
|
||||
voice_floor = self._s.video_voice_confirm_min_confidence
|
||||
if conf < voice_floor:
|
||||
return
|
||||
@@ -639,7 +693,7 @@ class CameraSessionManager:
|
||||
if conf >= auto_th and in_allowed(label):
|
||||
await self._append_confirmed_detail(
|
||||
state=state,
|
||||
item_id=label or "unknown",
|
||||
item_id=item_id or label or "unknown",
|
||||
item_name=label or "unknown",
|
||||
doctor_id=self._s.video_result_doctor_id,
|
||||
source="vision",
|
||||
|
||||
@@ -78,6 +78,25 @@ def parse_voice_choice(asr_text: str, options: list[str]) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def match_voice_choice_against_candidates(
|
||||
asr_text: str, candidates: list[str]
|
||||
) -> str | None:
|
||||
"""
|
||||
在未匹配 pending 展示的 topk 话术时,按本台手术「候选耗材清单」做名称子串匹配。
|
||||
长名优先,减少短名误命中(如「纱」同时匹配多种耗材时优先更长全称)。
|
||||
"""
|
||||
raw = (asr_text or "").strip()
|
||||
if not raw:
|
||||
return None
|
||||
stripped = [c.strip() for c in candidates if c and str(c).strip()]
|
||||
if not stripped:
|
||||
return None
|
||||
for c in sorted(stripped, key=len, reverse=True):
|
||||
if c in raw:
|
||||
return c
|
||||
return None
|
||||
|
||||
|
||||
def is_rejection_phrase(asr_text: str) -> bool:
|
||||
"""医生明确否认全部候选时返回 True(须在 parse_voice_choice 之前调用)。"""
|
||||
raw = (asr_text or "").strip()
|
||||
@@ -88,7 +107,10 @@ def is_rejection_phrase(asr_text: str) -> bool:
|
||||
|
||||
|
||||
def build_prompt_text(options: list[tuple[str, float]]) -> str:
|
||||
parts = ["请确认刚才使用的耗材是下面哪一项,可以说序号或名称。"]
|
||||
parts = [
|
||||
"请确认刚才使用的耗材是下面哪一项,可以说序号或名称;"
|
||||
"若是清单内其它耗材,也可以直接说该耗材名称。"
|
||||
]
|
||||
for i, (name, _conf) in enumerate(options, start=1):
|
||||
parts.append(f"第{i}个,{name}。")
|
||||
parts.append("若都不是请说不是。")
|
||||
|
||||
@@ -15,7 +15,11 @@ from app.services.audio_wav import WavDecodeError, wav_bytes_to_pcm16k_mono_s16l
|
||||
from app.services.baidu_speech import BaiduSpeechNotConfiguredError, BaiduSpeechService
|
||||
from app.services.minio_audio_storage import MinioAudioStorageService, StoredAudio
|
||||
from app.services.video.session_manager import CameraSessionManager
|
||||
from app.services.voice_confirm import is_rejection_phrase, parse_voice_choice
|
||||
from app.services.voice_confirm import (
|
||||
is_rejection_phrase,
|
||||
match_voice_choice_against_candidates,
|
||||
parse_voice_choice,
|
||||
)
|
||||
from app.surgery_errors import SurgeryPipelineError
|
||||
|
||||
|
||||
@@ -256,9 +260,19 @@ class VoiceConfirmationService:
|
||||
chosen: str | None = None
|
||||
if not rejected:
|
||||
chosen = parse_voice_choice(text, option_labels)
|
||||
if chosen is None:
|
||||
surgery_candidates = self._sessions.get_surgery_candidate_consumables(
|
||||
surgery_id
|
||||
)
|
||||
chosen = match_voice_choice_against_candidates(
|
||||
text, surgery_candidates
|
||||
)
|
||||
|
||||
if not rejected and not chosen:
|
||||
msg = "无法从语音中匹配候选项,请重试或说「不是」否认全部"
|
||||
msg = (
|
||||
"无法从语音中匹配候选项或本台手术候选清单中的耗材名称,"
|
||||
"请重试或说「不是」否认全部"
|
||||
)
|
||||
await self._persist_audit(
|
||||
surgery_id=surgery_id,
|
||||
confirmation_id=confirmation_id,
|
||||
|
||||
Reference in New Issue
Block a user