#!/usr/bin/env python3 """ PyTorch Dataset/DataLoader for fish weight regression from point clouds. Preprocessing (per sample): - Load points from .ply - Scale XYZ by 0.001 - Center to origin (subtract centroid) - Sample exactly N points (default: 768) - Random rotation augmentation (default: around Z axis, train mode only) Label: - CSV column F is in grams - We scale label by /1000 so the target is in kg (e.g. 0.1 == 0.1 kg == 100 g) """ from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np import open3d as o3d import torch from torch.utils.data import Dataset, DataLoader def _load_points_from_ply(ply_path: Path) -> np.ndarray: """ Load Nx3 points from a PLY. Supports either point cloud or mesh PLY. """ # Try point cloud first pcd = o3d.io.read_point_cloud(str(ply_path)) if len(pcd.points) > 0: pts = np.asarray(pcd.points) return pts.astype(np.float32, copy=False) # Fallback: triangle mesh mesh = o3d.io.read_triangle_mesh(str(ply_path)) if len(mesh.vertices) > 0: pts = np.asarray(mesh.vertices) return pts.astype(np.float32, copy=False) raise ValueError(f"No points/vertices found in: {ply_path}") def _sample_points_deterministic(points: np.ndarray, num_points: int, rng: np.random.Generator) -> np.ndarray: """Deterministic sampling (used for eval).""" n = points.shape[0] if n <= 0: raise ValueError("Empty point cloud") if n >= num_points: idx = rng.choice(n, size=num_points, replace=False) else: idx = rng.choice(n, size=num_points, replace=True) return points[idx] def _sample_points_random(points: np.ndarray, num_points: int) -> np.ndarray: """Random sampling using torch RNG (used for train; works well with DataLoader worker seeding).""" n = points.shape[0] if n <= 0: raise ValueError("Empty point cloud") if n >= num_points: idx = torch.randperm(n)[:num_points].cpu().numpy() else: idx = torch.randint(low=0, high=n, size=(num_points,)).cpu().numpy() return points[idx] def _random_rotation_matrix_z(theta: float) -> np.ndarray: c, s = float(np.cos(theta)), float(np.sin(theta)) return np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32) @dataclass(frozen=True) class FishWeightSample: points: torch.Tensor # (N, 3), float32 target: torch.Tensor # (), float32, kg meta: Dict[str, Any] class FishWeightDataset(Dataset): """ Dataset backed by the JSON produced by weight_estimator/dataset.py. """ def __init__( self, index_json: str | Path, data_root: Optional[str | Path] = None, num_points: int = 768, xyz_scale: float = 0.001, label_scale: float = 0.001, # grams -> kg train: bool = True, random_rotate: bool = True, seed: int = 42, ) -> None: self.index_json = Path(index_json).expanduser().resolve() if not self.index_json.exists(): raise FileNotFoundError(f"Index JSON not found: {self.index_json}") with self.index_json.open("r", encoding="utf-8") as f: index = json.load(f) items = index.get("items", None) if not isinstance(items, list) or not items: raise ValueError(f"Invalid index JSON (missing/empty items): {self.index_json}") self.items: List[Dict[str, Any]] = items # If the JSON stored relative paths, meta.data_root is the base. meta_root = index.get("meta", {}).get("data_root", None) self.data_root = Path(data_root or meta_root or "/").expanduser().resolve() self.num_points = int(num_points) self.xyz_scale = float(xyz_scale) self.label_scale = float(label_scale) self.train = bool(train) self.random_rotate = bool(random_rotate) self._base_seed = int(seed) def __len__(self) -> int: return len(self.items) def __getitem__(self, idx: int) -> FishWeightSample: item = self.items[idx] ply_str = item["ply"] sample_id = item.get("sample_id", "") weight_g = float(item["weight_g"]) ply_path = Path(ply_str) if not ply_path.is_absolute(): ply_path = self.data_root / ply_path ply_path = ply_path.expanduser().resolve() pts = _load_points_from_ply(ply_path) # (M, 3) # Scale XYZ (e.g. mm -> m) pts = pts * self.xyz_scale # Center to origin (centroid) centroid = pts.mean(axis=0, keepdims=True) pts = pts - centroid if self.train: # Random sampling/augmentation (torch RNG is seeded per-worker by DataLoader) pts = _sample_points_random(pts, self.num_points) if self.random_rotate: theta = float(torch.rand(1).item() * 2.0 * np.pi) R = _random_rotation_matrix_z(theta) pts = (pts @ R.T).astype(np.float32, copy=False) else: # Deterministic sampling, no augmentation rng = np.random.default_rng(self._base_seed + int(idx)) pts = _sample_points_deterministic(pts, self.num_points, rng=rng) # Label scaling: grams -> kg target = weight_g * self.label_scale points_t = torch.from_numpy(pts).to(dtype=torch.float32) # (N, 3) target_t = torch.tensor(target, dtype=torch.float32) meta = { "ply": str(ply_path), "sample_id": sample_id, "weight_g": weight_g, "target_kg": float(target), "xyz_scale": self.xyz_scale, "label_scale": self.label_scale, } return FishWeightSample(points=points_t, target=target_t, meta=meta) def create_dataloader( index_json: str | Path, batch_size: int = 16, num_workers: int = 4, shuffle: bool = True, train: bool = True, **dataset_kwargs: Any, ) -> DataLoader: ds = FishWeightDataset(index_json=index_json, train=train, **dataset_kwargs) def _collate(batch: List[FishWeightSample]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: x = torch.stack([b.points for b in batch], dim=0) # (B, N, 3) y = torch.stack([b.target for b in batch], dim=0) # (B,) meta = [b.meta for b in batch] return x, y, meta return DataLoader( ds, batch_size=int(batch_size), shuffle=bool(shuffle) if train else False, num_workers=int(num_workers), pin_memory=True, drop_last=train, collate_fn=_collate, )