Files
operating-room-monitor-server/tests/test_probs_numpy_device.py

37 lines
1.2 KiB
Python
Raw Normal View History

"""_probs_data_to_numpy1dCPU / CUDA / MPS 上均能离设备再转 NumPy。"""
from __future__ import annotations
import numpy as np
import pytest
torch = pytest.importorskip("torch")
from app.services.consumable_vision_algorithm import _probs_data_to_numpy1d
def test_probs_numpy_cpu_tensor() -> None:
t = torch.tensor([0.1, 0.3, 0.6], dtype=torch.float32)
arr = _probs_data_to_numpy1d(t)
assert arr.dtype == np.float64
np.testing.assert_allclose(arr, [0.1, 0.3, 0.6], rtol=1e-5)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA 不可用,跳过设备张量用例")
def test_probs_numpy_cuda_tensor() -> None:
t = torch.tensor([0.0, 1.0], dtype=torch.float32, device="cuda")
arr = _probs_data_to_numpy1d(t)
assert arr.dtype == np.float64
np.testing.assert_allclose(arr, [0.0, 1.0], rtol=1e-5)
@pytest.mark.skipif(
not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
reason="MPS 不可用,跳过设备张量用例",
)
def test_probs_numpy_mps_tensor() -> None:
t = torch.tensor([0.25, 0.75], dtype=torch.float32, device="mps")
arr = _probs_data_to_numpy1d(t)
assert arr.dtype == np.float64
np.testing.assert_allclose(arr, [0.25, 0.75], rtol=1e-5)