Files
FishServer/FishMeasure/weight_estimator/dataloader.py

206 lines
6.6 KiB
Python

#!/usr/bin/env python3
"""
PyTorch Dataset/DataLoader for fish weight regression from point clouds.
Preprocessing (per sample):
- Load points from .ply
- Scale XYZ by 0.001
- Center to origin (subtract centroid)
- Sample exactly N points (default: 768)
- Random rotation augmentation (default: around Z axis, train mode only)
Label:
- CSV column F is in grams
- We scale label by /1000 so the target is in kg
(e.g. 0.1 == 0.1 kg == 100 g)
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import open3d as o3d
import torch
from torch.utils.data import Dataset, DataLoader
def _load_points_from_ply(ply_path: Path) -> np.ndarray:
"""
Load Nx3 points from a PLY. Supports either point cloud or mesh PLY.
"""
# Try point cloud first
pcd = o3d.io.read_point_cloud(str(ply_path))
if len(pcd.points) > 0:
pts = np.asarray(pcd.points)
return pts.astype(np.float32, copy=False)
# Fallback: triangle mesh
mesh = o3d.io.read_triangle_mesh(str(ply_path))
if len(mesh.vertices) > 0:
pts = np.asarray(mesh.vertices)
return pts.astype(np.float32, copy=False)
raise ValueError(f"No points/vertices found in: {ply_path}")
def _sample_points_deterministic(points: np.ndarray, num_points: int, rng: np.random.Generator) -> np.ndarray:
"""Deterministic sampling (used for eval)."""
n = points.shape[0]
if n <= 0:
raise ValueError("Empty point cloud")
if n >= num_points:
idx = rng.choice(n, size=num_points, replace=False)
else:
idx = rng.choice(n, size=num_points, replace=True)
return points[idx]
def _sample_points_random(points: np.ndarray, num_points: int) -> np.ndarray:
"""Random sampling using torch RNG (used for train; works well with DataLoader worker seeding)."""
n = points.shape[0]
if n <= 0:
raise ValueError("Empty point cloud")
if n >= num_points:
idx = torch.randperm(n)[:num_points].cpu().numpy()
else:
idx = torch.randint(low=0, high=n, size=(num_points,)).cpu().numpy()
return points[idx]
def _random_rotation_matrix_z(theta: float) -> np.ndarray:
c, s = float(np.cos(theta)), float(np.sin(theta))
return np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
@dataclass(frozen=True)
class FishWeightSample:
points: torch.Tensor # (N, 3), float32
target: torch.Tensor # (), float32, kg
meta: Dict[str, Any]
class FishWeightDataset(Dataset):
"""
Dataset backed by the JSON produced by weight_estimator/dataset.py.
"""
def __init__(
self,
index_json: str | Path,
data_root: Optional[str | Path] = None,
num_points: int = 768,
xyz_scale: float = 0.001,
label_scale: float = 0.001, # grams -> kg
train: bool = True,
random_rotate: bool = True,
seed: int = 42,
) -> None:
self.index_json = Path(index_json).expanduser().resolve()
if not self.index_json.exists():
raise FileNotFoundError(f"Index JSON not found: {self.index_json}")
with self.index_json.open("r", encoding="utf-8") as f:
index = json.load(f)
items = index.get("items", None)
if not isinstance(items, list) or not items:
raise ValueError(f"Invalid index JSON (missing/empty items): {self.index_json}")
self.items: List[Dict[str, Any]] = items
# If the JSON stored relative paths, meta.data_root is the base.
meta_root = index.get("meta", {}).get("data_root", None)
self.data_root = Path(data_root or meta_root or "/").expanduser().resolve()
self.num_points = int(num_points)
self.xyz_scale = float(xyz_scale)
self.label_scale = float(label_scale)
self.train = bool(train)
self.random_rotate = bool(random_rotate)
self._base_seed = int(seed)
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> FishWeightSample:
item = self.items[idx]
ply_str = item["ply"]
sample_id = item.get("sample_id", "")
weight_g = float(item["weight_g"])
ply_path = Path(ply_str)
if not ply_path.is_absolute():
ply_path = self.data_root / ply_path
ply_path = ply_path.expanduser().resolve()
pts = _load_points_from_ply(ply_path) # (M, 3)
# Scale XYZ (e.g. mm -> m)
pts = pts * self.xyz_scale
# Center to origin (centroid)
centroid = pts.mean(axis=0, keepdims=True)
pts = pts - centroid
if self.train:
# Random sampling/augmentation (torch RNG is seeded per-worker by DataLoader)
pts = _sample_points_random(pts, self.num_points)
if self.random_rotate:
theta = float(torch.rand(1).item() * 2.0 * np.pi)
R = _random_rotation_matrix_z(theta)
pts = (pts @ R.T).astype(np.float32, copy=False)
else:
# Deterministic sampling, no augmentation
rng = np.random.default_rng(self._base_seed + int(idx))
pts = _sample_points_deterministic(pts, self.num_points, rng=rng)
# Label scaling: grams -> kg
target = weight_g * self.label_scale
points_t = torch.from_numpy(pts).to(dtype=torch.float32) # (N, 3)
target_t = torch.tensor(target, dtype=torch.float32)
meta = {
"ply": str(ply_path),
"sample_id": sample_id,
"weight_g": weight_g,
"target_kg": float(target),
"xyz_scale": self.xyz_scale,
"label_scale": self.label_scale,
}
return FishWeightSample(points=points_t, target=target_t, meta=meta)
def create_dataloader(
index_json: str | Path,
batch_size: int = 16,
num_workers: int = 4,
shuffle: bool = True,
train: bool = True,
**dataset_kwargs: Any,
) -> DataLoader:
ds = FishWeightDataset(index_json=index_json, train=train, **dataset_kwargs)
def _collate(batch: List[FishWeightSample]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
x = torch.stack([b.points for b in batch], dim=0) # (B, N, 3)
y = torch.stack([b.target for b in batch], dim=0) # (B,)
meta = [b.meta for b in batch]
return x, y, meta
return DataLoader(
ds,
batch_size=int(batch_size),
shuffle=bool(shuffle) if train else False,
num_workers=int(num_workers),
pin_memory=True,
drop_last=train,
collate_fn=_collate,
)