Files
FishServer/FishMeasure/pointcloud_classifier/Pointnet_Pointnet2_pytorch/train_classification.py
2026-04-08 19:32:23 +08:00

320 lines
13 KiB
Python
Executable File

"""
Author: Benny
Date: Nov 2019
Modified for Fish Point Cloud Quality Classification
"""
import os
import sys
import torch
import numpy as np
import datetime
import logging
import provider
import importlib
import shutil
import argparse
from pathlib import Path
from tqdm import tqdm
from data_utils.ModelNetDataLoader import ModelNetDataLoader
from data_utils.FishPointCloudDataLoader import FishPointCloudDataLoader
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('training')
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
parser.add_argument('--num_category', default=2, type=int, help='number of classes (2 for good/bad)')
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--data_path', type=str, default=None, help='path to dataset (for fish point cloud)')
parser.add_argument('--use_fish_dataset', action='store_true', default=False, help='use fish point cloud dataset instead of ModelNet')
parser.add_argument('--pretrained', type=str, default=None, help='path to pretrained model checkpoint (e.g., log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth)')
return parser.parse_args()
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace=True
def test(model, loader, num_class=2):
mean_correct = []
class_acc = np.zeros((num_class, 3))
classifier = model.eval()
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
if not args.use_cpu:
points, target = points.cuda(), target.cuda()
points = points.transpose(2, 1)
pred, _ = classifier(points)
pred_choice = pred.data.max(1)[1]
for cat in np.unique(target.cpu()):
cat = int(cat)
if cat < num_class:
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
class_acc[cat, 1] += 1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
class_acc[:, 2] = class_acc[:, 0] / (class_acc[:, 1] + 1e-8) # Avoid division by zero
class_acc = np.mean(class_acc[:, 2])
instance_acc = np.mean(mean_correct)
return instance_acc, class_acc
def main(args):
def log_string(str):
logger.info(str)
print(str)
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
'''CREATE DIR'''
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)
exp_dir = exp_dir.joinpath('classification')
exp_dir.mkdir(exist_ok=True)
if args.log_dir is None:
exp_dir = exp_dir.joinpath(timestr)
else:
exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)
checkpoints_dir = exp_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = exp_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
'''LOG'''
args = parse_args()
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
'''DATA LOADING'''
log_string('Load dataset ...')
if args.use_fish_dataset:
if args.data_path is None:
log_string('Error: --data_path must be specified when using --use_fish_dataset')
return
data_path = args.data_path
log_string(f'Using Fish Point Cloud Dataset from: {data_path}')
train_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
test_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='val', process_data=args.process_data) # Use val as test
else:
data_path = 'data/modelnet40_normal_resampled/'
log_string(f'Using ModelNet Dataset from: {data_path}')
train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
'''MODEL LOADING'''
num_class = args.num_category
model = importlib.import_module(args.model)
shutil.copy('./models/%s.py' % args.model, str(exp_dir))
if os.path.exists('models/pointnet2_utils.py'):
shutil.copy('models/pointnet2_utils.py', str(exp_dir))
if os.path.exists('models/pointnet_utils.py'):
shutil.copy('models/pointnet_utils.py', str(exp_dir))
shutil.copy('./train_classification.py', str(exp_dir))
classifier = model.get_model(num_class, normal_channel=args.use_normals)
criterion = model.get_loss()
classifier.apply(inplace_relu)
log_string(f'Model: {args.model}, Number of classes: {num_class}')
if not args.use_cpu:
classifier = classifier.cuda()
criterion = criterion.cuda()
# Load pretrained model
start_epoch = 0
if args.pretrained:
# Load from specified pretrained path
pretrained_path = Path(args.pretrained)
if not pretrained_path.is_absolute():
# Relative path: try relative to current directory or log directory
if (Path('./log/classification') / pretrained_path).exists():
pretrained_path = Path('./log/classification') / pretrained_path
elif pretrained_path.exists():
pretrained_path = pretrained_path
else:
log_string(f'Warning: Pretrained model not found at {args.pretrained}')
pretrained_path = None
else:
if not pretrained_path.exists():
log_string(f'Warning: Pretrained model not found at {pretrained_path}')
pretrained_path = None
if pretrained_path and pretrained_path.exists():
try:
log_string(f'Loading pretrained model from {pretrained_path}...')
checkpoint = torch.load(str(pretrained_path), map_location='cpu')
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
pretrained_dict = checkpoint['model_state_dict']
elif 'state_dict' in checkpoint:
pretrained_dict = checkpoint['state_dict']
else:
pretrained_dict = checkpoint
else:
pretrained_dict = checkpoint
# Get current model state dict
model_dict = classifier.state_dict()
# Filter out incompatible layers (especially classification head if num_category differs)
# Only load layers that match in shape
compatible_dict = {}
for k, v in pretrained_dict.items():
if k in model_dict:
if model_dict[k].shape == v.shape:
compatible_dict[k] = v
else:
log_string(f' Skipping {k}: shape mismatch (model: {model_dict[k].shape}, pretrained: {v.shape})')
else:
log_string(f' Skipping {k}: not in current model')
# Load compatible weights
model_dict.update(compatible_dict)
classifier.load_state_dict(model_dict, strict=False)
log_string(f'✓ Loaded {len(compatible_dict)}/{len(pretrained_dict)} layers from pretrained model')
if len(compatible_dict) < len(pretrained_dict):
log_string(' Note: Some layers were skipped (e.g., classification head for different num_category)')
except Exception as e:
log_string(f'Error loading pretrained model: {e}')
log_string('Starting training from scratch...')
# Try to load checkpoint from current experiment directory (for resuming training)
if start_epoch == 0:
try:
checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth', map_location='cpu')
start_epoch = checkpoint['epoch']
classifier.load_state_dict(checkpoint['model_state_dict'])
log_string(f'Resumed training from epoch {start_epoch}')
except:
if not args.pretrained:
log_string('No existing model, starting training from scratch...')
if args.optimizer == 'Adam':
optimizer = torch.optim.Adam(
classifier.parameters(),
lr=args.learning_rate,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.decay_rate
)
else:
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
global_epoch = 0
global_step = 0
best_instance_acc = 0.0
best_class_acc = 0.0
'''TRANING'''
logger.info('Start training...')
for epoch in range(start_epoch, args.epoch):
log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
mean_correct = []
classifier = classifier.train()
scheduler.step()
for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
optimizer.zero_grad()
points = points.data.numpy()
points = provider.random_point_dropout(points)
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
points = torch.Tensor(points)
points = points.transpose(2, 1)
if not args.use_cpu:
points, target = points.cuda(), target.cuda()
pred, trans_feat = classifier(points)
loss = criterion(pred, target.long(), trans_feat)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
loss.backward()
optimizer.step()
global_step += 1
train_instance_acc = np.mean(mean_correct)
log_string('Train Instance Accuracy: %f' % train_instance_acc)
with torch.no_grad():
instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)
if (instance_acc >= best_instance_acc):
best_instance_acc = instance_acc
best_epoch = epoch + 1
if (class_acc >= best_class_acc):
best_class_acc = class_acc
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
if (instance_acc >= best_instance_acc):
logger.info('Save model...')
savepath = str(checkpoints_dir) + '/best_model.pth'
log_string('Saving at %s' % savepath)
state = {
'epoch': best_epoch,
'instance_acc': instance_acc,
'class_acc': class_acc,
'model_state_dict': classifier.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, savepath)
global_epoch += 1
logger.info('End of training...')
if __name__ == '__main__':
args = parse_args()
main(args)