574 lines
21 KiB
Python
574 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Train PointNet++ for fish point cloud quality classification.
|
|
|
|
This script trains a PointNet++ model to classify fish point clouds
|
|
as "good" (high quality) or "bad" (low quality).
|
|
|
|
Supports loading pretrained weights if available.
|
|
|
|
Usage:
|
|
python train_pointnet.py \
|
|
--data dataset/ \
|
|
--num_points 1024 \
|
|
--epochs 100 \
|
|
--batch_size 32 \
|
|
--lr 0.001 \
|
|
--output checkpoints/ \
|
|
--pretrained path/to/pretrained_weights.pth
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.optim.lr_scheduler import StepLR
|
|
import open3d as o3d
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
|
import json
|
|
from tqdm import tqdm
|
|
|
|
|
|
class PointCloudDataset(Dataset):
|
|
"""Dataset for loading point clouds from PLY files."""
|
|
|
|
def __init__(self, data_dir, num_points=1024, split='train', augment=False):
|
|
"""
|
|
Args:
|
|
data_dir: Root directory containing train/val/test splits
|
|
num_points: Number of points to sample from each point cloud
|
|
split: 'train', 'val', or 'test'
|
|
augment: Whether to apply data augmentation (only for training)
|
|
"""
|
|
self.data_dir = Path(data_dir)
|
|
self.num_points = num_points
|
|
self.split = split
|
|
self.augment = augment and (split == 'train')
|
|
|
|
# Load file paths
|
|
self.good_files = list((self.data_dir / split / 'good').glob('*.ply'))
|
|
self.bad_files = list((self.data_dir / split / 'bad').glob('*.ply'))
|
|
|
|
# Create labels: 0 for bad, 1 for good
|
|
self.files = self.bad_files + self.good_files
|
|
self.labels = [0] * len(self.bad_files) + [1] * len(self.good_files)
|
|
|
|
print(f"Loaded {split} dataset:")
|
|
print(f" Good samples: {len(self.good_files)}")
|
|
print(f" Bad samples: {len(self.bad_files)}")
|
|
print(f" Total: {len(self.files)}")
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def __getitem__(self, idx):
|
|
file_path = self.files[idx]
|
|
label = self.labels[idx]
|
|
|
|
# Load point cloud
|
|
try:
|
|
pcd = o3d.io.read_point_cloud(str(file_path))
|
|
points = np.asarray(pcd.points)
|
|
|
|
# Handle empty point clouds
|
|
if len(points) == 0:
|
|
print(f"Warning: Empty point cloud {file_path}")
|
|
points = np.zeros((self.num_points, 3), dtype=np.float32)
|
|
|
|
# Sample or pad to fixed number of points
|
|
points = self._normalize_points(points)
|
|
|
|
# Data augmentation (only for training)
|
|
if self.augment:
|
|
points = self._augment_points(points)
|
|
|
|
# Convert to tensor
|
|
points = torch.FloatTensor(points)
|
|
label = torch.LongTensor([label])[0]
|
|
|
|
return points, label
|
|
|
|
except Exception as e:
|
|
print(f"Error loading {file_path}: {e}")
|
|
# Return zero point cloud as fallback
|
|
points = np.zeros((self.num_points, 3), dtype=np.float32)
|
|
points = torch.FloatTensor(points)
|
|
label = torch.LongTensor([label])[0]
|
|
return points, label
|
|
|
|
def _normalize_points(self, points):
|
|
"""Normalize and sample points to fixed number."""
|
|
# Center the point cloud
|
|
centroid = points.mean(axis=0)
|
|
points = points - centroid
|
|
|
|
# Scale to unit sphere
|
|
max_dist = np.max(np.linalg.norm(points, axis=1))
|
|
if max_dist > 0:
|
|
points = points / max_dist
|
|
|
|
# Sample or pad to fixed number of points
|
|
n_points = len(points)
|
|
if n_points >= self.num_points:
|
|
# Randomly sample points
|
|
indices = np.random.choice(n_points, self.num_points, replace=False)
|
|
points = points[indices]
|
|
else:
|
|
# Pad with random points from the point cloud
|
|
indices = np.random.choice(n_points, self.num_points, replace=True)
|
|
points = points[indices]
|
|
|
|
return points.astype(np.float32)
|
|
|
|
def _augment_points(self, points):
|
|
"""Apply data augmentation."""
|
|
# Random rotation around z-axis
|
|
if np.random.rand() > 0.5:
|
|
angle = np.random.uniform(0, 2 * np.pi)
|
|
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
|
rotation = np.array([[cos_a, -sin_a, 0],
|
|
[sin_a, cos_a, 0],
|
|
[0, 0, 1]])
|
|
points = points @ rotation.T
|
|
|
|
# Random scaling
|
|
if np.random.rand() > 0.5:
|
|
scale = np.random.uniform(0.9, 1.1)
|
|
points = points * scale
|
|
|
|
# Random jitter
|
|
if np.random.rand() > 0.5:
|
|
jitter = np.random.normal(0, 0.01, points.shape)
|
|
points = points + jitter
|
|
|
|
return points
|
|
|
|
|
|
class PointNetSetAbstraction(nn.Module):
|
|
"""PointNet++ Set Abstraction layer."""
|
|
|
|
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False):
|
|
super(PointNetSetAbstraction, self).__init__()
|
|
self.npoint = npoint
|
|
self.radius = radius
|
|
self.nsample = nsample
|
|
self.group_all = group_all
|
|
self.mlp_convs = nn.ModuleList()
|
|
self.mlp_bns = nn.ModuleList()
|
|
last_channel = in_channel
|
|
for out_channel in mlp:
|
|
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
|
|
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
|
|
last_channel = out_channel
|
|
|
|
def forward(self, xyz, points):
|
|
"""
|
|
Args:
|
|
xyz: Input points position data, [B, N, 3]
|
|
points: Input points data, [B, N, C]
|
|
Returns:
|
|
new_xyz: Sampled points position data, [B, S, 3]
|
|
new_points: Sample points feature data, [B, S, D']
|
|
"""
|
|
xyz = xyz.permute(0, 2, 1) # [B, 3, N]
|
|
if points is not None:
|
|
points = points.permute(0, 2, 1) # [B, C, N]
|
|
|
|
if self.group_all:
|
|
new_xyz = None
|
|
new_points = torch.cat([xyz, points], dim=1) if points is not None else xyz
|
|
else:
|
|
# Farthest Point Sampling
|
|
new_xyz = self._farthest_point_sample(xyz, self.npoint) # [B, 3, S]
|
|
|
|
# Ball Query
|
|
idx = self._ball_query(self.radius, self.nsample, xyz, new_xyz) # [B, S, nsample]
|
|
|
|
# Group points
|
|
grouped_xyz = self._index_points(xyz, idx) # [B, 3, S, nsample]
|
|
grouped_xyz_norm = grouped_xyz - new_xyz.view(-1, 3, 1, 1) # [B, 3, S, nsample]
|
|
|
|
if points is not None:
|
|
grouped_points = self._index_points(points, idx) # [B, C, S, nsample]
|
|
grouped_points_norm = torch.cat([grouped_points, grouped_xyz_norm], dim=1) # [B, C+3, S, nsample]
|
|
else:
|
|
grouped_points_norm = grouped_xyz_norm
|
|
|
|
# MLP
|
|
for i, conv in enumerate(self.mlp_convs):
|
|
bn = self.mlp_bns[i]
|
|
grouped_points_norm = F.relu(bn(conv(grouped_points_norm)))
|
|
|
|
new_points = torch.max(grouped_points_norm, 3)[0] # [B, D', S]
|
|
new_xyz = new_xyz.permute(0, 2, 1) # [B, S, 3]
|
|
|
|
return new_xyz, new_points
|
|
|
|
def _farthest_point_sample(self, xyz, npoint):
|
|
"""Farthest Point Sampling."""
|
|
device = xyz.device
|
|
B, C, N = xyz.shape
|
|
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
|
|
distance = torch.ones(B, N).to(device) * 1e10
|
|
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
|
|
|
|
for i in range(npoint):
|
|
centroids[:, i] = farthest
|
|
centroid = xyz[:, :, farthest].view(B, C, 1)
|
|
dist = torch.sum((xyz - centroid) ** 2, 1)
|
|
mask = dist < distance
|
|
distance[mask] = dist[mask]
|
|
farthest = torch.max(distance, -1)[1]
|
|
|
|
return xyz[:, :, centroids] # [B, 3, npoint]
|
|
|
|
def _ball_query(self, radius, nsample, xyz, new_xyz):
|
|
"""Ball Query."""
|
|
device = xyz.device
|
|
B, C, N = xyz.shape
|
|
_, _, S = new_xyz.shape
|
|
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
|
|
sqrdists = self._square_distance(new_xyz, xyz)
|
|
group_idx[sqrdists > radius ** 2] = N
|
|
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
|
|
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
|
|
mask = group_idx == N
|
|
group_idx[mask] = group_first[mask]
|
|
return group_idx
|
|
|
|
def _square_distance(self, src, dst):
|
|
"""Calculate squared euclidean distance."""
|
|
B, C, N = src.shape
|
|
_, _, M = dst.shape
|
|
dist = -2 * torch.matmul(src.permute(0, 2, 1), dst)
|
|
dist += torch.sum(src ** 2, dim=1).view(B, N, 1)
|
|
dist += torch.sum(dst ** 2, dim=1).view(B, 1, M)
|
|
return dist
|
|
|
|
def _index_points(self, points, idx):
|
|
"""Index points."""
|
|
device = points.device
|
|
B = points.shape[0]
|
|
view_shape = list(idx.shape)
|
|
view_shape[1:] = [1] * (len(view_shape) - 1)
|
|
repeat_shape = list(idx.shape)
|
|
repeat_shape[0] = 1
|
|
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
|
|
new_points = points[batch_indices, :, idx]
|
|
return new_points
|
|
|
|
|
|
class PointNetPlusPlus(nn.Module):
|
|
"""PointNet++ for classification."""
|
|
|
|
def __init__(self, num_classes=2, num_points=1024, use_pretrained=False, pretrained_path=None):
|
|
super(PointNetPlusPlus, self).__init__()
|
|
self.num_points = num_points
|
|
|
|
# SA1: 512 points
|
|
self.sa1 = PointNetSetAbstraction(
|
|
npoint=512, radius=0.2, nsample=32,
|
|
in_channel=3, mlp=[64, 64, 128]
|
|
)
|
|
|
|
# SA2: 128 points
|
|
self.sa2 = PointNetSetAbstraction(
|
|
npoint=128, radius=0.4, nsample=64,
|
|
in_channel=128 + 3, mlp=[128, 128, 256]
|
|
)
|
|
|
|
# SA3: Global
|
|
self.sa3 = PointNetSetAbstraction(
|
|
npoint=None, radius=None, nsample=None,
|
|
in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True
|
|
)
|
|
|
|
# Classification head
|
|
self.fc1 = nn.Linear(1024, 512)
|
|
self.bn1 = nn.BatchNorm1d(512)
|
|
self.drop1 = nn.Dropout(0.4)
|
|
self.fc2 = nn.Linear(512, 256)
|
|
self.bn2 = nn.BatchNorm1d(256)
|
|
self.drop2 = nn.Dropout(0.4)
|
|
self.fc3 = nn.Linear(256, num_classes)
|
|
|
|
# Load pretrained weights if provided
|
|
if use_pretrained and pretrained_path:
|
|
self._load_pretrained(pretrained_path)
|
|
|
|
def _load_pretrained(self, pretrained_path):
|
|
"""Load pretrained weights."""
|
|
if not os.path.exists(pretrained_path):
|
|
print(f"Warning: Pretrained weights not found at {pretrained_path}")
|
|
print("Training from scratch...")
|
|
return
|
|
|
|
try:
|
|
checkpoint = torch.load(pretrained_path, map_location='cpu')
|
|
|
|
# Handle different checkpoint formats
|
|
if isinstance(checkpoint, dict):
|
|
if 'model_state_dict' in checkpoint:
|
|
state_dict = checkpoint['model_state_dict']
|
|
elif 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
else:
|
|
state_dict = checkpoint
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
# Filter out classification head if num_classes doesn't match
|
|
model_dict = self.state_dict()
|
|
pretrained_dict = {k: v for k, v in state_dict.items()
|
|
if k in model_dict and model_dict[k].shape == v.shape}
|
|
|
|
# Exclude classification head for transfer learning
|
|
pretrained_dict = {k: v for k, v in pretrained_dict.items()
|
|
if not k.startswith('fc3')}
|
|
|
|
model_dict.update(pretrained_dict)
|
|
self.load_state_dict(model_dict, strict=False)
|
|
|
|
print(f"✓ Loaded pretrained weights from {pretrained_path}")
|
|
print(f" Loaded {len(pretrained_dict)} layers")
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to load pretrained weights: {e}")
|
|
print("Training from scratch...")
|
|
|
|
def forward(self, xyz):
|
|
"""
|
|
Args:
|
|
xyz: Point cloud [B, N, 3]
|
|
Returns:
|
|
logits: Classification logits [B, num_classes]
|
|
"""
|
|
B, N, C = xyz.shape
|
|
|
|
# Set Abstraction layers
|
|
l1_xyz, l1_points = self.sa1(xyz, None)
|
|
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
|
|
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
|
|
|
|
# Global feature
|
|
x = l3_points.view(B, 1024)
|
|
|
|
# Classification head
|
|
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
|
|
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
|
|
x = self.fc3(x)
|
|
|
|
return x
|
|
|
|
|
|
def train_epoch(model, dataloader, criterion, optimizer, device):
|
|
"""Train for one epoch."""
|
|
model.train()
|
|
total_loss = 0.0
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
pbar = tqdm(dataloader, desc='Training')
|
|
for points, labels in pbar:
|
|
points = points.to(device)
|
|
labels = labels.to(device)
|
|
|
|
# Forward pass
|
|
optimizer.zero_grad()
|
|
logits = model(points)
|
|
loss = criterion(logits, labels)
|
|
|
|
# Backward pass
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# Statistics
|
|
total_loss += loss.item()
|
|
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
|
all_preds.extend(preds)
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
# Update progress bar
|
|
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
|
|
return avg_loss, accuracy
|
|
|
|
|
|
def validate(model, dataloader, criterion, device):
|
|
"""Validate the model."""
|
|
model.eval()
|
|
total_loss = 0.0
|
|
all_preds = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
pbar = tqdm(dataloader, desc='Validation')
|
|
for points, labels in pbar:
|
|
points = points.to(device)
|
|
labels = labels.to(device)
|
|
|
|
# Forward pass
|
|
logits = model(points)
|
|
loss = criterion(logits, labels)
|
|
|
|
# Statistics
|
|
total_loss += loss.item()
|
|
preds = torch.argmax(logits, dim=1).cpu().numpy()
|
|
all_preds.extend(preds)
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
|
|
|
avg_loss = total_loss / len(dataloader)
|
|
accuracy = accuracy_score(all_labels, all_preds)
|
|
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
|
|
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
|
|
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
|
|
|
|
return avg_loss, accuracy, precision, recall, f1
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train PointNet++ for point cloud classification')
|
|
parser.add_argument('--data', type=str, required=True,
|
|
help='Root directory of dataset (containing train/val/test)')
|
|
parser.add_argument('--num_points', type=int, default=1024,
|
|
help='Number of points per point cloud')
|
|
parser.add_argument('--epochs', type=int, default=100,
|
|
help='Number of training epochs')
|
|
parser.add_argument('--batch_size', type=int, default=32,
|
|
help='Batch size')
|
|
parser.add_argument('--lr', type=float, default=0.001,
|
|
help='Learning rate')
|
|
parser.add_argument('--weight_decay', type=float, default=1e-4,
|
|
help='Weight decay')
|
|
parser.add_argument('--output', type=str, default='checkpoints',
|
|
help='Output directory for checkpoints')
|
|
parser.add_argument('--device', type=str,
|
|
default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
help='Device to use (cuda/cpu)')
|
|
parser.add_argument('--resume', type=str, default=None,
|
|
help='Path to checkpoint to resume from')
|
|
parser.add_argument('--pretrained', type=str, default=None,
|
|
help='Path to pretrained PointNet++ weights (.pth file)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create output directory
|
|
output_dir = Path(args.output)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Device
|
|
device = torch.device(args.device)
|
|
print(f"Using device: {device}")
|
|
|
|
# Create datasets
|
|
print("\nLoading datasets...")
|
|
train_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='train', augment=True)
|
|
val_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='val', augment=False)
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
|
|
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
|
|
# Create model
|
|
print("\nCreating PointNet++ model...")
|
|
use_pretrained = args.pretrained is not None
|
|
model = PointNetPlusPlus(
|
|
num_classes=2,
|
|
num_points=args.num_points,
|
|
use_pretrained=use_pretrained,
|
|
pretrained_path=args.pretrained
|
|
)
|
|
|
|
model = model.to(device)
|
|
|
|
# Loss and optimizer
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
|
scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
|
|
|
|
# Resume from checkpoint
|
|
start_epoch = 0
|
|
best_val_acc = 0.0
|
|
if args.resume:
|
|
print(f"Resuming from {args.resume}...")
|
|
checkpoint = torch.load(args.resume, map_location=device)
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
start_epoch = checkpoint['epoch']
|
|
best_val_acc = checkpoint.get('best_val_acc', 0.0)
|
|
|
|
# Training loop
|
|
print("\nStarting training...")
|
|
history = {
|
|
'train_loss': [], 'train_acc': [],
|
|
'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': []
|
|
}
|
|
|
|
for epoch in range(start_epoch, args.epochs):
|
|
print(f"\nEpoch {epoch + 1}/{args.epochs}")
|
|
|
|
# Train
|
|
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
|
|
|
# Validate
|
|
val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device)
|
|
|
|
# Update learning rate
|
|
scheduler.step()
|
|
|
|
# Save history
|
|
history['train_loss'].append(train_loss)
|
|
history['train_acc'].append(train_acc)
|
|
history['val_loss'].append(val_loss)
|
|
history['val_acc'].append(val_acc)
|
|
history['val_precision'].append(val_precision)
|
|
history['val_recall'].append(val_recall)
|
|
history['val_f1'].append(val_f1)
|
|
|
|
# Print metrics
|
|
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
|
|
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
|
f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}")
|
|
|
|
# Save checkpoint
|
|
checkpoint = {
|
|
'epoch': epoch + 1,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'best_val_acc': best_val_acc,
|
|
'history': history
|
|
}
|
|
|
|
# Save latest
|
|
torch.save(checkpoint, output_dir / 'latest_checkpoint.pt')
|
|
|
|
# Save best
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
checkpoint['best_val_acc'] = best_val_acc
|
|
torch.save(checkpoint, output_dir / 'best_checkpoint.pt')
|
|
print(f"✓ Saved best model (val_acc: {val_acc:.4f})")
|
|
|
|
# Save training history
|
|
with open(output_dir / 'training_history.json', 'w') as f:
|
|
json.dump(history, f, indent=2)
|
|
|
|
print(f"\nTraining completed!")
|
|
print(f"Best validation accuracy: {best_val_acc:.4f}")
|
|
print(f"Checkpoints saved to: {output_dir}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|