Files
FishServer/FishMeasure/pointcloud_classifier/train_pointtransformer.py

532 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Train PointTransformer for fish point cloud quality classification.
This script trains a PointTransformer model to classify fish point clouds
as "good" (high quality) or "bad" (low quality).
Usage:
python train_pointtransformer.py \
--data dataset/ \
--num_points 1024 \
--epochs 100 \
--batch_size 32 \
--lr 0.001 \
--output checkpoints/
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
import open3d as o3d
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import json
from tqdm import tqdm
class PointCloudDataset(Dataset):
"""Dataset for loading point clouds from PLY files."""
def __init__(self, data_dir, num_points=1024, split='train', augment=False):
"""
Args:
data_dir: Root directory containing train/val/test splits
num_points: Number of points to sample from each point cloud
split: 'train', 'val', or 'test'
augment: Whether to apply data augmentation (only for training)
"""
self.data_dir = Path(data_dir)
self.num_points = num_points
self.split = split
self.augment = augment and (split == 'train')
# Load file paths
self.good_files = list((self.data_dir / split / 'good').glob('*.ply'))
self.bad_files = list((self.data_dir / split / 'bad').glob('*.ply'))
# Create labels: 0 for bad, 1 for good
self.files = self.bad_files + self.good_files
self.labels = [0] * len(self.bad_files) + [1] * len(self.good_files)
print(f"Loaded {split} dataset:")
print(f" Good samples: {len(self.good_files)}")
print(f" Bad samples: {len(self.bad_files)}")
print(f" Total: {len(self.files)}")
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file_path = self.files[idx]
label = self.labels[idx]
# Load point cloud
try:
pcd = o3d.io.read_point_cloud(str(file_path))
points = np.asarray(pcd.points)
# Handle empty point clouds
if len(points) == 0:
print(f"Warning: Empty point cloud {file_path}")
points = np.zeros((self.num_points, 3), dtype=np.float32)
# Sample or pad to fixed number of points
points = self._normalize_points(points)
# Data augmentation (only for training)
if self.augment:
points = self._augment_points(points)
# Convert to tensor
points = torch.FloatTensor(points)
label = torch.LongTensor([label])[0]
return points, label
except Exception as e:
print(f"Error loading {file_path}: {e}")
# Return zero point cloud as fallback
points = np.zeros((self.num_points, 3), dtype=np.float32)
points = torch.FloatTensor(points)
label = torch.LongTensor([label])[0]
return points, label
def _normalize_points(self, points):
"""Normalize and sample points to fixed number."""
# Center the point cloud
centroid = points.mean(axis=0)
points = points - centroid
# Scale to unit sphere
max_dist = np.max(np.linalg.norm(points, axis=1))
if max_dist > 0:
points = points / max_dist
# Sample or pad to fixed number of points
n_points = len(points)
if n_points >= self.num_points:
# Randomly sample points
indices = np.random.choice(n_points, self.num_points, replace=False)
points = points[indices]
else:
# Pad with random points from the point cloud
indices = np.random.choice(n_points, self.num_points, replace=True)
points = points[indices]
return points.astype(np.float32)
def _augment_points(self, points):
"""Apply data augmentation."""
# Random rotation around z-axis
if np.random.rand() > 0.5:
angle = np.random.uniform(0, 2 * np.pi)
cos_a, sin_a = np.cos(angle), np.sin(angle)
rotation = np.array([[cos_a, -sin_a, 0],
[sin_a, cos_a, 0],
[0, 0, 1]])
points = points @ rotation.T
# Random scaling
if np.random.rand() > 0.5:
scale = np.random.uniform(0.9, 1.1)
points = points * scale
# Random jitter
if np.random.rand() > 0.5:
jitter = np.random.normal(0, 0.01, points.shape)
points = points + jitter
return points
def knn_points(points, k):
"""
Find k nearest neighbors for each point.
Args:
points: (B, N, 3) point coordinates
k: number of neighbors
Returns:
knn_idx: (B, N, k) indices of k nearest neighbors
"""
B, N, _ = points.shape
k = min(k, N - 1)
# Compute pairwise distances
points_expanded = points.unsqueeze(2) # (B, N, 1, 3)
points_transposed = points.unsqueeze(1) # (B, 1, N, 3)
distances = torch.sum((points_expanded - points_transposed) ** 2, dim=3) # (B, N, N)
# Get k nearest neighbors (excluding self)
_, knn_idx = torch.topk(distances, k + 1, dim=2, largest=False) # (B, N, k+1)
knn_idx = knn_idx[:, :, 1:] # Remove self, (B, N, k)
return knn_idx
class PointTransformerBlock(nn.Module):
"""Simplified Point Transformer block with efficient KNN."""
def __init__(self, dim, k=16):
super().__init__()
self.k = k
self.dim = dim
self.linear_q = nn.Linear(dim, dim)
self.linear_k = nn.Linear(dim, dim)
self.linear_v = nn.Linear(dim, dim)
self.linear_out = nn.Linear(dim, dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(dim * 2, dim)
)
def forward(self, x, pos):
"""
Args:
x: Point features (B, N, dim)
pos: Point positions (B, N, 3)
"""
B, N, _ = x.shape
# Self-attention
q = self.linear_q(x) # (B, N, dim)
k = self.linear_k(x) # (B, N, dim)
v = self.linear_v(x) # (B, N, dim)
# KNN
knn_idx = knn_points(pos, self.k) # (B, N, k)
# Gather neighbors
batch_idx = torch.arange(B, device=x.device).view(B, 1, 1).expand(B, N, self.k)
k_neighbors = k[batch_idx, knn_idx] # (B, N, k, dim)
v_neighbors = v[batch_idx, knn_idx] # (B, N, k, dim)
pos_neighbors = pos[batch_idx, knn_idx] # (B, N, k, 3)
# Positional encoding (distance-based)
pos_diff = pos_neighbors - pos.unsqueeze(2) # (B, N, k, 3)
pos_encoding = torch.norm(pos_diff, dim=3, keepdim=True) # (B, N, k, 1)
# Attention scores
q_expanded = q.unsqueeze(2) # (B, N, 1, dim)
scores = (q_expanded @ k_neighbors.transpose(2, 3)).squeeze(2) # (B, N, k)
scores = scores - pos_encoding.squeeze(3) # Subtract distance
attn_weights = torch.softmax(scores, dim=2) # (B, N, k)
# Apply attention
attn_weights_expanded = attn_weights.unsqueeze(3) # (B, N, k, 1)
attended = (attn_weights_expanded * v_neighbors).sum(dim=2) # (B, N, dim)
# Residual connection
x = x + self.linear_out(attended)
x = self.norm1(x)
# FFN
x = x + self.ffn(x)
x = self.norm2(x)
return x
class SimplePointTransformer(nn.Module):
"""Simplified PointTransformer for classification."""
def __init__(self, num_points=1024, dim=64, num_classes=2):
super().__init__()
self.num_points = num_points
self.dim = dim
# Input projection
self.input_proj = nn.Linear(3, dim)
# Transformer blocks
self.transformer1 = PointTransformerBlock(dim, k=16)
self.transformer2 = PointTransformerBlock(dim, k=16)
self.transformer3 = PointTransformerBlock(dim, k=16)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(dim, num_classes)
)
def forward(self, x):
"""
Args:
x: Point cloud (B, N, 3)
"""
B, N, _ = x.shape
# Project to feature space
pos = x
x = self.input_proj(x)
# Apply transformer blocks
x = self.transformer1(x, pos)
x = self.transformer2(x, pos)
x = self.transformer3(x, pos)
# Global max pooling
x = torch.max(x, dim=1)[0] # (B, dim)
# Classification
logits = self.classifier(x)
return logits
class PointNetPlusPlus(nn.Module):
"""PointNet++ as an alternative model."""
def __init__(self, num_classes=2):
super().__init__()
# Simplified PointNet++ implementation
# For full implementation, consider using torch-geometric
self.mlp1 = nn.Sequential(
nn.Conv1d(3, 64, 1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Conv1d(64, 64, 1),
nn.BatchNorm1d(64),
nn.ReLU()
)
self.mlp2 = nn.Sequential(
nn.Conv1d(64, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Conv1d(128, 1024, 1),
nn.BatchNorm1d(1024),
nn.ReLU()
)
self.classifier = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
"""
Args:
x: Point cloud (B, N, 3)
"""
x = x.transpose(2, 1) # (B, 3, N)
x = self.mlp1(x)
x = self.mlp2(x)
x = torch.max(x, dim=2)[0] # Global max pooling
logits = self.classifier(x)
return logits
def train_epoch(model, dataloader, criterion, optimizer, device):
"""Train for one epoch."""
model.train()
total_loss = 0.0
all_preds = []
all_labels = []
pbar = tqdm(dataloader, desc='Training')
for points, labels in pbar:
points = points.to(device)
labels = labels.to(device)
# Forward pass
optimizer.zero_grad()
logits = model(points)
loss = criterion(logits, labels)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
# Update progress bar
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy
def validate(model, dataloader, criterion, device):
"""Validate the model."""
model.eval()
total_loss = 0.0
all_preds = []
all_labels = []
with torch.no_grad():
pbar = tqdm(dataloader, desc='Validation')
for points, labels in pbar:
points = points.to(device)
labels = labels.to(device)
# Forward pass
logits = model(points)
loss = criterion(logits, labels)
# Statistics
total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
return avg_loss, accuracy, precision, recall, f1
def main():
parser = argparse.ArgumentParser(description='Train PointTransformer for point cloud classification')
parser.add_argument('--data', type=str, required=True, help='Root directory of dataset (containing train/val/test)')
parser.add_argument('--model', type=str, default='pointtransformer', choices=['pointtransformer', 'pointnet++'],
help='Model architecture')
parser.add_argument('--num_points', type=int, default=1024, help='Number of points per point cloud')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
parser.add_argument('--output', type=str, default='checkpoints', help='Output directory for checkpoints')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to use (cuda/cpu)')
parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
args = parser.parse_args()
# Create output directory
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Device
device = torch.device(args.device)
print(f"Using device: {device}")
# Create datasets
print("\nLoading datasets...")
train_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='train', augment=True)
val_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='val', augment=False)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# Create model
print(f"\nCreating {args.model} model...")
if args.model == 'pointtransformer':
model = SimplePointTransformer(num_points=args.num_points, dim=64, num_classes=2)
else: # pointnet++
model = PointNetPlusPlus(num_classes=2)
model = model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
# Resume from checkpoint
start_epoch = 0
best_val_acc = 0.0
if args.resume:
print(f"Resuming from {args.resume}...")
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
best_val_acc = checkpoint.get('best_val_acc', 0.0)
# Training loop
print("\nStarting training...")
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': []
}
for epoch in range(start_epoch, args.epochs):
print(f"\nEpoch {epoch + 1}/{args.epochs}")
# Train
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
# Validate
val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device)
# Update learning rate
scheduler.step()
# Save history
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
history['val_precision'].append(val_precision)
history['val_recall'].append(val_recall)
history['val_f1'].append(val_f1)
# Print metrics
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}")
# Save checkpoint
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_acc': best_val_acc,
'history': history
}
# Save latest
torch.save(checkpoint, output_dir / 'latest_checkpoint.pt')
# Save best
if val_acc > best_val_acc:
best_val_acc = val_acc
checkpoint['best_val_acc'] = best_val_acc
torch.save(checkpoint, output_dir / 'best_checkpoint.pt')
print(f"✓ Saved best model (val_acc: {val_acc:.4f})")
# Save training history
with open(output_dir / 'training_history.json', 'w') as f:
json.dump(history, f, indent=2)
print(f"\nTraining completed!")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Checkpoints saved to: {output_dir}")
if __name__ == '__main__':
main()