Files
FishServer/FishMeasure/weight_estimator/train_dgcnn_weight_estimator.py

464 lines
17 KiB
Python

#!/usr/bin/env python3
"""
Train a DGCNN (Dynamic Graph CNN) regressor to predict fish weight from 3D point clouds.
DGCNN uses EdgeConv layers with k-NN dynamic graphs. Often works well with limited
data due to simpler architecture than Point Transformer while capturing local
geometry better than vanilla PointNet.
Model:
- 4 EdgeConv layers (k-NN, k=20)
- Concat features -> global max+avg pool -> MLP -> 1 scalar (kg)
Uses the same FishWeightDataset and P0 augmentations as Point Transformer.
Typically needs lr=1e-3 and batch_size=32 (similar to PointNet++).
Example:
python weight_estimator/train_dgcnn_weight_estimator.py \\
--index-json weight_estimator/dataset_index.json \\
--epochs 400 --batch-size 32 --lr 1e-3
"""
from __future__ import annotations
import argparse
import json
import sys
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
REPO_ROOT = Path(__file__).resolve().parents[1]
WEIGHT_EST_DIR = Path(__file__).resolve().parent
for _p in (str(REPO_ROOT), str(WEIGHT_EST_DIR)):
if _p not in sys.path:
sys.path.insert(0, _p)
from weight_estimator.dataloader import FishWeightDataset, FishWeightSample
from dgcnn_weight_model import DGCNNWeightRegressor
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
@dataclass
class TrainConfig:
index_json: str
data_root: str | None
num_points: int
batch_size: int
epochs: int
lr: float
weight_decay: float
val_ratio: float
seed: int
num_workers: int
amp: bool
grad_clip: float
early_stop_patience: int
k: int
emb_dims: int
dropout: float
random_rotate: bool
random_translate: bool
translate_range: float
jitter_std: float
jitter_clip: float
point_dropout_prob: float
point_dropout_ratio: float
local_occlusion_prob: float
local_occlusion_ratio: float
out_dir: str
backbone: str = "dgcnn"
# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
def resolve_path(path: str) -> Path:
p = Path(path).expanduser()
if p.is_absolute():
return p
for base in (Path.cwd(), WEIGHT_EST_DIR, REPO_ROOT):
candidate = (base / p).resolve()
if candidate.exists():
return candidate
return (Path.cwd() / p).resolve()
def set_seed(seed: int) -> None:
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _group_id_from_item(item: Dict[str, Any]) -> str:
fish_id = str(item.get("fish_id", "")).strip()
if fish_id:
return fish_id
sample_id = str(item.get("sample_id", "")).strip()
if sample_id:
return sample_id
ply = str(item.get("ply", "")).strip()
if ply:
p = Path(ply)
if p.parent.name == "cloud":
return p.parent.parent.name
return p.parent.name
return "__unknown_group__"
def split_indices_by_group_id(
ds: FishWeightDataset, val_ratio: float, seed: int,
) -> Tuple[List[int], List[int], List[str], List[str]]:
group_to_indices: Dict[str, List[int]] = {}
for i, item in enumerate(ds.items):
gid = _group_id_from_item(item)
group_to_indices.setdefault(gid, []).append(i)
group_ids = sorted(group_to_indices.keys())
rng = np.random.default_rng(seed)
rng.shuffle(group_ids)
n_val = max(1, int(round(len(group_ids) * val_ratio)))
val_ids = set(group_ids[:n_val])
train_idx: List[int] = []
val_idx: List[int] = []
train_groups: List[str] = []
val_groups: List[str] = []
for gid, indices in group_to_indices.items():
if gid in val_ids:
val_idx.extend(indices)
val_groups.append(gid)
else:
train_idx.extend(indices)
train_groups.append(gid)
return train_idx, val_idx, train_groups, val_groups
def collate_fn(batch: List[FishWeightSample]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
xs = torch.stack([b.points for b in batch], dim=0)
ys = torch.stack([b.target for b in batch], dim=0)
metas = [b.meta for b in batch]
return xs, ys, metas
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate(
model: nn.Module,
loader: DataLoader,
device: torch.device,
use_amp: bool = False,
) -> Dict[str, float]:
model.eval()
total_loss = 0.0
total_ae = 0.0
n_samples = 0
for x, y, _meta in loader:
x = x.to(device)
y = y.to(device)
B = x.shape[0]
with torch.amp.autocast("cuda", enabled=(use_amp and device.type == "cuda")):
pred = model(x)
loss = F.smooth_l1_loss(pred, y, reduction="sum")
total_loss += loss.item()
total_ae += (pred - y).abs().sum().item()
n_samples += B
n_samples = max(n_samples, 1)
mae_kg = total_ae / n_samples
return {
"val_loss": total_loss / n_samples,
"val_mae_kg": mae_kg,
"val_mae_g": mae_kg * 1000.0,
}
# ---------------------------------------------------------------------------
# Checkpoint helpers
# ---------------------------------------------------------------------------
def save_checkpoint(
path: Path,
epoch: int,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
metrics: Dict[str, float],
best_mae_g: float,
cfg: TrainConfig,
) -> None:
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"scheduler_state": scheduler.state_dict(),
"metrics": metrics,
"best_mae_g": best_mae_g,
"config": asdict(cfg),
},
path,
)
def load_model_from_checkpoint(ckpt_path: str | Path, device: torch.device) -> nn.Module:
"""Load a trained DGCNNWeightRegressor for inference."""
ckpt_path = Path(ckpt_path).expanduser().resolve()
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
try:
ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=False)
except TypeError:
ckpt = torch.load(str(ckpt_path), map_location=device)
config = ckpt.get("config", {}) or {}
k = config.get("k", 20)
emb_dims = config.get("emb_dims", 1024)
dropout = config.get("dropout", 0.5)
model = DGCNNWeightRegressor(k=k, emb_dims=emb_dims, dropout=dropout).to(device)
state = ckpt.get("model_state") or ckpt.get("model_state_dict") or ckpt
model.load_state_dict(state, strict=True)
model.eval()
return model
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="DGCNN fish weight regressor training")
parser.add_argument("--index-json", type=str, required=True,
help="Path to dataset_index.json")
parser.add_argument("--data-root", type=str, default=None)
parser.add_argument("--resume", type=str, default=None)
parser.add_argument("--num-points", type=int, default=768)
parser.add_argument("--batch-size", type=int, default=32,
help="Batch size (default: 32, DGCNN is lighter than Point Transformer)")
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--lr", type=float, default=1e-3,
help="Learning rate (default: 1e-3, similar to PointNet++)")
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--val-ratio", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--no-amp", action="store_true")
parser.add_argument("--grad-clip", type=float, default=1.0)
parser.add_argument("--early-stop-patience", type=int, default=0)
parser.add_argument("--k", type=int, default=20,
help="k-NN neighbors for EdgeConv (default: 20)")
parser.add_argument("--emb-dims", type=int, default=1024,
help="Embedding dimension (default: 1024)")
parser.add_argument("--dropout", type=float, default=0.5,
help="Dropout rate (default: 0.5)")
parser.add_argument("--no-random-rotate", action="store_true")
parser.add_argument("--no-random-translate", action="store_true")
parser.add_argument("--translate-range", type=float, default=0.005)
parser.add_argument("--jitter-std", type=float, default=0.001)
parser.add_argument("--jitter-clip", type=float, default=0.003)
parser.add_argument("--point-dropout-prob", type=float, default=0.5)
parser.add_argument("--point-dropout-ratio", type=float, default=0.1)
parser.add_argument("--local-occlusion-prob", type=float, default=0.3)
parser.add_argument("--local-occlusion-ratio", type=float, default=0.15)
parser.add_argument("--out-dir", type=str, default=None)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(args.seed)
use_amp = (not args.no_amp) and (device.type == "cuda")
out_dir_base = args.out_dir or str(WEIGHT_EST_DIR / "runs")
out_dir_base = Path(out_dir_base).expanduser()
if not out_dir_base.is_absolute():
out_dir_base = (Path.cwd() / out_dir_base).resolve()
run_name = f"dgcnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
out_dir = out_dir_base / run_name
out_dir.mkdir(parents=True, exist_ok=True)
index_json_resolved = resolve_path(args.index_json)
cfg = TrainConfig(
index_json=str(index_json_resolved),
data_root=args.data_root,
num_points=args.num_points,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
val_ratio=args.val_ratio,
seed=args.seed,
num_workers=args.num_workers,
amp=use_amp,
grad_clip=args.grad_clip,
early_stop_patience=args.early_stop_patience,
k=args.k,
emb_dims=args.emb_dims,
dropout=args.dropout,
random_rotate=not args.no_random_rotate,
random_translate=not args.no_random_translate,
translate_range=args.translate_range,
jitter_std=args.jitter_std,
jitter_clip=args.jitter_clip,
point_dropout_prob=args.point_dropout_prob,
point_dropout_ratio=args.point_dropout_ratio,
local_occlusion_prob=args.local_occlusion_prob,
local_occlusion_ratio=args.local_occlusion_ratio,
out_dir=str(out_dir),
)
(out_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8")
full_ds = FishWeightDataset(
index_json=str(index_json_resolved),
data_root=args.data_root,
num_points=args.num_points,
xyz_scale=0.001,
label_scale=0.001,
train=True,
random_rotate=not args.no_random_rotate,
random_translate=not args.no_random_translate,
translate_range=args.translate_range,
jitter_std=args.jitter_std,
jitter_clip=args.jitter_clip,
point_dropout_prob=args.point_dropout_prob,
point_dropout_ratio=args.point_dropout_ratio,
local_occlusion_prob=args.local_occlusion_prob,
local_occlusion_ratio=args.local_occlusion_ratio,
seed=args.seed,
)
train_idx, val_idx, train_groups, val_groups = split_indices_by_group_id(full_ds, args.val_ratio, args.seed)
train_ds = Subset(full_ds, train_idx)
val_base_ds = FishWeightDataset(
index_json=str(index_json_resolved),
data_root=args.data_root,
num_points=args.num_points,
xyz_scale=0.001,
label_scale=0.001,
train=False,
random_rotate=False,
random_translate=False,
translate_range=0.0,
jitter_std=0.0,
jitter_clip=0.0,
point_dropout_prob=0.0,
point_dropout_ratio=0.0,
local_occlusion_prob=0.0,
local_occlusion_ratio=0.0,
seed=args.seed,
)
val_ds = Subset(val_base_ds, val_idx)
if set(train_groups).intersection(set(val_groups)):
raise RuntimeError("Data leakage: train/val groups overlap")
loader_kwargs = dict(
num_workers=args.num_workers,
pin_memory=(device.type == "cuda"),
collate_fn=collate_fn,
)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, **loader_kwargs)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, **loader_kwargs)
model = DGCNNWeightRegressor(k=args.k, emb_dims=args.emb_dims, dropout=args.dropout).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
start_epoch = 1
best_mae_g = float("inf")
epochs_no_improve = 0
if args.resume:
ckpt_path = Path(args.resume).expanduser().resolve()
if not ckpt_path.exists():
raise SystemExit(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(str(ckpt_path), map_location=device)
model.load_state_dict(ckpt["model_state"], strict=True)
if "optimizer_state" in ckpt:
optimizer.load_state_dict(ckpt["optimizer_state"])
if "scheduler_state" in ckpt:
scheduler.load_state_dict(ckpt["scheduler_state"])
start_epoch = ckpt.get("epoch", 0) + 1
best_mae_g = ckpt.get("best_mae_g", float("inf"))
print(f"Resumed from {ckpt_path} | epoch {start_epoch}, best_mae_g={best_mae_g:.2f}g")
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model: DGCNNWeightRegressor (k={args.k}, emb={args.emb_dims}, {n_params:,} params)")
print(f"Train: {len(train_ds)} samples, groups={len(train_groups)}")
print(f"Val: {len(val_ds)} samples, groups={len(val_groups)}")
print(f"Device: {device}, AMP: {use_amp}")
print(f"Output: {out_dir}")
print(flush=True)
for epoch in range(start_epoch, args.epochs + 1):
model.train()
running_loss: List[float] = []
running_ae: List[float] = []
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=False, unit="batch")
for x, y, _meta in pbar:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", enabled=use_amp):
pred = model(x)
loss = F.smooth_l1_loss(pred, y)
scaler.scale(loss).backward()
if args.grad_clip > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
running_loss.append(loss.item())
with torch.no_grad():
running_ae.append((pred - y).abs().mean().item())
pbar.set_postfix(loss=f"{loss.item():.5f}", avg=f"{np.mean(running_loss):.5f}", mae_g=f"{np.mean(running_ae)*1000:.1f}")
scheduler.step()
if device.type == "cuda":
torch.cuda.empty_cache()
metrics = evaluate(model, val_loader, device, use_amp=use_amp)
train_loss = float(np.mean(running_loss)) if running_loss else float("inf")
train_mae_g = float(np.mean(running_ae)) * 1000 if running_ae else float("inf")
lr_now = scheduler.get_last_lr()[0]
print(
f"[{epoch:03d}/{args.epochs}] "
f"train_loss={train_loss:.6f} train_mae={train_mae_g:.1f}g | "
f"val_loss={metrics['val_loss']:.6f} val_mae={metrics['val_mae_g']:.1f}g | "
f"lr={lr_now:.2e}"
)
save_checkpoint(out_dir / "last.pt", epoch, model, optimizer, scheduler, metrics, best_mae_g, cfg)
if metrics["val_mae_g"] < best_mae_g:
best_mae_g = metrics["val_mae_g"]
save_checkpoint(out_dir / "best.pt", epoch, model, optimizer, scheduler, metrics, best_mae_g, cfg)
print(f" -> New best: {best_mae_g:.1f}g")
epochs_no_improve = 0
elif args.early_stop_patience > 0:
epochs_no_improve += 1
if epochs_no_improve >= args.early_stop_patience:
print(f"Early stop: no val MAE improvement for {args.early_stop_patience} epochs (best={best_mae_g:.1f}g)")
break
print(f"\nDone. Best val MAE: {best_mae_g:.1f}g")
print(f"Run dir: {out_dir}")
if __name__ == "__main__":
main()