#!/usr/bin/env python3 """ Train a DGCNN (Dynamic Graph CNN) regressor to predict fish weight from 3D point clouds. DGCNN uses EdgeConv layers with k-NN dynamic graphs. Often works well with limited data due to simpler architecture than Point Transformer while capturing local geometry better than vanilla PointNet. Model: - 4 EdgeConv layers (k-NN, k=20) - Concat features -> global max+avg pool -> MLP -> 1 scalar (kg) Uses the same FishWeightDataset and P0 augmentations as Point Transformer. Typically needs lr=1e-3 and batch_size=32 (similar to PointNet++). Example: python weight_estimator/train_dgcnn_weight_estimator.py \\ --index-json weight_estimator/dataset_index.json \\ --epochs 400 --batch-size 32 --lr 1e-3 """ from __future__ import annotations import argparse import json import sys from dataclasses import asdict, dataclass from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Subset from tqdm import tqdm REPO_ROOT = Path(__file__).resolve().parents[1] WEIGHT_EST_DIR = Path(__file__).resolve().parent for _p in (str(REPO_ROOT), str(WEIGHT_EST_DIR)): if _p not in sys.path: sys.path.insert(0, _p) from weight_estimator.dataloader import FishWeightDataset, FishWeightSample from dgcnn_weight_model import DGCNNWeightRegressor # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @dataclass class TrainConfig: index_json: str data_root: str | None num_points: int batch_size: int epochs: int lr: float weight_decay: float val_ratio: float seed: int num_workers: int amp: bool grad_clip: float early_stop_patience: int k: int emb_dims: int dropout: float random_rotate: bool random_translate: bool translate_range: float jitter_std: float jitter_clip: float point_dropout_prob: float point_dropout_ratio: float local_occlusion_prob: float local_occlusion_ratio: float out_dir: str backbone: str = "dgcnn" # --------------------------------------------------------------------------- # Utilities # --------------------------------------------------------------------------- def resolve_path(path: str) -> Path: p = Path(path).expanduser() if p.is_absolute(): return p for base in (Path.cwd(), WEIGHT_EST_DIR, REPO_ROOT): candidate = (base / p).resolve() if candidate.exists(): return candidate return (Path.cwd() / p).resolve() def set_seed(seed: int) -> None: torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _group_id_from_item(item: Dict[str, Any]) -> str: fish_id = str(item.get("fish_id", "")).strip() if fish_id: return fish_id sample_id = str(item.get("sample_id", "")).strip() if sample_id: return sample_id ply = str(item.get("ply", "")).strip() if ply: p = Path(ply) if p.parent.name == "cloud": return p.parent.parent.name return p.parent.name return "__unknown_group__" def split_indices_by_group_id( ds: FishWeightDataset, val_ratio: float, seed: int, ) -> Tuple[List[int], List[int], List[str], List[str]]: group_to_indices: Dict[str, List[int]] = {} for i, item in enumerate(ds.items): gid = _group_id_from_item(item) group_to_indices.setdefault(gid, []).append(i) group_ids = sorted(group_to_indices.keys()) rng = np.random.default_rng(seed) rng.shuffle(group_ids) n_val = max(1, int(round(len(group_ids) * val_ratio))) val_ids = set(group_ids[:n_val]) train_idx: List[int] = [] val_idx: List[int] = [] train_groups: List[str] = [] val_groups: List[str] = [] for gid, indices in group_to_indices.items(): if gid in val_ids: val_idx.extend(indices) val_groups.append(gid) else: train_idx.extend(indices) train_groups.append(gid) return train_idx, val_idx, train_groups, val_groups def collate_fn(batch: List[FishWeightSample]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: xs = torch.stack([b.points for b in batch], dim=0) ys = torch.stack([b.target for b in batch], dim=0) metas = [b.meta for b in batch] return xs, ys, metas # --------------------------------------------------------------------------- # Evaluation # --------------------------------------------------------------------------- @torch.no_grad() def evaluate( model: nn.Module, loader: DataLoader, device: torch.device, use_amp: bool = False, ) -> Dict[str, float]: model.eval() total_loss = 0.0 total_ae = 0.0 n_samples = 0 for x, y, _meta in loader: x = x.to(device) y = y.to(device) B = x.shape[0] with torch.amp.autocast("cuda", enabled=(use_amp and device.type == "cuda")): pred = model(x) loss = F.smooth_l1_loss(pred, y, reduction="sum") total_loss += loss.item() total_ae += (pred - y).abs().sum().item() n_samples += B n_samples = max(n_samples, 1) mae_kg = total_ae / n_samples return { "val_loss": total_loss / n_samples, "val_mae_kg": mae_kg, "val_mae_g": mae_kg * 1000.0, } # --------------------------------------------------------------------------- # Checkpoint helpers # --------------------------------------------------------------------------- def save_checkpoint( path: Path, epoch: int, model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, metrics: Dict[str, float], best_mae_g: float, cfg: TrainConfig, ) -> None: torch.save( { "epoch": epoch, "model_state": model.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict(), "metrics": metrics, "best_mae_g": best_mae_g, "config": asdict(cfg), }, path, ) def load_model_from_checkpoint(ckpt_path: str | Path, device: torch.device) -> nn.Module: """Load a trained DGCNNWeightRegressor for inference.""" ckpt_path = Path(ckpt_path).expanduser().resolve() if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") try: ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=False) except TypeError: ckpt = torch.load(str(ckpt_path), map_location=device) config = ckpt.get("config", {}) or {} k = config.get("k", 20) emb_dims = config.get("emb_dims", 1024) dropout = config.get("dropout", 0.5) model = DGCNNWeightRegressor(k=k, emb_dims=emb_dims, dropout=dropout).to(device) state = ckpt.get("model_state") or ckpt.get("model_state_dict") or ckpt model.load_state_dict(state, strict=True) model.eval() return model # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: parser = argparse.ArgumentParser(description="DGCNN fish weight regressor training") parser.add_argument("--index-json", type=str, required=True, help="Path to dataset_index.json") parser.add_argument("--data-root", type=str, default=None) parser.add_argument("--resume", type=str, default=None) parser.add_argument("--num-points", type=int, default=768) parser.add_argument("--batch-size", type=int, default=32, help="Batch size (default: 32, DGCNN is lighter than Point Transformer)") parser.add_argument("--epochs", type=int, default=400) parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate (default: 1e-3, similar to PointNet++)") parser.add_argument("--weight-decay", type=float, default=1e-4) parser.add_argument("--val-ratio", type=float, default=0.2) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--no-amp", action="store_true") parser.add_argument("--grad-clip", type=float, default=1.0) parser.add_argument("--early-stop-patience", type=int, default=0) parser.add_argument("--k", type=int, default=20, help="k-NN neighbors for EdgeConv (default: 20)") parser.add_argument("--emb-dims", type=int, default=1024, help="Embedding dimension (default: 1024)") parser.add_argument("--dropout", type=float, default=0.5, help="Dropout rate (default: 0.5)") parser.add_argument("--no-random-rotate", action="store_true") parser.add_argument("--no-random-translate", action="store_true") parser.add_argument("--translate-range", type=float, default=0.005) parser.add_argument("--jitter-std", type=float, default=0.001) parser.add_argument("--jitter-clip", type=float, default=0.003) parser.add_argument("--point-dropout-prob", type=float, default=0.5) parser.add_argument("--point-dropout-ratio", type=float, default=0.1) parser.add_argument("--local-occlusion-prob", type=float, default=0.3) parser.add_argument("--local-occlusion-ratio", type=float, default=0.15) parser.add_argument("--out-dir", type=str, default=None) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") set_seed(args.seed) use_amp = (not args.no_amp) and (device.type == "cuda") out_dir_base = args.out_dir or str(WEIGHT_EST_DIR / "runs") out_dir_base = Path(out_dir_base).expanduser() if not out_dir_base.is_absolute(): out_dir_base = (Path.cwd() / out_dir_base).resolve() run_name = f"dgcnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}" out_dir = out_dir_base / run_name out_dir.mkdir(parents=True, exist_ok=True) index_json_resolved = resolve_path(args.index_json) cfg = TrainConfig( index_json=str(index_json_resolved), data_root=args.data_root, num_points=args.num_points, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, val_ratio=args.val_ratio, seed=args.seed, num_workers=args.num_workers, amp=use_amp, grad_clip=args.grad_clip, early_stop_patience=args.early_stop_patience, k=args.k, emb_dims=args.emb_dims, dropout=args.dropout, random_rotate=not args.no_random_rotate, random_translate=not args.no_random_translate, translate_range=args.translate_range, jitter_std=args.jitter_std, jitter_clip=args.jitter_clip, point_dropout_prob=args.point_dropout_prob, point_dropout_ratio=args.point_dropout_ratio, local_occlusion_prob=args.local_occlusion_prob, local_occlusion_ratio=args.local_occlusion_ratio, out_dir=str(out_dir), ) (out_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8") full_ds = FishWeightDataset( index_json=str(index_json_resolved), data_root=args.data_root, num_points=args.num_points, xyz_scale=0.001, label_scale=0.001, train=True, random_rotate=not args.no_random_rotate, random_translate=not args.no_random_translate, translate_range=args.translate_range, jitter_std=args.jitter_std, jitter_clip=args.jitter_clip, point_dropout_prob=args.point_dropout_prob, point_dropout_ratio=args.point_dropout_ratio, local_occlusion_prob=args.local_occlusion_prob, local_occlusion_ratio=args.local_occlusion_ratio, seed=args.seed, ) train_idx, val_idx, train_groups, val_groups = split_indices_by_group_id(full_ds, args.val_ratio, args.seed) train_ds = Subset(full_ds, train_idx) val_base_ds = FishWeightDataset( index_json=str(index_json_resolved), data_root=args.data_root, num_points=args.num_points, xyz_scale=0.001, label_scale=0.001, train=False, random_rotate=False, random_translate=False, translate_range=0.0, jitter_std=0.0, jitter_clip=0.0, point_dropout_prob=0.0, point_dropout_ratio=0.0, local_occlusion_prob=0.0, local_occlusion_ratio=0.0, seed=args.seed, ) val_ds = Subset(val_base_ds, val_idx) if set(train_groups).intersection(set(val_groups)): raise RuntimeError("Data leakage: train/val groups overlap") loader_kwargs = dict( num_workers=args.num_workers, pin_memory=(device.type == "cuda"), collate_fn=collate_fn, ) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, **loader_kwargs) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, **loader_kwargs) model = DGCNNWeightRegressor(k=args.k, emb_dims=args.emb_dims, dropout=args.dropout).to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) scaler = torch.amp.GradScaler("cuda", enabled=use_amp) start_epoch = 1 best_mae_g = float("inf") epochs_no_improve = 0 if args.resume: ckpt_path = Path(args.resume).expanduser().resolve() if not ckpt_path.exists(): raise SystemExit(f"Checkpoint not found: {ckpt_path}") ckpt = torch.load(str(ckpt_path), map_location=device) model.load_state_dict(ckpt["model_state"], strict=True) if "optimizer_state" in ckpt: optimizer.load_state_dict(ckpt["optimizer_state"]) if "scheduler_state" in ckpt: scheduler.load_state_dict(ckpt["scheduler_state"]) start_epoch = ckpt.get("epoch", 0) + 1 best_mae_g = ckpt.get("best_mae_g", float("inf")) print(f"Resumed from {ckpt_path} | epoch {start_epoch}, best_mae_g={best_mae_g:.2f}g") n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model: DGCNNWeightRegressor (k={args.k}, emb={args.emb_dims}, {n_params:,} params)") print(f"Train: {len(train_ds)} samples, groups={len(train_groups)}") print(f"Val: {len(val_ds)} samples, groups={len(val_groups)}") print(f"Device: {device}, AMP: {use_amp}") print(f"Output: {out_dir}") print(flush=True) for epoch in range(start_epoch, args.epochs + 1): model.train() running_loss: List[float] = [] running_ae: List[float] = [] pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=False, unit="batch") for x, y, _meta in pbar: x = x.to(device) y = y.to(device) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", enabled=use_amp): pred = model(x) loss = F.smooth_l1_loss(pred, y) scaler.scale(loss).backward() if args.grad_clip > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() running_loss.append(loss.item()) with torch.no_grad(): running_ae.append((pred - y).abs().mean().item()) pbar.set_postfix(loss=f"{loss.item():.5f}", avg=f"{np.mean(running_loss):.5f}", mae_g=f"{np.mean(running_ae)*1000:.1f}") scheduler.step() if device.type == "cuda": torch.cuda.empty_cache() metrics = evaluate(model, val_loader, device, use_amp=use_amp) train_loss = float(np.mean(running_loss)) if running_loss else float("inf") train_mae_g = float(np.mean(running_ae)) * 1000 if running_ae else float("inf") lr_now = scheduler.get_last_lr()[0] print( f"[{epoch:03d}/{args.epochs}] " f"train_loss={train_loss:.6f} train_mae={train_mae_g:.1f}g | " f"val_loss={metrics['val_loss']:.6f} val_mae={metrics['val_mae_g']:.1f}g | " f"lr={lr_now:.2e}" ) save_checkpoint(out_dir / "last.pt", epoch, model, optimizer, scheduler, metrics, best_mae_g, cfg) if metrics["val_mae_g"] < best_mae_g: best_mae_g = metrics["val_mae_g"] save_checkpoint(out_dir / "best.pt", epoch, model, optimizer, scheduler, metrics, best_mae_g, cfg) print(f" -> New best: {best_mae_g:.1f}g") epochs_no_improve = 0 elif args.early_stop_patience > 0: epochs_no_improve += 1 if epochs_no_improve >= args.early_stop_patience: print(f"Early stop: no val MAE improvement for {args.early_stop_patience} epochs (best={best_mae_g:.1f}g)") break print(f"\nDone. Best val MAE: {best_mae_g:.1f}g") print(f"Run dir: {out_dir}") if __name__ == "__main__": main()