459 lines
15 KiB
Python
Executable File
459 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Train a PointNet++ regressor to predict fish weight from 3D point clouds.
|
|
|
|
Inputs:
|
|
- Dataset index JSON produced by: weight_estimator/dataset.py
|
|
(weights are in grams in the CSV; dataloader scales target by /1000 -> kg)
|
|
|
|
Preprocessing is handled by weight_estimator/dataloader.py:
|
|
- XYZ scaled by 0.001
|
|
- Centered to origin
|
|
- Sample fixed N points (default 768)
|
|
- Random rotation (train only)
|
|
|
|
Model:
|
|
- PointNet++ SSG or MSG backbone (--model ssg or msg)
|
|
- Regression head -> 1 scalar (kg)
|
|
|
|
Example:
|
|
python weight_estimator/train_pointnet_weigth_estimator.py \
|
|
--index-json weight_estimator/dataset_index.json \
|
|
--model msg --epochs 200 --batch-size 32 --lr 1e-3
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
from dataclasses import asdict, dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Subset
|
|
from tqdm import tqdm
|
|
|
|
# Ensure repo root is on sys.path so `weight_estimator.*` imports work no matter cwd
|
|
import sys
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
# Local dataloader
|
|
from weight_estimator.dataloader import FishWeightDataset
|
|
|
|
|
|
def _add_pointnet_paths() -> Path:
|
|
"""
|
|
Add local PointNet++ library paths (weight_estimator/Pointnet_Pointnet2_pytorch).
|
|
Returns the base directory.
|
|
"""
|
|
base = Path(__file__).parent / "Pointnet_Pointnet2_pytorch"
|
|
models_dir = base / "models"
|
|
if not base.exists():
|
|
raise FileNotFoundError(f"PointNet++ library not found at: {base}")
|
|
import sys
|
|
|
|
if str(base) not in sys.path:
|
|
sys.path.insert(0, str(base))
|
|
if str(models_dir) not in sys.path:
|
|
sys.path.insert(0, str(models_dir))
|
|
return base
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
index_json: str
|
|
data_root: str | None
|
|
model: str
|
|
num_points: int
|
|
batch_size: int
|
|
epochs: int
|
|
lr: float
|
|
weight_decay: float
|
|
val_ratio: float
|
|
seed: int
|
|
num_workers: int
|
|
amp: bool
|
|
out_dir: str
|
|
|
|
|
|
class PointNet2RegressorSSG(nn.Module):
|
|
"""
|
|
Minimal regression adaptation of pointnet2_cls_ssg.py (remove log_softmax; output scalar).
|
|
"""
|
|
|
|
def __init__(self, normal_channel: bool = False):
|
|
super().__init__()
|
|
_add_pointnet_paths()
|
|
from pointnet2_utils import PointNetSetAbstraction # type: ignore
|
|
|
|
in_channel = 6 if normal_channel else 3
|
|
self.normal_channel = normal_channel
|
|
|
|
# Same SA settings as pointnet2_cls_ssg
|
|
self.sa1 = PointNetSetAbstraction(
|
|
npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False
|
|
)
|
|
self.sa2 = PointNetSetAbstraction(
|
|
npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False
|
|
)
|
|
self.sa3 = PointNetSetAbstraction(
|
|
npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True
|
|
)
|
|
|
|
self.fc1 = nn.Linear(1024, 512)
|
|
self.bn1 = nn.BatchNorm1d(512)
|
|
self.drop1 = nn.Dropout(0.4)
|
|
self.fc2 = nn.Linear(512, 256)
|
|
self.bn2 = nn.BatchNorm1d(256)
|
|
self.drop2 = nn.Dropout(0.4)
|
|
self.fc3 = nn.Linear(256, 1)
|
|
|
|
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
xyz: (B, 3, N) if normal_channel=False
|
|
(B, 6, N) if normal_channel=True (xyz+normals)
|
|
returns: (B,) predicted weight in kg
|
|
"""
|
|
b, c, n = xyz.shape
|
|
if self.normal_channel:
|
|
norm = xyz[:, 3:, :]
|
|
xyz3 = xyz[:, :3, :]
|
|
else:
|
|
norm = None
|
|
xyz3 = xyz
|
|
|
|
l1_xyz, l1_points = self.sa1(xyz3, norm)
|
|
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
|
|
_l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
|
|
|
|
x = l3_points.view(b, 1024)
|
|
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
|
|
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
|
|
x = self.fc3(x).squeeze(-1)
|
|
return x
|
|
|
|
|
|
class PointNet2RegressorMSG(nn.Module):
|
|
"""
|
|
Minimal regression adaptation of pointnet2_cls_msg.py (MSA: Multi-Scale Aggregation).
|
|
Uses PointNetSetAbstractionMsg for multi-scale grouping.
|
|
"""
|
|
|
|
def __init__(self, normal_channel: bool = False):
|
|
super().__init__()
|
|
_add_pointnet_paths()
|
|
from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction # type: ignore
|
|
|
|
in_channel = 3 if normal_channel else 0
|
|
self.normal_channel = normal_channel
|
|
|
|
# Same SA settings as pointnet2_cls_msg
|
|
self.sa1 = PointNetSetAbstractionMsg(
|
|
512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,
|
|
[[32, 32, 64], [64, 64, 128], [64, 96, 128]]
|
|
)
|
|
self.sa2 = PointNetSetAbstractionMsg(
|
|
128, [0.2, 0.4, 0.8], [32, 64, 128], 320,
|
|
[[64, 64, 128], [128, 128, 256], [128, 128, 256]]
|
|
)
|
|
self.sa3 = PointNetSetAbstraction(
|
|
npoint=None, radius=None, nsample=None, in_channel=640 + 3,
|
|
mlp=[256, 512, 1024], group_all=True
|
|
)
|
|
|
|
self.fc1 = nn.Linear(1024, 512)
|
|
self.bn1 = nn.BatchNorm1d(512)
|
|
self.drop1 = nn.Dropout(0.4)
|
|
self.fc2 = nn.Linear(512, 256)
|
|
self.bn2 = nn.BatchNorm1d(256)
|
|
self.drop2 = nn.Dropout(0.4)
|
|
self.fc3 = nn.Linear(256, 1)
|
|
|
|
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
xyz: (B, 3, N) if normal_channel=False
|
|
(B, 6, N) if normal_channel=True (xyz+normals)
|
|
returns: (B,) predicted weight in kg
|
|
"""
|
|
b, c, n = xyz.shape
|
|
if self.normal_channel:
|
|
norm = xyz[:, 3:, :]
|
|
xyz3 = xyz[:, :3, :]
|
|
else:
|
|
norm = None
|
|
xyz3 = xyz
|
|
|
|
l1_xyz, l1_points = self.sa1(xyz3, norm)
|
|
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
|
|
_l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
|
|
|
|
x = l3_points.view(b, 1024)
|
|
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
|
|
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
|
|
x = self.fc3(x).squeeze(-1)
|
|
return x
|
|
|
|
|
|
def build_model(model_type: str = "ssg", normal_channel: bool = False) -> nn.Module:
|
|
"""Build PointNet++ regressor. model_type: 'ssg' or 'msg'."""
|
|
if model_type == "msg":
|
|
return PointNet2RegressorMSG(normal_channel=normal_channel)
|
|
return PointNet2RegressorSSG(normal_channel=normal_channel)
|
|
|
|
|
|
def load_model_from_checkpoint(
|
|
ckpt_path: str | Path,
|
|
device: torch.device,
|
|
map_location: str | torch.device = "cuda",
|
|
) -> nn.Module:
|
|
"""
|
|
Load PointNet++ regressor from checkpoint.
|
|
Uses config.model from checkpoint if present; otherwise defaults to SSG.
|
|
"""
|
|
ckpt_path = Path(ckpt_path).expanduser().resolve()
|
|
if not ckpt_path.exists():
|
|
raise FileNotFoundError(f"checkpoint not found: {ckpt_path}")
|
|
ckpt = torch.load(str(ckpt_path), map_location=map_location)
|
|
config = ckpt.get("config", {}) or {}
|
|
model_type = str(config.get("model", "ssg")).lower()
|
|
model = build_model(model_type=model_type, normal_channel=False).to(device)
|
|
state = ckpt.get("model_state", None) or ckpt.get("model_state_dict", None) or ckpt
|
|
model.load_state_dict(state, strict=True)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def split_indices_by_sample_id(ds: FishWeightDataset, val_ratio: float, seed: int) -> Tuple[List[int], List[int]]:
|
|
# group indices by sample_id
|
|
sample_to_indices: Dict[str, List[int]] = {}
|
|
for i, item in enumerate(ds.items):
|
|
sid = str(item.get("sample_id", ""))
|
|
sample_to_indices.setdefault(sid, []).append(i)
|
|
|
|
sample_ids = sorted(sample_to_indices.keys())
|
|
rng = np.random.default_rng(seed)
|
|
rng.shuffle(sample_ids)
|
|
|
|
n_val = max(1, int(round(len(sample_ids) * val_ratio)))
|
|
val_ids = set(sample_ids[:n_val])
|
|
|
|
train_idx: List[int] = []
|
|
val_idx: List[int] = []
|
|
for sid, indices in sample_to_indices.items():
|
|
(val_idx if sid in val_ids else train_idx).extend(indices)
|
|
|
|
return train_idx, val_idx
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model: nn.Module, loader, device: torch.device) -> Dict[str, float]:
|
|
model.eval()
|
|
losses = []
|
|
abs_errs = []
|
|
|
|
for x, y, _meta in loader:
|
|
x = x.to(device) # (B, N, 3)
|
|
y = y.to(device) # (B,)
|
|
x = x.transpose(1, 2).contiguous() # (B, 3, N)
|
|
|
|
pred = model(x) # (B,)
|
|
loss = F.smooth_l1_loss(pred, y)
|
|
losses.append(loss.item())
|
|
|
|
abs_err = (pred - y).abs()
|
|
abs_errs.append(abs_err.mean().item())
|
|
|
|
mae_kg = float(np.mean(abs_errs)) if abs_errs else float("inf")
|
|
return {
|
|
"val_loss": float(np.mean(losses)) if losses else float("inf"),
|
|
"val_mae_kg": mae_kg,
|
|
"val_mae_g": mae_kg * 1000.0,
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser("PointNet++ fish weight regressor training")
|
|
parser.add_argument("--index-json", type=str, required=True, help="Path to dataset_index.json")
|
|
parser.add_argument("--data-root", type=str, default=None, help="Override data root (if JSON uses relative paths)")
|
|
parser.add_argument("--model", type=str, default="ssg", choices=["ssg", "msg"],
|
|
help="PointNet++ backbone: ssg (single-scale) or msg (multi-scale). Default: ssg")
|
|
parser.add_argument("--num-points", type=int, default=768, help="Number of points to sample (default: 768)")
|
|
parser.add_argument("--batch-size", type=int, default=32, help="Batch size (default: 32)")
|
|
parser.add_argument("--epochs", type=int, default=400, help="Epochs (default: 400)")
|
|
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate (default: 1e-3)")
|
|
parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay (default: 1e-4)")
|
|
parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation sample_id ratio (default: 0.2)")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
|
|
parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers (default: 4)")
|
|
parser.add_argument("--no-amp", action="store_true", help="Disable AMP")
|
|
parser.add_argument("--out-dir", type=str, default="weight_estimator/runs", help="Output base dir")
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
|
|
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
out_dir = Path(args.out_dir).expanduser().resolve() / run_name
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
cfg = TrainConfig(
|
|
index_json=args.index_json,
|
|
data_root=args.data_root,
|
|
model=args.model,
|
|
num_points=args.num_points,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs,
|
|
lr=args.lr,
|
|
weight_decay=args.weight_decay,
|
|
val_ratio=args.val_ratio,
|
|
seed=args.seed,
|
|
num_workers=args.num_workers,
|
|
amp=not args.no_amp,
|
|
out_dir=str(out_dir),
|
|
)
|
|
(out_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8")
|
|
|
|
# Build full dataset (train=True for augmentations); we will subset train/val.
|
|
full_ds = FishWeightDataset(
|
|
index_json=args.index_json,
|
|
data_root=args.data_root,
|
|
num_points=args.num_points,
|
|
xyz_scale=0.001,
|
|
label_scale=0.001, # g -> kg
|
|
train=True,
|
|
random_rotate=True,
|
|
seed=args.seed,
|
|
)
|
|
|
|
train_idx, val_idx = split_indices_by_sample_id(full_ds, val_ratio=args.val_ratio, seed=args.seed)
|
|
train_ds = Subset(full_ds, train_idx)
|
|
|
|
# Validation dataset: same underlying index, but disable augmentation
|
|
val_base_ds = FishWeightDataset(
|
|
index_json=args.index_json,
|
|
data_root=args.data_root,
|
|
num_points=args.num_points,
|
|
xyz_scale=0.001,
|
|
label_scale=0.001,
|
|
train=False,
|
|
random_rotate=False,
|
|
seed=args.seed,
|
|
)
|
|
val_ds = Subset(val_base_ds, val_idx)
|
|
|
|
def collate(batch):
|
|
# batch items are FishWeightSample from the underlying dataset
|
|
xs = torch.stack([b.points for b in batch], dim=0)
|
|
ys = torch.stack([b.target for b in batch], dim=0)
|
|
metas = [b.meta for b in batch]
|
|
return xs, ys, metas
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_ds,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
collate_fn=collate,
|
|
)
|
|
val_loader = torch.utils.data.DataLoader(
|
|
val_ds,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=args.num_workers,
|
|
pin_memory=True,
|
|
drop_last=False,
|
|
collate_fn=collate,
|
|
)
|
|
|
|
if args.model == "msg":
|
|
model = PointNet2RegressorMSG(normal_channel=False).to(device)
|
|
print(f"Using PointNet++ MSG (multi-scale)")
|
|
else:
|
|
model = PointNet2RegressorSSG(normal_channel=False).to(device)
|
|
print(f"Using PointNet++ SSG (single-scale)")
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and device.type == "cuda"))
|
|
|
|
best_mae = float("inf")
|
|
|
|
for epoch in range(1, args.epochs + 1):
|
|
model.train()
|
|
running = []
|
|
pbar = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False)
|
|
for x, y, _meta in pbar:
|
|
x = x.to(device) # (B, N, 3)
|
|
y = y.to(device) # (B,)
|
|
x = x.transpose(1, 2).contiguous() # (B, 3, N)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
with torch.cuda.amp.autocast(enabled=(cfg.amp and device.type == "cuda")):
|
|
pred = model(x)
|
|
loss = F.smooth_l1_loss(pred, y)
|
|
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
running.append(loss.item())
|
|
pbar.set_postfix(loss=float(np.mean(running)))
|
|
|
|
scheduler.step()
|
|
|
|
metrics = evaluate(model, val_loader, device)
|
|
lr_now = scheduler.get_last_lr()[0]
|
|
print(
|
|
f"[{epoch:03d}/{args.epochs}] "
|
|
f"train_loss={float(np.mean(running)):.6f} "
|
|
f"val_loss={metrics['val_loss']:.6f} "
|
|
f"val_mae={metrics['val_mae_g']:.2f}g "
|
|
f"lr={lr_now:.2e}"
|
|
)
|
|
|
|
# Save last
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"scheduler_state": scheduler.state_dict(),
|
|
"metrics": metrics,
|
|
"config": asdict(cfg),
|
|
},
|
|
out_dir / "last.pt",
|
|
)
|
|
|
|
# Save best
|
|
if metrics["val_mae_g"] < best_mae:
|
|
best_mae = metrics["val_mae_g"]
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model_state": model.state_dict(),
|
|
"metrics": metrics,
|
|
"config": asdict(cfg),
|
|
},
|
|
out_dir / "best.pt",
|
|
)
|
|
|
|
print(f"Done. Best val MAE: {best_mae:.2f} g")
|
|
print(f"Run dir: {out_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|