Files
FishServer/FishMeasure/dump_fish_frames_from_svo2.py
2026-04-08 19:32:23 +08:00

316 lines
10 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Dump images from SVO2 files when fish are detected in consecutive frames.
Given a batch of input folders (each may contain multiple SVO2 files), uses the
trained YOLO model to detect fish. When fish are detected in consecutive frames,
samples every N frames (default 5) and dumps those images to the output folder
with the same directory structure.
Usage:
python dump_fish_frames_from_svo2.py \
--input-root /home/ubuntu/data/fish/2016-1-22-last \
--output /home/ubuntu/data/fish/2016-1-22-last_images \
--yolo-model /home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt \
--sample-every 5 \
--recursive
Input structure: input_root/fish1/svo1.svo2, input_root/fish1/svo2.svo2, ...
Output structure: output/fish1/svo1/frame_000001.png, frame_000006.png, ...
"""
from __future__ import annotations
import argparse
import sys
import cv2
from pathlib import Path
from typing import List, Tuple
# Ensure project root is on path
_SCRIPT_DIR = Path(__file__).resolve().parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from ultralytics import YOLO
try:
import pyzed.sl as sl
ZED_AVAILABLE = True
except ImportError:
ZED_AVAILABLE = False
from dataset.zed_reader import ZEDReader
def collect_svo2_files(
input_folders: List[Path],
recursive: bool = False,
) -> List[Tuple[Path, Path, Path]]:
"""
Collect (input_folder, parent_folder, svo2_path) for all SVO2 files.
Args:
input_folders: Root folders to search
recursive: If True, search recursively for *.svo2 in subfolders
Returns:
List of (input_folder, parent_folder, svo2_path) for output path construction.
"""
results: List[Tuple[Path, Path, Path]] = []
seen = set()
for folder in input_folders:
folder = folder.expanduser().resolve()
if not folder.exists() or not folder.is_dir():
continue
pattern = "**/*.svo2" if recursive else "*.svo2"
for svo_path in sorted(folder.glob(pattern)):
if not svo_path.is_file():
continue
parent = svo_path.parent
key = (str(folder), str(svo_path))
if key in seen:
continue
seen.add(key)
results.append((folder, parent, svo_path))
return results
def process_single_svo2(
svo_path: Path,
output_dir: Path,
yolo_model: YOLO,
conf: float = 0.25,
imgsz: int = 640,
sample_every: int = 5,
max_frames: int = 0,
draw_boxes: bool = True,
) -> int:
"""
Process one SVO2 file: detect fish, dump every sample_every-th frame when
fish are detected in consecutive frames.
Returns:
Number of images dumped.
"""
if not ZED_AVAILABLE:
print("ERROR: pyzed not available.")
return 0
svo_path = Path(svo_path).expanduser().resolve()
if not svo_path.exists():
print(f"ERROR: SVO2 not found: {svo_path}")
return 0
output_dir = Path(output_dir).expanduser().resolve()
output_dir.mkdir(parents=True, exist_ok=True)
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
if not zed_reader.open():
print(f"ERROR: Failed to open SVO2: {svo_path}")
return 0
runtime_params = sl.RuntimeParameters()
left_image_mat = sl.Mat()
consecutive_with_fish = 0
dumped_count = 0
idx = 0
try:
while True:
if max_frames > 0 and idx >= max_frames:
break
err = zed_reader.zed.grab(runtime_params)
if err != sl.ERROR_CODE.SUCCESS:
break
zed_reader.zed.retrieve_image(left_image_mat, sl.VIEW.LEFT)
left_np = left_image_mat.get_data()
if left_np.shape[2] > 3:
img = left_np[:, :, :3].copy()
else:
img = left_np.copy()
# YOLO detection (no tracking needed for simple dump)
results = yolo_model(img, conf=conf, imgsz=imgsz, verbose=False)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
if num_dets > 0:
consecutive_with_fish += 1
# Dump every sample_every-th frame within the consecutive run
if (consecutive_with_fish - 1) % sample_every == 0:
frame_name = f"frame_{idx + 1:06d}"
out_path = output_dir / f"{frame_name}.png"
if draw_boxes and results.boxes is not None:
boxes = results.boxes.xyxy.cpu().numpy()
for box in boxes:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(
img, f"fish: {num_dets}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2
)
cv2.imwrite(str(out_path), img)
dumped_count += 1
if idx % 100 == 0 or dumped_count <= 3:
print(f" Dumped: {out_path.name} (frame {idx + 1}, {num_dets} fish)")
else:
consecutive_with_fish = 0
idx += 1
return dumped_count
finally:
zed_reader.close()
def main() -> None:
parser = argparse.ArgumentParser(
description="Dump images from SVO2 when fish detected (every N frames in consecutive runs)"
)
parser.add_argument(
"--input-root",
type=str,
default="",
help="Root folder containing subfolders with SVO2 files (e.g., fish1/, fish2/)",
)
parser.add_argument(
"--input-folders",
type=str,
nargs="+",
default=[],
help="Explicit list of folders, each containing SVO2 files",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output base directory (preserves input folder structure)",
)
parser.add_argument(
"--yolo-model",
type=str,
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
help="Path to YOLO model",
)
parser.add_argument(
"--conf",
type=float,
default=0.8,
help="YOLO confidence threshold (default: 0.25)",
)
parser.add_argument(
"--imgsz",
type=int,
default=640,
help="YOLO image size (default: 640)",
)
parser.add_argument(
"--sample-every",
type=int,
default=5,
help="Within consecutive frames with fish, dump every N-th frame (default: 5)",
)
parser.add_argument(
"--max-frames",
type=int,
default=0,
help="Max frames per SVO2 (0 = all)",
)
parser.add_argument(
"--no-draw-boxes",
action="store_true",
help="Save raw images without drawing detection boxes",
)
parser.add_argument(
"--recursive",
action="store_true",
help="Recursively search for SVO2 files in subfolders",
)
args = parser.parse_args()
if not ZED_AVAILABLE:
print("ERROR: pyzed not available. Install ZED SDK and pyzed.")
return
# Collect input folders
input_folders: List[Path] = []
if args.input_root:
root = Path(args.input_root).expanduser().resolve()
if not root.exists():
print(f"ERROR: Input root not found: {root}")
return
# Subfolders directly under root (e.g., fish1, fish2)
input_folders = [p for p in root.iterdir() if p.is_dir()]
if args.input_folders:
input_folders.extend(Path(p).expanduser().resolve() for p in args.input_folders)
if not input_folders:
print("ERROR: No input folders. Use --input-root or --input-folders.")
return
# Collect all SVO2 files
svo_items = collect_svo2_files(input_folders, recursive=args.recursive)
if not svo_items:
print("ERROR: No SVO2 files found in input folders.")
return
print(f"Found {len(svo_items)} SVO2 file(s) in {len(input_folders)} folder(s)")
print(f"Output base: {args.output}")
print(f"Sample every: {args.sample_every} frames (within consecutive fish detections)")
print("=" * 60)
# Load YOLO
yolo_path = Path(args.yolo_model).expanduser().resolve()
if not yolo_path.exists():
print(f"ERROR: YOLO model not found: {yolo_path}")
return
yolo_model = YOLO(str(yolo_path))
print(f"✓ YOLO loaded: {yolo_path.name}")
output_base = Path(args.output).expanduser().resolve()
output_base.mkdir(parents=True, exist_ok=True)
total_dumped = 0
for i, (input_folder, parent_folder, svo_path) in enumerate(svo_items):
# Output: output_base/{input_folder_name}/{rel_path}/svo_stem/
# e.g. output_base/fish1/HD1080_xxx/ or output_base/fish1/session1/HD1080_xxx/
try:
rel = parent_folder.relative_to(input_folder)
rel_parts = (input_folder.name,) + rel.parts if rel.parts else (input_folder.name,)
except ValueError:
rel_parts = (parent_folder.name,)
svo_stem = svo_path.stem
out_dir = output_base / Path(*rel_parts) / svo_stem
print(f"\n[{i + 1}/{len(svo_items)}] {Path(*rel_parts) / svo_path.name}")
count = process_single_svo2(
svo_path=svo_path,
output_dir=out_dir,
yolo_model=yolo_model,
conf=args.conf,
imgsz=args.imgsz,
sample_every=args.sample_every,
max_frames=args.max_frames,
draw_boxes=not args.no_draw_boxes,
)
total_dumped += count
print(f" Dumped {count} images to {out_dir}")
print("\n" + "=" * 60)
print(f"Done. Total images dumped: {total_dumped}")
print(f"Output: {output_base.resolve()}")
if __name__ == "__main__":
main()