320 lines
13 KiB
Python
Executable File
320 lines
13 KiB
Python
Executable File
"""
|
|
Author: Benny
|
|
Date: Nov 2019
|
|
Modified for Fish Point Cloud Quality Classification
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import torch
|
|
import numpy as np
|
|
|
|
import datetime
|
|
import logging
|
|
import provider
|
|
import importlib
|
|
import shutil
|
|
import argparse
|
|
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
from data_utils.ModelNetDataLoader import ModelNetDataLoader
|
|
from data_utils.FishPointCloudDataLoader import FishPointCloudDataLoader
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
ROOT_DIR = BASE_DIR
|
|
sys.path.append(os.path.join(ROOT_DIR, 'models'))
|
|
|
|
def parse_args():
|
|
'''PARAMETERS'''
|
|
parser = argparse.ArgumentParser('training')
|
|
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
|
|
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
|
|
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
|
|
parser.add_argument('--model', default='pointnet_cls', help='model name [default: pointnet_cls]')
|
|
parser.add_argument('--num_category', default=2, type=int, help='number of classes (2 for good/bad)')
|
|
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
|
|
parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training')
|
|
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
|
|
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training')
|
|
parser.add_argument('--log_dir', type=str, default=None, help='experiment root')
|
|
parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate')
|
|
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
|
|
parser.add_argument('--process_data', action='store_true', default=False, help='save data offline')
|
|
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
|
|
parser.add_argument('--data_path', type=str, default=None, help='path to dataset (for fish point cloud)')
|
|
parser.add_argument('--use_fish_dataset', action='store_true', default=False, help='use fish point cloud dataset instead of ModelNet')
|
|
parser.add_argument('--pretrained', type=str, default=None, help='path to pretrained model checkpoint (e.g., log/classification/pointnet2_cls_ssg/checkpoints/best_model.pth)')
|
|
return parser.parse_args()
|
|
|
|
|
|
def inplace_relu(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('ReLU') != -1:
|
|
m.inplace=True
|
|
|
|
|
|
def test(model, loader, num_class=2):
|
|
mean_correct = []
|
|
class_acc = np.zeros((num_class, 3))
|
|
classifier = model.eval()
|
|
|
|
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
|
|
|
|
if not args.use_cpu:
|
|
points, target = points.cuda(), target.cuda()
|
|
|
|
points = points.transpose(2, 1)
|
|
pred, _ = classifier(points)
|
|
pred_choice = pred.data.max(1)[1]
|
|
|
|
for cat in np.unique(target.cpu()):
|
|
cat = int(cat)
|
|
if cat < num_class:
|
|
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
|
|
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
|
|
class_acc[cat, 1] += 1
|
|
|
|
correct = pred_choice.eq(target.long().data).cpu().sum()
|
|
mean_correct.append(correct.item() / float(points.size()[0]))
|
|
|
|
class_acc[:, 2] = class_acc[:, 0] / (class_acc[:, 1] + 1e-8) # Avoid division by zero
|
|
class_acc = np.mean(class_acc[:, 2])
|
|
instance_acc = np.mean(mean_correct)
|
|
|
|
return instance_acc, class_acc
|
|
|
|
|
|
def main(args):
|
|
def log_string(str):
|
|
logger.info(str)
|
|
print(str)
|
|
|
|
'''HYPER PARAMETER'''
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
|
|
|
'''CREATE DIR'''
|
|
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
|
|
exp_dir = Path('./log/')
|
|
exp_dir.mkdir(exist_ok=True)
|
|
exp_dir = exp_dir.joinpath('classification')
|
|
exp_dir.mkdir(exist_ok=True)
|
|
if args.log_dir is None:
|
|
exp_dir = exp_dir.joinpath(timestr)
|
|
else:
|
|
exp_dir = exp_dir.joinpath(args.log_dir)
|
|
exp_dir.mkdir(exist_ok=True)
|
|
checkpoints_dir = exp_dir.joinpath('checkpoints/')
|
|
checkpoints_dir.mkdir(exist_ok=True)
|
|
log_dir = exp_dir.joinpath('logs/')
|
|
log_dir.mkdir(exist_ok=True)
|
|
|
|
'''LOG'''
|
|
args = parse_args()
|
|
logger = logging.getLogger("Model")
|
|
logger.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
|
|
file_handler.setLevel(logging.INFO)
|
|
file_handler.setFormatter(formatter)
|
|
logger.addHandler(file_handler)
|
|
log_string('PARAMETER ...')
|
|
log_string(args)
|
|
|
|
'''DATA LOADING'''
|
|
log_string('Load dataset ...')
|
|
|
|
if args.use_fish_dataset:
|
|
if args.data_path is None:
|
|
log_string('Error: --data_path must be specified when using --use_fish_dataset')
|
|
return
|
|
data_path = args.data_path
|
|
log_string(f'Using Fish Point Cloud Dataset from: {data_path}')
|
|
train_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
|
|
test_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='val', process_data=args.process_data) # Use val as test
|
|
else:
|
|
data_path = 'data/modelnet40_normal_resampled/'
|
|
log_string(f'Using ModelNet Dataset from: {data_path}')
|
|
train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
|
|
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
|
|
|
|
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
|
|
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
|
|
'''MODEL LOADING'''
|
|
num_class = args.num_category
|
|
model = importlib.import_module(args.model)
|
|
shutil.copy('./models/%s.py' % args.model, str(exp_dir))
|
|
if os.path.exists('models/pointnet2_utils.py'):
|
|
shutil.copy('models/pointnet2_utils.py', str(exp_dir))
|
|
if os.path.exists('models/pointnet_utils.py'):
|
|
shutil.copy('models/pointnet_utils.py', str(exp_dir))
|
|
shutil.copy('./train_classification.py', str(exp_dir))
|
|
|
|
classifier = model.get_model(num_class, normal_channel=args.use_normals)
|
|
criterion = model.get_loss()
|
|
classifier.apply(inplace_relu)
|
|
|
|
log_string(f'Model: {args.model}, Number of classes: {num_class}')
|
|
|
|
if not args.use_cpu:
|
|
classifier = classifier.cuda()
|
|
criterion = criterion.cuda()
|
|
|
|
# Load pretrained model
|
|
start_epoch = 0
|
|
if args.pretrained:
|
|
# Load from specified pretrained path
|
|
pretrained_path = Path(args.pretrained)
|
|
if not pretrained_path.is_absolute():
|
|
# Relative path: try relative to current directory or log directory
|
|
if (Path('./log/classification') / pretrained_path).exists():
|
|
pretrained_path = Path('./log/classification') / pretrained_path
|
|
elif pretrained_path.exists():
|
|
pretrained_path = pretrained_path
|
|
else:
|
|
log_string(f'Warning: Pretrained model not found at {args.pretrained}')
|
|
pretrained_path = None
|
|
else:
|
|
if not pretrained_path.exists():
|
|
log_string(f'Warning: Pretrained model not found at {pretrained_path}')
|
|
pretrained_path = None
|
|
|
|
if pretrained_path and pretrained_path.exists():
|
|
try:
|
|
log_string(f'Loading pretrained model from {pretrained_path}...')
|
|
checkpoint = torch.load(str(pretrained_path), map_location='cpu')
|
|
|
|
# Handle different checkpoint formats
|
|
if isinstance(checkpoint, dict):
|
|
if 'model_state_dict' in checkpoint:
|
|
pretrained_dict = checkpoint['model_state_dict']
|
|
elif 'state_dict' in checkpoint:
|
|
pretrained_dict = checkpoint['state_dict']
|
|
else:
|
|
pretrained_dict = checkpoint
|
|
else:
|
|
pretrained_dict = checkpoint
|
|
|
|
# Get current model state dict
|
|
model_dict = classifier.state_dict()
|
|
|
|
# Filter out incompatible layers (especially classification head if num_category differs)
|
|
# Only load layers that match in shape
|
|
compatible_dict = {}
|
|
for k, v in pretrained_dict.items():
|
|
if k in model_dict:
|
|
if model_dict[k].shape == v.shape:
|
|
compatible_dict[k] = v
|
|
else:
|
|
log_string(f' Skipping {k}: shape mismatch (model: {model_dict[k].shape}, pretrained: {v.shape})')
|
|
else:
|
|
log_string(f' Skipping {k}: not in current model')
|
|
|
|
# Load compatible weights
|
|
model_dict.update(compatible_dict)
|
|
classifier.load_state_dict(model_dict, strict=False)
|
|
|
|
log_string(f'✓ Loaded {len(compatible_dict)}/{len(pretrained_dict)} layers from pretrained model')
|
|
if len(compatible_dict) < len(pretrained_dict):
|
|
log_string(' Note: Some layers were skipped (e.g., classification head for different num_category)')
|
|
|
|
except Exception as e:
|
|
log_string(f'Error loading pretrained model: {e}')
|
|
log_string('Starting training from scratch...')
|
|
|
|
# Try to load checkpoint from current experiment directory (for resuming training)
|
|
if start_epoch == 0:
|
|
try:
|
|
checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth', map_location='cpu')
|
|
start_epoch = checkpoint['epoch']
|
|
classifier.load_state_dict(checkpoint['model_state_dict'])
|
|
log_string(f'Resumed training from epoch {start_epoch}')
|
|
except:
|
|
if not args.pretrained:
|
|
log_string('No existing model, starting training from scratch...')
|
|
|
|
if args.optimizer == 'Adam':
|
|
optimizer = torch.optim.Adam(
|
|
classifier.parameters(),
|
|
lr=args.learning_rate,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-08,
|
|
weight_decay=args.decay_rate
|
|
)
|
|
else:
|
|
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
|
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)
|
|
global_epoch = 0
|
|
global_step = 0
|
|
best_instance_acc = 0.0
|
|
best_class_acc = 0.0
|
|
|
|
'''TRANING'''
|
|
logger.info('Start training...')
|
|
for epoch in range(start_epoch, args.epoch):
|
|
log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
|
|
mean_correct = []
|
|
classifier = classifier.train()
|
|
|
|
scheduler.step()
|
|
for batch_id, (points, target) in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9):
|
|
optimizer.zero_grad()
|
|
|
|
points = points.data.numpy()
|
|
points = provider.random_point_dropout(points)
|
|
points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])
|
|
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
|
|
points = torch.Tensor(points)
|
|
points = points.transpose(2, 1)
|
|
|
|
if not args.use_cpu:
|
|
points, target = points.cuda(), target.cuda()
|
|
|
|
pred, trans_feat = classifier(points)
|
|
loss = criterion(pred, target.long(), trans_feat)
|
|
pred_choice = pred.data.max(1)[1]
|
|
|
|
correct = pred_choice.eq(target.long().data).cpu().sum()
|
|
mean_correct.append(correct.item() / float(points.size()[0]))
|
|
loss.backward()
|
|
optimizer.step()
|
|
global_step += 1
|
|
|
|
train_instance_acc = np.mean(mean_correct)
|
|
log_string('Train Instance Accuracy: %f' % train_instance_acc)
|
|
|
|
with torch.no_grad():
|
|
instance_acc, class_acc = test(classifier.eval(), testDataLoader, num_class=num_class)
|
|
|
|
if (instance_acc >= best_instance_acc):
|
|
best_instance_acc = instance_acc
|
|
best_epoch = epoch + 1
|
|
|
|
if (class_acc >= best_class_acc):
|
|
best_class_acc = class_acc
|
|
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
|
|
log_string('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
|
|
|
|
if (instance_acc >= best_instance_acc):
|
|
logger.info('Save model...')
|
|
savepath = str(checkpoints_dir) + '/best_model.pth'
|
|
log_string('Saving at %s' % savepath)
|
|
state = {
|
|
'epoch': best_epoch,
|
|
'instance_acc': instance_acc,
|
|
'class_acc': class_acc,
|
|
'model_state_dict': classifier.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
}
|
|
torch.save(state, savepath)
|
|
global_epoch += 1
|
|
|
|
logger.info('End of training...')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
main(args)
|