#!/usr/bin/env python3 """ Train PointNet++ for fish point cloud quality classification. This script trains a PointNet++ model to classify fish point clouds as "good" (high quality) or "bad" (low quality). Supports loading pretrained weights if available. Usage: python train_pointnet.py \ --data dataset/ \ --num_points 1024 \ --epochs 100 \ --batch_size 32 \ --lr 0.001 \ --output checkpoints/ \ --pretrained path/to/pretrained_weights.pth """ import argparse import os import sys from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import StepLR import open3d as o3d from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score import json from tqdm import tqdm class PointCloudDataset(Dataset): """Dataset for loading point clouds from PLY files.""" def __init__(self, data_dir, num_points=1024, split='train', augment=False): """ Args: data_dir: Root directory containing train/val/test splits num_points: Number of points to sample from each point cloud split: 'train', 'val', or 'test' augment: Whether to apply data augmentation (only for training) """ self.data_dir = Path(data_dir) self.num_points = num_points self.split = split self.augment = augment and (split == 'train') # Load file paths self.good_files = list((self.data_dir / split / 'good').glob('*.ply')) self.bad_files = list((self.data_dir / split / 'bad').glob('*.ply')) # Create labels: 0 for bad, 1 for good self.files = self.bad_files + self.good_files self.labels = [0] * len(self.bad_files) + [1] * len(self.good_files) print(f"Loaded {split} dataset:") print(f" Good samples: {len(self.good_files)}") print(f" Bad samples: {len(self.bad_files)}") print(f" Total: {len(self.files)}") def __len__(self): return len(self.files) def __getitem__(self, idx): file_path = self.files[idx] label = self.labels[idx] # Load point cloud try: pcd = o3d.io.read_point_cloud(str(file_path)) points = np.asarray(pcd.points) # Handle empty point clouds if len(points) == 0: print(f"Warning: Empty point cloud {file_path}") points = np.zeros((self.num_points, 3), dtype=np.float32) # Sample or pad to fixed number of points points = self._normalize_points(points) # Data augmentation (only for training) if self.augment: points = self._augment_points(points) # Convert to tensor points = torch.FloatTensor(points) label = torch.LongTensor([label])[0] return points, label except Exception as e: print(f"Error loading {file_path}: {e}") # Return zero point cloud as fallback points = np.zeros((self.num_points, 3), dtype=np.float32) points = torch.FloatTensor(points) label = torch.LongTensor([label])[0] return points, label def _normalize_points(self, points): """Normalize and sample points to fixed number.""" # Center the point cloud centroid = points.mean(axis=0) points = points - centroid # Scale to unit sphere max_dist = np.max(np.linalg.norm(points, axis=1)) if max_dist > 0: points = points / max_dist # Sample or pad to fixed number of points n_points = len(points) if n_points >= self.num_points: # Randomly sample points indices = np.random.choice(n_points, self.num_points, replace=False) points = points[indices] else: # Pad with random points from the point cloud indices = np.random.choice(n_points, self.num_points, replace=True) points = points[indices] return points.astype(np.float32) def _augment_points(self, points): """Apply data augmentation.""" # Random rotation around z-axis if np.random.rand() > 0.5: angle = np.random.uniform(0, 2 * np.pi) cos_a, sin_a = np.cos(angle), np.sin(angle) rotation = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]]) points = points @ rotation.T # Random scaling if np.random.rand() > 0.5: scale = np.random.uniform(0.9, 1.1) points = points * scale # Random jitter if np.random.rand() > 0.5: jitter = np.random.normal(0, 0.01, points.shape) points = points + jitter return points class PointNetSetAbstraction(nn.Module): """PointNet++ Set Abstraction layer.""" def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False): super(PointNetSetAbstraction, self).__init__() self.npoint = npoint self.radius = radius self.nsample = nsample self.group_all = group_all self.mlp_convs = nn.ModuleList() self.mlp_bns = nn.ModuleList() last_channel = in_channel for out_channel in mlp: self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) self.mlp_bns.append(nn.BatchNorm2d(out_channel)) last_channel = out_channel def forward(self, xyz, points): """ Args: xyz: Input points position data, [B, N, 3] points: Input points data, [B, N, C] Returns: new_xyz: Sampled points position data, [B, S, 3] new_points: Sample points feature data, [B, S, D'] """ xyz = xyz.permute(0, 2, 1) # [B, 3, N] if points is not None: points = points.permute(0, 2, 1) # [B, C, N] if self.group_all: new_xyz = None new_points = torch.cat([xyz, points], dim=1) if points is not None else xyz else: # Farthest Point Sampling new_xyz = self._farthest_point_sample(xyz, self.npoint) # [B, 3, S] # Ball Query idx = self._ball_query(self.radius, self.nsample, xyz, new_xyz) # [B, S, nsample] # Group points grouped_xyz = self._index_points(xyz, idx) # [B, 3, S, nsample] grouped_xyz_norm = grouped_xyz - new_xyz.view(-1, 3, 1, 1) # [B, 3, S, nsample] if points is not None: grouped_points = self._index_points(points, idx) # [B, C, S, nsample] grouped_points_norm = torch.cat([grouped_points, grouped_xyz_norm], dim=1) # [B, C+3, S, nsample] else: grouped_points_norm = grouped_xyz_norm # MLP for i, conv in enumerate(self.mlp_convs): bn = self.mlp_bns[i] grouped_points_norm = F.relu(bn(conv(grouped_points_norm))) new_points = torch.max(grouped_points_norm, 3)[0] # [B, D', S] new_xyz = new_xyz.permute(0, 2, 1) # [B, S, 3] return new_xyz, new_points def _farthest_point_sample(self, xyz, npoint): """Farthest Point Sampling.""" device = xyz.device B, C, N = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[:, :, farthest].view(B, C, 1) dist = torch.sum((xyz - centroid) ** 2, 1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return xyz[:, :, centroids] # [B, 3, npoint] def _ball_query(self, radius, nsample, xyz, new_xyz): """Ball Query.""" device = xyz.device B, C, N = xyz.shape _, _, S = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) sqrdists = self._square_distance(new_xyz, xyz) group_idx[sqrdists > radius ** 2] = N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] return group_idx def _square_distance(self, src, dst): """Calculate squared euclidean distance.""" B, C, N = src.shape _, _, M = dst.shape dist = -2 * torch.matmul(src.permute(0, 2, 1), dst) dist += torch.sum(src ** 2, dim=1).view(B, N, 1) dist += torch.sum(dst ** 2, dim=1).view(B, 1, M) return dist def _index_points(self, points, idx): """Index points.""" device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, :, idx] return new_points class PointNetPlusPlus(nn.Module): """PointNet++ for classification.""" def __init__(self, num_classes=2, num_points=1024, use_pretrained=False, pretrained_path=None): super(PointNetPlusPlus, self).__init__() self.num_points = num_points # SA1: 512 points self.sa1 = PointNetSetAbstraction( npoint=512, radius=0.2, nsample=32, in_channel=3, mlp=[64, 64, 128] ) # SA2: 128 points self.sa2 = PointNetSetAbstraction( npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256] ) # SA3: Global self.sa3 = PointNetSetAbstraction( npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True ) # Classification head self.fc1 = nn.Linear(1024, 512) self.bn1 = nn.BatchNorm1d(512) self.drop1 = nn.Dropout(0.4) self.fc2 = nn.Linear(512, 256) self.bn2 = nn.BatchNorm1d(256) self.drop2 = nn.Dropout(0.4) self.fc3 = nn.Linear(256, num_classes) # Load pretrained weights if provided if use_pretrained and pretrained_path: self._load_pretrained(pretrained_path) def _load_pretrained(self, pretrained_path): """Load pretrained weights.""" if not os.path.exists(pretrained_path): print(f"Warning: Pretrained weights not found at {pretrained_path}") print("Training from scratch...") return try: checkpoint = torch.load(pretrained_path, map_location='cpu') # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint else: state_dict = checkpoint # Filter out classification head if num_classes doesn't match model_dict = self.state_dict() pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and model_dict[k].shape == v.shape} # Exclude classification head for transfer learning pretrained_dict = {k: v for k, v in pretrained_dict.items() if not k.startswith('fc3')} model_dict.update(pretrained_dict) self.load_state_dict(model_dict, strict=False) print(f"✓ Loaded pretrained weights from {pretrained_path}") print(f" Loaded {len(pretrained_dict)} layers") except Exception as e: print(f"Warning: Failed to load pretrained weights: {e}") print("Training from scratch...") def forward(self, xyz): """ Args: xyz: Point cloud [B, N, 3] Returns: logits: Classification logits [B, num_classes] """ B, N, C = xyz.shape # Set Abstraction layers l1_xyz, l1_points = self.sa1(xyz, None) l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) # Global feature x = l3_points.view(B, 1024) # Classification head x = self.drop1(F.relu(self.bn1(self.fc1(x)))) x = self.drop2(F.relu(self.bn2(self.fc2(x)))) x = self.fc3(x) return x def train_epoch(model, dataloader, criterion, optimizer, device): """Train for one epoch.""" model.train() total_loss = 0.0 all_preds = [] all_labels = [] pbar = tqdm(dataloader, desc='Training') for points, labels in pbar: points = points.to(device) labels = labels.to(device) # Forward pass optimizer.zero_grad() logits = model(points) loss = criterion(logits, labels) # Backward pass loss.backward() optimizer.step() # Statistics total_loss += loss.item() preds = torch.argmax(logits, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) # Update progress bar pbar.set_postfix({'loss': f'{loss.item():.4f}'}) avg_loss = total_loss / len(dataloader) accuracy = accuracy_score(all_labels, all_preds) return avg_loss, accuracy def validate(model, dataloader, criterion, device): """Validate the model.""" model.eval() total_loss = 0.0 all_preds = [] all_labels = [] with torch.no_grad(): pbar = tqdm(dataloader, desc='Validation') for points, labels in pbar: points = points.to(device) labels = labels.to(device) # Forward pass logits = model(points) loss = criterion(logits, labels) # Statistics total_loss += loss.item() preds = torch.argmax(logits, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) pbar.set_postfix({'loss': f'{loss.item():.4f}'}) avg_loss = total_loss / len(dataloader) accuracy = accuracy_score(all_labels, all_preds) precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) return avg_loss, accuracy, precision, recall, f1 def main(): parser = argparse.ArgumentParser(description='Train PointNet++ for point cloud classification') parser.add_argument('--data', type=str, required=True, help='Root directory of dataset (containing train/val/test)') parser.add_argument('--num_points', type=int, default=1024, help='Number of points per point cloud') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') parser.add_argument('--output', type=str, default='checkpoints', help='Output directory for checkpoints') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda/cpu)') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') parser.add_argument('--pretrained', type=str, default=None, help='Path to pretrained PointNet++ weights (.pth file)') args = parser.parse_args() # Create output directory output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) # Device device = torch.device(args.device) print(f"Using device: {device}") # Create datasets print("\nLoading datasets...") train_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='train', augment=True) val_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='val', augment=False) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # Create model print("\nCreating PointNet++ model...") use_pretrained = args.pretrained is not None model = PointNetPlusPlus( num_classes=2, num_points=args.num_points, use_pretrained=use_pretrained, pretrained_path=args.pretrained ) model = model.to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = StepLR(optimizer, step_size=30, gamma=0.5) # Resume from checkpoint start_epoch = 0 best_val_acc = 0.0 if args.resume: print(f"Resuming from {args.resume}...") checkpoint = torch.load(args.resume, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] best_val_acc = checkpoint.get('best_val_acc', 0.0) # Training loop print("\nStarting training...") history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': [] } for epoch in range(start_epoch, args.epochs): print(f"\nEpoch {epoch + 1}/{args.epochs}") # Train train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) # Validate val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device) # Update learning rate scheduler.step() # Save history history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc) history['val_precision'].append(val_precision) history['val_recall'].append(val_recall) history['val_f1'].append(val_f1) # Print metrics print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, " f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}") # Save checkpoint checkpoint = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_acc': best_val_acc, 'history': history } # Save latest torch.save(checkpoint, output_dir / 'latest_checkpoint.pt') # Save best if val_acc > best_val_acc: best_val_acc = val_acc checkpoint['best_val_acc'] = best_val_acc torch.save(checkpoint, output_dir / 'best_checkpoint.pt') print(f"✓ Saved best model (val_acc: {val_acc:.4f})") # Save training history with open(output_dir / 'training_history.json', 'w') as f: json.dump(history, f, indent=2) print(f"\nTraining completed!") print(f"Best validation accuracy: {best_val_acc:.4f}") print(f"Checkpoints saved to: {output_dir}") if __name__ == '__main__': main()