Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
300
FishMeasure/pointcloud_classifier/dataset.py
Executable file
300
FishMeasure/pointcloud_classifier/dataset.py
Executable file
@@ -0,0 +1,300 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prepare labeled point cloud data for training.
|
||||
|
||||
This script scans the output_preview4 directory structure, reads JSON labels,
|
||||
and organizes PLY files into train/val/test splits for training.
|
||||
|
||||
Usage:
|
||||
python dataset.py \
|
||||
--source /home/ubuntu/projects/FishMeasure/output_preview4/ \
|
||||
--output dataset/ \
|
||||
--train_ratio 0.7 \
|
||||
--val_ratio 0.15 \
|
||||
--test_ratio 0.15
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def load_labeled_pointclouds(source_dir):
|
||||
"""
|
||||
Scan source directory and collect all labeled point clouds.
|
||||
|
||||
Args:
|
||||
source_dir: Root directory containing subfolders with cloud/ directories
|
||||
|
||||
Returns:
|
||||
good_files: List of (source_path, json_data) tuples for good point clouds
|
||||
bad_files: List of (source_path, json_data) tuples for bad point clouds
|
||||
"""
|
||||
source_path = Path(source_dir)
|
||||
if not source_path.exists():
|
||||
raise ValueError(f"Source directory does not exist: {source_dir}")
|
||||
|
||||
good_files = []
|
||||
bad_files = []
|
||||
skipped = []
|
||||
|
||||
# Find all subfolders
|
||||
subfolders = [d for d in source_path.iterdir() if d.is_dir()]
|
||||
print(f"Found {len(subfolders)} subfolders")
|
||||
|
||||
# Scan each subfolder's cloud/ directory
|
||||
for subfolder in tqdm(subfolders, desc="Scanning subfolders"):
|
||||
cloud_dir = subfolder / "cloud"
|
||||
if not cloud_dir.exists():
|
||||
continue
|
||||
|
||||
# Find all JSON files
|
||||
json_files = list(cloud_dir.glob("*.json"))
|
||||
|
||||
for json_file in json_files:
|
||||
try:
|
||||
# Read JSON
|
||||
with open(json_file, 'r') as f:
|
||||
label_data = json.load(f)
|
||||
|
||||
# Get label
|
||||
label = label_data.get('label', None)
|
||||
if label not in ['good', 'bad']:
|
||||
skipped.append(json_file)
|
||||
continue
|
||||
|
||||
# Find corresponding PLY file
|
||||
ply_name = label_data.get('pointcloud_file', None)
|
||||
if not ply_name:
|
||||
# Try to infer from JSON filename
|
||||
ply_name = json_file.stem + '.ply'
|
||||
|
||||
ply_path = cloud_dir / ply_name
|
||||
if not ply_path.exists():
|
||||
# Try alternative: same name as JSON
|
||||
ply_path = json_file.with_suffix('.ply')
|
||||
if not ply_path.exists():
|
||||
skipped.append(json_file)
|
||||
continue
|
||||
|
||||
# Add to appropriate list
|
||||
file_info = (ply_path, label_data, subfolder.name)
|
||||
if label == 'good':
|
||||
good_files.append(file_info)
|
||||
else:
|
||||
bad_files.append(file_info)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to process {json_file}: {e}")
|
||||
skipped.append(json_file)
|
||||
|
||||
print(f"\nFound labeled point clouds:")
|
||||
print(f" Good: {len(good_files)}")
|
||||
print(f" Bad: {len(bad_files)}")
|
||||
print(f" Skipped: {len(skipped)}")
|
||||
|
||||
if len(skipped) > 0:
|
||||
print(f"\nSkipped files (first 10):")
|
||||
for f in skipped[:10]:
|
||||
print(f" {f}")
|
||||
|
||||
return good_files, bad_files
|
||||
|
||||
|
||||
def split_dataset(files, train_ratio, val_ratio, test_ratio):
|
||||
"""
|
||||
Split files into train/val/test sets.
|
||||
|
||||
Args:
|
||||
files: List of file info tuples
|
||||
train_ratio: Ratio for training set
|
||||
val_ratio: Ratio for validation set
|
||||
test_ratio: Ratio for test set
|
||||
|
||||
Returns:
|
||||
train_files, val_files, test_files
|
||||
"""
|
||||
# Shuffle files
|
||||
random.shuffle(files)
|
||||
|
||||
total = len(files)
|
||||
train_count = int(total * train_ratio)
|
||||
val_count = int(total * val_ratio)
|
||||
|
||||
train_files = files[:train_count]
|
||||
val_files = files[train_count:train_count + val_count]
|
||||
test_files = files[train_count + val_count:]
|
||||
|
||||
return train_files, val_files, test_files
|
||||
|
||||
|
||||
def copy_files_to_structure(files, output_dir, split_name, label_name):
|
||||
"""
|
||||
Copy PLY files to the target directory structure.
|
||||
|
||||
Args:
|
||||
files: List of (source_path, json_data, subfolder_name) tuples
|
||||
output_dir: Root output directory
|
||||
split_name: 'train', 'val', or 'test'
|
||||
label_name: 'good' or 'bad'
|
||||
"""
|
||||
target_dir = output_dir / split_name / label_name
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
copied = 0
|
||||
failed = 0
|
||||
|
||||
for source_path, json_data, subfolder_name in tqdm(files, desc=f"Copying {split_name}/{label_name}"):
|
||||
try:
|
||||
# Create unique filename to avoid conflicts
|
||||
# Format: subfolder_original_name.ply
|
||||
original_name = source_path.name
|
||||
unique_name = f"{subfolder_name}_{original_name}"
|
||||
|
||||
# Handle potential name conflicts
|
||||
target_path = target_dir / unique_name
|
||||
counter = 1
|
||||
while target_path.exists():
|
||||
stem = target_path.stem
|
||||
target_path = target_dir / f"{stem}_{counter}.ply"
|
||||
counter += 1
|
||||
|
||||
# Copy file
|
||||
shutil.copy2(source_path, target_path)
|
||||
copied += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error copying {source_path}: {e}")
|
||||
failed += 1
|
||||
|
||||
print(f" Copied: {copied}, Failed: {failed}")
|
||||
return copied, failed
|
||||
|
||||
|
||||
def create_dataset_structure(source_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
|
||||
"""
|
||||
Main function to create dataset structure.
|
||||
|
||||
Args:
|
||||
source_dir: Source directory containing labeled point clouds
|
||||
output_dir: Output directory for organized dataset
|
||||
train_ratio: Ratio for training set (default: 0.7)
|
||||
val_ratio: Ratio for validation set (default: 0.15)
|
||||
test_ratio: Ratio for test set (default: 0.15)
|
||||
seed: Random seed for reproducibility
|
||||
"""
|
||||
# Validate ratios
|
||||
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
||||
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
|
||||
|
||||
# Set random seed
|
||||
random.seed(seed)
|
||||
|
||||
# Create output directory
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("="*60)
|
||||
print("Point Cloud Dataset Preparation")
|
||||
print("="*60)
|
||||
print(f"Source: {source_dir}")
|
||||
print(f"Output: {output_dir}")
|
||||
print(f"Split ratios: Train={train_ratio:.2f}, Val={val_ratio:.2f}, Test={test_ratio:.2f}")
|
||||
print("="*60)
|
||||
|
||||
# Load labeled files
|
||||
print("\nStep 1: Loading labeled point clouds...")
|
||||
good_files, bad_files = load_labeled_pointclouds(source_dir)
|
||||
|
||||
if len(good_files) == 0 and len(bad_files) == 0:
|
||||
print("Error: No labeled point clouds found!")
|
||||
return
|
||||
|
||||
# Split datasets
|
||||
print("\nStep 2: Splitting datasets...")
|
||||
good_train, good_val, good_test = split_dataset(good_files, train_ratio, val_ratio, test_ratio)
|
||||
bad_train, bad_val, bad_test = split_dataset(bad_files, train_ratio, val_ratio, test_ratio)
|
||||
|
||||
print(f"\nGood point clouds:")
|
||||
print(f" Train: {len(good_train)}, Val: {len(good_val)}, Test: {len(good_test)}")
|
||||
print(f"\nBad point clouds:")
|
||||
print(f" Train: {len(bad_train)}, Val: {len(bad_val)}, Test: {len(bad_test)}")
|
||||
|
||||
# Copy files
|
||||
print("\nStep 3: Copying files to dataset structure...")
|
||||
|
||||
# Good files
|
||||
print("\nProcessing GOOD point clouds:")
|
||||
copy_files_to_structure(good_train, output_path, 'train', 'good')
|
||||
copy_files_to_structure(good_val, output_path, 'val', 'good')
|
||||
copy_files_to_structure(good_test, output_path, 'test', 'good')
|
||||
|
||||
# Bad files
|
||||
print("\nProcessing BAD point clouds:")
|
||||
copy_files_to_structure(bad_train, output_path, 'train', 'bad')
|
||||
copy_files_to_structure(bad_val, output_path, 'val', 'bad')
|
||||
copy_files_to_structure(bad_test, output_path, 'test', 'bad')
|
||||
|
||||
# Create summary
|
||||
print("\n" + "="*60)
|
||||
print("Dataset Summary")
|
||||
print("="*60)
|
||||
|
||||
for split in ['train', 'val', 'test']:
|
||||
good_dir = output_path / split / 'good'
|
||||
bad_dir = output_path / split / 'bad'
|
||||
good_count = len(list(good_dir.glob('*.ply'))) if good_dir.exists() else 0
|
||||
bad_count = len(list(bad_dir.glob('*.ply'))) if bad_dir.exists() else 0
|
||||
total = good_count + bad_count
|
||||
|
||||
print(f"\n{split.upper()}:")
|
||||
print(f" Good: {good_count}")
|
||||
print(f" Bad: {bad_count}")
|
||||
print(f" Total: {total}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Dataset preparation completed!")
|
||||
print(f"Dataset saved to: {output_path}")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Prepare labeled point cloud dataset for training')
|
||||
parser.add_argument('--source', type=str,
|
||||
default='/home/ubuntu/projects/FishMeasure/output_preview4/',
|
||||
help='Source directory containing labeled point clouds')
|
||||
parser.add_argument('--output', type=str, default='dataset',
|
||||
help='Output directory for organized dataset')
|
||||
parser.add_argument('--train_ratio', type=float, default=0.7,
|
||||
help='Ratio for training set (default: 0.7)')
|
||||
parser.add_argument('--val_ratio', type=float, default=0.15,
|
||||
help='Ratio for validation set (default: 0.15)')
|
||||
parser.add_argument('--test_ratio', type=float, default=0.15,
|
||||
help='Ratio for test set (default: 0.15)')
|
||||
parser.add_argument('--seed', type=int, default=42,
|
||||
help='Random seed for reproducibility (default: 42)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
create_dataset_structure(
|
||||
source_dir=args.source,
|
||||
output_dir=args.output,
|
||||
train_ratio=args.train_ratio,
|
||||
val_ratio=args.val_ratio,
|
||||
test_ratio=args.test_ratio,
|
||||
seed=args.seed
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user