""" Author: Benny Date: Nov 2019 Modified for Fish Point Cloud Quality Classification """ import os import sys import torch import numpy as np import datetime import logging import provider import importlib import shutil import argparse from pathlib import Path from tqdm import tqdm from data_utils.ModelNetDataLoader import ModelNetDataLoader from data_utils.FishPointCloudDataLoader import FishPointCloudDataLoader BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = BASE_DIR sys.path.append(os.path.join(ROOT_DIR, 'models')) def parse_args(): '''PARAMETERS''' parser = argparse.ArgumentParser('training') parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode') parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') parser.add_argument('--batch_size', type=int, default=24, help='batch size in training') parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]') parser.add_argument('--num_category', default=2, type=int, help='number of classes (2 for good/bad)') parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training') parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training') parser.add_argument('--num_point', type=int, default=1024, help='Point Number') parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training') parser.add_argument('--log_dir', type=str, default=None, help='experiment root') parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate') parser.add_argument('--use_normals', action='store_true', default=False, help='use normals') parser.add_argument('--process_data', action='store_true', default=False, help='save data offline') parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling') parser.add_argument('--data_path', type=str, default=None, help='path to dataset (for fish point cloud)') parser.add_argument('--use_fish_dataset', action='store_true', default=False, help='use fish point cloud dataset instead of ModelNet') parser.add_argument('--pretrained', type=str, default=None, help='path to pretrained model checkpoint (e.g., log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth)') return parser.parse_args() def inplace_relu(m): classname = m.__class__.__name__ if classname.find('ReLU') != -1: m.inplace=True def test(model, loader, num_class=2): mean_correct = [] class_acc = np.zeros((num_class, 3)) classifier = model.eval() for j, (points, target) in tqdm(enumerate(loader), total=len(loader)): if not args.use_cpu: points, target = points.cuda(), target.cuda() points = points.transpose(2, 1) pred, _ = classifier(points) pred_choice = pred.data.max(1)[1] for cat in np.unique(target.cpu()): cat = int(cat) if cat < num_class: classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum() class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0]) class_acc[cat, 1] += 1 correct = pred_choice.eq(target.long().data).cpu().sum() mean_correct.append(correct.item() / float(points.size()[0])) class_acc[:, 2] = class_acc[:, 0] / (class_acc[:, 1] + 1e-8) # Avoid division by zero class_acc = np.mean(class_acc[:, 2]) instance_acc = np.mean(mean_correct) return instance_acc, class_acc def main(args): def log_string(str): logger.info(str) print(str) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) exp_dir = Path('./log/') exp_dir.mkdir(exist_ok=True) exp_dir = exp_dir.joinpath('classification') exp_dir.mkdir(exist_ok=True) if args.log_dir is None: exp_dir = exp_dir.joinpath(timestr) else: exp_dir = exp_dir.joinpath(args.log_dir) exp_dir.mkdir(exist_ok=True) checkpoints_dir = exp_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = exp_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) '''DATA LOADING''' log_string('Load dataset ...') if args.use_fish_dataset: if args.data_path is None: log_string('Error: --data_path must be specified when using --use_fish_dataset') return data_path = args.data_path log_string(f'Using Fish Point Cloud Dataset from: {data_path}') train_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='train', process_data=args.process_data) test_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='val', process_data=args.process_data) # Use val as test else: data_path = 'data/modelnet40_normal_resampled/' log_string(f'Using ModelNet Dataset from: {data_path}') train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data) test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data) trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) '''MODEL LOADING''' num_class = args.num_category model = importlib.import_module(args.model) shutil.copy('./models/%s.py' % args.model, str(exp_dir)) if os.path.exists('models/pointnet2_utils.py'): shutil.copy('models/pointnet2_utils.py', str(exp_dir)) if os.path.exists('models/pointnet_utils.py'): shutil.copy('models/pointnet_utils.py', str(exp_dir)) shutil.copy('./train_classification.py', str(exp_dir)) classifier = model.get_model(num_class, normal_channel=args.use_normals) criterion = model.get_loss() classifier.apply(inplace_relu) log_string(f'Model: {args.model}, Number of classes: {num_class}') if not args.use_cpu: classifier = classifier.cuda() criterion = criterion.cuda() # Load pretrained model start_epoch = 0 if args.pretrained: # Load from specified pretrained path pretrained_path = Path(args.pretrained) if not pretrained_path.is_absolute(): # Relative path: try relative to current directory or log directory if (Path('./log/classification') / pretrained_path).exists(): pretrained_path = Path('./log/classification') / pretrained_path elif pretrained_path.exists(): pretrained_path = pretrained_path else: log_string(f'Warning: Pretrained model not found at {args.pretrained}') pretrained_path = None else: if not pretrained_path.exists(): log_string(f'Warning: Pretrained model not found at {pretrained_path}') pretrained_path = None if pretrained_path and pretrained_path.exists(): try: log_string(f'Loading pretrained model from {pretrained_path}...') checkpoint = torch.load(str(pretrained_path), map_location='cpu') # Handle different checkpoint formats if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: pretrained_dict = checkpoint['model_state_dict'] elif 'state_dict' in checkpoint: pretrained_dict = checkpoint['state_dict'] else: pretrained_dict = checkpoint else: pretrained_dict = checkpoint # Get current model state dict model_dict = classifier.state_dict() # Filter out incompatible layers (especially classification head if num_category differs) # Only load layers that match in shape compatible_dict = {} for k, v in pretrained_dict.items(): if k in model_dict: if model_dict[k].shape == v.shape: compatible_dict[k] = v else: log_string(f' Skipping {k}: shape mismatch (model: {model_dict[k].shape}, pretrained: {v.shape})') else: log_string(f' Skipping {k}: not in current model') # Load compatible weights model_dict.update(compatible_dict) classifier.load_state_dict(model_dict, strict=False) log_string(f'✓ Loaded {len(compatible_dict)}/{len(pretrained_dict)} layers from pretrained model') if len(compatible_dict) < len(pretrained_dict): log_string(' Note: Some layers were skipped (e.g., classification head for different num_category)') except Exception as e: log_string(f'Error loading pretrained model: {e}') log_string('Starting training from scratch...') # Try to load checkpoint from current experiment directory (for resuming training) if start_epoch == 0: try: checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth', map_location='cpu') start_epoch = checkpoint['epoch'] classifier.load_state_dict(checkpoint['model_state_dict']) log_string(f'Resumed training from epoch {start_epoch}') except: if not args.pretrained: log_string('No existing model, starting training from scratch...') if args.optimizer == 'Adam': optimizer = torch.optim.Adam( classifier.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate ) else: optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7) global_epoch = 0 global_step = 0 best_instance_acc = 0.0 best_class_acc = 0.0 '''TRANING''' logger.info('Start training...') for epoch in range(start_epoch, args.epoch): log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) mean_correct = [] classifier = classifier.train() scheduler.step() for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): optimizer.zero_grad() points = points.data.numpy() points = provider.random_point_dropout(points) points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3]) points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3]) points = torch.Tensor(points) points = points.transpose(2, 1) if not args.use_cpu: points, target = points.cuda(), target.cuda() pred, trans_feat = classifier(points) loss = criterion(pred, target.long(), trans_feat) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.long().data).cpu().sum() mean_correct.append(correct.item() / float(points.size()[0])) loss.backward() optimizer.step() global_step += 1 train_instance_acc = np.mean(mean_correct) log_string('Train Instance Accuracy: %f' % train_instance_acc) with torch.no_grad(): instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class) if (instance_acc >= best_instance_acc): best_instance_acc = instance_acc best_epoch = epoch + 1 if (class_acc >= best_class_acc): best_class_acc = class_acc log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc)) log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc)) if (instance_acc >= best_instance_acc): logger.info('Save model...') savepath = str(checkpoints_dir) + '/best_model.pth' log_string('Saving at %s' % savepath) state = { 'epoch': best_epoch, 'instance_acc': instance_acc, 'class_acc': class_acc, 'model_state_dict': classifier.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), } torch.save(state, savepath) global_epoch += 1 logger.info('End of training...') if __name__ == '__main__': args = parse_args() main(args)