Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
202
FishMeasure/batch_process_svo2.py
Executable file
202
FishMeasure/batch_process_svo2.py
Executable file
@@ -0,0 +1,202 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Batch process multiple SVO2 files.
|
||||
Loops over all .svo2 files in a folder and processes each one.
|
||||
Loads YOLO and SAM once and reuses them for all files.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from ultralytics import YOLO
|
||||
from seg import init_models
|
||||
from fish_video_weight_evaluation import process_single_svo2, load_pointcloud_classifier
|
||||
|
||||
|
||||
def batch_process_svo2_folder(svo_folder, output_base, yolo_model_path, conf=0.25, imgsz=640,
|
||||
sam_device="cuda", max_frames=0, save_images=False, filter_pointcloud=False,
|
||||
use_clustering_filter=False, use_density_filter=False,
|
||||
pointcloud_classifier_path=None, use_pointcloud_classifier=False,
|
||||
pointcloud_classifier_threshold=0.7,
|
||||
use_flatness_filter=False, flatness_threshold=50.0):
|
||||
"""Process all SVO2 files in a folder. Loads YOLO and SAM once and reuses them."""
|
||||
svo_folder = Path(svo_folder).expanduser().resolve()
|
||||
if not svo_folder.is_dir():
|
||||
print(f"ERROR: Folder not found: {svo_folder}")
|
||||
return
|
||||
|
||||
# Find all SVO2 files
|
||||
svo_files = sorted(svo_folder.glob("*.svo2"))
|
||||
if not svo_files:
|
||||
print(f"ERROR: No .svo2 files found in {svo_folder}")
|
||||
return
|
||||
|
||||
print(f"Found {len(svo_files)} SVO2 file(s) to process")
|
||||
print("="*60)
|
||||
|
||||
# Load YOLO and SAM once for all files
|
||||
print(f"\nLoading YOLO model: {yolo_model_path}")
|
||||
yolo_model = YOLO(yolo_model_path)
|
||||
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
|
||||
print(f"✓ YOLO loaded. Classes: {class_names}")
|
||||
|
||||
print(f"\nLoading SAM (device: {sam_device})...")
|
||||
sam_predictor = init_models(device=sam_device, seg_model="sam")
|
||||
sam_device_obj = torch.device(sam_device)
|
||||
print(f"✓ SAM loaded")
|
||||
|
||||
# Load point cloud classifier if specified
|
||||
pointcloud_classifier = None
|
||||
if use_pointcloud_classifier:
|
||||
if pointcloud_classifier_path is None:
|
||||
# Try default path
|
||||
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
|
||||
if default_path.exists():
|
||||
pointcloud_classifier_path = str(default_path)
|
||||
print(f"Using default classifier path: {pointcloud_classifier_path}")
|
||||
else:
|
||||
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
|
||||
print(f" Default path checked: {default_path}")
|
||||
use_pointcloud_classifier = False
|
||||
|
||||
if use_pointcloud_classifier and pointcloud_classifier_path:
|
||||
pointcloud_classifier = load_pointcloud_classifier(
|
||||
pointcloud_classifier_path,
|
||||
num_classes=2,
|
||||
use_normals=False,
|
||||
device=sam_device
|
||||
)
|
||||
if pointcloud_classifier is None:
|
||||
print("Warning: Failed to load point cloud classifier. Continuing without quality filtering.")
|
||||
use_pointcloud_classifier = False
|
||||
else:
|
||||
print(f" Point cloud classifier confidence threshold: {pointcloud_classifier_threshold}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Models loaded. Processing SVO2 files...")
|
||||
print("="*60)
|
||||
|
||||
output_base = Path(output_base).expanduser().resolve()
|
||||
output_base.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0
|
||||
|
||||
for idx, svo_file in enumerate(svo_files):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"[{idx + 1}/{len(svo_files)}] Processing: {svo_file.name}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Check if output folder already exists
|
||||
svo_name = svo_file.stem
|
||||
output_folder = output_base / svo_name
|
||||
if output_folder.exists() and output_folder.is_dir():
|
||||
print(f" ⏭ Skipping {svo_file.name} - output folder already exists: {output_folder}")
|
||||
skipped_count += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# Process using pre-loaded models
|
||||
success = process_single_svo2(
|
||||
svo_path=svo_file,
|
||||
output_base=output_base,
|
||||
yolo_model=yolo_model,
|
||||
sam_predictor=sam_predictor,
|
||||
sam_device=sam_device_obj,
|
||||
conf=conf,
|
||||
imgsz=imgsz,
|
||||
max_frames=max_frames,
|
||||
save_images=save_images,
|
||||
filter_pointcloud=filter_pointcloud,
|
||||
use_clustering_filter=use_clustering_filter,
|
||||
use_density_filter=use_density_filter,
|
||||
pointcloud_classifier=pointcloud_classifier,
|
||||
use_pointcloud_classifier=use_pointcloud_classifier,
|
||||
pointcloud_classifier_threshold=pointcloud_classifier_threshold,
|
||||
use_flatness_filter=use_flatness_filter,
|
||||
flatness_threshold=flatness_threshold
|
||||
)
|
||||
|
||||
if success:
|
||||
success_count += 1
|
||||
print(f"\n✓ Successfully processed: {svo_file.name}")
|
||||
else:
|
||||
failed_count += 1
|
||||
print(f"\n✗ Failed to process: {svo_file.name}")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
print(f"\n✗ Error processing {svo_file.name}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("Batch Processing Summary:")
|
||||
print("="*60)
|
||||
print(f"Total files: {len(svo_files)}")
|
||||
print(f"Successfully processed: {success_count}")
|
||||
print(f"Skipped (already exists): {skipped_count}")
|
||||
print(f"Failed: {failed_count}")
|
||||
print(f"Output folder: {output_base.resolve()}")
|
||||
print("="*60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Batch process multiple SVO2 files")
|
||||
parser.add_argument("--svo-folder", required=True,
|
||||
help="Folder containing SVO2 files to process")
|
||||
parser.add_argument("--output", default="output_preview",
|
||||
help="Base output folder (default: output_preview)")
|
||||
parser.add_argument("--yolo-model",
|
||||
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
|
||||
help="YOLO model path")
|
||||
parser.add_argument("--conf", type=float, default=0.7, help="Confidence threshold")
|
||||
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
|
||||
parser.add_argument("--sam-device", type=str, default="cuda",
|
||||
help="Device for SAM (cuda or cpu)")
|
||||
parser.add_argument("--max-frames", type=int, default=0,
|
||||
help="Maximum frames per SVO2 file (0 = all frames)")
|
||||
parser.add_argument("--save-images", action="store_true",
|
||||
help="Save individual image files instead of creating videos")
|
||||
parser.add_argument("--filter-pointcloud", action="store_true",
|
||||
help="Apply filtering to remove outliers from point clouds (default: no filtering)")
|
||||
parser.add_argument("--use-clustering-filter", action="store_true",
|
||||
help="Use clustering filter to keep only the largest cluster (requires --filter-pointcloud)")
|
||||
parser.add_argument("--use-density-filter", action="store_true",
|
||||
help="Use density filter: keep only points with at least 200 neighbors within 100mm radius (requires --filter-pointcloud)")
|
||||
parser.add_argument("--pointcloud-classifier", type=str, default=None,
|
||||
help="Path to point cloud quality classifier checkpoint")
|
||||
parser.add_argument("--use-pointcloud-classifier", action="store_true",
|
||||
help="Use point cloud classifier to filter out bad quality point clouds (requires --pointcloud-classifier)")
|
||||
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7,
|
||||
help="Confidence threshold for point cloud classifier (default: 0.7). Only 'good' point clouds with confidence >= threshold will be saved.")
|
||||
parser.add_argument("--use-flatness-filter", action="store_true",
|
||||
help="Evaluate point cloud flatness before saving. Skip point clouds that are not flat enough.")
|
||||
parser.add_argument("--flatness-threshold", type=float, default=50.0,
|
||||
help="Minimum flatness score (0-100%%) required to save point cloud (default: 50.0%%). Higher values mean stricter flatness requirement.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
batch_process_svo2_folder(
|
||||
svo_folder=args.svo_folder,
|
||||
output_base=args.output,
|
||||
yolo_model_path=args.yolo_model,
|
||||
conf=args.conf,
|
||||
imgsz=args.imgsz,
|
||||
sam_device=args.sam_device,
|
||||
max_frames=args.max_frames,
|
||||
save_images=args.save_images,
|
||||
filter_pointcloud=args.filter_pointcloud,
|
||||
use_clustering_filter=args.use_clustering_filter,
|
||||
use_density_filter=args.use_density_filter,
|
||||
pointcloud_classifier_path=args.pointcloud_classifier,
|
||||
use_pointcloud_classifier=args.use_pointcloud_classifier,
|
||||
pointcloud_classifier_threshold=args.pointcloud_classifier_threshold,
|
||||
use_flatness_filter=args.use_flatness_filter,
|
||||
flatness_threshold=args.flatness_threshold
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user