Files
FishServer/FishMeasure/batch_process_svo2.py
2026-04-08 19:32:23 +08:00

203 lines
9.4 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Batch process multiple SVO2 files.
Loops over all .svo2 files in a folder and processes each one.
Loads YOLO and SAM once and reuses them for all files.
"""
import argparse
import torch
from pathlib import Path
from ultralytics import YOLO
from seg import init_models
from fish_video_weight_evaluation import process_single_svo2, load_pointcloud_classifier
def batch_process_svo2_folder(svo_folder, output_base, yolo_model_path, conf=0.25, imgsz=640,
sam_device="cuda", max_frames=0, save_images=False, filter_pointcloud=False,
use_clustering_filter=False, use_density_filter=False,
pointcloud_classifier_path=None, use_pointcloud_classifier=False,
pointcloud_classifier_threshold=0.7,
use_flatness_filter=False, flatness_threshold=50.0):
"""Process all SVO2 files in a folder. Loads YOLO and SAM once and reuses them."""
svo_folder = Path(svo_folder).expanduser().resolve()
if not svo_folder.is_dir():
print(f"ERROR: Folder not found: {svo_folder}")
return
# Find all SVO2 files
svo_files = sorted(svo_folder.glob("*.svo2"))
if not svo_files:
print(f"ERROR: No .svo2 files found in {svo_folder}")
return
print(f"Found {len(svo_files)} SVO2 file(s) to process")
print("="*60)
# Load YOLO and SAM once for all files
print(f"\nLoading YOLO model: {yolo_model_path}")
yolo_model = YOLO(yolo_model_path)
class_names = yolo_model.names if hasattr(yolo_model, 'names') else {}
print(f"✓ YOLO loaded. Classes: {class_names}")
print(f"\nLoading SAM (device: {sam_device})...")
sam_predictor = init_models(device=sam_device, seg_model="sam")
sam_device_obj = torch.device(sam_device)
print(f"✓ SAM loaded")
# Load point cloud classifier if specified
pointcloud_classifier = None
if use_pointcloud_classifier:
if pointcloud_classifier_path is None:
# Try default path
default_path = Path(__file__).parent / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth"
if default_path.exists():
pointcloud_classifier_path = str(default_path)
print(f"Using default classifier path: {pointcloud_classifier_path}")
else:
print("Warning: --use-pointcloud-classifier specified but --pointcloud-classifier not provided and default not found")
print(f" Default path checked: {default_path}")
use_pointcloud_classifier = False
if use_pointcloud_classifier and pointcloud_classifier_path:
pointcloud_classifier = load_pointcloud_classifier(
pointcloud_classifier_path,
num_classes=2,
use_normals=False,
device=sam_device
)
if pointcloud_classifier is None:
print("Warning: Failed to load point cloud classifier. Continuing without quality filtering.")
use_pointcloud_classifier = False
else:
print(f" Point cloud classifier confidence threshold: {pointcloud_classifier_threshold}")
print("\n" + "="*60)
print("Models loaded. Processing SVO2 files...")
print("="*60)
output_base = Path(output_base).expanduser().resolve()
output_base.mkdir(parents=True, exist_ok=True)
success_count = 0
failed_count = 0
skipped_count = 0
for idx, svo_file in enumerate(svo_files):
print(f"\n{'='*60}")
print(f"[{idx + 1}/{len(svo_files)}] Processing: {svo_file.name}")
print(f"{'='*60}")
# Check if output folder already exists
svo_name = svo_file.stem
output_folder = output_base / svo_name
if output_folder.exists() and output_folder.is_dir():
print(f" ⏭ Skipping {svo_file.name} - output folder already exists: {output_folder}")
skipped_count += 1
continue
try:
# Process using pre-loaded models
success = process_single_svo2(
svo_path=svo_file,
output_base=output_base,
yolo_model=yolo_model,
sam_predictor=sam_predictor,
sam_device=sam_device_obj,
conf=conf,
imgsz=imgsz,
max_frames=max_frames,
save_images=save_images,
filter_pointcloud=filter_pointcloud,
use_clustering_filter=use_clustering_filter,
use_density_filter=use_density_filter,
pointcloud_classifier=pointcloud_classifier,
use_pointcloud_classifier=use_pointcloud_classifier,
pointcloud_classifier_threshold=pointcloud_classifier_threshold,
use_flatness_filter=use_flatness_filter,
flatness_threshold=flatness_threshold
)
if success:
success_count += 1
print(f"\n✓ Successfully processed: {svo_file.name}")
else:
failed_count += 1
print(f"\n✗ Failed to process: {svo_file.name}")
except Exception as e:
failed_count += 1
print(f"\n✗ Error processing {svo_file.name}: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("Batch Processing Summary:")
print("="*60)
print(f"Total files: {len(svo_files)}")
print(f"Successfully processed: {success_count}")
print(f"Skipped (already exists): {skipped_count}")
print(f"Failed: {failed_count}")
print(f"Output folder: {output_base.resolve()}")
print("="*60)
def main():
parser = argparse.ArgumentParser(description="Batch process multiple SVO2 files")
parser.add_argument("--svo-folder", required=True,
help="Folder containing SVO2 files to process")
parser.add_argument("--output", default="output_preview",
help="Base output folder (default: output_preview)")
parser.add_argument("--yolo-model",
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
help="YOLO model path")
parser.add_argument("--conf", type=float, default=0.7, help="Confidence threshold")
parser.add_argument("--imgsz", type=int, default=640, help="Image size")
parser.add_argument("--sam-device", type=str, default="cuda",
help="Device for SAM (cuda or cpu)")
parser.add_argument("--max-frames", type=int, default=0,
help="Maximum frames per SVO2 file (0 = all frames)")
parser.add_argument("--save-images", action="store_true",
help="Save individual image files instead of creating videos")
parser.add_argument("--filter-pointcloud", action="store_true",
help="Apply filtering to remove outliers from point clouds (default: no filtering)")
parser.add_argument("--use-clustering-filter", action="store_true",
help="Use clustering filter to keep only the largest cluster (requires --filter-pointcloud)")
parser.add_argument("--use-density-filter", action="store_true",
help="Use density filter: keep only points with at least 200 neighbors within 100mm radius (requires --filter-pointcloud)")
parser.add_argument("--pointcloud-classifier", type=str, default=None,
help="Path to point cloud quality classifier checkpoint")
parser.add_argument("--use-pointcloud-classifier", action="store_true",
help="Use point cloud classifier to filter out bad quality point clouds (requires --pointcloud-classifier)")
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7,
help="Confidence threshold for point cloud classifier (default: 0.7). Only 'good' point clouds with confidence >= threshold will be saved.")
parser.add_argument("--use-flatness-filter", action="store_true",
help="Evaluate point cloud flatness before saving. Skip point clouds that are not flat enough.")
parser.add_argument("--flatness-threshold", type=float, default=50.0,
help="Minimum flatness score (0-100%%) required to save point cloud (default: 50.0%%). Higher values mean stricter flatness requirement.")
args = parser.parse_args()
batch_process_svo2_folder(
svo_folder=args.svo_folder,
output_base=args.output,
yolo_model_path=args.yolo_model,
conf=args.conf,
imgsz=args.imgsz,
sam_device=args.sam_device,
max_frames=args.max_frames,
save_images=args.save_images,
filter_pointcloud=args.filter_pointcloud,
use_clustering_filter=args.use_clustering_filter,
use_density_filter=args.use_density_filter,
pointcloud_classifier_path=args.pointcloud_classifier,
use_pointcloud_classifier=args.use_pointcloud_classifier,
pointcloud_classifier_threshold=args.pointcloud_classifier_threshold,
use_flatness_filter=args.use_flatness_filter,
flatness_threshold=args.flatness_threshold
)
if __name__ == "__main__":
main()