198 lines
5.8 KiB
Python
198 lines
5.8 KiB
Python
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from io import BytesIO
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
from threading import Lock
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
from fastapi.concurrency import run_in_threadpool
|
|||
|
|
from loguru import logger
|
|||
|
|
from PIL import Image, UnidentifiedImageError
|
|||
|
|
|
|||
|
|
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
|||
|
|
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
from app.config import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
def resolve_classifier_inference_device(explicit: str) -> str | None:
|
|||
|
|
"""Ultralytics `device` string. If unset: macOS prefers MPS; Linux/Windows prefer CUDA when available."""
|
|||
|
|
configured = (explicit or "").strip()
|
|||
|
|
if configured:
|
|||
|
|
return configured
|
|||
|
|
try:
|
|||
|
|
import torch
|
|||
|
|
except Exception:
|
|||
|
|
return None
|
|||
|
|
if sys.platform == "darwin":
|
|||
|
|
if torch.backends.mps.is_available():
|
|||
|
|
return "mps"
|
|||
|
|
return None
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
return "cuda:0"
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class PredictionCandidate:
|
|||
|
|
label: str
|
|||
|
|
confidence: float
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass(frozen=True)
|
|||
|
|
class PredictionResult:
|
|||
|
|
label: str
|
|||
|
|
confidence: float
|
|||
|
|
topk: list[PredictionCandidate]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModelNotConfiguredError(RuntimeError):
|
|||
|
|
"""Raised when the model weights are not configured or missing."""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class InvalidImageError(ValueError):
|
|||
|
|
"""Raised when uploaded bytes cannot be decoded as an image."""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PredictionError(RuntimeError):
|
|||
|
|
"""Raised when the model cannot produce a prediction."""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ConsumableClassifierService:
|
|||
|
|
"""耗材识别与分类(YOLO-cls):判断画面中的耗材类别;与撕扯动作模型 `TearActionService` 分离。内部流水线调用,不对外 HTTP。"""
|
|||
|
|
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self._model: YOLO | None = None
|
|||
|
|
self._model_lock = Lock()
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def weights_path(self) -> Path | None:
|
|||
|
|
if not settings.consumable_classifier_weights:
|
|||
|
|
return None
|
|||
|
|
return Path(settings.consumable_classifier_weights).expanduser()
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def configured(self) -> bool:
|
|||
|
|
return self.weights_path is not None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def weights_found(self) -> bool:
|
|||
|
|
path = self.weights_path
|
|||
|
|
return path is not None and path.is_file()
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def model_loaded(self) -> bool:
|
|||
|
|
return self._model is not None
|
|||
|
|
|
|||
|
|
async def predict_image_bytes(
|
|||
|
|
self,
|
|||
|
|
payload: bytes,
|
|||
|
|
*,
|
|||
|
|
topk: int | None = None,
|
|||
|
|
) -> PredictionResult:
|
|||
|
|
return await run_in_threadpool(self._predict_image_bytes, payload, topk)
|
|||
|
|
|
|||
|
|
def _predict_image_bytes(
|
|||
|
|
self,
|
|||
|
|
payload: bytes,
|
|||
|
|
topk: int | None,
|
|||
|
|
) -> PredictionResult:
|
|||
|
|
model = self._get_model()
|
|||
|
|
image = self._decode_image(payload)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result = model.predict(
|
|||
|
|
image,
|
|||
|
|
imgsz=settings.consumable_classifier_imgsz,
|
|||
|
|
device=resolve_classifier_inference_device(
|
|||
|
|
settings.consumable_classifier_device
|
|||
|
|
),
|
|||
|
|
verbose=False,
|
|||
|
|
)[0]
|
|||
|
|
except Exception as exc: # pragma: no cover - ultralytics runtime errors vary.
|
|||
|
|
raise PredictionError(
|
|||
|
|
f"Failed to run consumable classifier inference: {exc}"
|
|||
|
|
) from exc
|
|||
|
|
|
|||
|
|
return self._build_prediction_result(result, model, topk=topk)
|
|||
|
|
|
|||
|
|
def _get_model(self) -> YOLO:
|
|||
|
|
path = self.weights_path
|
|||
|
|
if path is None:
|
|||
|
|
raise ModelNotConfiguredError(
|
|||
|
|
"Consumable classifier weights are not configured. "
|
|||
|
|
"Set CONSUMABLE_CLASSIFIER_WEIGHTS."
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
path = path.resolve()
|
|||
|
|
if not path.is_file():
|
|||
|
|
raise ModelNotConfiguredError(
|
|||
|
|
f"Consumable classifier weights not found: {path}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if self._model is None:
|
|||
|
|
with self._model_lock:
|
|||
|
|
if self._model is None:
|
|||
|
|
logger.info("Loading consumable classifier weights from {}", path)
|
|||
|
|
self._model = YOLO(str(path))
|
|||
|
|
|
|||
|
|
return self._model
|
|||
|
|
|
|||
|
|
def _decode_image(self, payload: bytes) -> np.ndarray:
|
|||
|
|
if not payload:
|
|||
|
|
raise InvalidImageError("Uploaded image is empty.")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with Image.open(BytesIO(payload)) as image:
|
|||
|
|
return np.asarray(image.convert("RGB"))
|
|||
|
|
except (UnidentifiedImageError, OSError) as exc:
|
|||
|
|
raise InvalidImageError("Uploaded file is not a valid image.") from exc
|
|||
|
|
|
|||
|
|
def _build_prediction_result(
|
|||
|
|
self,
|
|||
|
|
result: object,
|
|||
|
|
model: YOLO,
|
|||
|
|
*,
|
|||
|
|
topk: int | None,
|
|||
|
|
) -> PredictionResult:
|
|||
|
|
probs = getattr(result, "probs", None)
|
|||
|
|
data = getattr(probs, "data", None)
|
|||
|
|
if probs is None or data is None:
|
|||
|
|
raise PredictionError("Model did not return classification probabilities.")
|
|||
|
|
|
|||
|
|
scores = data.tolist()
|
|||
|
|
if not isinstance(scores, list):
|
|||
|
|
scores = [float(scores)]
|
|||
|
|
|
|||
|
|
names = self._names(model)
|
|||
|
|
limit = max(1, topk or settings.consumable_classifier_topk)
|
|||
|
|
ranked = sorted(
|
|||
|
|
((index, float(score)) for index, score in enumerate(scores)),
|
|||
|
|
key=lambda item: item[1],
|
|||
|
|
reverse=True,
|
|||
|
|
)[:limit]
|
|||
|
|
|
|||
|
|
if not ranked:
|
|||
|
|
raise PredictionError("Model returned an empty prediction result.")
|
|||
|
|
|
|||
|
|
candidates = [
|
|||
|
|
PredictionCandidate(
|
|||
|
|
label=names.get(index, str(index)),
|
|||
|
|
confidence=confidence,
|
|||
|
|
)
|
|||
|
|
for index, confidence in ranked
|
|||
|
|
]
|
|||
|
|
return PredictionResult(
|
|||
|
|
label=candidates[0].label,
|
|||
|
|
confidence=candidates[0].confidence,
|
|||
|
|
topk=candidates,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _names(self, model: YOLO) -> dict[int, str]:
|
|||
|
|
raw = getattr(model.model, "names", None) or {}
|
|||
|
|
return {int(key): str(value) for key, value in raw.items()}
|