Files
FishServer/FishMeasure/pointcloud_classifier/Pointnet_Pointnet2_pytorch/test_classification.py
2026-04-08 19:32:23 +08:00

488 lines
21 KiB
Python
Executable File

"""
Author: Benny
Date: Nov 2019
Modified for Fish Point Cloud Quality Classification
"""
from data_utils.ModelNetDataLoader import ModelNetDataLoader
from data_utils.FishPointCloudDataLoader import FishPointCloudDataLoader
import argparse
import numpy as np
import os
import torch
import logging
from tqdm import tqdm
import sys
import importlib
from pathlib import Path
import open3d as o3d
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
def parse_args():
'''PARAMETERS'''
parser = argparse.ArgumentParser('Testing Point Cloud Classification')
parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
parser.add_argument('--num_category', default=2, type=int, help='number of classes (2 for good/bad)')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
parser.add_argument('--log_dir', type=str, default='fish_pointnet2_finetune',
help='Experiment root (default: fish_pointnet2_finetune)')
parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
parser.add_argument('--data_path', type=str, default=None, help='path to test dataset (for fish point cloud)')
parser.add_argument('--use_fish_dataset', action='store_true', default=False, help='use fish point cloud dataset instead of ModelNet')
# Input options (mutually exclusive)
input_group = parser.add_mutually_exclusive_group()
input_group.add_argument('--test_file', type=str, default=None,
help='path to single PLY file for classification')
input_group.add_argument('--test_folder', type=str, default=None,
help='path to folder containing PLY files for batch classification')
input_group.add_argument('--test_files', type=str, nargs='+', default=None,
help='list of PLY file paths for batch classification')
parser.add_argument('--output', type=str, default=None,
help='output JSON file to save classification results (required for file/folder classification)')
return parser.parse_args()
def test(model, loader, num_class=2, vote_num=1, use_cpu=False, dataset=None, save_results=False):
"""Test model on dataset with labels."""
mean_correct = []
classifier = model.eval()
class_acc = np.zeros((num_class, 3))
all_results = [] # Store detailed results
for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
if not use_cpu:
points, target = points.cuda(), target.cuda()
points = points.transpose(2, 1)
if use_cpu:
vote_pool = torch.zeros(target.size()[0], num_class)
else:
vote_pool = torch.zeros(target.size()[0], num_class).cuda()
for _ in range(vote_num):
pred, _ = classifier(points)
vote_pool += pred
pred = vote_pool / vote_num
pred_choice = pred.data.max(1)[1]
probs = torch.softmax(pred, dim=1)
# Save detailed results if requested
if save_results and dataset is not None:
batch_size = points.size(0)
for i in range(batch_size):
true_label = int(target[i].cpu().item())
pred_label = int(pred_choice[i].cpu().item())
confidence = float(probs[i, pred_label].cpu().item())
# Try to get file path if available
file_path = "unknown"
idx = j * loader.batch_size + i
if hasattr(dataset, 'files') and idx < len(dataset.files):
file_path = str(dataset.files[idx])
elif hasattr(dataset, 'datapath') and idx < len(dataset.datapath):
# For FishPointCloudDataLoader
label, file_path = dataset.datapath[idx]
file_path = str(file_path)
all_results.append({
'file': file_path,
'true_label': 'good' if true_label == 1 else 'bad',
'predicted_label': 'good' if pred_label == 1 else 'bad',
'correct': true_label == pred_label,
'confidence': confidence,
'probabilities': {
'bad': float(probs[i, 0].cpu().item()),
'good': float(probs[i, 1].cpu().item())
}
})
for cat in np.unique(target.cpu()):
cat = int(cat)
if cat < num_class:
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
class_acc[cat, 1] += 1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
class_acc[:, 2] = class_acc[:, 0] / (class_acc[:, 1] + 1e-8) # Avoid division by zero
class_acc = np.mean(class_acc[:, 2])
instance_acc = np.mean(mean_correct)
return instance_acc, class_acc, all_results
def classify_single_pointcloud(model, ply_path, num_point=1024, vote_num=3, use_cpu=False, use_normals=False):
"""Classify a single point cloud file."""
from data_utils.FishPointCloudDataLoader import pc_normalize, farthest_point_sample
# Load point cloud
try:
pcd = o3d.io.read_point_cloud(str(ply_path))
points = np.asarray(pcd.points).astype(np.float32)
if len(points) == 0:
return None, "Empty point cloud"
# Sample points
if len(points) >= num_point:
points = points[0:num_point, :]
else:
indices = np.random.choice(len(points), num_point, replace=True)
points = points[indices]
# Normalize
points[:, 0:3] = pc_normalize(points[:, 0:3])
if not use_normals:
points = points[:, 0:3]
# Convert to tensor
points = torch.FloatTensor(points).unsqueeze(0) # [1, N, 3]
if not use_cpu:
points = points.cuda()
points = points.transpose(2, 1) # [1, 3, N]
# Classify with voting
model.eval()
num_class = 2
if use_cpu:
vote_pool = torch.zeros(1, num_class)
else:
vote_pool = torch.zeros(1, num_class).cuda()
with torch.no_grad():
for _ in range(vote_num):
pred, _ = model(points)
vote_pool += pred
pred = vote_pool / vote_num
pred_choice = pred.data.max(1)[1]
probs = torch.softmax(pred, dim=1)
class_id = pred_choice.item()
class_name = "good" if class_id == 1 else "bad"
confidence = probs[0, class_id].item()
return {
'file': str(ply_path),
'prediction': class_name,
'class_id': int(class_id),
'confidence': float(confidence),
'probabilities': {
'bad': float(probs[0, 0].item()),
'good': float(probs[0, 1].item())
}
}, None
except Exception as e:
return None, str(e)
def classify_folder(model, folder_path, num_point=1024, vote_num=3, use_cpu=False, use_normals=False):
"""Classify all PLY files in a folder."""
folder = Path(folder_path)
if not folder.exists():
return [], [{'file': str(folder), 'error': 'Folder does not exist'}]
ply_files = sorted(list(folder.glob('*.ply')))
if len(ply_files) == 0:
return [], [{'file': str(folder), 'error': 'No PLY files found in folder'}]
results = []
errors = []
for ply_file in tqdm(ply_files, desc='Classifying'):
result, error = classify_single_pointcloud(model, ply_file, num_point, vote_num, use_cpu, use_normals)
if result:
results.append(result)
else:
errors.append({'file': str(ply_file), 'error': error})
return results, errors
def classify_files(model, file_paths, num_point=1024, vote_num=3, use_cpu=False, use_normals=False):
"""Classify a list of PLY files."""
results = []
errors = []
for file_path in tqdm(file_paths, desc='Classifying'):
ply_path = Path(file_path)
if not ply_path.exists():
errors.append({'file': str(ply_path), 'error': 'File does not exist'})
continue
if not ply_path.suffix.lower() == '.ply':
errors.append({'file': str(ply_path), 'error': 'Not a PLY file'})
continue
result, error = classify_single_pointcloud(model, ply_path, num_point, vote_num, use_cpu, use_normals)
if result:
results.append(result)
else:
errors.append({'file': str(ply_path), 'error': error})
return results, errors
def main():
def log_string(str):
logger.info(str)
print(str)
'''PARSE ARGUMENTS'''
args = parse_args()
'''HYPER PARAMETER'''
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
'''CREATE DIR'''
experiment_dir = 'log/classification/' + args.log_dir
'''LOG'''
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create log directory if it doesn't exist
os.makedirs(experiment_dir, exist_ok=True)
file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
'''DATA LOADING'''
testDataLoader = None
test_dataset = None
# Check if testing single file, folder, or file list
if args.test_file or args.test_folder or args.test_files:
log_string('Testing PLY file(s) (no labels required)')
# Validate that output is specified for file/folder classification
if args.output is None:
log_string('Error: --output must be specified when using --test_file, --test_folder, or --test_files')
return
else:
log_string('Load dataset ...')
if args.use_fish_dataset:
if args.data_path is None:
log_string('Error: --data_path must be specified when using --use_fish_dataset')
return
data_path = args.data_path
log_string(f'Using Fish Point Cloud Dataset from: {data_path}')
test_dataset = FishPointCloudDataLoader(root=data_path, args=args, split='test', process_data=False)
else:
data_path = 'data/modelnet40_normal_resampled/'
log_string(f'Using ModelNet Dataset from: {data_path}')
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=False)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
'''MODEL LOADING'''
num_class = args.num_category
model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
model = importlib.import_module(model_name)
classifier = model.get_model(num_class, normal_channel=args.use_normals)
if not args.use_cpu:
classifier = classifier.cuda()
checkpoint_path = str(experiment_dir) + '/checkpoints/best_model.pth'
if not os.path.exists(checkpoint_path):
log_string(f'Error: Checkpoint not found at {checkpoint_path}')
return
checkpoint = torch.load(checkpoint_path, map_location='cpu')
classifier.load_state_dict(checkpoint['model_state_dict'])
log_string(f'Loaded model from {checkpoint_path}')
# Test on dataset with labels
if testDataLoader is not None:
with torch.no_grad():
save_results = args.output is not None
instance_acc, class_acc, all_results = test(
classifier.eval(), testDataLoader, vote_num=args.num_votes,
num_class=num_class, use_cpu=args.use_cpu,
dataset=test_dataset, save_results=save_results
)
log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
# Save detailed results if output specified
if args.output and all_results:
import json
# Calculate per-class statistics
good_correct = sum(1 for r in all_results if r['true_label'] == 'good' and r['correct'])
good_total = sum(1 for r in all_results if r['true_label'] == 'good')
bad_correct = sum(1 for r in all_results if r['true_label'] == 'bad' and r['correct'])
bad_total = sum(1 for r in all_results if r['true_label'] == 'bad')
output_data = {
'summary': {
'instance_accuracy': float(instance_acc),
'class_accuracy': float(class_acc),
'total_samples': len(all_results),
'correct': sum(1 for r in all_results if r['correct']),
'incorrect': sum(1 for r in all_results if not r['correct']),
'good': {
'total': good_total,
'correct': good_correct,
'accuracy': good_correct / good_total if good_total > 0 else 0.0
},
'bad': {
'total': bad_total,
'correct': bad_correct,
'accuracy': bad_correct / bad_total if bad_total > 0 else 0.0
}
},
'results': all_results
}
with open(args.output, 'w') as f:
json.dump(output_data, f, indent=2)
log_string(f"\nDetailed results saved to {args.output}")
log_string(f" Total samples: {len(all_results)}")
log_string(f" Good accuracy: {output_data['summary']['good']['accuracy']:.4f} ({good_correct}/{good_total})")
log_string(f" Bad accuracy: {output_data['summary']['bad']['accuracy']:.4f} ({bad_correct}/{bad_total})")
# Classify single file
elif args.test_file:
log_string(f'Classifying single file: {args.test_file}')
result, error = classify_single_pointcloud(
classifier, args.test_file, args.num_point, args.num_votes,
args.use_cpu, args.use_normals
)
if result:
log_string(f"File: {result['file']}")
log_string(f"Prediction: {result['prediction']} (confidence: {result['confidence']:.4f})")
log_string(f"Probabilities: bad={result['probabilities']['bad']:.4f}, good={result['probabilities']['good']:.4f}")
# Save result to JSON
import json
output_data = {
'file': result['file'],
'prediction': result['prediction'],
'class_id': result['class_id'],
'confidence': result['confidence'],
'probabilities': result['probabilities']
}
with open(args.output, 'w') as f:
json.dump(output_data, f, indent=2)
log_string(f"Result saved to {args.output}")
else:
log_string(f"Error: {error}")
import json
error_data = {'file': args.test_file, 'error': error}
with open(args.output, 'w') as f:
json.dump(error_data, f, indent=2)
# Classify folder
elif args.test_folder:
log_string(f'Classifying folder: {args.test_folder}')
results, errors = classify_folder(
classifier, args.test_folder, args.num_point, args.num_votes,
args.use_cpu, args.use_normals
)
log_string(f'\nClassification Results:')
log_string(f' Total files: {len(results) + len(errors)}')
log_string(f' Successful: {len(results)}')
log_string(f' Errors: {len(errors)}')
if results:
good_count = sum(1 for r in results if r['prediction'] == 'good')
bad_count = sum(1 for r in results if r['prediction'] == 'bad')
log_string(f' Good: {good_count}, Bad: {bad_count}')
# Show top predictions
log_string(f'\nTop 5 most confident predictions:')
sorted_results = sorted(results, key=lambda x: x['confidence'], reverse=True)
for i, r in enumerate(sorted_results[:5], 1):
log_string(f' {i}. {Path(r["file"]).name}: {r["prediction"]} ({r["confidence"]:.4f})')
# Save results to JSON
import json
output_data = {
'results': results,
'errors': errors,
'summary': {
'total': len(results) + len(errors),
'successful': len(results),
'errors': len(errors),
'good': sum(1 for r in results if r['prediction'] == 'good'),
'bad': sum(1 for r in results if r['prediction'] == 'bad'),
'good_confidence_avg': np.mean([r['confidence'] for r in results if r['prediction'] == 'good']) if sum(1 for r in results if r['prediction'] == 'good') > 0 else 0.0,
'bad_confidence_avg': np.mean([r['confidence'] for r in results if r['prediction'] == 'bad']) if sum(1 for r in results if r['prediction'] == 'bad') > 0 else 0.0
}
}
with open(args.output, 'w') as f:
json.dump(output_data, f, indent=2)
log_string(f"\nResults saved to {args.output}")
if errors:
log_string(f'\nErrors:')
for err in errors[:5]: # Show first 5 errors
log_string(f' {err["file"]}: {err["error"]}')
# Classify list of files
elif args.test_files:
log_string(f'Classifying {len(args.test_files)} PLY files')
results, errors = classify_files(
classifier, args.test_files, args.num_point, args.num_votes,
args.use_cpu, args.use_normals
)
log_string(f'\nClassification Results:')
log_string(f' Total files: {len(results) + len(errors)}')
log_string(f' Successful: {len(results)}')
log_string(f' Errors: {len(errors)}')
if results:
good_count = sum(1 for r in results if r['prediction'] == 'good')
bad_count = sum(1 for r in results if r['prediction'] == 'bad')
log_string(f' Good: {good_count}, Bad: {bad_count}')
# Show top predictions
if len(results) > 0:
log_string(f'\nTop 5 most confident predictions:')
sorted_results = sorted(results, key=lambda x: x['confidence'], reverse=True)
for i, r in enumerate(sorted_results[:5], 1):
log_string(f' {i}. {Path(r["file"]).name}: {r["prediction"]} ({r["confidence"]:.4f})')
# Save results to JSON
import json
output_data = {
'results': results,
'errors': errors,
'summary': {
'total': len(results) + len(errors),
'successful': len(results),
'errors': len(errors),
'good': sum(1 for r in results if r['prediction'] == 'good'),
'bad': sum(1 for r in results if r['prediction'] == 'bad'),
'good_confidence_avg': float(np.mean([r['confidence'] for r in results if r['prediction'] == 'good'])) if sum(1 for r in results if r['prediction'] == 'good') > 0 else 0.0,
'bad_confidence_avg': float(np.mean([r['confidence'] for r in results if r['prediction'] == 'bad'])) if sum(1 for r in results if r['prediction'] == 'bad') > 0 else 0.0
}
}
with open(args.output, 'w') as f:
json.dump(output_data, f, indent=2)
log_string(f"\nResults saved to {args.output}")
if errors:
log_string(f'\nErrors:')
for err in errors[:5]: # Show first 5 errors
log_string(f' {err["file"]}: {err["error"]}')
if __name__ == '__main__':
main()