#!/usr/bin/env python3 """ Train a PointNet++ regressor to predict fish weight from 3D point clouds. Inputs: - Dataset index JSON produced by: weight_estimator/dataset.py (weights are in grams in the CSV; dataloader scales target by /1000 -> kg) Preprocessing is handled by weight_estimator/dataloader.py: - XYZ scaled by 0.001 - Centered to origin - Sample fixed N points (default 768) - Random rotation (train only) Model: - PointNet++ SSG or MSG backbone (--model ssg or msg) - Regression head -> 1 scalar (kg) Example: python weight_estimator/train_pointnet_weigth_estimator.py \ --index-json weight_estimator/dataset_index.json \ --model msg --epochs 200 --batch-size 32 --lr 1e-3 """ from __future__ import annotations import argparse import json import math import os from dataclasses import asdict, dataclass from datetime import datetime from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Subset from tqdm import tqdm # Ensure repo root is on sys.path so `weight_estimator.*` imports work no matter cwd import sys REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) # Local dataloader from weight_estimator.dataloader import FishWeightDataset def _add_pointnet_paths() -> Path: """ Add local PointNet++ library paths (weight_estimator/Pointnet_Pointnet2_pytorch). Returns the base directory. """ base = Path(__file__).parent / "Pointnet_Pointnet2_pytorch" models_dir = base / "models" if not base.exists(): raise FileNotFoundError(f"PointNet++ library not found at: {base}") import sys if str(base) not in sys.path: sys.path.insert(0, str(base)) if str(models_dir) not in sys.path: sys.path.insert(0, str(models_dir)) return base @dataclass class TrainConfig: index_json: str data_root: str | None model: str num_points: int batch_size: int epochs: int lr: float weight_decay: float val_ratio: float seed: int num_workers: int amp: bool out_dir: str class PointNet2RegressorSSG(nn.Module): """ Minimal regression adaptation of pointnet2_cls_ssg.py (remove log_softmax; output scalar). """ def __init__(self, normal_channel: bool = False): super().__init__() _add_pointnet_paths() from pointnet2_utils import PointNetSetAbstraction # type: ignore in_channel = 6 if normal_channel else 3 self.normal_channel = normal_channel # Same SA settings as pointnet2_cls_ssg self.sa1 = PointNetSetAbstraction( npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False ) self.sa2 = PointNetSetAbstraction( npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False ) self.sa3 = PointNetSetAbstraction( npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True ) self.fc1 = nn.Linear(1024, 512) self.bn1 = nn.BatchNorm1d(512) self.drop1 = nn.Dropout(0.4) self.fc2 = nn.Linear(512, 256) self.bn2 = nn.BatchNorm1d(256) self.drop2 = nn.Dropout(0.4) self.fc3 = nn.Linear(256, 1) def forward(self, xyz: torch.Tensor) -> torch.Tensor: """ xyz: (B, 3, N) if normal_channel=False (B, 6, N) if normal_channel=True (xyz+normals) returns: (B,) predicted weight in kg """ b, c, n = xyz.shape if self.normal_channel: norm = xyz[:, 3:, :] xyz3 = xyz[:, :3, :] else: norm = None xyz3 = xyz l1_xyz, l1_points = self.sa1(xyz3, norm) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) _l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) x = l3_points.view(b, 1024) x = self.drop1(F.relu(self.bn1(self.fc1(x)))) x = self.drop2(F.relu(self.bn2(self.fc2(x)))) x = self.fc3(x).squeeze(-1) return x class PointNet2RegressorMSG(nn.Module): """ Minimal regression adaptation of pointnet2_cls_msg.py (MSA: Multi-Scale Aggregation). Uses PointNetSetAbstractionMsg for multi-scale grouping. """ def __init__(self, normal_channel: bool = False): super().__init__() _add_pointnet_paths() from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction # type: ignore in_channel = 3 if normal_channel else 0 self.normal_channel = normal_channel # Same SA settings as pointnet2_cls_msg self.sa1 = PointNetSetAbstractionMsg( 512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]] ) self.sa2 = PointNetSetAbstractionMsg( 128, [0.2, 0.4, 0.8], [32, 64, 128], 320, [[64, 64, 128], [128, 128, 256], [128, 128, 256]] ) self.sa3 = PointNetSetAbstraction( npoint=None, radius=None, nsample=None, in_channel=640 + 3, mlp=[256, 512, 1024], group_all=True ) self.fc1 = nn.Linear(1024, 512) self.bn1 = nn.BatchNorm1d(512) self.drop1 = nn.Dropout(0.4) self.fc2 = nn.Linear(512, 256) self.bn2 = nn.BatchNorm1d(256) self.drop2 = nn.Dropout(0.4) self.fc3 = nn.Linear(256, 1) def forward(self, xyz: torch.Tensor) -> torch.Tensor: """ xyz: (B, 3, N) if normal_channel=False (B, 6, N) if normal_channel=True (xyz+normals) returns: (B,) predicted weight in kg """ b, c, n = xyz.shape if self.normal_channel: norm = xyz[:, 3:, :] xyz3 = xyz[:, :3, :] else: norm = None xyz3 = xyz l1_xyz, l1_points = self.sa1(xyz3, norm) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) _l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) x = l3_points.view(b, 1024) x = self.drop1(F.relu(self.bn1(self.fc1(x)))) x = self.drop2(F.relu(self.bn2(self.fc2(x)))) x = self.fc3(x).squeeze(-1) return x def build_model(model_type: str = "ssg", normal_channel: bool = False) -> nn.Module: """Build PointNet++ regressor. model_type: 'ssg' or 'msg'.""" if model_type == "msg": return PointNet2RegressorMSG(normal_channel=normal_channel) return PointNet2RegressorSSG(normal_channel=normal_channel) def load_model_from_checkpoint( ckpt_path: str | Path, device: torch.device, map_location: str | torch.device = "cuda", ) -> nn.Module: """ Load PointNet++ regressor from checkpoint. Uses config.model from checkpoint if present; otherwise defaults to SSG. """ ckpt_path = Path(ckpt_path).expanduser().resolve() if not ckpt_path.exists(): raise FileNotFoundError(f"checkpoint not found: {ckpt_path}") ckpt = torch.load(str(ckpt_path), map_location=map_location) config = ckpt.get("config", {}) or {} model_type = str(config.get("model", "ssg")).lower() model = build_model(model_type=model_type, normal_channel=False).to(device) state = ckpt.get("model_state", None) or ckpt.get("model_state_dict", None) or ckpt model.load_state_dict(state, strict=True) model.eval() return model def split_indices_by_sample_id(ds: FishWeightDataset, val_ratio: float, seed: int) -> Tuple[List[int], List[int]]: # group indices by sample_id sample_to_indices: Dict[str, List[int]] = {} for i, item in enumerate(ds.items): sid = str(item.get("sample_id", "")) sample_to_indices.setdefault(sid, []).append(i) sample_ids = sorted(sample_to_indices.keys()) rng = np.random.default_rng(seed) rng.shuffle(sample_ids) n_val = max(1, int(round(len(sample_ids) * val_ratio))) val_ids = set(sample_ids[:n_val]) train_idx: List[int] = [] val_idx: List[int] = [] for sid, indices in sample_to_indices.items(): (val_idx if sid in val_ids else train_idx).extend(indices) return train_idx, val_idx @torch.no_grad() def evaluate(model: nn.Module, loader, device: torch.device) -> Dict[str, float]: model.eval() losses = [] abs_errs = [] for x, y, _meta in loader: x = x.to(device) # (B, N, 3) y = y.to(device) # (B,) x = x.transpose(1, 2).contiguous() # (B, 3, N) pred = model(x) # (B,) loss = F.smooth_l1_loss(pred, y) losses.append(loss.item()) abs_err = (pred - y).abs() abs_errs.append(abs_err.mean().item()) mae_kg = float(np.mean(abs_errs)) if abs_errs else float("inf") return { "val_loss": float(np.mean(losses)) if losses else float("inf"), "val_mae_kg": mae_kg, "val_mae_g": mae_kg * 1000.0, } def main() -> None: parser = argparse.ArgumentParser("PointNet++ fish weight regressor training") parser.add_argument("--index-json", type=str, required=True, help="Path to dataset_index.json") parser.add_argument("--data-root", type=str, default=None, help="Override data root (if JSON uses relative paths)") parser.add_argument("--model", type=str, default="ssg", choices=["ssg", "msg"], help="PointNet++ backbone: ssg (single-scale) or msg (multi-scale). Default: ssg") parser.add_argument("--num-points", type=int, default=768, help="Number of points to sample (default: 768)") parser.add_argument("--batch-size", type=int, default=32, help="Batch size (default: 32)") parser.add_argument("--epochs", type=int, default=400, help="Epochs (default: 400)") parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate (default: 1e-3)") parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay (default: 1e-4)") parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation sample_id ratio (default: 0.2)") parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers (default: 4)") parser.add_argument("--no-amp", action="store_true", help="Disable AMP") parser.add_argument("--out-dir", type=str, default="weight_estimator/runs", help="Output base dir") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(args.seed) np.random.seed(args.seed) run_name = datetime.now().strftime("%Y%m%d_%H%M%S") out_dir = Path(args.out_dir).expanduser().resolve() / run_name out_dir.mkdir(parents=True, exist_ok=True) cfg = TrainConfig( index_json=args.index_json, data_root=args.data_root, model=args.model, num_points=args.num_points, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, val_ratio=args.val_ratio, seed=args.seed, num_workers=args.num_workers, amp=not args.no_amp, out_dir=str(out_dir), ) (out_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8") # Build full dataset (train=True for augmentations); we will subset train/val. full_ds = FishWeightDataset( index_json=args.index_json, data_root=args.data_root, num_points=args.num_points, xyz_scale=0.001, label_scale=0.001, # g -> kg train=True, random_rotate=True, seed=args.seed, ) train_idx, val_idx = split_indices_by_sample_id(full_ds, val_ratio=args.val_ratio, seed=args.seed) train_ds = Subset(full_ds, train_idx) # Validation dataset: same underlying index, but disable augmentation val_base_ds = FishWeightDataset( index_json=args.index_json, data_root=args.data_root, num_points=args.num_points, xyz_scale=0.001, label_scale=0.001, train=False, random_rotate=False, seed=args.seed, ) val_ds = Subset(val_base_ds, val_idx) def collate(batch): # batch items are FishWeightSample from the underlying dataset xs = torch.stack([b.points for b in batch], dim=0) ys = torch.stack([b.target for b in batch], dim=0) metas = [b.meta for b in batch] return xs, ys, metas train_loader = torch.utils.data.DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=True, collate_fn=collate, ) val_loader = torch.utils.data.DataLoader( val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False, collate_fn=collate, ) if args.model == "msg": model = PointNet2RegressorMSG(normal_channel=False).to(device) print(f"Using PointNet++ MSG (multi-scale)") else: model = PointNet2RegressorSSG(normal_channel=False).to(device) print(f"Using PointNet++ SSG (single-scale)") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and device.type == "cuda")) best_mae = float("inf") for epoch in range(1, args.epochs + 1): model.train() running = [] pbar = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False) for x, y, _meta in pbar: x = x.to(device) # (B, N, 3) y = y.to(device) # (B,) x = x.transpose(1, 2).contiguous() # (B, 3, N) optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=(cfg.amp and device.type == "cuda")): pred = model(x) loss = F.smooth_l1_loss(pred, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() running.append(loss.item()) pbar.set_postfix(loss=float(np.mean(running))) scheduler.step() metrics = evaluate(model, val_loader, device) lr_now = scheduler.get_last_lr()[0] print( f"[{epoch:03d}/{args.epochs}] " f"train_loss={float(np.mean(running)):.6f} " f"val_loss={metrics['val_loss']:.6f} " f"val_mae={metrics['val_mae_g']:.2f}g " f"lr={lr_now:.2e}" ) # Save last torch.save( { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "metrics": metrics, "config": asdict(cfg), }, out_dir / "last.pt", ) # Save best if metrics["val_mae_g"] < best_mae: best_mae = metrics["val_mae_g"] torch.save( { "epoch": epoch, "model_state": model.state_dict(), "metrics": metrics, "config": asdict(cfg), }, out_dir / "best.pt", ) print(f"Done. Best val MAE: {best_mae:.2f} g") print(f"Run dir: {out_dir}") if __name__ == "__main__": main()