Files
operating-room-monitor-server/app/services/consumable_classifier.py

198 lines
5.8 KiB
Python
Raw Normal View History

from __future__ import annotations
from dataclasses import dataclass
from io import BytesIO
import os
import sys
from pathlib import Path
from threading import Lock
import numpy as np
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from PIL import Image, UnidentifiedImageError
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
from ultralytics import YOLO
from app.config import settings
def resolve_classifier_inference_device(explicit: str) -> str | None:
"""Ultralytics `device` string. If unset: macOS prefers MPS; Linux/Windows prefer CUDA when available."""
configured = (explicit or "").strip()
if configured:
return configured
try:
import torch
except Exception:
return None
if sys.platform == "darwin":
if torch.backends.mps.is_available():
return "mps"
return None
if torch.cuda.is_available():
return "cuda:0"
return None
@dataclass(frozen=True)
class PredictionCandidate:
label: str
confidence: float
@dataclass(frozen=True)
class PredictionResult:
label: str
confidence: float
topk: list[PredictionCandidate]
class ModelNotConfiguredError(RuntimeError):
"""Raised when the model weights are not configured or missing."""
class InvalidImageError(ValueError):
"""Raised when uploaded bytes cannot be decoded as an image."""
class PredictionError(RuntimeError):
"""Raised when the model cannot produce a prediction."""
class ConsumableClassifierService:
"""耗材识别与分类YOLO-cls判断画面中的耗材类别与撕扯动作模型 `TearActionService` 分离。内部流水线调用,不对外 HTTP。"""
def __init__(self) -> None:
self._model: YOLO | None = None
self._model_lock = Lock()
@property
def weights_path(self) -> Path | None:
if not settings.consumable_classifier_weights:
return None
return Path(settings.consumable_classifier_weights).expanduser()
@property
def configured(self) -> bool:
return self.weights_path is not None
@property
def weights_found(self) -> bool:
path = self.weights_path
return path is not None and path.is_file()
@property
def model_loaded(self) -> bool:
return self._model is not None
async def predict_image_bytes(
self,
payload: bytes,
*,
topk: int | None = None,
) -> PredictionResult:
return await run_in_threadpool(self._predict_image_bytes, payload, topk)
def _predict_image_bytes(
self,
payload: bytes,
topk: int | None,
) -> PredictionResult:
model = self._get_model()
image = self._decode_image(payload)
try:
result = model.predict(
image,
imgsz=settings.consumable_classifier_imgsz,
device=resolve_classifier_inference_device(
settings.consumable_classifier_device
),
verbose=False,
)[0]
except Exception as exc: # pragma: no cover - ultralytics runtime errors vary.
raise PredictionError(
f"Failed to run consumable classifier inference: {exc}"
) from exc
return self._build_prediction_result(result, model, topk=topk)
def _get_model(self) -> YOLO:
path = self.weights_path
if path is None:
raise ModelNotConfiguredError(
"Consumable classifier weights are not configured. "
"Set CONSUMABLE_CLASSIFIER_WEIGHTS."
)
path = path.resolve()
if not path.is_file():
raise ModelNotConfiguredError(
f"Consumable classifier weights not found: {path}"
)
if self._model is None:
with self._model_lock:
if self._model is None:
logger.info("Loading consumable classifier weights from {}", path)
self._model = YOLO(str(path))
return self._model
def _decode_image(self, payload: bytes) -> np.ndarray:
if not payload:
raise InvalidImageError("Uploaded image is empty.")
try:
with Image.open(BytesIO(payload)) as image:
return np.asarray(image.convert("RGB"))
except (UnidentifiedImageError, OSError) as exc:
raise InvalidImageError("Uploaded file is not a valid image.") from exc
def _build_prediction_result(
self,
result: object,
model: YOLO,
*,
topk: int | None,
) -> PredictionResult:
probs = getattr(result, "probs", None)
data = getattr(probs, "data", None)
if probs is None or data is None:
raise PredictionError("Model did not return classification probabilities.")
scores = data.tolist()
if not isinstance(scores, list):
scores = [float(scores)]
names = self._names(model)
limit = max(1, topk or settings.consumable_classifier_topk)
ranked = sorted(
((index, float(score)) for index, score in enumerate(scores)),
key=lambda item: item[1],
reverse=True,
)[:limit]
if not ranked:
raise PredictionError("Model returned an empty prediction result.")
candidates = [
PredictionCandidate(
label=names.get(index, str(index)),
confidence=confidence,
)
for index, confidence in ranked
]
return PredictionResult(
label=candidates[0].label,
confidence=candidates[0].confidence,
topk=candidates,
)
def _names(self, model: YOLO) -> dict[int, str]:
raw = getattr(model.model, "names", None) or {}
return {int(key): str(value) for key, value in raw.items()}