Files
FishServer/FishMeasure/pointcloud_classifier/train_pointnet.py
2026-05-06 15:59:38 +08:00

574 lines
21 KiB
Python

#!/usr/bin/env python3
"""
Train PointNet++ for fish point cloud quality classification.
This script trains a PointNet++ model to classify fish point clouds
as "good" (high quality) or "bad" (low quality).
Supports loading pretrained weights if available.
Usage:
python train_pointnet.py \
--data dataset/ \
--num_points 1024 \
--epochs 100 \
--batch_size 32 \
--lr 0.001 \
--output checkpoints/ \
--pretrained path/to/pretrained_weights.pth
"""
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
import open3d as o3d
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import json
from tqdm import tqdm
class PointCloudDataset(Dataset):
"""Dataset for loading point clouds from PLY files."""
def __init__(self, data_dir, num_points=1024, split='train', augment=False):
"""
Args:
data_dir: Root directory containing train/val/test splits
num_points: Number of points to sample from each point cloud
split: 'train', 'val', or 'test'
augment: Whether to apply data augmentation (only for training)
"""
self.data_dir = Path(data_dir)
self.num_points = num_points
self.split = split
self.augment = augment and (split == 'train')
# Load file paths
self.good_files = list((self.data_dir / split / 'good').glob('*.ply'))
self.bad_files = list((self.data_dir / split / 'bad').glob('*.ply'))
# Create labels: 0 for bad, 1 for good
self.files = self.bad_files + self.good_files
self.labels = [0] * len(self.bad_files) + [1] * len(self.good_files)
print(f"Loaded {split} dataset:")
print(f" Good samples: {len(self.good_files)}")
print(f" Bad samples: {len(self.bad_files)}")
print(f" Total: {len(self.files)}")
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
file_path = self.files[idx]
label = self.labels[idx]
# Load point cloud
try:
pcd = o3d.io.read_point_cloud(str(file_path))
points = np.asarray(pcd.points)
# Handle empty point clouds
if len(points) == 0:
print(f"Warning: Empty point cloud {file_path}")
points = np.zeros((self.num_points, 3), dtype=np.float32)
# Sample or pad to fixed number of points
points = self._normalize_points(points)
# Data augmentation (only for training)
if self.augment:
points = self._augment_points(points)
# Convert to tensor
points = torch.FloatTensor(points)
label = torch.LongTensor([label])[0]
return points, label
except Exception as e:
print(f"Error loading {file_path}: {e}")
# Return zero point cloud as fallback
points = np.zeros((self.num_points, 3), dtype=np.float32)
points = torch.FloatTensor(points)
label = torch.LongTensor([label])[0]
return points, label
def _normalize_points(self, points):
"""Normalize and sample points to fixed number."""
# Center the point cloud
centroid = points.mean(axis=0)
points = points - centroid
# Scale to unit sphere
max_dist = np.max(np.linalg.norm(points, axis=1))
if max_dist > 0:
points = points / max_dist
# Sample or pad to fixed number of points
n_points = len(points)
if n_points >= self.num_points:
# Randomly sample points
indices = np.random.choice(n_points, self.num_points, replace=False)
points = points[indices]
else:
# Pad with random points from the point cloud
indices = np.random.choice(n_points, self.num_points, replace=True)
points = points[indices]
return points.astype(np.float32)
def _augment_points(self, points):
"""Apply data augmentation."""
# Random rotation around z-axis
if np.random.rand() > 0.5:
angle = np.random.uniform(0, 2 * np.pi)
cos_a, sin_a = np.cos(angle), np.sin(angle)
rotation = np.array([[cos_a, -sin_a, 0],
[sin_a, cos_a, 0],
[0, 0, 1]])
points = points @ rotation.T
# Random scaling
if np.random.rand() > 0.5:
scale = np.random.uniform(0.9, 1.1)
points = points * scale
# Random jitter
if np.random.rand() > 0.5:
jitter = np.random.normal(0, 0.01, points.shape)
points = points + jitter
return points
class PointNetSetAbstraction(nn.Module):
"""PointNet++ Set Abstraction layer."""
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.group_all = group_all
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
def forward(self, xyz, points):
"""
Args:
xyz: Input points position data, [B, N, 3]
points: Input points data, [B, N, C]
Returns:
new_xyz: Sampled points position data, [B, S, 3]
new_points: Sample points feature data, [B, S, D']
"""
xyz = xyz.permute(0, 2, 1) # [B, 3, N]
if points is not None:
points = points.permute(0, 2, 1) # [B, C, N]
if self.group_all:
new_xyz = None
new_points = torch.cat([xyz, points], dim=1) if points is not None else xyz
else:
# Farthest Point Sampling
new_xyz = self._farthest_point_sample(xyz, self.npoint) # [B, 3, S]
# Ball Query
idx = self._ball_query(self.radius, self.nsample, xyz, new_xyz) # [B, S, nsample]
# Group points
grouped_xyz = self._index_points(xyz, idx) # [B, 3, S, nsample]
grouped_xyz_norm = grouped_xyz - new_xyz.view(-1, 3, 1, 1) # [B, 3, S, nsample]
if points is not None:
grouped_points = self._index_points(points, idx) # [B, C, S, nsample]
grouped_points_norm = torch.cat([grouped_points, grouped_xyz_norm], dim=1) # [B, C+3, S, nsample]
else:
grouped_points_norm = grouped_xyz_norm
# MLP
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
grouped_points_norm = F.relu(bn(conv(grouped_points_norm)))
new_points = torch.max(grouped_points_norm, 3)[0] # [B, D', S]
new_xyz = new_xyz.permute(0, 2, 1) # [B, S, 3]
return new_xyz, new_points
def _farthest_point_sample(self, xyz, npoint):
"""Farthest Point Sampling."""
device = xyz.device
B, C, N = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[:, :, farthest].view(B, C, 1)
dist = torch.sum((xyz - centroid) ** 2, 1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return xyz[:, :, centroids] # [B, 3, npoint]
def _ball_query(self, radius, nsample, xyz, new_xyz):
"""Ball Query."""
device = xyz.device
B, C, N = xyz.shape
_, _, S = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = self._square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def _square_distance(self, src, dst):
"""Calculate squared euclidean distance."""
B, C, N = src.shape
_, _, M = dst.shape
dist = -2 * torch.matmul(src.permute(0, 2, 1), dst)
dist += torch.sum(src ** 2, dim=1).view(B, N, 1)
dist += torch.sum(dst ** 2, dim=1).view(B, 1, M)
return dist
def _index_points(self, points, idx):
"""Index points."""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, :, idx]
return new_points
class PointNetPlusPlus(nn.Module):
"""PointNet++ for classification."""
def __init__(self, num_classes=2, num_points=1024, use_pretrained=False, pretrained_path=None):
super(PointNetPlusPlus, self).__init__()
self.num_points = num_points
# SA1: 512 points
self.sa1 = PointNetSetAbstraction(
npoint=512, radius=0.2, nsample=32,
in_channel=3, mlp=[64, 64, 128]
)
# SA2: 128 points
self.sa2 = PointNetSetAbstraction(
npoint=128, radius=0.4, nsample=64,
in_channel=128 + 3, mlp=[128, 128, 256]
)
# SA3: Global
self.sa3 = PointNetSetAbstraction(
npoint=None, radius=None, nsample=None,
in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True
)
# Classification head
self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(256, num_classes)
# Load pretrained weights if provided
if use_pretrained and pretrained_path:
self._load_pretrained(pretrained_path)
def _load_pretrained(self, pretrained_path):
"""Load pretrained weights."""
if not os.path.exists(pretrained_path):
print(f"Warning: Pretrained weights not found at {pretrained_path}")
print("Training from scratch...")
return
try:
checkpoint = torch.load(pretrained_path, map_location='cpu')
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
else:
state_dict = checkpoint
# Filter out classification head if num_classes doesn't match
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items()
if k in model_dict and model_dict[k].shape == v.shape}
# Exclude classification head for transfer learning
pretrained_dict = {k: v for k, v in pretrained_dict.items()
if not k.startswith('fc3')}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict, strict=False)
print(f"✓ Loaded pretrained weights from {pretrained_path}")
print(f" Loaded {len(pretrained_dict)} layers")
except Exception as e:
print(f"Warning: Failed to load pretrained weights: {e}")
print("Training from scratch...")
def forward(self, xyz):
"""
Args:
xyz: Point cloud [B, N, 3]
Returns:
logits: Classification logits [B, num_classes]
"""
B, N, C = xyz.shape
# Set Abstraction layers
l1_xyz, l1_points = self.sa1(xyz, None)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
# Global feature
x = l3_points.view(B, 1024)
# Classification head
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
x = self.fc3(x)
return x
def train_epoch(model, dataloader, criterion, optimizer, device):
"""Train for one epoch."""
model.train()
total_loss = 0.0
all_preds = []
all_labels = []
pbar = tqdm(dataloader, desc='Training')
for points, labels in pbar:
points = points.to(device)
labels = labels.to(device)
# Forward pass
optimizer.zero_grad()
logits = model(points)
loss = criterion(logits, labels)
# Backward pass
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
# Update progress bar
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(all_labels, all_preds)
return avg_loss, accuracy
def validate(model, dataloader, criterion, device):
"""Validate the model."""
model.eval()
total_loss = 0.0
all_preds = []
all_labels = []
with torch.no_grad():
pbar = tqdm(dataloader, desc='Validation')
for points, labels in pbar:
points = points.to(device)
labels = labels.to(device)
# Forward pass
logits = model(points)
loss = criterion(logits, labels)
# Statistics
total_loss += loss.item()
preds = torch.argmax(logits, dim=1).cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.cpu().numpy())
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
avg_loss = total_loss / len(dataloader)
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
return avg_loss, accuracy, precision, recall, f1
def main():
parser = argparse.ArgumentParser(description='Train PointNet++ for point cloud classification')
parser.add_argument('--data', type=str, required=True,
help='Root directory of dataset (containing train/val/test)')
parser.add_argument('--num_points', type=int, default=1024,
help='Number of points per point cloud')
parser.add_argument('--epochs', type=int, default=100,
help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=32,
help='Batch size')
parser.add_argument('--lr', type=float, default=0.001,
help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='Weight decay')
parser.add_argument('--output', type=str, default='checkpoints',
help='Output directory for checkpoints')
parser.add_argument('--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to use (cuda/cpu)')
parser.add_argument('--resume', type=str, default=None,
help='Path to checkpoint to resume from')
parser.add_argument('--pretrained', type=str, default=None,
help='Path to pretrained PointNet++ weights (.pth file)')
args = parser.parse_args()
# Create output directory
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Device
device = torch.device(args.device)
print(f"Using device: {device}")
# Create datasets
print("\nLoading datasets...")
train_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='train', augment=True)
val_dataset = PointCloudDataset(args.data, num_points=args.num_points, split='val', augment=False)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
# Create model
print("\nCreating PointNet++ model...")
use_pretrained = args.pretrained is not None
model = PointNetPlusPlus(
num_classes=2,
num_points=args.num_points,
use_pretrained=use_pretrained,
pretrained_path=args.pretrained
)
model = model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
# Resume from checkpoint
start_epoch = 0
best_val_acc = 0.0
if args.resume:
print(f"Resuming from {args.resume}...")
checkpoint = torch.load(args.resume, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
best_val_acc = checkpoint.get('best_val_acc', 0.0)
# Training loop
print("\nStarting training...")
history = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': [], 'val_precision': [], 'val_recall': [], 'val_f1': []
}
for epoch in range(start_epoch, args.epochs):
print(f"\nEpoch {epoch + 1}/{args.epochs}")
# Train
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
# Validate
val_loss, val_acc, val_precision, val_recall, val_f1 = validate(model, val_loader, criterion, device)
# Update learning rate
scheduler.step()
# Save history
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['val_loss'].append(val_loss)
history['val_acc'].append(val_acc)
history['val_precision'].append(val_precision)
history['val_recall'].append(val_recall)
history['val_f1'].append(val_f1)
# Print metrics
print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1: {val_f1:.4f}")
# Save checkpoint
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_acc': best_val_acc,
'history': history
}
# Save latest
torch.save(checkpoint, output_dir / 'latest_checkpoint.pt')
# Save best
if val_acc > best_val_acc:
best_val_acc = val_acc
checkpoint['best_val_acc'] = best_val_acc
torch.save(checkpoint, output_dir / 'best_checkpoint.pt')
print(f"✓ Saved best model (val_acc: {val_acc:.4f})")
# Save training history
with open(output_dir / 'training_history.json', 'w') as f:
json.dump(history, f, indent=2)
print(f"\nTraining completed!")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Checkpoints saved to: {output_dir}")
if __name__ == '__main__':
main()