488 lines
21 KiB
Python
488 lines
21 KiB
Python
"""
|
|
Author: Benny
|
|
Date: Nov 2019
|
|
Modified for Fish Point Cloud Quality Classification
|
|
"""
|
|
from data_utils.ModelNetDataLoader import ModelNetDataLoader
|
|
from data_utils.FishPointCloudDataLoader import FishPointCloudDataLoader
|
|
import argparse
|
|
import numpy as np
|
|
import os
|
|
import torch
|
|
import logging
|
|
from tqdm import tqdm
|
|
import sys
|
|
import importlib
|
|
from pathlib import Path
|
|
import open3d as o3d
|
|
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
ROOT_DIR = BASE_DIR
|
|
sys.path.append(os.path.join(ROOT_DIR, 'models'))
|
|
|
|
|
|
def parse_args():
|
|
'''PARAMETERS'''
|
|
parser = argparse.ArgumentParser('Testing Point Cloud Classification')
|
|
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
|
|
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
|
|
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
|
|
parser.add_argument('--num_category', default=2, type=int, help='number of classes (2 for good/bad)')
|
|
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
|
|
parser.add_argument('--log_dir', type=str, default='fish_pointnet2_finetune',
|
|
help='Experiment root (default: fish_pointnet2_finetune)')
|
|
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
|
|
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
|
|
parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
|
|
parser.add_argument('--data_path', type=str, default=None, help='path to test dataset (for fish point cloud)')
|
|
parser.add_argument('--use_fish_dataset', action='store_true', default=False, help='use fish point cloud dataset instead of ModelNet')
|
|
|
|
# Input options (mutually exclusive)
|
|
input_group = parser.add_mutually_exclusive_group()
|
|
input_group.add_argument('--test_file', type=str, default=None,
|
|
help='path to single PLY file for classification')
|
|
input_group.add_argument('--test_folder', type=str, default=None,
|
|
help='path to folder containing PLY files for batch classification')
|
|
input_group.add_argument('--test_files', type=str, nargs='+', default=None,
|
|
help='list of PLY file paths for batch classification')
|
|
|
|
parser.add_argument('--output', type=str, default=None,
|
|
help='output JSON file to save classification results (required for file/folder classification)')
|
|
return parser.parse_args()
|
|
|
|
|
|
def test(model, loader, num_class=2, vote_num=1, use_cpu=False, dataset=None, save_results=False):
|
|
"""Test model on dataset with labels."""
|
|
mean_correct = []
|
|
classifier = model.eval()
|
|
class_acc = np.zeros((num_class, 3))
|
|
all_results = [] # Store detailed results
|
|
|
|
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
|
|
if not use_cpu:
|
|
points, target = points.cuda(), target.cuda()
|
|
|
|
points = points.transpose(2, 1)
|
|
if use_cpu:
|
|
vote_pool = torch.zeros(target.size()[0], num_class)
|
|
else:
|
|
vote_pool = torch.zeros(target.size()[0], num_class).cuda()
|
|
|
|
for _ in range(vote_num):
|
|
pred, _ = classifier(points)
|
|
vote_pool += pred
|
|
pred = vote_pool / vote_num
|
|
pred_choice = pred.data.max(1)[1]
|
|
probs = torch.softmax(pred, dim=1)
|
|
|
|
# Save detailed results if requested
|
|
if save_results and dataset is not None:
|
|
batch_size = points.size(0)
|
|
for i in range(batch_size):
|
|
true_label = int(target[i].cpu().item())
|
|
pred_label = int(pred_choice[i].cpu().item())
|
|
confidence = float(probs[i, pred_label].cpu().item())
|
|
|
|
# Try to get file path if available
|
|
file_path = "unknown"
|
|
idx = j * loader.batch_size + i
|
|
if hasattr(dataset, 'files') and idx < len(dataset.files):
|
|
file_path = str(dataset.files[idx])
|
|
elif hasattr(dataset, 'datapath') and idx < len(dataset.datapath):
|
|
# For FishPointCloudDataLoader
|
|
label, file_path = dataset.datapath[idx]
|
|
file_path = str(file_path)
|
|
|
|
all_results.append({
|
|
'file': file_path,
|
|
'true_label': 'good' if true_label == 1 else 'bad',
|
|
'predicted_label': 'good' if pred_label == 1 else 'bad',
|
|
'correct': true_label == pred_label,
|
|
'confidence': confidence,
|
|
'probabilities': {
|
|
'bad': float(probs[i, 0].cpu().item()),
|
|
'good': float(probs[i, 1].cpu().item())
|
|
}
|
|
})
|
|
|
|
for cat in np.unique(target.cpu()):
|
|
cat = int(cat)
|
|
if cat < num_class:
|
|
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
|
|
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
|
|
class_acc[cat, 1] += 1
|
|
correct = pred_choice.eq(target.long().data).cpu().sum()
|
|
mean_correct.append(correct.item() / float(points.size()[0]))
|
|
|
|
class_acc[:, 2] = class_acc[:, 0] / (class_acc[:, 1] + 1e-8) # Avoid division by zero
|
|
class_acc = np.mean(class_acc[:, 2])
|
|
instance_acc = np.mean(mean_correct)
|
|
return instance_acc, class_acc, all_results
|
|
|
|
|
|
def classify_single_pointcloud(model, ply_path, num_point=1024, vote_num=3, use_cpu=False, use_normals=False):
|
|
"""Classify a single point cloud file."""
|
|
from data_utils.FishPointCloudDataLoader import pc_normalize, farthest_point_sample
|
|
|
|
# Load point cloud
|
|
try:
|
|
pcd = o3d.io.read_point_cloud(str(ply_path))
|
|
points = np.asarray(pcd.points).astype(np.float32)
|
|
|
|
if len(points) == 0:
|
|
return None, "Empty point cloud"
|
|
|
|
# Sample points
|
|
if len(points) >= num_point:
|
|
points = points[0:num_point, :]
|
|
else:
|
|
indices = np.random.choice(len(points), num_point, replace=True)
|
|
points = points[indices]
|
|
|
|
# Normalize
|
|
points[:, 0:3] = pc_normalize(points[:, 0:3])
|
|
if not use_normals:
|
|
points = points[:, 0:3]
|
|
|
|
# Convert to tensor
|
|
points = torch.FloatTensor(points).unsqueeze(0) # [1, N, 3]
|
|
if not use_cpu:
|
|
points = points.cuda()
|
|
|
|
points = points.transpose(2, 1) # [1, 3, N]
|
|
|
|
# Classify with voting
|
|
model.eval()
|
|
num_class = 2
|
|
if use_cpu:
|
|
vote_pool = torch.zeros(1, num_class)
|
|
else:
|
|
vote_pool = torch.zeros(1, num_class).cuda()
|
|
|
|
with torch.no_grad():
|
|
for _ in range(vote_num):
|
|
pred, _ = model(points)
|
|
vote_pool += pred
|
|
pred = vote_pool / vote_num
|
|
pred_choice = pred.data.max(1)[1]
|
|
probs = torch.softmax(pred, dim=1)
|
|
|
|
class_id = pred_choice.item()
|
|
class_name = "good" if class_id == 1 else "bad"
|
|
confidence = probs[0, class_id].item()
|
|
|
|
return {
|
|
'file': str(ply_path),
|
|
'prediction': class_name,
|
|
'class_id': int(class_id),
|
|
'confidence': float(confidence),
|
|
'probabilities': {
|
|
'bad': float(probs[0, 0].item()),
|
|
'good': float(probs[0, 1].item())
|
|
}
|
|
}, None
|
|
|
|
except Exception as e:
|
|
return None, str(e)
|
|
|
|
|
|
def classify_folder(model, folder_path, num_point=1024, vote_num=3, use_cpu=False, use_normals=False):
|
|
"""Classify all PLY files in a folder."""
|
|
folder = Path(folder_path)
|
|
if not folder.exists():
|
|
return [], [{'file': str(folder), 'error': 'Folder does not exist'}]
|
|
|
|
ply_files = sorted(list(folder.glob('*.ply')))
|
|
if len(ply_files) == 0:
|
|
return [], [{'file': str(folder), 'error': 'No PLY files found in folder'}]
|
|
|
|
results = []
|
|
errors = []
|
|
|
|
for ply_file in tqdm(ply_files, desc='Classifying'):
|
|
result, error = classify_single_pointcloud(model, ply_file, num_point, vote_num, use_cpu, use_normals)
|
|
if result:
|
|
results.append(result)
|
|
else:
|
|
errors.append({'file': str(ply_file), 'error': error})
|
|
|
|
return results, errors
|
|
|
|
|
|
def classify_files(model, file_paths, num_point=1024, vote_num=3, use_cpu=False, use_normals=False):
|
|
"""Classify a list of PLY files."""
|
|
results = []
|
|
errors = []
|
|
|
|
for file_path in tqdm(file_paths, desc='Classifying'):
|
|
ply_path = Path(file_path)
|
|
if not ply_path.exists():
|
|
errors.append({'file': str(ply_path), 'error': 'File does not exist'})
|
|
continue
|
|
|
|
if not ply_path.suffix.lower() == '.ply':
|
|
errors.append({'file': str(ply_path), 'error': 'Not a PLY file'})
|
|
continue
|
|
|
|
result, error = classify_single_pointcloud(model, ply_path, num_point, vote_num, use_cpu, use_normals)
|
|
if result:
|
|
results.append(result)
|
|
else:
|
|
errors.append({'file': str(ply_path), 'error': error})
|
|
|
|
return results, errors
|
|
|
|
|
|
def main():
|
|
def log_string(str):
|
|
logger.info(str)
|
|
print(str)
|
|
|
|
'''PARSE ARGUMENTS'''
|
|
args = parse_args()
|
|
|
|
'''HYPER PARAMETER'''
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
|
|
|
'''CREATE DIR'''
|
|
experiment_dir = 'log/classification/' + args.log_dir
|
|
|
|
'''LOG'''
|
|
logger = logging.getLogger("Model")
|
|
logger.setLevel(logging.INFO)
|
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
# Create log directory if it doesn't exist
|
|
os.makedirs(experiment_dir, exist_ok=True)
|
|
file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
|
|
file_handler.setLevel(logging.INFO)
|
|
file_handler.setFormatter(formatter)
|
|
logger.addHandler(file_handler)
|
|
log_string('PARAMETER ...')
|
|
log_string(args)
|
|
|
|
'''DATA LOADING'''
|
|
testDataLoader = None
|
|
test_dataset = None
|
|
|
|
# Check if testing single file, folder, or file list
|
|
if args.test_file or args.test_folder or args.test_files:
|
|
log_string('Testing PLY file(s) (no labels required)')
|
|
# Validate that output is specified for file/folder classification
|
|
if args.output is None:
|
|
log_string('Error: --output must be specified when using --test_file, --test_folder, or --test_files')
|
|
return
|
|
else:
|
|
log_string('Load dataset ...')
|
|
if args.use_fish_dataset:
|
|
if args.data_path is None:
|
|
log_string('Error: --data_path must be specified when using --use_fish_dataset')
|
|
return
|
|
data_path = args.data_path
|
|
log_string(f'Using Fish Point Cloud Dataset from: {data_path}')
|
|
test_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='test', process_data=False)
|
|
else:
|
|
data_path = 'data/modelnet40_normal_resampled/'
|
|
log_string(f'Using ModelNet Dataset from: {data_path}')
|
|
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=False)
|
|
|
|
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
|
|
|
|
'''MODEL LOADING'''
|
|
num_class = args.num_category
|
|
model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
|
|
model = importlib.import_module(model_name)
|
|
|
|
classifier = model.get_model(num_class, normal_channel=args.use_normals)
|
|
if not args.use_cpu:
|
|
classifier = classifier.cuda()
|
|
|
|
checkpoint_path = str(experiment_dir) + '/checkpoints/best_model.pth'
|
|
if not os.path.exists(checkpoint_path):
|
|
log_string(f'Error: Checkpoint not found at {checkpoint_path}')
|
|
return
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
classifier.load_state_dict(checkpoint['model_state_dict'])
|
|
log_string(f'Loaded model from {checkpoint_path}')
|
|
|
|
# Test on dataset with labels
|
|
if testDataLoader is not None:
|
|
with torch.no_grad():
|
|
save_results = args.output is not None
|
|
instance_acc, class_acc, all_results = test(
|
|
classifier.eval(), testDataLoader, vote_num=args.num_votes,
|
|
num_class=num_class, use_cpu=args.use_cpu,
|
|
dataset=test_dataset, save_results=save_results
|
|
)
|
|
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
|
|
|
|
# Save detailed results if output specified
|
|
if args.output and all_results:
|
|
import json
|
|
# Calculate per-class statistics
|
|
good_correct = sum(1 for r in all_results if r['true_label'] == 'good' and r['correct'])
|
|
good_total = sum(1 for r in all_results if r['true_label'] == 'good')
|
|
bad_correct = sum(1 for r in all_results if r['true_label'] == 'bad' and r['correct'])
|
|
bad_total = sum(1 for r in all_results if r['true_label'] == 'bad')
|
|
|
|
output_data = {
|
|
'summary': {
|
|
'instance_accuracy': float(instance_acc),
|
|
'class_accuracy': float(class_acc),
|
|
'total_samples': len(all_results),
|
|
'correct': sum(1 for r in all_results if r['correct']),
|
|
'incorrect': sum(1 for r in all_results if not r['correct']),
|
|
'good': {
|
|
'total': good_total,
|
|
'correct': good_correct,
|
|
'accuracy': good_correct / good_total if good_total > 0 else 0.0
|
|
},
|
|
'bad': {
|
|
'total': bad_total,
|
|
'correct': bad_correct,
|
|
'accuracy': bad_correct / bad_total if bad_total > 0 else 0.0
|
|
}
|
|
},
|
|
'results': all_results
|
|
}
|
|
|
|
with open(args.output, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
log_string(f"\nDetailed results saved to {args.output}")
|
|
log_string(f" Total samples: {len(all_results)}")
|
|
log_string(f" Good accuracy: {output_data['summary']['good']['accuracy']:.4f} ({good_correct}/{good_total})")
|
|
log_string(f" Bad accuracy: {output_data['summary']['bad']['accuracy']:.4f} ({bad_correct}/{bad_total})")
|
|
|
|
# Classify single file
|
|
elif args.test_file:
|
|
log_string(f'Classifying single file: {args.test_file}')
|
|
result, error = classify_single_pointcloud(
|
|
classifier, args.test_file, args.num_point, args.num_votes,
|
|
args.use_cpu, args.use_normals
|
|
)
|
|
if result:
|
|
log_string(f"File: {result['file']}")
|
|
log_string(f"Prediction: {result['prediction']} (confidence: {result['confidence']:.4f})")
|
|
log_string(f"Probabilities: bad={result['probabilities']['bad']:.4f}, good={result['probabilities']['good']:.4f}")
|
|
|
|
# Save result to JSON
|
|
import json
|
|
output_data = {
|
|
'file': result['file'],
|
|
'prediction': result['prediction'],
|
|
'class_id': result['class_id'],
|
|
'confidence': result['confidence'],
|
|
'probabilities': result['probabilities']
|
|
}
|
|
with open(args.output, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
log_string(f"Result saved to {args.output}")
|
|
else:
|
|
log_string(f"Error: {error}")
|
|
import json
|
|
error_data = {'file': args.test_file, 'error': error}
|
|
with open(args.output, 'w') as f:
|
|
json.dump(error_data, f, indent=2)
|
|
|
|
# Classify folder
|
|
elif args.test_folder:
|
|
log_string(f'Classifying folder: {args.test_folder}')
|
|
results, errors = classify_folder(
|
|
classifier, args.test_folder, args.num_point, args.num_votes,
|
|
args.use_cpu, args.use_normals
|
|
)
|
|
|
|
log_string(f'\nClassification Results:')
|
|
log_string(f' Total files: {len(results) + len(errors)}')
|
|
log_string(f' Successful: {len(results)}')
|
|
log_string(f' Errors: {len(errors)}')
|
|
|
|
if results:
|
|
good_count = sum(1 for r in results if r['prediction'] == 'good')
|
|
bad_count = sum(1 for r in results if r['prediction'] == 'bad')
|
|
log_string(f' Good: {good_count}, Bad: {bad_count}')
|
|
|
|
# Show top predictions
|
|
log_string(f'\nTop 5 most confident predictions:')
|
|
sorted_results = sorted(results, key=lambda x: x['confidence'], reverse=True)
|
|
for i, r in enumerate(sorted_results[:5], 1):
|
|
log_string(f' {i}. {Path(r["file"]).name}: {r["prediction"]} ({r["confidence"]:.4f})')
|
|
|
|
# Save results to JSON
|
|
import json
|
|
output_data = {
|
|
'results': results,
|
|
'errors': errors,
|
|
'summary': {
|
|
'total': len(results) + len(errors),
|
|
'successful': len(results),
|
|
'errors': len(errors),
|
|
'good': sum(1 for r in results if r['prediction'] == 'good'),
|
|
'bad': sum(1 for r in results if r['prediction'] == 'bad'),
|
|
'good_confidence_avg': np.mean([r['confidence'] for r in results if r['prediction'] == 'good']) if sum(1 for r in results if r['prediction'] == 'good') > 0 else 0.0,
|
|
'bad_confidence_avg': np.mean([r['confidence'] for r in results if r['prediction'] == 'bad']) if sum(1 for r in results if r['prediction'] == 'bad') > 0 else 0.0
|
|
}
|
|
}
|
|
with open(args.output, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
log_string(f"\nResults saved to {args.output}")
|
|
|
|
if errors:
|
|
log_string(f'\nErrors:')
|
|
for err in errors[:5]: # Show first 5 errors
|
|
log_string(f' {err["file"]}: {err["error"]}')
|
|
|
|
# Classify list of files
|
|
elif args.test_files:
|
|
log_string(f'Classifying {len(args.test_files)} PLY files')
|
|
results, errors = classify_files(
|
|
classifier, args.test_files, args.num_point, args.num_votes,
|
|
args.use_cpu, args.use_normals
|
|
)
|
|
|
|
log_string(f'\nClassification Results:')
|
|
log_string(f' Total files: {len(results) + len(errors)}')
|
|
log_string(f' Successful: {len(results)}')
|
|
log_string(f' Errors: {len(errors)}')
|
|
|
|
if results:
|
|
good_count = sum(1 for r in results if r['prediction'] == 'good')
|
|
bad_count = sum(1 for r in results if r['prediction'] == 'bad')
|
|
log_string(f' Good: {good_count}, Bad: {bad_count}')
|
|
|
|
# Show top predictions
|
|
if len(results) > 0:
|
|
log_string(f'\nTop 5 most confident predictions:')
|
|
sorted_results = sorted(results, key=lambda x: x['confidence'], reverse=True)
|
|
for i, r in enumerate(sorted_results[:5], 1):
|
|
log_string(f' {i}. {Path(r["file"]).name}: {r["prediction"]} ({r["confidence"]:.4f})')
|
|
|
|
# Save results to JSON
|
|
import json
|
|
output_data = {
|
|
'results': results,
|
|
'errors': errors,
|
|
'summary': {
|
|
'total': len(results) + len(errors),
|
|
'successful': len(results),
|
|
'errors': len(errors),
|
|
'good': sum(1 for r in results if r['prediction'] == 'good'),
|
|
'bad': sum(1 for r in results if r['prediction'] == 'bad'),
|
|
'good_confidence_avg': float(np.mean([r['confidence'] for r in results if r['prediction'] == 'good'])) if sum(1 for r in results if r['prediction'] == 'good') > 0 else 0.0,
|
|
'bad_confidence_avg': float(np.mean([r['confidence'] for r in results if r['prediction'] == 'bad'])) if sum(1 for r in results if r['prediction'] == 'bad') > 0 else 0.0
|
|
}
|
|
}
|
|
with open(args.output, 'w') as f:
|
|
json.dump(output_data, f, indent=2)
|
|
log_string(f"\nResults saved to {args.output}")
|
|
|
|
if errors:
|
|
log_string(f'\nErrors:')
|
|
for err in errors[:5]: # Show first 5 errors
|
|
log_string(f' {err["file"]}: {err["error"]}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|