Files
FishServer/FishMeasure/weight_estimator/train_pointnet_weigth_estimator.py
2026-04-08 19:32:23 +08:00

459 lines
15 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Train a PointNet++ regressor to predict fish weight from 3D point clouds.
Inputs:
- Dataset index JSON produced by: weight_estimator/dataset.py
(weights are in grams in the CSV; dataloader scales target by /1000 -> kg)
Preprocessing is handled by weight_estimator/dataloader.py:
- XYZ scaled by 0.001
- Centered to origin
- Sample fixed N points (default 768)
- Random rotation (train only)
Model:
- PointNet++ SSG or MSG backbone (--model ssg or msg)
- Regression head -> 1 scalar (kg)
Example:
python weight_estimator/train_pointnet_weigth_estimator.py \
--index-json weight_estimator/dataset_index.json \
--model msg --epochs 200 --batch-size 32 --lr 1e-3
"""
from __future__ import annotations
import argparse
import json
import math
import os
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
from tqdm import tqdm
# Ensure repo root is on sys.path so `weight_estimator.*` imports work no matter cwd
import sys
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
# Local dataloader
from weight_estimator.dataloader import FishWeightDataset
def _add_pointnet_paths() -> Path:
"""
Add local PointNet++ library paths (weight_estimator/Pointnet_Pointnet2_pytorch).
Returns the base directory.
"""
base = Path(__file__).parent / "Pointnet_Pointnet2_pytorch"
models_dir = base / "models"
if not base.exists():
raise FileNotFoundError(f"PointNet++ library not found at: {base}")
import sys
if str(base) not in sys.path:
sys.path.insert(0, str(base))
if str(models_dir) not in sys.path:
sys.path.insert(0, str(models_dir))
return base
@dataclass
class TrainConfig:
index_json: str
data_root: str | None
model: str
num_points: int
batch_size: int
epochs: int
lr: float
weight_decay: float
val_ratio: float
seed: int
num_workers: int
amp: bool
out_dir: str
class PointNet2RegressorSSG(nn.Module):
"""
Minimal regression adaptation of pointnet2_cls_ssg.py (remove log_softmax; output scalar).
"""
def __init__(self, normal_channel: bool = False):
super().__init__()
_add_pointnet_paths()
from pointnet2_utils import PointNetSetAbstraction # type: ignore
in_channel = 6 if normal_channel else 3
self.normal_channel = normal_channel
# Same SA settings as pointnet2_cls_ssg
self.sa1 = PointNetSetAbstraction(
npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False
)
self.sa2 = PointNetSetAbstraction(
npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False
)
self.sa3 = PointNetSetAbstraction(
npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True
)
self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(256, 1)
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
"""
xyz: (B, 3, N) if normal_channel=False
(B, 6, N) if normal_channel=True (xyz+normals)
returns: (B,) predicted weight in kg
"""
b, c, n = xyz.shape
if self.normal_channel:
norm = xyz[:, 3:, :]
xyz3 = xyz[:, :3, :]
else:
norm = None
xyz3 = xyz
l1_xyz, l1_points = self.sa1(xyz3, norm)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
_l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
x = l3_points.view(b, 1024)
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
x = self.fc3(x).squeeze(-1)
return x
class PointNet2RegressorMSG(nn.Module):
"""
Minimal regression adaptation of pointnet2_cls_msg.py (MSA: Multi-Scale Aggregation).
Uses PointNetSetAbstractionMsg for multi-scale grouping.
"""
def __init__(self, normal_channel: bool = False):
super().__init__()
_add_pointnet_paths()
from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction # type: ignore
in_channel = 3 if normal_channel else 0
self.normal_channel = normal_channel
# Same SA settings as pointnet2_cls_msg
self.sa1 = PointNetSetAbstractionMsg(
512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,
[[32, 32, 64], [64, 64, 128], [64, 96, 128]]
)
self.sa2 = PointNetSetAbstractionMsg(
128, [0.2, 0.4, 0.8], [32, 64, 128], 320,
[[64, 64, 128], [128, 128, 256], [128, 128, 256]]
)
self.sa3 = PointNetSetAbstraction(
npoint=None, radius=None, nsample=None, in_channel=640 + 3,
mlp=[256, 512, 1024], group_all=True
)
self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(256, 1)
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
"""
xyz: (B, 3, N) if normal_channel=False
(B, 6, N) if normal_channel=True (xyz+normals)
returns: (B,) predicted weight in kg
"""
b, c, n = xyz.shape
if self.normal_channel:
norm = xyz[:, 3:, :]
xyz3 = xyz[:, :3, :]
else:
norm = None
xyz3 = xyz
l1_xyz, l1_points = self.sa1(xyz3, norm)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
_l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
x = l3_points.view(b, 1024)
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
x = self.fc3(x).squeeze(-1)
return x
def build_model(model_type: str = "ssg", normal_channel: bool = False) -> nn.Module:
"""Build PointNet++ regressor. model_type: 'ssg' or 'msg'."""
if model_type == "msg":
return PointNet2RegressorMSG(normal_channel=normal_channel)
return PointNet2RegressorSSG(normal_channel=normal_channel)
def load_model_from_checkpoint(
ckpt_path: str | Path,
device: torch.device,
map_location: str | torch.device = "cuda",
) -> nn.Module:
"""
Load PointNet++ regressor from checkpoint.
Uses config.model from checkpoint if present; otherwise defaults to SSG.
"""
ckpt_path = Path(ckpt_path).expanduser().resolve()
if not ckpt_path.exists():
raise FileNotFoundError(f"checkpoint not found: {ckpt_path}")
ckpt = torch.load(str(ckpt_path), map_location=map_location)
config = ckpt.get("config", {}) or {}
model_type = str(config.get("model", "ssg")).lower()
model = build_model(model_type=model_type, normal_channel=False).to(device)
state = ckpt.get("model_state", None) or ckpt.get("model_state_dict", None) or ckpt
model.load_state_dict(state, strict=True)
model.eval()
return model
def split_indices_by_sample_id(ds: FishWeightDataset, val_ratio: float, seed: int) -> Tuple[List[int], List[int]]:
# group indices by sample_id
sample_to_indices: Dict[str, List[int]] = {}
for i, item in enumerate(ds.items):
sid = str(item.get("sample_id", ""))
sample_to_indices.setdefault(sid, []).append(i)
sample_ids = sorted(sample_to_indices.keys())
rng = np.random.default_rng(seed)
rng.shuffle(sample_ids)
n_val = max(1, int(round(len(sample_ids) * val_ratio)))
val_ids = set(sample_ids[:n_val])
train_idx: List[int] = []
val_idx: List[int] = []
for sid, indices in sample_to_indices.items():
(val_idx if sid in val_ids else train_idx).extend(indices)
return train_idx, val_idx
@torch.no_grad()
def evaluate(model: nn.Module, loader, device: torch.device) -> Dict[str, float]:
model.eval()
losses = []
abs_errs = []
for x, y, _meta in loader:
x = x.to(device) # (B, N, 3)
y = y.to(device) # (B,)
x = x.transpose(1, 2).contiguous() # (B, 3, N)
pred = model(x) # (B,)
loss = F.smooth_l1_loss(pred, y)
losses.append(loss.item())
abs_err = (pred - y).abs()
abs_errs.append(abs_err.mean().item())
mae_kg = float(np.mean(abs_errs)) if abs_errs else float("inf")
return {
"val_loss": float(np.mean(losses)) if losses else float("inf"),
"val_mae_kg": mae_kg,
"val_mae_g": mae_kg * 1000.0,
}
def main() -> None:
parser = argparse.ArgumentParser("PointNet++ fish weight regressor training")
parser.add_argument("--index-json", type=str, required=True, help="Path to dataset_index.json")
parser.add_argument("--data-root", type=str, default=None, help="Override data root (if JSON uses relative paths)")
parser.add_argument("--model", type=str, default="ssg", choices=["ssg", "msg"],
help="PointNet++ backbone: ssg (single-scale) or msg (multi-scale). Default: ssg")
parser.add_argument("--num-points", type=int, default=768, help="Number of points to sample (default: 768)")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size (default: 32)")
parser.add_argument("--epochs", type=int, default=400, help="Epochs (default: 400)")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate (default: 1e-3)")
parser.add_argument("--weight-decay", type=float, default=1e-4, help="Weight decay (default: 1e-4)")
parser.add_argument("--val-ratio", type=float, default=0.2, help="Validation sample_id ratio (default: 0.2)")
parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers (default: 4)")
parser.add_argument("--no-amp", action="store_true", help="Disable AMP")
parser.add_argument("--out-dir", type=str, default="weight_estimator/runs", help="Output base dir")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
np.random.seed(args.seed)
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
out_dir = Path(args.out_dir).expanduser().resolve() / run_name
out_dir.mkdir(parents=True, exist_ok=True)
cfg = TrainConfig(
index_json=args.index_json,
data_root=args.data_root,
model=args.model,
num_points=args.num_points,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
val_ratio=args.val_ratio,
seed=args.seed,
num_workers=args.num_workers,
amp=not args.no_amp,
out_dir=str(out_dir),
)
(out_dir / "config.json").write_text(json.dumps(asdict(cfg), indent=2), encoding="utf-8")
# Build full dataset (train=True for augmentations); we will subset train/val.
full_ds = FishWeightDataset(
index_json=args.index_json,
data_root=args.data_root,
num_points=args.num_points,
xyz_scale=0.001,
label_scale=0.001, # g -> kg
train=True,
random_rotate=True,
seed=args.seed,
)
train_idx, val_idx = split_indices_by_sample_id(full_ds, val_ratio=args.val_ratio, seed=args.seed)
train_ds = Subset(full_ds, train_idx)
# Validation dataset: same underlying index, but disable augmentation
val_base_ds = FishWeightDataset(
index_json=args.index_json,
data_root=args.data_root,
num_points=args.num_points,
xyz_scale=0.001,
label_scale=0.001,
train=False,
random_rotate=False,
seed=args.seed,
)
val_ds = Subset(val_base_ds, val_idx)
def collate(batch):
# batch items are FishWeightSample from the underlying dataset
xs = torch.stack([b.points for b in batch], dim=0)
ys = torch.stack([b.target for b in batch], dim=0)
metas = [b.meta for b in batch]
return xs, ys, metas
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=collate,
)
val_loader = torch.utils.data.DataLoader(
val_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=collate,
)
if args.model == "msg":
model = PointNet2RegressorMSG(normal_channel=False).to(device)
print(f"Using PointNet++ MSG (multi-scale)")
else:
model = PointNet2RegressorSSG(normal_channel=False).to(device)
print(f"Using PointNet++ SSG (single-scale)")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
scaler = torch.cuda.amp.GradScaler(enabled=(cfg.amp and device.type == "cuda"))
best_mae = float("inf")
for epoch in range(1, args.epochs + 1):
model.train()
running = []
pbar = tqdm(train_loader, desc=f"epoch {epoch}/{args.epochs}", leave=False)
for x, y, _meta in pbar:
x = x.to(device) # (B, N, 3)
y = y.to(device) # (B,)
x = x.transpose(1, 2).contiguous() # (B, 3, N)
optimizer.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=(cfg.amp and device.type == "cuda")):
pred = model(x)
loss = F.smooth_l1_loss(pred, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running.append(loss.item())
pbar.set_postfix(loss=float(np.mean(running)))
scheduler.step()
metrics = evaluate(model, val_loader, device)
lr_now = scheduler.get_last_lr()[0]
print(
f"[{epoch:03d}/{args.epochs}] "
f"train_loss={float(np.mean(running)):.6f} "
f"val_loss={metrics['val_loss']:.6f} "
f"val_mae={metrics['val_mae_g']:.2f}g "
f"lr={lr_now:.2e}"
)
# Save last
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"scheduler_state": scheduler.state_dict(),
"metrics": metrics,
"config": asdict(cfg),
},
out_dir / "last.pt",
)
# Save best
if metrics["val_mae_g"] < best_mae:
best_mae = metrics["val_mae_g"]
torch.save(
{
"epoch": epoch,
"model_state": model.state_dict(),
"metrics": metrics,
"config": asdict(cfg),
},
out_dir / "best.pt",
)
print(f"Done. Best val MAE: {best_mae:.2f} g")
print(f"Run dir: {out_dir}")
if __name__ == "__main__":
main()