532 lines
18 KiB
Python
532 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Train PointTransformer for fish point cloud quality classification.
|
|
|
|
This script trains a PointTransformer model to classify fish point clouds
|
|
as "good" (high quality) or "bad" (low quality).
|
|
|
|
Usage:
|
|
python train_pointtransformer.py \
|
|
--data dataset/ \
|
|
--num_points 1024 \
|
|
--epochs 100 \
|
|
--batch_size 32 \
|
|
--lr 0.001 \
|
|
--output checkpoints/
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.optim.lr_scheduler import StepLR
|
|
import open3d as o3d
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
|
import json
|
|
from tqdm import tqdm
|
|
|
|
|
|
class PointCloudDataset(Dataset):
|
|
"""Dataset for loading point clouds from PLY files."""
|
|
|
|
def __init__(self, data_dir, num_points=1024, split='train', augment=False):
|
|
"""
|
|
Args:
|
|
data_dir: Root directory containing train/val/test splits
|
|
num_points: Number of points to sample from each point cloud
|
|
split: 'train', 'val', or 'test'
|
|
augment: Whether to apply data augmentation (only for training)
|
|
"""
|
|
self.data_dir = Path(data_dir)
|
|
self.num_points = num_points
|
|
self.split = split
|
|
self.augment = augment and (split == 'train')
|
|
|
|
# Load file paths
|
|
self.good_files = list((self.data_dir / split / 'good').glob('*.ply'))
|
|
self.bad_files = list((self.data_dir / split / 'bad').glob('*.ply'))
|
|
|
|
# Create labels: 0 for bad, 1 for good
|
|
self.files = self.bad_files + self.good_files
|
|
self.labels = [0] * len(self.bad_files) + [1] * len(self.good_files)
|
|
|
|
print(f"Loaded {split} dataset:")
|
|
print(f" Good samples: {len(self.good_files)}")
|
|
print(f" Bad samples: {len(self.bad_files)}")
|
|
print(f" Total: {len(self.files)}")
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def __getitem__(self, idx):
|
|
file_path = self.files[idx]
|
|
label = self.labels[idx]
|
|
|
|
# Load point cloud
|
|
try:
|
|
pcd = o3d.io.read_point_cloud(str(file_path))
|
|
points = np.asarray(pcd.points)
|
|
|
|
# Handle empty point clouds
|
|
if len(points) == 0:
|
|
print(f"Warning: Empty point cloud {file_path}")
|
|
points = np.zeros((self.num_points, 3), dtype=np.float32)
|
|
|
|
# Sample or pad to fixed number of points
|
|
points = self._normalize_points(points)
|
|
|
|
# Data augmentation (only for training)
|
|
if self.augment:
|
|
points = self._augment_points(points)
|
|
|
|
# Convert to tensor
|
|
points = torch.FloatTensor(points)
|
|
label = torch.LongTensor([label])[0]
|
|
|
|
return points, label
|
|
|
|
except Exception as e:
|
|
print(f"Error loading {file_path}: {e}")
|
|
# Return zero point cloud as fallback
|
|
points = np.zeros((self.num_points, 3), dtype=np.float32)
|
|
points = torch.FloatTensor(points)
|
|
label = torch.LongTensor([label])[0]
|
|
return points, label
|
|
|
|
def _normalize_points(self, points):
|
|
"""Normalize and sample points to fixed number."""
|
|
# Center the point cloud
|
|
centroid = points.mean(axis=0)
|
|
points = points - centroid
|
|
|
|
# Scale to unit sphere
|
|
max_dist = np.max(np.linalg.norm(points, axis=1))
|
|
if max_dist > 0:
|
|
points = points / max_dist
|
|
|
|
# Sample or pad to fixed number of points
|
|
n_points = len(points)
|
|
if n_points >= self.num_points:
|
|
# Randomly sample points
|
|
indices = np.random.choice(n_points, self.num_points, replace=False)
|
|
points = points[indices]
|
|
else:
|
|
# Pad with random points from the point cloud
|
|
indices = np.random.choice(n_points, self.num_points, replace=True)
|
|
points = points[indices]
|
|
|
|
return points.astype(np.float32)
|
|
|
|
def _augment_points(self, points):
|
|
"""Apply data augmentation."""
|
|
# Random rotation around z-axis
|
|
if np.random.rand() > 0.5:
|
|
angle = np.random.uniform(0, 2 * np.pi)
|
|
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
|
rotation = np.array([[cos_a, -sin_a, 0],
|
|
[sin_a, cos_a, 0],
|
|
[0, 0, 1]])
|
|
points = points @ rotation.T
|
|
|
|
# Random scaling
|
|
if np.random.rand() > 0.5:
|
|
scale = np.random.uniform(0.9, 1.1)
|
|
points = points * scale
|
|
|
|
# Random jitter
|
|
if np.random.rand() > 0.5:
|
|
jitter = np.random.normal(0, 0.01, points.shape)
|
|
points = points + jitter
|
|
|
|
return points
|
|
|
|
|
|
def knn_points(points, k):
|
|
"""
|
|
Find k nearest neighbors for each point.
|
|
Args:
|
|
points: (B, N, 3) point coordinates
|
|
k: number of neighbors
|
|
Returns:
|
|
knn_idx: (B, N, k) indices of k nearest neighbors
|
|
"""
|
|
B, N, _ = points.shape
|
|
k = min(k, N - 1)
|
|
|
|
# Compute pairwise distances
|
|
points_expanded = points.unsqueeze(2) # (B, N, 1, 3)
|
|
points_transposed = points.unsqueeze(1) # (B, 1, N, 3)
|
|
distances = torch.sum((points_expanded - points_transposed) ** 2, dim=3) # (B, N, N)
|
|
|
|
# Get k nearest neighbors (excluding self)
|
|
_, knn_idx = torch.topk(distances, k + 1, dim=2, largest=False) # (B, N, k+1)
|
|
knn_idx = knn_idx[:, :, 1:] # Remove self, (B, N, k)
|
|
|
|
return knn_idx
|
|
|
|
|
|
class PointTransformerBlock(nn.Module):
|
|
"""Simplified Point Transformer block with efficient KNN."""
|
|
|
|
def __init__(self, dim, k=16):
|
|
super().__init__()
|
|
self.k = k
|
|
self.dim = dim
|
|
|
|
self.linear_q = nn.Linear(dim, dim)
|
|
self.linear_k = nn.Linear(dim, dim)
|
|
self.linear_v = nn.Linear(dim, dim)
|
|
self.linear_out = nn.Linear(dim, dim)
|
|
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.ffn = nn.Sequential(
|
|
nn.Linear(dim, dim * 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
nn.Linear(dim * 2, dim)
|
|
)
|
|
|
|
def forward(self, x, pos):
|
|
"""
|
|
Args:
|
|
x: Point features (B, N, dim)
|
|
pos: Point positions (B, N, 3)
|
|
"""
|
|
B, N, _ = x.shape
|
|
|
|
# Self-attention
|
|
q = self.linear_q(x) # (B, N, dim)
|
|
k = self.linear_k(x) # (B, N, dim)
|
|
v = self.linear_v(x) # (B, N, dim)
|
|
|
|
# KNN
|
|
knn_idx = knn_points(pos, self.k) # (B, N, k)
|
|
|
|
# Gather neighbors
|
|
batch_idx = torch.arange(B, device=x.device).view(B, 1, 1).expand(B, N, self.k)
|
|
k_neighbors = k[batch_idx, knn_idx] # (B, N, k, dim)
|
|
v_neighbors = v[batch_idx, knn_idx] # (B, N, k, dim)
|
|
pos_neighbors = pos[batch_idx, knn_idx] # (B, N, k, 3)
|
|
|
|
# Positional encoding (distance-based)
|
|
pos_diff = pos_neighbors - pos.unsqueeze(2) # (B, N, k, 3)
|
|
pos_encoding = torch.norm(pos_diff, dim=3, keepdim=True) # (B, N, k, 1)
|
|
|
|
# Attention scores
|
|
q_expanded = q.unsqueeze(2) # (B, N, 1, dim)
|
|
scores = (q_expanded @ k_neighbors.transpose(2, 3)).squeeze(2) # (B, N, k)
|
|
scores = scores - pos_encoding.squeeze(3) # Subtract distance
|
|
attn_weights = torch.softmax(scores, dim=2) # (B, N, k)
|
|
|
|
# Apply attention
|
|
attn_weights_expanded = attn_weights.unsqueeze(3) # (B, N, k, 1)
|
|
attended = (attn_weights_expanded * v_neighbors).sum(dim=2) # (B, N, dim)
|
|
|
|
# Residual connection
|
|
x = x + self.linear_out(attended)
|
|
x = self.norm1(x)
|
|
|
|
# FFN
|
|
x = x + self.ffn(x)
|
|
x = self.norm2(x)
|
|
|
|
return x
|
|
|
|
|
|
class SimplePointTransformer(nn.Module):
|
|
"""Simplified PointTransformer for classification."""
|
|
|
|
def __init__(self, num_points=1024, dim=64, num_classes=2):
|
|
super().__init__()
|
|
self.num_points = num_points
|
|
self.dim = dim
|
|
|
|
# Input projection
|
|
self.input_proj = nn.Linear(3, dim)
|
|
|
|
# Transformer blocks
|
|
self.transformer1 = PointTransformerBlock(dim, k=16)
|
|
self.transformer2 = PointTransformerBlock(dim, k=16)
|
|
self.transformer3 = PointTransformerBlock(dim, k=16)
|
|
|
|
# Classification head
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(dim, dim * 2),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(dim * 2, dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(dim, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x: Point cloud (B, N, 3)
|
|
"""
|
|
B, N, _ = x.shape
|
|
|
|
# Project to feature space
|
|
pos = x
|
|
x = self.input_proj(x)
|
|
|
|
# Apply transformer blocks
|
|
x = self.transformer1(x, pos)
|
|
x = self.transformer2(x, pos)
|
|
x = self.transformer3(x, pos)
|
|
|
|
# Global max pooling
|
|
x = torch.max(x, dim=1)[0] # (B, dim)
|
|
|
|
# Classification
|
|
logits = self.classifier(x)
|
|
|
|
return logits
|
|
|
|
|
|
class PointNetPlusPlus(nn.Module):
|
|
"""PointNet++ as an alternative model."""
|
|
|
|
def __init__(self, num_classes=2):
|
|
super().__init__()
|
|
# Simplified PointNet++ implementation
|
|
# For full implementation, consider using torch-geometric
|
|
self.mlp1 = nn.Sequential(
|
|
nn.Conv1d(3, 64, 1),
|
|
nn.BatchNorm1d(64),
|
|
nn.ReLU(),
|
|
nn.Conv1d(64, 64, 1),
|
|
nn.BatchNorm1d(64),
|
|
nn.ReLU()
|
|
)
|
|
self.mlp2 = nn.Sequential(
|
|
nn.Conv1d(64, 128, 1),
|
|
nn.BatchNorm1d(128),
|
|
nn.ReLU(),
|
|
nn.Conv1d(128, 1024, 1),
|
|
nn.BatchNorm1d(1024),
|
|
nn.ReLU()
|
|
)
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(1024, 512),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(512, 256),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(256, num_classes)
|
|
)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x: Point cloud (B, N, 3)
|
|
"""
|
|
x = x.transpose(2, 1) # (B, 3, N)
|
|
x = self.mlp1(x)
|
|
x = self.mlp2(x)
|
|
x = torch.max(x, dim=2)[0] # Global max pooling
|
|
logits = self.classifier(x)
|
|
return logits
|
|
|
|
|
|
def train_epoch(model, dataloader, criterion, optimizer, device):
|
|
"""Train for one epoch."""
|
|
model.train()
|
|
total_loss = 0.0
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
pbar = tqdm(dataloader, desc='Training')
|
|
for points, labels in pbar:
|
|
points = points.to(device)
|
|
labels = labels.to(device)
|
|
|
|
# Forward pass
|
|
optimizer.zero_grad()
|
|
logits = model(points)
|
|
loss = criterion(logits, labels)
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Statistics
|
|
total_loss += loss.item()
|
|
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
|
all_preds.extend(preds)
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
# Update progress bar
|
|
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
|
|
return avg_loss, accuracy
|
|
|
|
|
|
def validate(model, dataloader, criterion, device):
|
|
"""Validate the model."""
|
|
model.eval()
|
|
total_loss = 0.0
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
pbar = tqdm(dataloader, desc='Validation')
|
|
for points, labels in pbar:
|
|
points = points.to(device)
|
|
labels = labels.to(device)
|
|
|
|
# Forward pass
|
|
logits = model(points)
|
|
loss = criterion(logits, labels)
|
|
|
|
# Statistics
|
|
total_loss += loss.item()
|
|
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
|
all_preds.extend(preds)
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
|
|
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
|
|
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
|
|
|
|
return avg_loss, accuracy, precision, recall, f1
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train PointTransformer for point cloud classification')
|
|
parser.add_argument('--data', type=str, required=True, help='Root directory of dataset (containing train/val/test)')
|
|
parser.add_argument('--model', type=str, default='pointtransformer', choices=['pointtransformer', 'pointnet++'],
|
|
help='Model architecture')
|
|
parser.add_argument('--num_points', type=int, default=1024, help='Number of points per point cloud')
|
|
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
|
|
parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
|
|
parser.add_argument('--output', type=str, default='checkpoints', help='Output directory for checkpoints')
|
|
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
help='Device to use (cuda/cpu)')
|
|
parser.add_argument('--resume', type=str, default=None, help='Path to checkpoint to resume from')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create output directory
|
|
output_dir = Path(args.output)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Device
|
|
device = torch.device(args.device)
|
|
print(f"Using device: {device}")
|
|
|
|
# Create datasets
|
|
print("\nLoading datasets...")
|
|
train_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='train', augment=True)
|
|
val_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='val', augment=False)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
|
|
# Create model
|
|
print(f"\nCreating {args.model} model...")
|
|
if args.model == 'pointtransformer':
|
|
model = SimplePointTransformer(num_points=args.num_points, dim=64, num_classes=2)
|
|
else: # pointnet++
|
|
model = PointNetPlusPlus(num_classes=2)
|
|
|
|
model = model.to(device)
|
|
|
|
# Loss and optimizer
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
|
|
|
|
# Resume from checkpoint
|
|
start_epoch = 0
|
|
best_val_acc = 0.0
|
|
if args.resume:
|
|
print(f"Resuming from {args.resume}...")
|
|
checkpoint = torch.load(args.resume)
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
start_epoch = checkpoint['epoch']
|
|
best_val_acc = checkpoint.get('best_val_acc', 0.0)
|
|
|
|
# Training loop
|
|
print("\nStarting training...")
|
|
history = {
|
|
'train_loss': [], 'train_acc': [],
|
|
'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': []
|
|
}
|
|
|
|
for epoch in range(start_epoch, args.epochs):
|
|
print(f"\nEpoch {epoch + 1}/{args.epochs}")
|
|
|
|
# Train
|
|
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
|
|
|
# Validate
|
|
val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device)
|
|
|
|
# Update learning rate
|
|
scheduler.step()
|
|
|
|
# Save history
|
|
history['train_loss'].append(train_loss)
|
|
history['train_acc'].append(train_acc)
|
|
history['val_loss'].append(val_loss)
|
|
history['val_acc'].append(val_acc)
|
|
history['val_precision'].append(val_precision)
|
|
history['val_recall'].append(val_recall)
|
|
history['val_f1'].append(val_f1)
|
|
|
|
# Print metrics
|
|
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
|
|
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
|
f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}")
|
|
|
|
# Save checkpoint
|
|
checkpoint = {
|
|
'epoch': epoch + 1,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'best_val_acc': best_val_acc,
|
|
'history': history
|
|
}
|
|
|
|
# Save latest
|
|
torch.save(checkpoint, output_dir / 'latest_checkpoint.pt')
|
|
|
|
# Save best
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
checkpoint['best_val_acc'] = best_val_acc
|
|
torch.save(checkpoint, output_dir / 'best_checkpoint.pt')
|
|
print(f"✓ Saved best model (val_acc: {val_acc:.4f})")
|
|
|
|
# Save training history
|
|
with open(output_dir / 'training_history.json', 'w') as f:
|
|
json.dump(history, f, indent=2)
|
|
|
|
print(f"\nTraining completed!")
|
|
print(f"Best validation accuracy: {best_val_acc:.4f}")
|
|
print(f"Checkpoints saved to: {output_dir}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|