Files
FishServer/FishMeasure/pointcloud_classifier/dataset.py

301 lines
10 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Prepare labeled point cloud data for training.
This script scans the output_preview4 directory structure, reads JSON labels,
and organizes PLY files into train/val/test splits for training.
Usage:
python dataset.py \
--source /home/ubuntu/projects/FishMeasure/output_preview4/ \
--output dataset/ \
--train_ratio 0.7 \
--val_ratio 0.15 \
--test_ratio 0.15
"""
import argparse
import json
import shutil
from pathlib import Path
import random
from collections import defaultdict
from tqdm import tqdm
def load_labeled_pointclouds(source_dir):
"""
Scan source directory and collect all labeled point clouds.
Args:
source_dir: Root directory containing subfolders with cloud/ directories
Returns:
good_files: List of (source_path, json_data) tuples for good point clouds
bad_files: List of (source_path, json_data) tuples for bad point clouds
"""
source_path = Path(source_dir)
if not source_path.exists():
raise ValueError(f"Source directory does not exist: {source_dir}")
good_files = []
bad_files = []
skipped = []
# Find all subfolders
subfolders = [d for d in source_path.iterdir() if d.is_dir()]
print(f"Found {len(subfolders)} subfolders")
# Scan each subfolder's cloud/ directory
for subfolder in tqdm(subfolders, desc="Scanning subfolders"):
cloud_dir = subfolder / "cloud"
if not cloud_dir.exists():
continue
# Find all JSON files
json_files = list(cloud_dir.glob("*.json"))
for json_file in json_files:
try:
# Read JSON
with open(json_file, 'r') as f:
label_data = json.load(f)
# Get label
label = label_data.get('label', None)
if label not in ['good', 'bad']:
skipped.append(json_file)
continue
# Find corresponding PLY file
ply_name = label_data.get('pointcloud_file', None)
if not ply_name:
# Try to infer from JSON filename
ply_name = json_file.stem + '.ply'
ply_path = cloud_dir / ply_name
if not ply_path.exists():
# Try alternative: same name as JSON
ply_path = json_file.with_suffix('.ply')
if not ply_path.exists():
skipped.append(json_file)
continue
# Add to appropriate list
file_info = (ply_path, label_data, subfolder.name)
if label == 'good':
good_files.append(file_info)
else:
bad_files.append(file_info)
except Exception as e:
print(f"Warning: Failed to process {json_file}: {e}")
skipped.append(json_file)
print(f"\nFound labeled point clouds:")
print(f" Good: {len(good_files)}")
print(f" Bad: {len(bad_files)}")
print(f" Skipped: {len(skipped)}")
if len(skipped) > 0:
print(f"\nSkipped files (first 10):")
for f in skipped[:10]:
print(f" {f}")
return good_files, bad_files
def split_dataset(files, train_ratio, val_ratio, test_ratio):
"""
Split files into train/val/test sets.
Args:
files: List of file info tuples
train_ratio: Ratio for training set
val_ratio: Ratio for validation set
test_ratio: Ratio for test set
Returns:
train_files, val_files, test_files
"""
# Shuffle files
random.shuffle(files)
total = len(files)
train_count = int(total * train_ratio)
val_count = int(total * val_ratio)
train_files = files[:train_count]
val_files = files[train_count:train_count + val_count]
test_files = files[train_count + val_count:]
return train_files, val_files, test_files
def copy_files_to_structure(files, output_dir, split_name, label_name):
"""
Copy PLY files to the target directory structure.
Args:
files: List of (source_path, json_data, subfolder_name) tuples
output_dir: Root output directory
split_name: 'train', 'val', or 'test'
label_name: 'good' or 'bad'
"""
target_dir = output_dir / split_name / label_name
target_dir.mkdir(parents=True, exist_ok=True)
copied = 0
failed = 0
for source_path, json_data, subfolder_name in tqdm(files, desc=f"Copying {split_name}/{label_name}"):
try:
# Create unique filename to avoid conflicts
# Format: subfolder_original_name.ply
original_name = source_path.name
unique_name = f"{subfolder_name}_{original_name}"
# Handle potential name conflicts
target_path = target_dir / unique_name
counter = 1
while target_path.exists():
stem = target_path.stem
target_path = target_dir / f"{stem}_{counter}.ply"
counter += 1
# Copy file
shutil.copy2(source_path, target_path)
copied += 1
except Exception as e:
print(f"Error copying {source_path}: {e}")
failed += 1
print(f" Copied: {copied}, Failed: {failed}")
return copied, failed
def create_dataset_structure(source_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
"""
Main function to create dataset structure.
Args:
source_dir: Source directory containing labeled point clouds
output_dir: Output directory for organized dataset
train_ratio: Ratio for training set (default: 0.7)
val_ratio: Ratio for validation set (default: 0.15)
test_ratio: Ratio for test set (default: 0.15)
seed: Random seed for reproducibility
"""
# Validate ratios
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
# Set random seed
random.seed(seed)
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
print("="*60)
print("Point Cloud Dataset Preparation")
print("="*60)
print(f"Source: {source_dir}")
print(f"Output: {output_dir}")
print(f"Split ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}")
print("="*60)
# Load labeled files
print("\nStep 1: Loading labeled point clouds...")
good_files, bad_files = load_labeled_pointclouds(source_dir)
if len(good_files) == 0 and len(bad_files) == 0:
print("Error: No labeled point clouds found!")
return
# Split datasets
print("\nStep 2: Splitting datasets...")
good_train, good_val, good_test = split_dataset(good_files, train_ratio, val_ratio, test_ratio)
bad_train, bad_val, bad_test = split_dataset(bad_files, train_ratio, val_ratio, test_ratio)
print(f"\nGood point clouds:")
print(f" Train: {len(good_train)}, Val: {len(good_val)}, Test: {len(good_test)}")
print(f"\nBad point clouds:")
print(f" Train: {len(bad_train)}, Val: {len(bad_val)}, Test: {len(bad_test)}")
# Copy files
print("\nStep 3: Copying files to dataset structure...")
# Good files
print("\nProcessing GOOD point clouds:")
copy_files_to_structure(good_train, output_path, 'train', 'good')
copy_files_to_structure(good_val, output_path, 'val', 'good')
copy_files_to_structure(good_test, output_path, 'test', 'good')
# Bad files
print("\nProcessing BAD point clouds:")
copy_files_to_structure(bad_train, output_path, 'train', 'bad')
copy_files_to_structure(bad_val, output_path, 'val', 'bad')
copy_files_to_structure(bad_test, output_path, 'test', 'bad')
# Create summary
print("\n" + "="*60)
print("Dataset Summary")
print("="*60)
for split in ['train', 'val', 'test']:
good_dir = output_path / split / 'good'
bad_dir = output_path / split / 'bad'
good_count = len(list(good_dir.glob('*.ply'))) if good_dir.exists() else 0
bad_count = len(list(bad_dir.glob('*.ply'))) if bad_dir.exists() else 0
total = good_count + bad_count
print(f"\n{split.upper()}:")
print(f" Good: {good_count}")
print(f" Bad: {bad_count}")
print(f" Total: {total}")
print("\n" + "="*60)
print("Dataset preparation completed!")
print(f"Dataset saved to: {output_path}")
print("="*60)
def main():
parser = argparse.ArgumentParser(description='Prepare labeled point cloud dataset for training')
parser.add_argument('--source', type=str,
default='/home/ubuntu/projects/FishMeasure/output_preview4/',
help='Source directory containing labeled point clouds')
parser.add_argument('--output', type=str, default='dataset',
help='Output directory for organized dataset')
parser.add_argument('--train_ratio', type=float, default=0.7,
help='Ratio for training set (default: 0.7)')
parser.add_argument('--val_ratio', type=float, default=0.15,
help='Ratio for validation set (default: 0.15)')
parser.add_argument('--test_ratio', type=float, default=0.15,
help='Ratio for test set (default: 0.15)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed for reproducibility (default: 42)')
args = parser.parse_args()
try:
create_dataset_structure(
source_dir=args.source,
output_dir=args.output,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
test_ratio=args.test_ratio,
seed=args.seed
)
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()