from __future__ import annotations from dataclasses import dataclass from io import BytesIO import os import sys from pathlib import Path from threading import Lock import numpy as np from fastapi.concurrency import run_in_threadpool from loguru import logger from PIL import Image, UnidentifiedImageError os.environ["YOLO_CONFIG_DIR"] = "/tmp" from ultralytics import YOLO from app.config import settings def resolve_classifier_inference_device(explicit: str) -> str | None: """Ultralytics `device` string. If unset: macOS prefers MPS; Linux/Windows prefer CUDA when available.""" configured = (explicit or "").strip() if configured: return configured try: import torch except Exception: return None if sys.platform == "darwin": if torch.backends.mps.is_available(): return "mps" return None if torch.cuda.is_available(): return "cuda:0" return None @dataclass(frozen=True) class PredictionCandidate: label: str confidence: float @dataclass(frozen=True) class PredictionResult: label: str confidence: float topk: list[PredictionCandidate] class ModelNotConfiguredError(RuntimeError): """Raised when the model weights are not configured or missing.""" class InvalidImageError(ValueError): """Raised when uploaded bytes cannot be decoded as an image.""" class PredictionError(RuntimeError): """Raised when the model cannot produce a prediction.""" class ConsumableClassifierService: """耗材识别与分类(YOLO-cls):判断画面中的耗材类别;与撕扯动作模型 `TearActionService` 分离。内部流水线调用,不对外 HTTP。""" def __init__(self) -> None: self._model: YOLO | None = None self._model_lock = Lock() @property def weights_path(self) -> Path | None: if not settings.consumable_classifier_weights: return None return Path(settings.consumable_classifier_weights).expanduser() @property def configured(self) -> bool: return self.weights_path is not None @property def weights_found(self) -> bool: path = self.weights_path return path is not None and path.is_file() @property def model_loaded(self) -> bool: return self._model is not None async def predict_image_bytes( self, payload: bytes, *, topk: int | None = None, ) -> PredictionResult: return await run_in_threadpool(self._predict_image_bytes, payload, topk) def _predict_image_bytes( self, payload: bytes, topk: int | None, ) -> PredictionResult: model = self._get_model() image = self._decode_image(payload) try: result = model.predict( image, imgsz=settings.consumable_classifier_imgsz, device=resolve_classifier_inference_device( settings.consumable_classifier_device ), verbose=False, )[0] except Exception as exc: # pragma: no cover - ultralytics runtime errors vary. raise PredictionError( f"Failed to run consumable classifier inference: {exc}" ) from exc return self._build_prediction_result(result, model, topk=topk) def _get_model(self) -> YOLO: path = self.weights_path if path is None: raise ModelNotConfiguredError( "Consumable classifier weights are not configured. " "Set CONSUMABLE_CLASSIFIER_WEIGHTS." ) path = path.resolve() if not path.is_file(): raise ModelNotConfiguredError( f"Consumable classifier weights not found: {path}" ) if self._model is None: with self._model_lock: if self._model is None: logger.info("Loading consumable classifier weights from {}", path) self._model = YOLO(str(path)) return self._model def _decode_image(self, payload: bytes) -> np.ndarray: if not payload: raise InvalidImageError("Uploaded image is empty.") try: with Image.open(BytesIO(payload)) as image: return np.asarray(image.convert("RGB")) except (UnidentifiedImageError, OSError) as exc: raise InvalidImageError("Uploaded file is not a valid image.") from exc def _build_prediction_result( self, result: object, model: YOLO, *, topk: int | None, ) -> PredictionResult: probs = getattr(result, "probs", None) data = getattr(probs, "data", None) if probs is None or data is None: raise PredictionError("Model did not return classification probabilities.") scores = data.tolist() if not isinstance(scores, list): scores = [float(scores)] names = self._names(model) limit = max(1, topk or settings.consumable_classifier_topk) ranked = sorted( ((index, float(score)) for index, score in enumerate(scores)), key=lambda item: item[1], reverse=True, )[:limit] if not ranked: raise PredictionError("Model returned an empty prediction result.") candidates = [ PredictionCandidate( label=names.get(index, str(index)), confidence=confidence, ) for index, confidence in ranked ] return PredictionResult( label=candidates[0].label, confidence=candidates[0].confidence, topk=candidates, ) def _names(self, model: YOLO) -> dict[int, str]: raw = getattr(model.model, "names", None) or {} return {int(key): str(value) for key, value in raw.items()}