37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
|
|
"""_probs_data_to_numpy1d:CPU / CUDA / MPS 上均能离设备再转 NumPy。"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
torch = pytest.importorskip("torch")
|
|||
|
|
|
|||
|
|
from app.services.consumable_vision_algorithm import _probs_data_to_numpy1d
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_probs_numpy_cpu_tensor() -> None:
|
|||
|
|
t = torch.tensor([0.1, 0.3, 0.6], dtype=torch.float32)
|
|||
|
|
arr = _probs_data_to_numpy1d(t)
|
|||
|
|
assert arr.dtype == np.float64
|
|||
|
|
np.testing.assert_allclose(arr, [0.1, 0.3, 0.6], rtol=1e-5)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA 不可用,跳过设备张量用例")
|
|||
|
|
def test_probs_numpy_cuda_tensor() -> None:
|
|||
|
|
t = torch.tensor([0.0, 1.0], dtype=torch.float32, device="cuda")
|
|||
|
|
arr = _probs_data_to_numpy1d(t)
|
|||
|
|
assert arr.dtype == np.float64
|
|||
|
|
np.testing.assert_allclose(arr, [0.0, 1.0], rtol=1e-5)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.skipif(
|
|||
|
|
not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
|
|||
|
|
reason="MPS 不可用,跳过设备张量用例",
|
|||
|
|
)
|
|||
|
|
def test_probs_numpy_mps_tensor() -> None:
|
|||
|
|
t = torch.tensor([0.25, 0.75], dtype=torch.float32, device="mps")
|
|||
|
|
arr = _probs_data_to_numpy1d(t)
|
|||
|
|
assert arr.dtype == np.float64
|
|||
|
|
np.testing.assert_allclose(arr, [0.25, 0.75], rtol=1e-5)
|