#!/usr/bin/env python3 """ Train PointTransformer for fish point cloud quality classification. This script trains a PointTransformer model to classify fish point clouds as "good" (high quality) or "bad" (low quality). Usage: python train_pointtransformer.py \ --data dataset/ \ --num_points 1024 \ --epochs 100 \ --batch_size 32 \ --lr 0.001 \ --output checkpoints/ """ import argparse import os import sys from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import StepLR import open3d as o3d from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix import json from tqdm import tqdm class PointCloudDataset(Dataset): """Dataset for loading point clouds from PLY files.""" def __init__(self, data_dir, num_points=1024, split='train', augment=False): """ Args: data_dir: Root directory containing train/val/test splits num_points: Number of points to sample from each point cloud split: 'train', 'val', or 'test' augment: Whether to apply data augmentation (only for training) """ self.data_dir = Path(data_dir) self.num_points = num_points self.split = split self.augment = augment and (split == 'train') # Load file paths self.good_files = list((self.data_dir / split / 'good').glob('*.ply')) self.bad_files = list((self.data_dir / split / 'bad').glob('*.ply')) # Create labels: 0 for bad, 1 for good self.files = self.bad_files + self.good_files self.labels = [0] * len(self.bad_files) + [1] * len(self.good_files) print(f"Loaded {split} dataset:") print(f" Good samples: {len(self.good_files)}") print(f" Bad samples: {len(self.bad_files)}") print(f" Total: {len(self.files)}") def __len__(self): return len(self.files) def __getitem__(self, idx): file_path = self.files[idx] label = self.labels[idx] # Load point cloud try: pcd = o3d.io.read_point_cloud(str(file_path)) points = np.asarray(pcd.points) # Handle empty point clouds if len(points) == 0: print(f"Warning: Empty point cloud {file_path}") points = np.zeros((self.num_points, 3), dtype=np.float32) # Sample or pad to fixed number of points points = self._normalize_points(points) # Data augmentation (only for training) if self.augment: points = self._augment_points(points) # Convert to tensor points = torch.FloatTensor(points) label = torch.LongTensor([label])[0] return points, label except Exception as e: print(f"Error loading {file_path}: {e}") # Return zero point cloud as fallback points = np.zeros((self.num_points, 3), dtype=np.float32) points = torch.FloatTensor(points) label = torch.LongTensor([label])[0] return points, label def _normalize_points(self, points): """Normalize and sample points to fixed number.""" # Center the point cloud centroid = points.mean(axis=0) points = points - centroid # Scale to unit sphere max_dist = np.max(np.linalg.norm(points, axis=1)) if max_dist > 0: points = points / max_dist # Sample or pad to fixed number of points n_points = len(points) if n_points >= self.num_points: # Randomly sample points indices = np.random.choice(n_points, self.num_points, replace=False) points = points[indices] else: # Pad with random points from the point cloud indices = np.random.choice(n_points, self.num_points, replace=True) points = points[indices] return points.astype(np.float32) def _augment_points(self, points): """Apply data augmentation.""" # Random rotation around z-axis if np.random.rand() > 0.5: angle = np.random.uniform(0, 2 * np.pi) cos_a, sin_a = np.cos(angle), np.sin(angle) rotation = np.array([[cos_a, -sin_a, 0], [sin_a, cos_a, 0], [0, 0, 1]]) points = points @ rotation.T # Random scaling if np.random.rand() > 0.5: scale = np.random.uniform(0.9, 1.1) points = points * scale # Random jitter if np.random.rand() > 0.5: jitter = np.random.normal(0, 0.01, points.shape) points = points + jitter return points def knn_points(points, k): """ Find k nearest neighbors for each point. Args: points: (B, N, 3) point coordinates k: number of neighbors Returns: knn_idx: (B, N, k) indices of k nearest neighbors """ B, N, _ = points.shape k = min(k, N - 1) # Compute pairwise distances points_expanded = points.unsqueeze(2) # (B, N, 1, 3) points_transposed = points.unsqueeze(1) # (B, 1, N, 3) distances = torch.sum((points_expanded - points_transposed) ** 2, dim=3) # (B, N, N) # Get k nearest neighbors (excluding self) _, knn_idx = torch.topk(distances, k + 1, dim=2, largest=False) # (B, N, k+1) knn_idx = knn_idx[:, :, 1:] # Remove self, (B, N, k) return knn_idx class PointTransformerBlock(nn.Module): """Simplified Point Transformer block with efficient KNN.""" def __init__(self, dim, k=16): super().__init__() self.k = k self.dim = dim self.linear_q = nn.Linear(dim, dim) self.linear_k = nn.Linear(dim, dim) self.linear_v = nn.Linear(dim, dim) self.linear_out = nn.Linear(dim, dim) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.ffn = nn.Sequential( nn.Linear(dim, dim * 2), nn.ReLU(), nn.Dropout(0.1), nn.Linear(dim * 2, dim) ) def forward(self, x, pos): """ Args: x: Point features (B, N, dim) pos: Point positions (B, N, 3) """ B, N, _ = x.shape # Self-attention q = self.linear_q(x) # (B, N, dim) k = self.linear_k(x) # (B, N, dim) v = self.linear_v(x) # (B, N, dim) # KNN knn_idx = knn_points(pos, self.k) # (B, N, k) # Gather neighbors batch_idx = torch.arange(B, device=x.device).view(B, 1, 1).expand(B, N, self.k) k_neighbors = k[batch_idx, knn_idx] # (B, N, k, dim) v_neighbors = v[batch_idx, knn_idx] # (B, N, k, dim) pos_neighbors = pos[batch_idx, knn_idx] # (B, N, k, 3) # Positional encoding (distance-based) pos_diff = pos_neighbors - pos.unsqueeze(2) # (B, N, k, 3) pos_encoding = torch.norm(pos_diff, dim=3, keepdim=True) # (B, N, k, 1) # Attention scores q_expanded = q.unsqueeze(2) # (B, N, 1, dim) scores = (q_expanded @ k_neighbors.transpose(2, 3)).squeeze(2) # (B, N, k) scores = scores - pos_encoding.squeeze(3) # Subtract distance attn_weights = torch.softmax(scores, dim=2) # (B, N, k) # Apply attention attn_weights_expanded = attn_weights.unsqueeze(3) # (B, N, k, 1) attended = (attn_weights_expanded * v_neighbors).sum(dim=2) # (B, N, dim) # Residual connection x = x + self.linear_out(attended) x = self.norm1(x) # FFN x = x + self.ffn(x) x = self.norm2(x) return x class SimplePointTransformer(nn.Module): """Simplified PointTransformer for classification.""" def __init__(self, num_points=1024, dim=64, num_classes=2): super().__init__() self.num_points = num_points self.dim = dim # Input projection self.input_proj = nn.Linear(3, dim) # Transformer blocks self.transformer1 = PointTransformerBlock(dim, k=16) self.transformer2 = PointTransformerBlock(dim, k=16) self.transformer3 = PointTransformerBlock(dim, k=16) # Classification head self.classifier = nn.Sequential( nn.Linear(dim, dim * 2), nn.ReLU(), nn.Dropout(0.5), nn.Linear(dim * 2, dim), nn.ReLU(), nn.Dropout(0.5), nn.Linear(dim, num_classes) ) def forward(self, x): """ Args: x: Point cloud (B, N, 3) """ B, N, _ = x.shape # Project to feature space pos = x x = self.input_proj(x) # Apply transformer blocks x = self.transformer1(x, pos) x = self.transformer2(x, pos) x = self.transformer3(x, pos) # Global max pooling x = torch.max(x, dim=1)[0] # (B, dim) # Classification logits = self.classifier(x) return logits class PointNetPlusPlus(nn.Module): """PointNet++ as an alternative model.""" def __init__(self, num_classes=2): super().__init__() # Simplified PointNet++ implementation # For full implementation, consider using torch-geometric self.mlp1 = nn.Sequential( nn.Conv1d(3, 64, 1), nn.BatchNorm1d(64), nn.ReLU(), nn.Conv1d(64, 64, 1), nn.BatchNorm1d(64), nn.ReLU() ) self.mlp2 = nn.Sequential( nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), nn.ReLU(), nn.Conv1d(128, 1024, 1), nn.BatchNorm1d(1024), nn.ReLU() ) self.classifier = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def forward(self, x): """ Args: x: Point cloud (B, N, 3) """ x = x.transpose(2, 1) # (B, 3, N) x = self.mlp1(x) x = self.mlp2(x) x = torch.max(x, dim=2)[0] # Global max pooling logits = self.classifier(x) return logits def train_epoch(model, dataloader, criterion, optimizer, device): """Train for one epoch.""" model.train() total_loss = 0.0 all_preds = [] all_labels = [] pbar = tqdm(dataloader, desc='Training') for points, labels in pbar: points = points.to(device) labels = labels.to(device) # Forward pass optimizer.zero_grad() logits = model(points) loss = criterion(logits, labels) # Backward pass loss.backward() optimizer.step() # Statistics total_loss += loss.item() preds = torch.argmax(logits, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) # Update progress bar pbar.set_postfix({'loss': f'{loss.item():.4f}'}) avg_loss = total_loss / len(dataloader) accuracy = accuracy_score(all_labels, all_preds) return avg_loss, accuracy def validate(model, dataloader, criterion, device): """Validate the model.""" model.eval() total_loss = 0.0 all_preds = [] all_labels = [] with torch.no_grad(): pbar = tqdm(dataloader, desc='Validation') for points, labels in pbar: points = points.to(device) labels = labels.to(device) # Forward pass logits = model(points) loss = criterion(logits, labels) # Statistics total_loss += loss.item() preds = torch.argmax(logits, dim=1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels.cpu().numpy()) pbar.set_postfix({'loss': f'{loss.item():.4f}'}) avg_loss = total_loss / len(dataloader) accuracy = accuracy_score(all_labels, all_preds) precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) return avg_loss, accuracy, precision, recall, f1 def main(): parser = argparse.ArgumentParser(description='Train PointTransformer for point cloud classification') parser.add_argument('--data', type=str, required=True, help='Root directory of dataset (containing train/val/test)') parser.add_argument('--model', type=str, default='pointtransformer', choices=['pointtransformer', 'pointnet++'], help='Model architecture') parser.add_argument('--num_points', type=int, default=1024, help='Number of points per point cloud') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') parser.add_argument('--output', type=str, default='checkpoints', help='Output directory for checkpoints') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use (cuda/cpu)') parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from') args = parser.parse_args() # Create output directory output_dir = Path(args.output) output_dir.mkdir(parents=True, exist_ok=True) # Device device = torch.device(args.device) print(f"Using device: {device}") # Create datasets print("\nLoading datasets...") train_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='train', augment=True) val_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='val', augment=False) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) # Create model print(f"\nCreating {args.model} model...") if args.model == 'pointtransformer': model = SimplePointTransformer(num_points=args.num_points, dim=64, num_classes=2) else: # pointnet++ model = PointNetPlusPlus(num_classes=2) model = model.to(device) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = StepLR(optimizer, step_size=30, gamma=0.5) # Resume from checkpoint start_epoch = 0 best_val_acc = 0.0 if args.resume: print(f"Resuming from {args.resume}...") checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] best_val_acc = checkpoint.get('best_val_acc', 0.0) # Training loop print("\nStarting training...") history = { 'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': [] } for epoch in range(start_epoch, args.epochs): print(f"\nEpoch {epoch + 1}/{args.epochs}") # Train train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) # Validate val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device) # Update learning rate scheduler.step() # Save history history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc) history['val_precision'].append(val_precision) history['val_recall'].append(val_recall) history['val_f1'].append(val_f1) # Print metrics print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, " f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}") # Save checkpoint checkpoint = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_val_acc': best_val_acc, 'history': history } # Save latest torch.save(checkpoint, output_dir / 'latest_checkpoint.pt') # Save best if val_acc > best_val_acc: best_val_acc = val_acc checkpoint['best_val_acc'] = best_val_acc torch.save(checkpoint, output_dir / 'best_checkpoint.pt') print(f"✓ Saved best model (val_acc: {val_acc:.4f})") # Save training history with open(output_dir / 'training_history.json', 'w') as f: json.dump(history, f, indent=2) print(f"\nTraining completed!") print(f"Best validation accuracy: {best_val_acc:.4f}") print(f"Checkpoints saved to: {output_dir}") if __name__ == '__main__': main()