#!/usr/bin/env python3 """ Batch process multiple SVO2 files. Loops over all .svo2 files in a folder and processes each one. Loads YOLO and SAM once and reuses them for all files. """ import argparse import torch from pathlib import Path from ultralytics import YOLO from seg import init_models from fish_video_weight_evaluation import process_single_svo2, load_pointcloud_classifier def batch_process_svo2_folder(svo_folder, output_base, yolo_model_path, conf=0.25, imgsz=640, sam_device="cuda", max_frames=0, save_images=False, filter_pointcloud=False, use_clustering_filter=False, use_density_filter=False, pointcloud_classifier_path=None, use_pointcloud_classifier=False, pointcloud_classifier_threshold=0.7, use_flatness_filter=False, flatness_threshold=50.0): """Process all SVO2 files in a folder. Loads YOLO and SAM once and reuses them.""" svo_folder = Path(svo_folder).expanduser().resolve() if not svo_folder.is_dir(): print(f"ERROR: Folder not found: {svo_folder}") return # Find all SVO2 files svo_files = sorted(svo_folder.glob("*.svo2")) if not svo_files: print(f"ERROR: No .svo2 files found in {svo_folder}") return print(f"Found {len(svo_files)} SVO2 file(s) to process") print("="*60) # Load YOLO and SAM once for all files print(f"\nLoading YOLO model: {yolo_model_path}") yolo_model = YOLO(yolo_model_path) class_names = yolo_model.names if hasattr(yolo_model, 'names') else {} print(f"✓ YOLO loaded. Classes: {class_names}") print(f"\nLoading SAM (device: {sam_device})...") sam_predictor = init_models(device=sam_device, seg_model="sam") sam_device_obj = torch.device(sam_device) print(f"✓ SAM loaded") # Load point cloud classifier if specified pointcloud_classifier = None if use_pointcloud_classifier: if pointcloud_classifier_path is None: # Try default path default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth" if default_path.exists(): pointcloud_classifier_path = str(default_path) print(f"Using default classifier path: {pointcloud_classifier_path}") else: print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found") print(f" Default path checked: {default_path}") use_pointcloud_classifier = False if use_pointcloud_classifier and pointcloud_classifier_path: pointcloud_classifier = load_pointcloud_classifier( pointcloud_classifier_path, num_classes=2, use_normals=False, device=sam_device ) if pointcloud_classifier is None: print("Warning: Failed to load point cloud classifier. Continuing without quality filtering.") use_pointcloud_classifier = False else: print(f" Point cloud classifier confidence threshold: {pointcloud_classifier_threshold}") print("\n" + "="*60) print("Models loaded. Processing SVO2 files...") print("="*60) output_base = Path(output_base).expanduser().resolve() output_base.mkdir(parents=True, exist_ok=True) success_count = 0 failed_count = 0 skipped_count = 0 for idx, svo_file in enumerate(svo_files): print(f"\n{'='*60}") print(f"[{idx + 1}/{len(svo_files)}] Processing: {svo_file.name}") print(f"{'='*60}") # Check if output folder already exists svo_name = svo_file.stem output_folder = output_base / svo_name if output_folder.exists() and output_folder.is_dir(): print(f" ⏭ Skipping {svo_file.name} - output folder already exists: {output_folder}") skipped_count += 1 continue try: # Process using pre-loaded models success = process_single_svo2( svo_path=svo_file, output_base=output_base, yolo_model=yolo_model, sam_predictor=sam_predictor, sam_device=sam_device_obj, conf=conf, imgsz=imgsz, max_frames=max_frames, save_images=save_images, filter_pointcloud=filter_pointcloud, use_clustering_filter=use_clustering_filter, use_density_filter=use_density_filter, pointcloud_classifier=pointcloud_classifier, use_pointcloud_classifier=use_pointcloud_classifier, pointcloud_classifier_threshold=pointcloud_classifier_threshold, use_flatness_filter=use_flatness_filter, flatness_threshold=flatness_threshold ) if success: success_count += 1 print(f"\n✓ Successfully processed: {svo_file.name}") else: failed_count += 1 print(f"\n✗ Failed to process: {svo_file.name}") except Exception as e: failed_count += 1 print(f"\n✗ Error processing {svo_file.name}: {e}") import traceback traceback.print_exc() print("\n" + "="*60) print("Batch Processing Summary:") print("="*60) print(f"Total files: {len(svo_files)}") print(f"Successfully processed: {success_count}") print(f"Skipped (already exists): {skipped_count}") print(f"Failed: {failed_count}") print(f"Output folder: {output_base.resolve()}") print("="*60) def main(): parser = argparse.ArgumentParser(description="Batch process multiple SVO2 files") parser.add_argument("--svo-folder", required=True, help="Folder containing SVO2 files to process") parser.add_argument("--output", default="output_preview", help="Base output folder (default: output_preview)") parser.add_argument("--yolo-model", default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt", help="YOLO model path") parser.add_argument("--conf", type=float, default=0.7, help="Confidence threshold") parser.add_argument("--imgsz", type=int, default=640, help="Image size") parser.add_argument("--sam-device", type=str, default="cuda", help="Device for SAM (cuda or cpu)") parser.add_argument("--max-frames", type=int, default=0, help="Maximum frames per SVO2 file (0 = all frames)") parser.add_argument("--save-images", action="store_true", help="Save individual image files instead of creating videos") parser.add_argument("--filter-pointcloud", action="store_true", help="Apply filtering to remove outliers from point clouds (default: no filtering)") parser.add_argument("--use-clustering-filter", action="store_true", help="Use clustering filter to keep only the largest cluster (requires --filter-pointcloud)") parser.add_argument("--use-density-filter", action="store_true", help="Use density filter: keep only points with at least 200 neighbors within 100mm radius (requires --filter-pointcloud)") parser.add_argument("--pointcloud-classifier", type=str, default=None, help="Path to point cloud quality classifier checkpoint") parser.add_argument("--use-pointcloud-classifier", action="store_true", help="Use point cloud classifier to filter out bad quality point clouds (requires --pointcloud-classifier)") parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7, help="Confidence threshold for point cloud classifier (default: 0.7). Only 'good' point clouds with confidence >= threshold will be saved.") parser.add_argument("--use-flatness-filter", action="store_true", help="Evaluate point cloud flatness before saving. Skip point clouds that are not flat enough.") parser.add_argument("--flatness-threshold", type=float, default=50.0, help="Minimum flatness score (0-100%%) required to save point cloud (default: 50.0%%). Higher values mean stricter flatness requirement.") args = parser.parse_args() batch_process_svo2_folder( svo_folder=args.svo_folder, output_base=args.output, yolo_model_path=args.yolo_model, conf=args.conf, imgsz=args.imgsz, sam_device=args.sam_device, max_frames=args.max_frames, save_images=args.save_images, filter_pointcloud=args.filter_pointcloud, use_clustering_filter=args.use_clustering_filter, use_density_filter=args.use_density_filter, pointcloud_classifier_path=args.pointcloud_classifier, use_pointcloud_classifier=args.use_pointcloud_classifier, pointcloud_classifier_threshold=args.pointcloud_classifier_threshold, use_flatness_filter=args.use_flatness_filter, flatness_threshold=args.flatness_threshold ) if __name__ == "__main__": main()