156 lines
4.7 KiB
Python
156 lines
4.7 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
from io import BytesIO
|
||
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
from threading import Lock
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
from fastapi.concurrency import run_in_threadpool
|
||
|
|
from loguru import logger
|
||
|
|
from PIL import Image, UnidentifiedImageError
|
||
|
|
|
||
|
|
os.environ["YOLO_CONFIG_DIR"] = "/tmp"
|
||
|
|
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.services.consumable_classifier import (
|
||
|
|
InvalidImageError,
|
||
|
|
ModelNotConfiguredError,
|
||
|
|
PredictionCandidate,
|
||
|
|
PredictionError,
|
||
|
|
PredictionResult,
|
||
|
|
resolve_classifier_inference_device,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TearActionService:
|
||
|
|
"""撕扯耗材动作识别(独立权重):判断是否存在/如何撕扯耗材等行为;与耗材分类 `ConsumableClassifierService` 分离。内部流水线调用,不对外 HTTP。"""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._model: YOLO | None = None
|
||
|
|
self._model_lock = Lock()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def weights_path(self) -> Path | None:
|
||
|
|
if not settings.tear_action_weights:
|
||
|
|
return None
|
||
|
|
return Path(settings.tear_action_weights).expanduser()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def configured(self) -> bool:
|
||
|
|
return self.weights_path is not None
|
||
|
|
|
||
|
|
@property
|
||
|
|
def weights_found(self) -> bool:
|
||
|
|
path = self.weights_path
|
||
|
|
return path is not None and path.is_file()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def model_loaded(self) -> bool:
|
||
|
|
return self._model is not None
|
||
|
|
|
||
|
|
async def predict_image_bytes(
|
||
|
|
self,
|
||
|
|
payload: bytes,
|
||
|
|
*,
|
||
|
|
topk: int | None = None,
|
||
|
|
) -> PredictionResult:
|
||
|
|
return await run_in_threadpool(self._predict_image_bytes, payload, topk)
|
||
|
|
|
||
|
|
def _predict_image_bytes(
|
||
|
|
self,
|
||
|
|
payload: bytes,
|
||
|
|
topk: int | None,
|
||
|
|
) -> PredictionResult:
|
||
|
|
model = self._get_model()
|
||
|
|
image = self._decode_image(payload)
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = model.predict(
|
||
|
|
image,
|
||
|
|
imgsz=settings.tear_action_imgsz,
|
||
|
|
device=resolve_classifier_inference_device(settings.tear_action_device),
|
||
|
|
verbose=False,
|
||
|
|
)[0]
|
||
|
|
except Exception as exc: # pragma: no cover
|
||
|
|
raise PredictionError(
|
||
|
|
f"Failed to run tear-action inference: {exc}"
|
||
|
|
) from exc
|
||
|
|
|
||
|
|
return self._build_prediction_result(result, model, topk=topk)
|
||
|
|
|
||
|
|
def _get_model(self) -> YOLO:
|
||
|
|
path = self.weights_path
|
||
|
|
if path is None:
|
||
|
|
raise ModelNotConfiguredError(
|
||
|
|
"Tear-action weights are not configured. Set TEAR_ACTION_WEIGHTS."
|
||
|
|
)
|
||
|
|
|
||
|
|
path = path.resolve()
|
||
|
|
if not path.is_file():
|
||
|
|
raise ModelNotConfiguredError(f"Tear-action weights not found: {path}")
|
||
|
|
|
||
|
|
if self._model is None:
|
||
|
|
with self._model_lock:
|
||
|
|
if self._model is None:
|
||
|
|
logger.info("Loading tear-action weights from {}", path)
|
||
|
|
self._model = YOLO(str(path))
|
||
|
|
|
||
|
|
return self._model
|
||
|
|
|
||
|
|
def _decode_image(self, payload: bytes) -> np.ndarray:
|
||
|
|
if not payload:
|
||
|
|
raise InvalidImageError("Uploaded image is empty.")
|
||
|
|
|
||
|
|
try:
|
||
|
|
with Image.open(BytesIO(payload)) as image:
|
||
|
|
return np.asarray(image.convert("RGB"))
|
||
|
|
except (UnidentifiedImageError, OSError) as exc:
|
||
|
|
raise InvalidImageError("Uploaded file is not a valid image.") from exc
|
||
|
|
|
||
|
|
def _build_prediction_result(
|
||
|
|
self,
|
||
|
|
result: object,
|
||
|
|
model: YOLO,
|
||
|
|
*,
|
||
|
|
topk: int | None,
|
||
|
|
) -> PredictionResult:
|
||
|
|
probs = getattr(result, "probs", None)
|
||
|
|
data = getattr(probs, "data", None)
|
||
|
|
if probs is None or data is None:
|
||
|
|
raise PredictionError("Model did not return classification probabilities.")
|
||
|
|
|
||
|
|
scores = data.tolist()
|
||
|
|
if not isinstance(scores, list):
|
||
|
|
scores = [float(scores)]
|
||
|
|
|
||
|
|
names = self._names(model)
|
||
|
|
limit = max(1, topk or settings.tear_action_topk)
|
||
|
|
ranked = sorted(
|
||
|
|
((index, float(score)) for index, score in enumerate(scores)),
|
||
|
|
key=lambda item: item[1],
|
||
|
|
reverse=True,
|
||
|
|
)[:limit]
|
||
|
|
|
||
|
|
if not ranked:
|
||
|
|
raise PredictionError("Model returned an empty prediction result.")
|
||
|
|
|
||
|
|
candidates = [
|
||
|
|
PredictionCandidate(
|
||
|
|
label=names.get(index, str(index)),
|
||
|
|
confidence=confidence,
|
||
|
|
)
|
||
|
|
for index, confidence in ranked
|
||
|
|
]
|
||
|
|
return PredictionResult(
|
||
|
|
label=candidates[0].label,
|
||
|
|
confidence=candidates[0].confidence,
|
||
|
|
topk=candidates,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _names(self, model: YOLO) -> dict[int, str]:
|
||
|
|
raw = getattr(model.model, "names", None) or {}
|
||
|
|
return {int(key): str(value) for key, value in raw.items()}
|