464 lines
17 KiB
Python
464 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Train a DGCNN (Dynamic Graph CNN) regressor to predict fish weight from 3D point clouds.
|
|
|
|
DGCNN uses EdgeConv layers with k-NN dynamic graphs. Often works well with limited
|
|
data due to simpler architecture than Point Transformer while capturing local
|
|
geometry better than vanilla PointNet.
|
|
|
|
Model:
|
|
- 4 EdgeConv layers (k-NN, k=20)
|
|
- Concat features -> global max+avg pool -> MLP -> 1 scalar (kg)
|
|
|
|
Uses the same FishWeightDataset and P0 augmentations as Point Transformer.
|
|
Typically needs lr=1e-3 and batch_size=32 (similar to PointNet++).
|
|
|
|
Example:
|
|
python weight_estimator/train_dgcnn_weight_estimator.py \\
|
|
--index-json weight_estimator/dataset_index.json \\
|
|
--epochs 400 --batch-size 32 --lr 1e-3
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from dataclasses import asdict, dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader, Subset
|
|
from tqdm import tqdm
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
WEIGHT_EST_DIR = Path(__file__).resolve().parent
|
|
for _p in (str(REPO_ROOT), str(WEIGHT_EST_DIR)):
|
|
if _p not in sys.path:
|
|
sys.path.insert(0, _p)
|
|
|
|
from weight_estimator.dataloader import FishWeightDataset, FishWeightSample
|
|
from dgcnn_weight_model import DGCNNWeightRegressor
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
index_json: str
|
|
data_root: str | None
|
|
num_points: int
|
|
batch_size: int
|
|
epochs: int
|
|
lr: float
|
|
weight_decay: float
|
|
val_ratio: float
|
|
seed: int
|
|
num_workers: int
|
|
amp: bool
|
|
grad_clip: float
|
|
early_stop_patience: int
|
|
k: int
|
|
emb_dims: int
|
|
dropout: float
|
|
random_rotate: bool
|
|
random_translate: bool
|
|
translate_range: float
|
|
jitter_std: float
|
|
jitter_clip: float
|
|
point_dropout_prob: float
|
|
point_dropout_ratio: float
|
|
local_occlusion_prob: float
|
|
local_occlusion_ratio: float
|
|
out_dir: str
|
|
backbone: str = "dgcnn"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def resolve_path(path: str) -> Path:
|
|
p = Path(path).expanduser()
|
|
if p.is_absolute():
|
|
return p
|
|
for base in (Path.cwd(), WEIGHT_EST_DIR, REPO_ROOT):
|
|
candidate = (base / p).resolve()
|
|
if candidate.exists():
|
|
return candidate
|
|
return (Path.cwd() / p).resolve()
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def _group_id_from_item(item: Dict[str, Any]) -> str:
|
|
fish_id = str(item.get("fish_id", "")).strip()
|
|
if fish_id:
|
|
return fish_id
|
|
sample_id = str(item.get("sample_id", "")).strip()
|
|
if sample_id:
|
|
return sample_id
|
|
ply = str(item.get("ply", "")).strip()
|
|
if ply:
|
|
p = Path(ply)
|
|
if p.parent.name == "cloud":
|
|
return p.parent.parent.name
|
|
return p.parent.name
|
|
return "__unknown_group__"
|
|
|
|
|
|
def split_indices_by_group_id(
|
|
ds: FishWeightDataset, val_ratio: float, seed: int,
|
|
) -> Tuple[List[int], List[int], List[str], List[str]]:
|
|
group_to_indices: Dict[str, List[int]] = {}
|
|
for i, item in enumerate(ds.items):
|
|
gid = _group_id_from_item(item)
|
|
group_to_indices.setdefault(gid, []).append(i)
|
|
group_ids = sorted(group_to_indices.keys())
|
|
rng = np.random.default_rng(seed)
|
|
rng.shuffle(group_ids)
|
|
n_val = max(1, int(round(len(group_ids) * val_ratio)))
|
|
val_ids = set(group_ids[:n_val])
|
|
train_idx: List[int] = []
|
|
val_idx: List[int] = []
|
|
train_groups: List[str] = []
|
|
val_groups: List[str] = []
|
|
for gid, indices in group_to_indices.items():
|
|
if gid in val_ids:
|
|
val_idx.extend(indices)
|
|
val_groups.append(gid)
|
|
else:
|
|
train_idx.extend(indices)
|
|
train_groups.append(gid)
|
|
return train_idx, val_idx, train_groups, val_groups
|
|
|
|
|
|
def collate_fn(batch: List[FishWeightSample]) -> Tuple[torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
|
|
xs = torch.stack([b.points for b in batch], dim=0)
|
|
ys = torch.stack([b.target for b in batch], dim=0)
|
|
metas = [b.meta for b in batch]
|
|
return xs, ys, metas
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Evaluation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@torch.no_grad()
|
|
def evaluate(
|
|
model: nn.Module,
|
|
loader: DataLoader,
|
|
device: torch.device,
|
|
use_amp: bool = False,
|
|
) -> Dict[str, float]:
|
|
model.eval()
|
|
total_loss = 0.0
|
|
total_ae = 0.0
|
|
n_samples = 0
|
|
for x, y, _meta in loader:
|
|
x = x.to(device)
|
|
y = y.to(device)
|
|
B = x.shape[0]
|
|
with torch.amp.autocast("cuda", enabled=(use_amp and device.type == "cuda")):
|
|
pred = model(x)
|
|
loss = F.smooth_l1_loss(pred, y, reduction="sum")
|
|
total_loss += loss.item()
|
|
total_ae += (pred - y).abs().sum().item()
|
|
n_samples += B
|
|
n_samples = max(n_samples, 1)
|
|
mae_kg = total_ae / n_samples
|
|
return {
|
|
"val_loss": total_loss / n_samples,
|
|
"val_mae_kg": mae_kg,
|
|
"val_mae_g": mae_kg * 1000.0,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Checkpoint helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def save_checkpoint(
|
|
path: Path,
|
|
epoch: int,
|
|
model: nn.Module,
|
|
optimizer: torch.optim.Optimizer,
|
|
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
|
metrics: Dict[str, float],
|
|
best_mae_g: float,
|
|
cfg: TrainConfig,
|
|
) -> None:
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"scheduler_state": scheduler.state_dict(),
|
|
"metrics": metrics,
|
|
"best_mae_g": best_mae_g,
|
|
"config": asdict(cfg),
|
|
},
|
|
path,
|
|
)
|
|
|
|
|
|
def load_model_from_checkpoint(ckpt_path: str | Path, device: torch.device) -> nn.Module:
|
|
"""Load a trained DGCNNWeightRegressor for inference."""
|
|
ckpt_path = Path(ckpt_path).expanduser().resolve()
|
|
if not ckpt_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
try:
|
|
ckpt = torch.load(str(ckpt_path), map_location=device, weights_only=False)
|
|
except TypeError:
|
|
ckpt = torch.load(str(ckpt_path), map_location=device)
|
|
config = ckpt.get("config", {}) or {}
|
|
k = config.get("k", 20)
|
|
emb_dims = config.get("emb_dims", 1024)
|
|
dropout = config.get("dropout", 0.5)
|
|
model = DGCNNWeightRegressor(k=k, emb_dims=emb_dims, dropout=dropout).to(device)
|
|
state = ckpt.get("model_state") or ckpt.get("model_state_dict") or ckpt
|
|
model.load_state_dict(state, strict=True)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="DGCNN fish weight regressor training")
|
|
parser.add_argument("--index-json", type=str, required=True,
|
|
help="Path to dataset_index.json")
|
|
parser.add_argument("--data-root", type=str, default=None)
|
|
parser.add_argument("--resume", type=str, default=None)
|
|
parser.add_argument("--num-points", type=int, default=768)
|
|
parser.add_argument("--batch-size", type=int, default=32,
|
|
help="Batch size (default: 32, DGCNN is lighter than Point Transformer)")
|
|
parser.add_argument("--epochs", type=int, default=400)
|
|
parser.add_argument("--lr", type=float, default=1e-3,
|
|
help="Learning rate (default: 1e-3, similar to PointNet++)")
|
|
parser.add_argument("--weight-decay", type=float, default=1e-4)
|
|
parser.add_argument("--val-ratio", type=float, default=0.2)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--num-workers", type=int, default=4)
|
|
parser.add_argument("--no-amp", action="store_true")
|
|
parser.add_argument("--grad-clip", type=float, default=1.0)
|
|
parser.add_argument("--early-stop-patience", type=int, default=0)
|
|
parser.add_argument("--k", type=int, default=20,
|
|
help="k-NN neighbors for EdgeConv (default: 20)")
|
|
parser.add_argument("--emb-dims", type=int, default=1024,
|
|
help="Embedding dimension (default: 1024)")
|
|
parser.add_argument("--dropout", type=float, default=0.5,
|
|
help="Dropout rate (default: 0.5)")
|
|
parser.add_argument("--no-random-rotate", action="store_true")
|
|
parser.add_argument("--no-random-translate", action="store_true")
|
|
parser.add_argument("--translate-range", type=float, default=0.005)
|
|
parser.add_argument("--jitter-std", type=float, default=0.001)
|
|
parser.add_argument("--jitter-clip", type=float, default=0.003)
|
|
parser.add_argument("--point-dropout-prob", type=float, default=0.5)
|
|
parser.add_argument("--point-dropout-ratio", type=float, default=0.1)
|
|
parser.add_argument("--local-occlusion-prob", type=float, default=0.3)
|
|
parser.add_argument("--local-occlusion-ratio", type=float, default=0.15)
|
|
parser.add_argument("--out-dir", type=str, default=None)
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
set_seed(args.seed)
|
|
|
|
use_amp = (not args.no_amp) and (device.type == "cuda")
|
|
|
|
out_dir_base = args.out_dir or str(WEIGHT_EST_DIR / "runs")
|
|
out_dir_base = Path(out_dir_base).expanduser()
|
|
if not out_dir_base.is_absolute():
|
|
out_dir_base = (Path.cwd() / out_dir_base).resolve()
|
|
run_name = f"dgcnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
|
out_dir = out_dir_base / run_name
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
index_json_resolved = resolve_path(args.index_json)
|
|
|
|
cfg = TrainConfig(
|
|
index_json=str(index_json_resolved),
|
|
data_root=args.data_root,
|
|
num_points=args.num_points,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs,
|
|
lr=args.lr,
|
|
weight_decay=args.weight_decay,
|
|
val_ratio=args.val_ratio,
|
|
seed=args.seed,
|
|
num_workers=args.num_workers,
|
|
amp=use_amp,
|
|
grad_clip=args.grad_clip,
|
|
early_stop_patience=args.early_stop_patience,
|
|
k=args.k,
|
|
emb_dims=args.emb_dims,
|
|
dropout=args.dropout,
|
|
random_rotate=not args.no_random_rotate,
|
|
random_translate=not args.no_random_translate,
|
|
translate_range=args.translate_range,
|
|
jitter_std=args.jitter_std,
|
|
jitter_clip=args.jitter_clip,
|
|
point_dropout_prob=args.point_dropout_prob,
|
|
point_dropout_ratio=args.point_dropout_ratio,
|
|
local_occlusion_prob=args.local_occlusion_prob,
|
|
local_occlusion_ratio=args.local_occlusion_ratio,
|
|
out_dir=str(out_dir),
|
|
)
|
|
(out_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8")
|
|
|
|
full_ds = FishWeightDataset(
|
|
index_json=str(index_json_resolved),
|
|
data_root=args.data_root,
|
|
num_points=args.num_points,
|
|
xyz_scale=0.001,
|
|
label_scale=0.001,
|
|
train=True,
|
|
random_rotate=not args.no_random_rotate,
|
|
random_translate=not args.no_random_translate,
|
|
translate_range=args.translate_range,
|
|
jitter_std=args.jitter_std,
|
|
jitter_clip=args.jitter_clip,
|
|
point_dropout_prob=args.point_dropout_prob,
|
|
point_dropout_ratio=args.point_dropout_ratio,
|
|
local_occlusion_prob=args.local_occlusion_prob,
|
|
local_occlusion_ratio=args.local_occlusion_ratio,
|
|
seed=args.seed,
|
|
)
|
|
|
|
train_idx, val_idx, train_groups, val_groups = split_indices_by_group_id(full_ds, args.val_ratio, args.seed)
|
|
train_ds = Subset(full_ds, train_idx)
|
|
|
|
val_base_ds = FishWeightDataset(
|
|
index_json=str(index_json_resolved),
|
|
data_root=args.data_root,
|
|
num_points=args.num_points,
|
|
xyz_scale=0.001,
|
|
label_scale=0.001,
|
|
train=False,
|
|
random_rotate=False,
|
|
random_translate=False,
|
|
translate_range=0.0,
|
|
jitter_std=0.0,
|
|
jitter_clip=0.0,
|
|
point_dropout_prob=0.0,
|
|
point_dropout_ratio=0.0,
|
|
local_occlusion_prob=0.0,
|
|
local_occlusion_ratio=0.0,
|
|
seed=args.seed,
|
|
)
|
|
val_ds = Subset(val_base_ds, val_idx)
|
|
if set(train_groups).intersection(set(val_groups)):
|
|
raise RuntimeError("Data leakage: train/val groups overlap")
|
|
|
|
loader_kwargs = dict(
|
|
num_workers=args.num_workers,
|
|
pin_memory=(device.type == "cuda"),
|
|
collate_fn=collate_fn,
|
|
)
|
|
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, drop_last=True, **loader_kwargs)
|
|
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, drop_last=False, **loader_kwargs)
|
|
|
|
model = DGCNNWeightRegressor(k=args.k, emb_dims=args.emb_dims, dropout=args.dropout).to(device)
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
|
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)
|
|
|
|
start_epoch = 1
|
|
best_mae_g = float("inf")
|
|
epochs_no_improve = 0
|
|
|
|
if args.resume:
|
|
ckpt_path = Path(args.resume).expanduser().resolve()
|
|
if not ckpt_path.exists():
|
|
raise SystemExit(f"Checkpoint not found: {ckpt_path}")
|
|
ckpt = torch.load(str(ckpt_path), map_location=device)
|
|
model.load_state_dict(ckpt["model_state"], strict=True)
|
|
if "optimizer_state" in ckpt:
|
|
optimizer.load_state_dict(ckpt["optimizer_state"])
|
|
if "scheduler_state" in ckpt:
|
|
scheduler.load_state_dict(ckpt["scheduler_state"])
|
|
start_epoch = ckpt.get("epoch", 0) + 1
|
|
best_mae_g = ckpt.get("best_mae_g", float("inf"))
|
|
print(f"Resumed from {ckpt_path} | epoch {start_epoch}, best_mae_g={best_mae_g:.2f}g")
|
|
|
|
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
print(f"Model: DGCNNWeightRegressor (k={args.k}, emb={args.emb_dims}, {n_params:,} params)")
|
|
print(f"Train: {len(train_ds)} samples, groups={len(train_groups)}")
|
|
print(f"Val: {len(val_ds)} samples, groups={len(val_groups)}")
|
|
print(f"Device: {device}, AMP: {use_amp}")
|
|
print(f"Output: {out_dir}")
|
|
print(flush=True)
|
|
|
|
for epoch in range(start_epoch, args.epochs + 1):
|
|
model.train()
|
|
running_loss: List[float] = []
|
|
running_ae: List[float] = []
|
|
|
|
pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{args.epochs}", leave=False, unit="batch")
|
|
for x, y, _meta in pbar:
|
|
x = x.to(device)
|
|
y = y.to(device)
|
|
optimizer.zero_grad(set_to_none=True)
|
|
with torch.amp.autocast("cuda", enabled=use_amp):
|
|
pred = model(x)
|
|
loss = F.smooth_l1_loss(pred, y)
|
|
scaler.scale(loss).backward()
|
|
if args.grad_clip > 0:
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
running_loss.append(loss.item())
|
|
with torch.no_grad():
|
|
running_ae.append((pred - y).abs().mean().item())
|
|
pbar.set_postfix(loss=f"{loss.item():.5f}", avg=f"{np.mean(running_loss):.5f}", mae_g=f"{np.mean(running_ae)*1000:.1f}")
|
|
|
|
scheduler.step()
|
|
if device.type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
|
|
metrics = evaluate(model, val_loader, device, use_amp=use_amp)
|
|
train_loss = float(np.mean(running_loss)) if running_loss else float("inf")
|
|
train_mae_g = float(np.mean(running_ae)) * 1000 if running_ae else float("inf")
|
|
lr_now = scheduler.get_last_lr()[0]
|
|
|
|
print(
|
|
f"[{epoch:03d}/{args.epochs}] "
|
|
f"train_loss={train_loss:.6f} train_mae={train_mae_g:.1f}g | "
|
|
f"val_loss={metrics['val_loss']:.6f} val_mae={metrics['val_mae_g']:.1f}g | "
|
|
f"lr={lr_now:.2e}"
|
|
)
|
|
|
|
save_checkpoint(out_dir / "last.pt", epoch, model, optimizer, scheduler, metrics, best_mae_g, cfg)
|
|
|
|
if metrics["val_mae_g"] < best_mae_g:
|
|
best_mae_g = metrics["val_mae_g"]
|
|
save_checkpoint(out_dir / "best.pt", epoch, model, optimizer, scheduler, metrics, best_mae_g, cfg)
|
|
print(f" -> New best: {best_mae_g:.1f}g")
|
|
epochs_no_improve = 0
|
|
elif args.early_stop_patience > 0:
|
|
epochs_no_improve += 1
|
|
if epochs_no_improve >= args.early_stop_patience:
|
|
print(f"Early stop: no val MAE improvement for {args.early_stop_patience} epochs (best={best_mae_g:.1f}g)")
|
|
break
|
|
|
|
print(f"\nDone. Best val MAE: {best_mae_g:.1f}g")
|
|
print(f"Run dir: {out_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|