203 lines
9.4 KiB
Python
Executable File
203 lines
9.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Batch process multiple SVO2 files.
|
|
Loops over all .svo2 files in a folder and processes each one.
|
|
Loads YOLO and SAM once and reuses them for all files.
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
from pathlib import Path
|
|
from ultralytics import YOLO
|
|
from seg import init_models
|
|
from fish_video_weight_evaluation import process_single_svo2, load_pointcloud_classifier
|
|
|
|
|
|
def batch_process_svo2_folder(svo_folder, output_base, yolo_model_path, conf=0.25, imgsz=640,
|
|
sam_device="cuda", max_frames=0, save_images=False, filter_pointcloud=False,
|
|
use_clustering_filter=False, use_density_filter=False,
|
|
pointcloud_classifier_path=None, use_pointcloud_classifier=False,
|
|
pointcloud_classifier_threshold=0.7,
|
|
use_flatness_filter=False, flatness_threshold=50.0):
|
|
"""Process all SVO2 files in a folder. Loads YOLO and SAM once and reuses them."""
|
|
svo_folder = Path(svo_folder).expanduser().resolve()
|
|
if not svo_folder.is_dir():
|
|
print(f"ERROR: Folder not found: {svo_folder}")
|
|
return
|
|
|
|
# Find all SVO2 files
|
|
svo_files = sorted(svo_folder.glob("*.svo2"))
|
|
if not svo_files:
|
|
print(f"ERROR: No .svo2 files found in {svo_folder}")
|
|
return
|
|
|
|
print(f"Found {len(svo_files)} SVO2 file(s) to process")
|
|
print("="*60)
|
|
|
|
# Load YOLO and SAM once for all files
|
|
print(f"\nLoading YOLO model: {yolo_model_path}")
|
|
yolo_model = YOLO(yolo_model_path)
|
|
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
|
|
print(f"✓ YOLO loaded. Classes: {class_names}")
|
|
|
|
print(f"\nLoading SAM (device: {sam_device})...")
|
|
sam_predictor = init_models(device=sam_device, seg_model="sam")
|
|
sam_device_obj = torch.device(sam_device)
|
|
print(f"✓ SAM loaded")
|
|
|
|
# Load point cloud classifier if specified
|
|
pointcloud_classifier = None
|
|
if use_pointcloud_classifier:
|
|
if pointcloud_classifier_path is None:
|
|
# Try default path
|
|
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
|
|
if default_path.exists():
|
|
pointcloud_classifier_path = str(default_path)
|
|
print(f"Using default classifier path: {pointcloud_classifier_path}")
|
|
else:
|
|
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
|
|
print(f" Default path checked: {default_path}")
|
|
use_pointcloud_classifier = False
|
|
|
|
if use_pointcloud_classifier and pointcloud_classifier_path:
|
|
pointcloud_classifier = load_pointcloud_classifier(
|
|
pointcloud_classifier_path,
|
|
num_classes=2,
|
|
use_normals=False,
|
|
device=sam_device
|
|
)
|
|
if pointcloud_classifier is None:
|
|
print("Warning: Failed to load point cloud classifier. Continuing without quality filtering.")
|
|
use_pointcloud_classifier = False
|
|
else:
|
|
print(f" Point cloud classifier confidence threshold: {pointcloud_classifier_threshold}")
|
|
|
|
print("\n" + "="*60)
|
|
print("Models loaded. Processing SVO2 files...")
|
|
print("="*60)
|
|
|
|
output_base = Path(output_base).expanduser().resolve()
|
|
output_base.mkdir(parents=True, exist_ok=True)
|
|
|
|
success_count = 0
|
|
failed_count = 0
|
|
skipped_count = 0
|
|
|
|
for idx, svo_file in enumerate(svo_files):
|
|
print(f"\n{'='*60}")
|
|
print(f"[{idx + 1}/{len(svo_files)}] Processing: {svo_file.name}")
|
|
print(f"{'='*60}")
|
|
|
|
# Check if output folder already exists
|
|
svo_name = svo_file.stem
|
|
output_folder = output_base / svo_name
|
|
if output_folder.exists() and output_folder.is_dir():
|
|
print(f" ⏭ Skipping {svo_file.name} - output folder already exists: {output_folder}")
|
|
skipped_count += 1
|
|
continue
|
|
|
|
try:
|
|
# Process using pre-loaded models
|
|
success = process_single_svo2(
|
|
svo_path=svo_file,
|
|
output_base=output_base,
|
|
yolo_model=yolo_model,
|
|
sam_predictor=sam_predictor,
|
|
sam_device=sam_device_obj,
|
|
conf=conf,
|
|
imgsz=imgsz,
|
|
max_frames=max_frames,
|
|
save_images=save_images,
|
|
filter_pointcloud=filter_pointcloud,
|
|
use_clustering_filter=use_clustering_filter,
|
|
use_density_filter=use_density_filter,
|
|
pointcloud_classifier=pointcloud_classifier,
|
|
use_pointcloud_classifier=use_pointcloud_classifier,
|
|
pointcloud_classifier_threshold=pointcloud_classifier_threshold,
|
|
use_flatness_filter=use_flatness_filter,
|
|
flatness_threshold=flatness_threshold
|
|
)
|
|
|
|
if success:
|
|
success_count += 1
|
|
print(f"\n✓ Successfully processed: {svo_file.name}")
|
|
else:
|
|
failed_count += 1
|
|
print(f"\n✗ Failed to process: {svo_file.name}")
|
|
except Exception as e:
|
|
failed_count += 1
|
|
print(f"\n✗ Error processing {svo_file.name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print("\n" + "="*60)
|
|
print("Batch Processing Summary:")
|
|
print("="*60)
|
|
print(f"Total files: {len(svo_files)}")
|
|
print(f"Successfully processed: {success_count}")
|
|
print(f"Skipped (already exists): {skipped_count}")
|
|
print(f"Failed: {failed_count}")
|
|
print(f"Output folder: {output_base.resolve()}")
|
|
print("="*60)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Batch process multiple SVO2 files")
|
|
parser.add_argument("--svo-folder", required=True,
|
|
help="Folder containing SVO2 files to process")
|
|
parser.add_argument("--output", default="output_preview",
|
|
help="Base output folder (default: output_preview)")
|
|
parser.add_argument("--yolo-model",
|
|
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
|
|
help="YOLO model path")
|
|
parser.add_argument("--conf", type=float, default=0.7, help="Confidence threshold")
|
|
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
|
parser.add_argument("--sam-device", type=str, default="cuda",
|
|
help="Device for SAM (cuda or cpu)")
|
|
parser.add_argument("--max-frames", type=int, default=0,
|
|
help="Maximum frames per SVO2 file (0 = all frames)")
|
|
parser.add_argument("--save-images", action="store_true",
|
|
help="Save individual image files instead of creating videos")
|
|
parser.add_argument("--filter-pointcloud", action="store_true",
|
|
help="Apply filtering to remove outliers from point clouds (default: no filtering)")
|
|
parser.add_argument("--use-clustering-filter", action="store_true",
|
|
help="Use clustering filter to keep only the largest cluster (requires --filter-pointcloud)")
|
|
parser.add_argument("--use-density-filter", action="store_true",
|
|
help="Use density filter: keep only points with at least 200 neighbors within 100mm radius (requires --filter-pointcloud)")
|
|
parser.add_argument("--pointcloud-classifier", type=str, default=None,
|
|
help="Path to point cloud quality classifier checkpoint")
|
|
parser.add_argument("--use-pointcloud-classifier", action="store_true",
|
|
help="Use point cloud classifier to filter out bad quality point clouds (requires --pointcloud-classifier)")
|
|
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7,
|
|
help="Confidence threshold for point cloud classifier (default: 0.7). Only 'good' point clouds with confidence >= threshold will be saved.")
|
|
parser.add_argument("--use-flatness-filter", action="store_true",
|
|
help="Evaluate point cloud flatness before saving. Skip point clouds that are not flat enough.")
|
|
parser.add_argument("--flatness-threshold", type=float, default=50.0,
|
|
help="Minimum flatness score (0-100%%) required to save point cloud (default: 50.0%%). Higher values mean stricter flatness requirement.")
|
|
|
|
args = parser.parse_args()
|
|
|
|
batch_process_svo2_folder(
|
|
svo_folder=args.svo_folder,
|
|
output_base=args.output,
|
|
yolo_model_path=args.yolo_model,
|
|
conf=args.conf,
|
|
imgsz=args.imgsz,
|
|
sam_device=args.sam_device,
|
|
max_frames=args.max_frames,
|
|
save_images=args.save_images,
|
|
filter_pointcloud=args.filter_pointcloud,
|
|
use_clustering_filter=args.use_clustering_filter,
|
|
use_density_filter=args.use_density_filter,
|
|
pointcloud_classifier_path=args.pointcloud_classifier,
|
|
use_pointcloud_classifier=args.use_pointcloud_classifier,
|
|
pointcloud_classifier_threshold=args.pointcloud_classifier_threshold,
|
|
use_flatness_filter=args.use_flatness_filter,
|
|
flatness_threshold=args.flatness_threshold
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|