- Add FastAPI routes for surgery start/end, results, pending confirmation (WAV upload), and health checks. - Implement RTSP/Hikvision capture, consumable classification, session manager, MinIO/Baidu voice resolution, and DB persistence. - Add documentation (client API, video backends, staging checklist) and sample camera/RTSP config. - Add pytest suite (API contract, session manager, voice, repositories, pipeline persistence) and httpx dev dependency. - Replace deprecated HTTP_422_UNPROCESSABLE_ENTITY with HTTP_422_UNPROCESSABLE_CONTENT. - Fix SurgeryPipeline DB reads to use an explicit transaction with autobegin disabled. Made-with: Cursor
198 lines
5.8 KiB
Python
198 lines
5.8 KiB
Python
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass
|
||
from io import BytesIO
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
from threading import Lock
|
||
|
||
import numpy as np
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from loguru import logger
|
||
from PIL import Image, UnidentifiedImageError
|
||
|
||
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
||
|
||
from ultralytics import YOLO
|
||
|
||
from app.config import settings
|
||
|
||
|
||
def resolve_classifier_inference_device(explicit: str) -> str | None:
|
||
"""Ultralytics `device` string. If unset: macOS prefers MPS; Linux/Windows prefer CUDA when available."""
|
||
configured = (explicit or "").strip()
|
||
if configured:
|
||
return configured
|
||
try:
|
||
import torch
|
||
except Exception:
|
||
return None
|
||
if sys.platform == "darwin":
|
||
if torch.backends.mps.is_available():
|
||
return "mps"
|
||
return None
|
||
if torch.cuda.is_available():
|
||
return "cuda:0"
|
||
return None
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class PredictionCandidate:
|
||
label: str
|
||
confidence: float
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class PredictionResult:
|
||
label: str
|
||
confidence: float
|
||
topk: list[PredictionCandidate]
|
||
|
||
|
||
class ModelNotConfiguredError(RuntimeError):
|
||
"""Raised when the model weights are not configured or missing."""
|
||
|
||
|
||
class InvalidImageError(ValueError):
|
||
"""Raised when uploaded bytes cannot be decoded as an image."""
|
||
|
||
|
||
class PredictionError(RuntimeError):
|
||
"""Raised when the model cannot produce a prediction."""
|
||
|
||
|
||
class ConsumableClassifierService:
|
||
"""耗材识别与分类(YOLO-cls):判断画面中的耗材类别;与撕扯动作模型 `TearActionService` 分离。内部流水线调用,不对外 HTTP。"""
|
||
|
||
def __init__(self) -> None:
|
||
self._model: YOLO | None = None
|
||
self._model_lock = Lock()
|
||
|
||
@property
|
||
def weights_path(self) -> Path | None:
|
||
if not settings.consumable_classifier_weights:
|
||
return None
|
||
return Path(settings.consumable_classifier_weights).expanduser()
|
||
|
||
@property
|
||
def configured(self) -> bool:
|
||
return self.weights_path is not None
|
||
|
||
@property
|
||
def weights_found(self) -> bool:
|
||
path = self.weights_path
|
||
return path is not None and path.is_file()
|
||
|
||
@property
|
||
def model_loaded(self) -> bool:
|
||
return self._model is not None
|
||
|
||
async def predict_image_bytes(
|
||
self,
|
||
payload: bytes,
|
||
*,
|
||
topk: int | None = None,
|
||
) -> PredictionResult:
|
||
return await run_in_threadpool(self._predict_image_bytes, payload, topk)
|
||
|
||
def _predict_image_bytes(
|
||
self,
|
||
payload: bytes,
|
||
topk: int | None,
|
||
) -> PredictionResult:
|
||
model = self._get_model()
|
||
image = self._decode_image(payload)
|
||
|
||
try:
|
||
result = model.predict(
|
||
image,
|
||
imgsz=settings.consumable_classifier_imgsz,
|
||
device=resolve_classifier_inference_device(
|
||
settings.consumable_classifier_device
|
||
),
|
||
verbose=False,
|
||
)[0]
|
||
except Exception as exc: # pragma: no cover - ultralytics runtime errors vary.
|
||
raise PredictionError(
|
||
f"Failed to run consumable classifier inference: {exc}"
|
||
) from exc
|
||
|
||
return self._build_prediction_result(result, model, topk=topk)
|
||
|
||
def _get_model(self) -> YOLO:
|
||
path = self.weights_path
|
||
if path is None:
|
||
raise ModelNotConfiguredError(
|
||
"Consumable classifier weights are not configured. "
|
||
"Set CONSUMABLE_CLASSIFIER_WEIGHTS."
|
||
)
|
||
|
||
path = path.resolve()
|
||
if not path.is_file():
|
||
raise ModelNotConfiguredError(
|
||
f"Consumable classifier weights not found: {path}"
|
||
)
|
||
|
||
if self._model is None:
|
||
with self._model_lock:
|
||
if self._model is None:
|
||
logger.info("Loading consumable classifier weights from {}", path)
|
||
self._model = YOLO(str(path))
|
||
|
||
return self._model
|
||
|
||
def _decode_image(self, payload: bytes) -> np.ndarray:
|
||
if not payload:
|
||
raise InvalidImageError("Uploaded image is empty.")
|
||
|
||
try:
|
||
with Image.open(BytesIO(payload)) as image:
|
||
return np.asarray(image.convert("RGB"))
|
||
except (UnidentifiedImageError, OSError) as exc:
|
||
raise InvalidImageError("Uploaded file is not a valid image.") from exc
|
||
|
||
def _build_prediction_result(
|
||
self,
|
||
result: object,
|
||
model: YOLO,
|
||
*,
|
||
topk: int | None,
|
||
) -> PredictionResult:
|
||
probs = getattr(result, "probs", None)
|
||
data = getattr(probs, "data", None)
|
||
if probs is None or data is None:
|
||
raise PredictionError("Model did not return classification probabilities.")
|
||
|
||
scores = data.tolist()
|
||
if not isinstance(scores, list):
|
||
scores = [float(scores)]
|
||
|
||
names = self._names(model)
|
||
limit = max(1, topk or settings.consumable_classifier_topk)
|
||
ranked = sorted(
|
||
((index, float(score)) for index, score in enumerate(scores)),
|
||
key=lambda item: item[1],
|
||
reverse=True,
|
||
)[:limit]
|
||
|
||
if not ranked:
|
||
raise PredictionError("Model returned an empty prediction result.")
|
||
|
||
candidates = [
|
||
PredictionCandidate(
|
||
label=names.get(index, str(index)),
|
||
confidence=confidence,
|
||
)
|
||
for index, confidence in ranked
|
||
]
|
||
return PredictionResult(
|
||
label=candidates[0].label,
|
||
confidence=candidates[0].confidence,
|
||
topk=candidates,
|
||
)
|
||
|
||
def _names(self, model: YOLO) -> dict[int, str]:
|
||
raw = getattr(model.model, "names", None) or {}
|
||
return {int(key): str(value) for key, value in raw.items()}
|