301 lines
10 KiB
Python
Executable File
301 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Prepare labeled point cloud data for training.
|
|
|
|
This script scans the output_preview4 directory structure, reads JSON labels,
|
|
and organizes PLY files into train/val/test splits for training.
|
|
|
|
Usage:
|
|
python dataset.py \
|
|
--source /home/ubuntu/projects/FishMeasure/output_preview4/ \
|
|
--output dataset/ \
|
|
--train_ratio 0.7 \
|
|
--val_ratio 0.15 \
|
|
--test_ratio 0.15
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import shutil
|
|
from pathlib import Path
|
|
import random
|
|
from collections import defaultdict
|
|
from tqdm import tqdm
|
|
|
|
|
|
def load_labeled_pointclouds(source_dir):
|
|
"""
|
|
Scan source directory and collect all labeled point clouds.
|
|
|
|
Args:
|
|
source_dir: Root directory containing subfolders with cloud/ directories
|
|
|
|
Returns:
|
|
good_files: List of (source_path, json_data) tuples for good point clouds
|
|
bad_files: List of (source_path, json_data) tuples for bad point clouds
|
|
"""
|
|
source_path = Path(source_dir)
|
|
if not source_path.exists():
|
|
raise ValueError(f"Source directory does not exist: {source_dir}")
|
|
|
|
good_files = []
|
|
bad_files = []
|
|
skipped = []
|
|
|
|
# Find all subfolders
|
|
subfolders = [d for d in source_path.iterdir() if d.is_dir()]
|
|
print(f"Found {len(subfolders)} subfolders")
|
|
|
|
# Scan each subfolder's cloud/ directory
|
|
for subfolder in tqdm(subfolders, desc="Scanning subfolders"):
|
|
cloud_dir = subfolder / "cloud"
|
|
if not cloud_dir.exists():
|
|
continue
|
|
|
|
# Find all JSON files
|
|
json_files = list(cloud_dir.glob("*.json"))
|
|
|
|
for json_file in json_files:
|
|
try:
|
|
# Read JSON
|
|
with open(json_file, 'r') as f:
|
|
label_data = json.load(f)
|
|
|
|
# Get label
|
|
label = label_data.get('label', None)
|
|
if label not in ['good', 'bad']:
|
|
skipped.append(json_file)
|
|
continue
|
|
|
|
# Find corresponding PLY file
|
|
ply_name = label_data.get('pointcloud_file', None)
|
|
if not ply_name:
|
|
# Try to infer from JSON filename
|
|
ply_name = json_file.stem + '.ply'
|
|
|
|
ply_path = cloud_dir / ply_name
|
|
if not ply_path.exists():
|
|
# Try alternative: same name as JSON
|
|
ply_path = json_file.with_suffix('.ply')
|
|
if not ply_path.exists():
|
|
skipped.append(json_file)
|
|
continue
|
|
|
|
# Add to appropriate list
|
|
file_info = (ply_path, label_data, subfolder.name)
|
|
if label == 'good':
|
|
good_files.append(file_info)
|
|
else:
|
|
bad_files.append(file_info)
|
|
|
|
except Exception as e:
|
|
print(f"Warning: Failed to process {json_file}: {e}")
|
|
skipped.append(json_file)
|
|
|
|
print(f"\nFound labeled point clouds:")
|
|
print(f" Good: {len(good_files)}")
|
|
print(f" Bad: {len(bad_files)}")
|
|
print(f" Skipped: {len(skipped)}")
|
|
|
|
if len(skipped) > 0:
|
|
print(f"\nSkipped files (first 10):")
|
|
for f in skipped[:10]:
|
|
print(f" {f}")
|
|
|
|
return good_files, bad_files
|
|
|
|
|
|
def split_dataset(files, train_ratio, val_ratio, test_ratio):
|
|
"""
|
|
Split files into train/val/test sets.
|
|
|
|
Args:
|
|
files: List of file info tuples
|
|
train_ratio: Ratio for training set
|
|
val_ratio: Ratio for validation set
|
|
test_ratio: Ratio for test set
|
|
|
|
Returns:
|
|
train_files, val_files, test_files
|
|
"""
|
|
# Shuffle files
|
|
random.shuffle(files)
|
|
|
|
total = len(files)
|
|
train_count = int(total * train_ratio)
|
|
val_count = int(total * val_ratio)
|
|
|
|
train_files = files[:train_count]
|
|
val_files = files[train_count:train_count + val_count]
|
|
test_files = files[train_count + val_count:]
|
|
|
|
return train_files, val_files, test_files
|
|
|
|
|
|
def copy_files_to_structure(files, output_dir, split_name, label_name):
|
|
"""
|
|
Copy PLY files to the target directory structure.
|
|
|
|
Args:
|
|
files: List of (source_path, json_data, subfolder_name) tuples
|
|
output_dir: Root output directory
|
|
split_name: 'train', 'val', or 'test'
|
|
label_name: 'good' or 'bad'
|
|
"""
|
|
target_dir = output_dir / split_name / label_name
|
|
target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
copied = 0
|
|
failed = 0
|
|
|
|
for source_path, json_data, subfolder_name in tqdm(files, desc=f"Copying {split_name}/{label_name}"):
|
|
try:
|
|
# Create unique filename to avoid conflicts
|
|
# Format: subfolder_original_name.ply
|
|
original_name = source_path.name
|
|
unique_name = f"{subfolder_name}_{original_name}"
|
|
|
|
# Handle potential name conflicts
|
|
target_path = target_dir / unique_name
|
|
counter = 1
|
|
while target_path.exists():
|
|
stem = target_path.stem
|
|
target_path = target_dir / f"{stem}_{counter}.ply"
|
|
counter += 1
|
|
|
|
# Copy file
|
|
shutil.copy2(source_path, target_path)
|
|
copied += 1
|
|
|
|
except Exception as e:
|
|
print(f"Error copying {source_path}: {e}")
|
|
failed += 1
|
|
|
|
print(f" Copied: {copied}, Failed: {failed}")
|
|
return copied, failed
|
|
|
|
|
|
def create_dataset_structure(source_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
|
|
"""
|
|
Main function to create dataset structure.
|
|
|
|
Args:
|
|
source_dir: Source directory containing labeled point clouds
|
|
output_dir: Output directory for organized dataset
|
|
train_ratio: Ratio for training set (default: 0.7)
|
|
val_ratio: Ratio for validation set (default: 0.15)
|
|
test_ratio: Ratio for test set (default: 0.15)
|
|
seed: Random seed for reproducibility
|
|
"""
|
|
# Validate ratios
|
|
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
|
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
|
|
|
|
# Set random seed
|
|
random.seed(seed)
|
|
|
|
# Create output directory
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
print("="*60)
|
|
print("Point Cloud Dataset Preparation")
|
|
print("="*60)
|
|
print(f"Source: {source_dir}")
|
|
print(f"Output: {output_dir}")
|
|
print(f"Split ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}")
|
|
print("="*60)
|
|
|
|
# Load labeled files
|
|
print("\nStep 1: Loading labeled point clouds...")
|
|
good_files, bad_files = load_labeled_pointclouds(source_dir)
|
|
|
|
if len(good_files) == 0 and len(bad_files) == 0:
|
|
print("Error: No labeled point clouds found!")
|
|
return
|
|
|
|
# Split datasets
|
|
print("\nStep 2: Splitting datasets...")
|
|
good_train, good_val, good_test = split_dataset(good_files, train_ratio, val_ratio, test_ratio)
|
|
bad_train, bad_val, bad_test = split_dataset(bad_files, train_ratio, val_ratio, test_ratio)
|
|
|
|
print(f"\nGood point clouds:")
|
|
print(f" Train: {len(good_train)}, Val: {len(good_val)}, Test: {len(good_test)}")
|
|
print(f"\nBad point clouds:")
|
|
print(f" Train: {len(bad_train)}, Val: {len(bad_val)}, Test: {len(bad_test)}")
|
|
|
|
# Copy files
|
|
print("\nStep 3: Copying files to dataset structure...")
|
|
|
|
# Good files
|
|
print("\nProcessing GOOD point clouds:")
|
|
copy_files_to_structure(good_train, output_path, 'train', 'good')
|
|
copy_files_to_structure(good_val, output_path, 'val', 'good')
|
|
copy_files_to_structure(good_test, output_path, 'test', 'good')
|
|
|
|
# Bad files
|
|
print("\nProcessing BAD point clouds:")
|
|
copy_files_to_structure(bad_train, output_path, 'train', 'bad')
|
|
copy_files_to_structure(bad_val, output_path, 'val', 'bad')
|
|
copy_files_to_structure(bad_test, output_path, 'test', 'bad')
|
|
|
|
# Create summary
|
|
print("\n" + "="*60)
|
|
print("Dataset Summary")
|
|
print("="*60)
|
|
|
|
for split in ['train', 'val', 'test']:
|
|
good_dir = output_path / split / 'good'
|
|
bad_dir = output_path / split / 'bad'
|
|
good_count = len(list(good_dir.glob('*.ply'))) if good_dir.exists() else 0
|
|
bad_count = len(list(bad_dir.glob('*.ply'))) if bad_dir.exists() else 0
|
|
total = good_count + bad_count
|
|
|
|
print(f"\n{split.upper()}:")
|
|
print(f" Good: {good_count}")
|
|
print(f" Bad: {bad_count}")
|
|
print(f" Total: {total}")
|
|
|
|
print("\n" + "="*60)
|
|
print("Dataset preparation completed!")
|
|
print(f"Dataset saved to: {output_path}")
|
|
print("="*60)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Prepare labeled point cloud dataset for training')
|
|
parser.add_argument('--source', type=str,
|
|
default='/home/ubuntu/projects/FishMeasure/output_preview4/',
|
|
help='Source directory containing labeled point clouds')
|
|
parser.add_argument('--output', type=str, default='dataset',
|
|
help='Output directory for organized dataset')
|
|
parser.add_argument('--train_ratio', type=float, default=0.7,
|
|
help='Ratio for training set (default: 0.7)')
|
|
parser.add_argument('--val_ratio', type=float, default=0.15,
|
|
help='Ratio for validation set (default: 0.15)')
|
|
parser.add_argument('--test_ratio', type=float, default=0.15,
|
|
help='Ratio for test set (default: 0.15)')
|
|
parser.add_argument('--seed', type=int, default=42,
|
|
help='Random seed for reproducibility (default: 42)')
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
create_dataset_structure(
|
|
source_dir=args.source,
|
|
output_dir=args.output,
|
|
train_ratio=args.train_ratio,
|
|
val_ratio=args.val_ratio,
|
|
test_ratio=args.test_ratio,
|
|
seed=args.seed
|
|
)
|
|
except Exception as e:
|
|
print(f"\nError: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|