206 lines
6.6 KiB
Python
206 lines
6.6 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
PyTorch Dataset/DataLoader for fish weight regression from point clouds.
|
||
|
|
|
||
|
|
Preprocessing (per sample):
|
||
|
|
- Load points from .ply
|
||
|
|
- Scale XYZ by 0.001
|
||
|
|
- Center to origin (subtract centroid)
|
||
|
|
- Sample exactly N points (default: 768)
|
||
|
|
- Random rotation augmentation (default: around Z axis, train mode only)
|
||
|
|
|
||
|
|
Label:
|
||
|
|
- CSV column F is in grams
|
||
|
|
- We scale label by /1000 so the target is in kg
|
||
|
|
(e.g. 0.1 == 0.1 kg == 100 g)
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, List, Optional, Tuple
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import open3d as o3d
|
||
|
|
import torch
|
||
|
|
from torch.utils.data import Dataset, DataLoader
|
||
|
|
|
||
|
|
|
||
|
|
def _load_points_from_ply(ply_path: Path) -> np.ndarray:
|
||
|
|
"""
|
||
|
|
Load Nx3 points from a PLY. Supports either point cloud or mesh PLY.
|
||
|
|
"""
|
||
|
|
# Try point cloud first
|
||
|
|
pcd = o3d.io.read_point_cloud(str(ply_path))
|
||
|
|
if len(pcd.points) > 0:
|
||
|
|
pts = np.asarray(pcd.points)
|
||
|
|
return pts.astype(np.float32, copy=False)
|
||
|
|
|
||
|
|
# Fallback: triangle mesh
|
||
|
|
mesh = o3d.io.read_triangle_mesh(str(ply_path))
|
||
|
|
if len(mesh.vertices) > 0:
|
||
|
|
pts = np.asarray(mesh.vertices)
|
||
|
|
return pts.astype(np.float32, copy=False)
|
||
|
|
|
||
|
|
raise ValueError(f"No points/vertices found in: {ply_path}")
|
||
|
|
|
||
|
|
|
||
|
|
def _sample_points_deterministic(points: np.ndarray, num_points: int, rng: np.random.Generator) -> np.ndarray:
|
||
|
|
"""Deterministic sampling (used for eval)."""
|
||
|
|
n = points.shape[0]
|
||
|
|
if n <= 0:
|
||
|
|
raise ValueError("Empty point cloud")
|
||
|
|
if n >= num_points:
|
||
|
|
idx = rng.choice(n, size=num_points, replace=False)
|
||
|
|
else:
|
||
|
|
idx = rng.choice(n, size=num_points, replace=True)
|
||
|
|
return points[idx]
|
||
|
|
|
||
|
|
|
||
|
|
def _sample_points_random(points: np.ndarray, num_points: int) -> np.ndarray:
|
||
|
|
"""Random sampling using torch RNG (used for train; works well with DataLoader worker seeding)."""
|
||
|
|
n = points.shape[0]
|
||
|
|
if n <= 0:
|
||
|
|
raise ValueError("Empty point cloud")
|
||
|
|
if n >= num_points:
|
||
|
|
idx = torch.randperm(n)[:num_points].cpu().numpy()
|
||
|
|
else:
|
||
|
|
idx = torch.randint(low=0, high=n, size=(num_points,)).cpu().numpy()
|
||
|
|
return points[idx]
|
||
|
|
|
||
|
|
|
||
|
|
def _random_rotation_matrix_z(theta: float) -> np.ndarray:
|
||
|
|
c, s = float(np.cos(theta)), float(np.sin(theta))
|
||
|
|
return np.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]], dtype=np.float32)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(frozen=True)
|
||
|
|
class FishWeightSample:
|
||
|
|
points: torch.Tensor # (N, 3), float32
|
||
|
|
target: torch.Tensor # (), float32, kg
|
||
|
|
meta: Dict[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
class FishWeightDataset(Dataset):
|
||
|
|
"""
|
||
|
|
Dataset backed by the JSON produced by weight_estimator/dataset.py.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
index_json: str | Path,
|
||
|
|
data_root: Optional[str | Path] = None,
|
||
|
|
num_points: int = 768,
|
||
|
|
xyz_scale: float = 0.001,
|
||
|
|
label_scale: float = 0.001, # grams -> kg
|
||
|
|
train: bool = True,
|
||
|
|
random_rotate: bool = True,
|
||
|
|
seed: int = 42,
|
||
|
|
) -> None:
|
||
|
|
self.index_json = Path(index_json).expanduser().resolve()
|
||
|
|
if not self.index_json.exists():
|
||
|
|
raise FileNotFoundError(f"Index JSON not found: {self.index_json}")
|
||
|
|
|
||
|
|
with self.index_json.open("r", encoding="utf-8") as f:
|
||
|
|
index = json.load(f)
|
||
|
|
|
||
|
|
items = index.get("items", None)
|
||
|
|
if not isinstance(items, list) or not items:
|
||
|
|
raise ValueError(f"Invalid index JSON (missing/empty items): {self.index_json}")
|
||
|
|
|
||
|
|
self.items: List[Dict[str, Any]] = items
|
||
|
|
|
||
|
|
# If the JSON stored relative paths, meta.data_root is the base.
|
||
|
|
meta_root = index.get("meta", {}).get("data_root", None)
|
||
|
|
self.data_root = Path(data_root or meta_root or "/").expanduser().resolve()
|
||
|
|
|
||
|
|
self.num_points = int(num_points)
|
||
|
|
self.xyz_scale = float(xyz_scale)
|
||
|
|
self.label_scale = float(label_scale)
|
||
|
|
self.train = bool(train)
|
||
|
|
self.random_rotate = bool(random_rotate)
|
||
|
|
self._base_seed = int(seed)
|
||
|
|
|
||
|
|
def __len__(self) -> int:
|
||
|
|
return len(self.items)
|
||
|
|
|
||
|
|
def __getitem__(self, idx: int) -> FishWeightSample:
|
||
|
|
item = self.items[idx]
|
||
|
|
|
||
|
|
ply_str = item["ply"]
|
||
|
|
sample_id = item.get("sample_id", "")
|
||
|
|
weight_g = float(item["weight_g"])
|
||
|
|
|
||
|
|
ply_path = Path(ply_str)
|
||
|
|
if not ply_path.is_absolute():
|
||
|
|
ply_path = self.data_root / ply_path
|
||
|
|
ply_path = ply_path.expanduser().resolve()
|
||
|
|
|
||
|
|
pts = _load_points_from_ply(ply_path) # (M, 3)
|
||
|
|
|
||
|
|
# Scale XYZ (e.g. mm -> m)
|
||
|
|
pts = pts * self.xyz_scale
|
||
|
|
|
||
|
|
# Center to origin (centroid)
|
||
|
|
centroid = pts.mean(axis=0, keepdims=True)
|
||
|
|
pts = pts - centroid
|
||
|
|
|
||
|
|
if self.train:
|
||
|
|
# Random sampling/augmentation (torch RNG is seeded per-worker by DataLoader)
|
||
|
|
pts = _sample_points_random(pts, self.num_points)
|
||
|
|
if self.random_rotate:
|
||
|
|
theta = float(torch.rand(1).item() * 2.0 * np.pi)
|
||
|
|
R = _random_rotation_matrix_z(theta)
|
||
|
|
pts = (pts @ R.T).astype(np.float32, copy=False)
|
||
|
|
else:
|
||
|
|
# Deterministic sampling, no augmentation
|
||
|
|
rng = np.random.default_rng(self._base_seed + int(idx))
|
||
|
|
pts = _sample_points_deterministic(pts, self.num_points, rng=rng)
|
||
|
|
|
||
|
|
# Label scaling: grams -> kg
|
||
|
|
target = weight_g * self.label_scale
|
||
|
|
|
||
|
|
points_t = torch.from_numpy(pts).to(dtype=torch.float32) # (N, 3)
|
||
|
|
target_t = torch.tensor(target, dtype=torch.float32)
|
||
|
|
|
||
|
|
meta = {
|
||
|
|
"ply": str(ply_path),
|
||
|
|
"sample_id": sample_id,
|
||
|
|
"weight_g": weight_g,
|
||
|
|
"target_kg": float(target),
|
||
|
|
"xyz_scale": self.xyz_scale,
|
||
|
|
"label_scale": self.label_scale,
|
||
|
|
}
|
||
|
|
|
||
|
|
return FishWeightSample(points=points_t, target=target_t, meta=meta)
|
||
|
|
|
||
|
|
|
||
|
|
def create_dataloader(
|
||
|
|
index_json: str | Path,
|
||
|
|
batch_size: int = 16,
|
||
|
|
num_workers: int = 4,
|
||
|
|
shuffle: bool = True,
|
||
|
|
train: bool = True,
|
||
|
|
**dataset_kwargs: Any,
|
||
|
|
) -> DataLoader:
|
||
|
|
ds = FishWeightDataset(index_json=index_json, train=train, **dataset_kwargs)
|
||
|
|
|
||
|
|
def _collate(batch: List[FishWeightSample]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
|
||
|
|
x = torch.stack([b.points for b in batch], dim=0) # (B, N, 3)
|
||
|
|
y = torch.stack([b.target for b in batch], dim=0) # (B,)
|
||
|
|
meta = [b.meta for b in batch]
|
||
|
|
return x, y, meta
|
||
|
|
|
||
|
|
return DataLoader(
|
||
|
|
ds,
|
||
|
|
batch_size=int(batch_size),
|
||
|
|
shuffle=bool(shuffle) if train else False,
|
||
|
|
num_workers=int(num_workers),
|
||
|
|
pin_memory=True,
|
||
|
|
drop_last=train,
|
||
|
|
collate_fn=_collate,
|
||
|
|
)
|
||
|
|
|