428 lines
16 KiB
Python
Executable File
428 lines
16 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Auto pre-label fish "scared" events using optical flow.
|
|
|
|
Goal
|
|
----
|
|
Scan videos and find time segments where motion is high, which can be used as a
|
|
proxy for "scared" events. By default, this script outputs a CSV with one segment per line:
|
|
|
|
video,start_t,end_t,label,score
|
|
|
|
Where:
|
|
- start_t / end_t are in seconds
|
|
- label defaults to "scared" for detected segments
|
|
- score is the aggregated optical-flow magnitude score for the segment
|
|
|
|
If you prefer a "packed" format (one CSV row per video, with repeated triplets),
|
|
use: --output_format packed_per_video
|
|
|
|
Method (high level)
|
|
-------------------
|
|
1) Sample frames (optionally with stride) from each video.
|
|
2) Compute dense optical flow (Farneback by default) between consecutive sampled frames.
|
|
3) For each window of W frames (default 30), compute a window motion score.
|
|
4) Mark windows above a threshold as positive.
|
|
5) Merge overlapping/adjacent positive windows into longer segments.
|
|
|
|
Notes
|
|
-----
|
|
- Default optical flow method is Farneback (works with opencv-python).
|
|
- Optional TV-L1 requires opencv-contrib-python.
|
|
|
|
Example
|
|
-------
|
|
python dataset/auto_label_optical_flow.py \
|
|
--input_dir ~/data/fish/fish_action_videos \
|
|
--output_csv ./dataset/optical_flow_scared_segments.csv \
|
|
--window_frames 30 \
|
|
--threshold_percentile 95 \
|
|
--label scared
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import math
|
|
import os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
import cv2 # type: ignore
|
|
except Exception as e: # pragma: no cover
|
|
raise RuntimeError(
|
|
"Failed to import cv2. Please install OpenCV first:\n"
|
|
" pip install opencv-python\n"
|
|
"Optional (for TV-L1):\n"
|
|
" pip install opencv-contrib-python\n"
|
|
) from e
|
|
|
|
|
|
VIDEO_EXTS_DEFAULT = (".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".m4v")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Segment:
|
|
start_s: float
|
|
end_s: float
|
|
score: float
|
|
|
|
|
|
def iter_videos(root: Path, recursive: bool, exts: Sequence[str]) -> List[Path]:
|
|
exts_norm = {e.lower() for e in exts}
|
|
it: Iterable[Path] = root.rglob("*") if recursive else root.glob("*")
|
|
out: List[Path] = []
|
|
for p in it:
|
|
if not p.is_file():
|
|
continue
|
|
if p.suffix.lower() in exts_norm:
|
|
out.append(p)
|
|
return sorted(out)
|
|
|
|
|
|
def _resize_gray(frame_bgr: np.ndarray, short_side: int) -> np.ndarray:
|
|
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
|
|
if short_side <= 0:
|
|
return gray
|
|
h, w = gray.shape[:2]
|
|
if h == 0 or w == 0:
|
|
return gray
|
|
if min(h, w) == short_side:
|
|
return gray
|
|
scale = short_side / float(min(h, w))
|
|
nh = max(1, int(round(h * scale)))
|
|
nw = max(1, int(round(w * scale)))
|
|
return cv2.resize(gray, (nw, nh), interpolation=cv2.INTER_AREA)
|
|
|
|
|
|
def _make_flow_fn(method: str):
|
|
method = method.lower().strip()
|
|
if method == "farneback":
|
|
def _flow(prev: np.ndarray, cur: np.ndarray) -> np.ndarray:
|
|
return cv2.calcOpticalFlowFarneback(
|
|
prev, cur, None,
|
|
0.5, 3, 15, 3, 5, 1.2, 0
|
|
)
|
|
|
|
return _flow
|
|
|
|
if method in {"tvl1", "tv-l1", "dual_tvl1"}:
|
|
# Requires opencv-contrib-python.
|
|
if not hasattr(cv2, "optflow") or not hasattr(cv2.optflow, "DualTVL1OpticalFlow_create"):
|
|
raise RuntimeError(
|
|
"TV-L1 optical flow requested but OpenCV contrib is missing. Install:\n"
|
|
" pip install opencv-contrib-python"
|
|
)
|
|
tvl1 = cv2.optflow.DualTVL1OpticalFlow_create()
|
|
|
|
def _flow(prev: np.ndarray, cur: np.ndarray) -> np.ndarray:
|
|
return tvl1.calc(prev, cur, None)
|
|
|
|
return _flow
|
|
|
|
raise ValueError(f"Unknown optical flow method: {method}. Use 'farneback' or 'tvl1'.")
|
|
|
|
|
|
def compute_window_scores(
|
|
video_path: Path,
|
|
*,
|
|
window_frames: int,
|
|
frame_stride: int,
|
|
resize_short_side: int,
|
|
flow_method: str,
|
|
per_flow_stat: str,
|
|
max_frames: Optional[int],
|
|
) -> Tuple[float, List[int], List[float], List[float]]:
|
|
"""
|
|
Returns:
|
|
fps,
|
|
sampled_frame_indices (actual frame indices in the original video, 0-based),
|
|
per_flow_scores (len = len(sampled_frames)-1),
|
|
window_scores (len = len(sampled_frames)-window_frames+1, may be empty)
|
|
"""
|
|
cap = cv2.VideoCapture(str(video_path))
|
|
if not cap.isOpened():
|
|
raise RuntimeError(f"Could not open video: {video_path}")
|
|
|
|
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
|
|
if not (fps > 0 and math.isfinite(fps)):
|
|
# Fallback to 30 if metadata missing, but keep warning-ish behavior in caller.
|
|
fps = 30.0
|
|
|
|
flow_fn = _make_flow_fn(flow_method)
|
|
per_flow_stat = per_flow_stat.lower().strip()
|
|
if per_flow_stat not in {"mean", "median", "p90", "p95"}:
|
|
raise ValueError("per_flow_stat must be one of: mean, median, p90, p95")
|
|
|
|
sampled_idx: List[int] = []
|
|
flow_scores: List[float] = []
|
|
|
|
actual_idx = -1
|
|
prev_gray: Optional[np.ndarray] = None
|
|
prev_sampled_actual_idx: Optional[int] = None
|
|
|
|
sampled_count = 0
|
|
while True:
|
|
# Advance by frame_stride frames; keep the last frame read as the sampled frame.
|
|
frame = None
|
|
for _ in range(frame_stride):
|
|
ok, f = cap.read()
|
|
if not ok:
|
|
frame = None
|
|
break
|
|
frame = f
|
|
actual_idx += 1
|
|
if frame is None:
|
|
break
|
|
|
|
gray = _resize_gray(frame, resize_short_side)
|
|
sampled_idx.append(actual_idx)
|
|
sampled_count += 1
|
|
|
|
if prev_gray is not None:
|
|
flow = flow_fn(prev_gray, gray)
|
|
mag = cv2.magnitude(flow[..., 0], flow[..., 1])
|
|
if per_flow_stat == "mean":
|
|
score = float(np.mean(mag))
|
|
elif per_flow_stat == "median":
|
|
score = float(np.median(mag))
|
|
elif per_flow_stat == "p90":
|
|
score = float(np.percentile(mag, 90))
|
|
else:
|
|
score = float(np.percentile(mag, 95))
|
|
flow_scores.append(score)
|
|
|
|
prev_gray = gray
|
|
prev_sampled_actual_idx = actual_idx
|
|
|
|
if max_frames is not None and sampled_count >= max_frames:
|
|
break
|
|
|
|
cap.release()
|
|
|
|
if window_frames < 2:
|
|
raise ValueError("window_frames must be >= 2")
|
|
if frame_stride < 1:
|
|
raise ValueError("frame_stride must be >= 1")
|
|
|
|
# Map W frames => (W-1) flow scores inside the window.
|
|
W = window_frames
|
|
window_scores: List[float] = []
|
|
if len(sampled_idx) >= W:
|
|
for i in range(0, len(sampled_idx) - W + 1):
|
|
# flows correspond to transitions between sampled frames:
|
|
# window frames are [i .. i+W-1], so use flow_scores indices [i .. i+W-2]
|
|
# BUT flow_scores is aligned to transitions after the first sampled frame:
|
|
# flow_scores[j-1] is between sampled_idx[j-1] and sampled_idx[j]
|
|
# For window starting at i, we need transitions for j in [i+1 .. i+W-1]
|
|
start = i
|
|
end = i + W - 1
|
|
# flow_scores slice: (i+1)-1 .. (i+W-1)-1 => i .. i+W-2
|
|
s = flow_scores[start : end]
|
|
window_scores.append(float(np.mean(s)) if s else 0.0)
|
|
|
|
return fps, sampled_idx, flow_scores, window_scores
|
|
|
|
|
|
def windows_to_segments(
|
|
sampled_idx: List[int],
|
|
window_scores: List[float],
|
|
*,
|
|
fps: float,
|
|
window_frames: int,
|
|
threshold: float,
|
|
min_duration_s: float,
|
|
merge_gap_s: float,
|
|
) -> List[Segment]:
|
|
if not window_scores:
|
|
return []
|
|
W = window_frames
|
|
positives: List[Tuple[float, float, float]] = []
|
|
for i, s in enumerate(window_scores):
|
|
if s < threshold:
|
|
continue
|
|
start_frame = sampled_idx[i]
|
|
end_frame = sampled_idx[i + W - 1]
|
|
start_s = start_frame / fps
|
|
end_s = end_frame / fps
|
|
positives.append((start_s, end_s, float(s)))
|
|
|
|
if not positives:
|
|
return []
|
|
|
|
# Merge overlapping/adjacent windows into longer segments. Score = max(window_score) in merged region.
|
|
positives.sort(key=lambda x: x[0])
|
|
merged: List[Segment] = []
|
|
cur_s, cur_e, cur_score = positives[0]
|
|
for s, e, sc in positives[1:]:
|
|
if s <= cur_e + merge_gap_s:
|
|
cur_e = max(cur_e, e)
|
|
cur_score = max(cur_score, sc)
|
|
else:
|
|
if (cur_e - cur_s) >= min_duration_s:
|
|
merged.append(Segment(cur_s, cur_e, cur_score))
|
|
cur_s, cur_e, cur_score = s, e, sc
|
|
if (cur_e - cur_s) >= min_duration_s:
|
|
merged.append(Segment(cur_s, cur_e, cur_score))
|
|
return merged
|
|
|
|
|
|
def choose_threshold(window_scores: List[float], threshold: Optional[float], threshold_percentile: float) -> float:
|
|
if threshold is not None:
|
|
return float(threshold)
|
|
if not window_scores:
|
|
return float("inf")
|
|
p = float(threshold_percentile)
|
|
if not (0.0 < p < 100.0):
|
|
raise ValueError("threshold_percentile must be in (0, 100)")
|
|
return float(np.percentile(window_scores, p))
|
|
|
|
|
|
def video_id_for_output(video_path: Path, input_root: Path, mode: str) -> str:
|
|
mode = mode.lower().strip()
|
|
if mode == "relative":
|
|
try:
|
|
return str(video_path.relative_to(input_root))
|
|
except Exception:
|
|
return str(video_path)
|
|
if mode == "absolute":
|
|
return str(video_path)
|
|
if mode == "basename":
|
|
return video_path.name
|
|
raise ValueError("path_mode must be one of: relative, absolute, basename")
|
|
|
|
|
|
def main() -> None:
|
|
ap = argparse.ArgumentParser(description="Auto pre-label high-motion segments using optical flow.")
|
|
ap.add_argument("--input_dir", type=str, required=True, help="Directory containing videos (can be nested).")
|
|
ap.add_argument("--output_csv", type=str, required=True, help="Output CSV path.")
|
|
ap.add_argument("--recursive", action="store_true", help="Recursively scan input_dir.")
|
|
ap.add_argument("--exts", type=str, default=",".join(VIDEO_EXTS_DEFAULT), help="Comma-separated video extensions.")
|
|
ap.add_argument("--label", type=str, default="scared", help="Label name for detected segments.")
|
|
ap.add_argument("--window_frames", type=int, default=30, help="Sliding window size in frames (default: 30).")
|
|
ap.add_argument("--frame_stride", type=int, default=1, help="Sample every N frames (default: 1).")
|
|
ap.add_argument("--resize_short_side", type=int, default=224, help="Resize grayscale frames so short side == N (0 disables).")
|
|
ap.add_argument("--flow_method", type=str, default="farneback", help="Optical flow method: farneback | tvl1")
|
|
ap.add_argument(
|
|
"--per_flow_stat",
|
|
type=str,
|
|
default="mean",
|
|
help="Per-flow magnitude aggregation: mean | median | p90 | p95",
|
|
)
|
|
ap.add_argument(
|
|
"--threshold",
|
|
type=float,
|
|
default=None,
|
|
help="Absolute threshold on window score. If omitted, uses --threshold_percentile.",
|
|
)
|
|
ap.add_argument(
|
|
"--threshold_percentile",
|
|
type=float,
|
|
default=95.0,
|
|
help="Adaptive threshold percentile computed per-video from window scores (default: 95).",
|
|
)
|
|
ap.add_argument("--min_duration_s", type=float, default=0.0, help="Drop segments shorter than this duration (seconds).")
|
|
ap.add_argument("--merge_gap_s", type=float, default=0.25, help="Merge segments whose gap <= this (seconds).")
|
|
ap.add_argument("--max_sampled_frames", type=int, default=None, help="For debugging: stop after N sampled frames per video.")
|
|
ap.add_argument(
|
|
"--path_mode",
|
|
type=str,
|
|
default="relative",
|
|
help="How to write the video field: relative | absolute | basename (default: relative)",
|
|
)
|
|
ap.add_argument("--no_header", action="store_true", help="Do not write CSV header.")
|
|
ap.add_argument(
|
|
"--output_format",
|
|
type=str,
|
|
default="segments",
|
|
help="CSV format: segments (one segment per line, includes score) | packed_per_video (one line per video)",
|
|
)
|
|
ap.add_argument("--quiet", action="store_true", help="Reduce logging.")
|
|
args = ap.parse_args()
|
|
|
|
input_root = Path(args.input_dir).expanduser().resolve()
|
|
output_csv = Path(args.output_csv).expanduser().resolve()
|
|
output_csv.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
exts = [e.strip().lower() for e in args.exts.split(",") if e.strip()]
|
|
exts = [e if e.startswith(".") else f".{e}" for e in exts]
|
|
videos = iter_videos(input_root, args.recursive, exts)
|
|
if not videos:
|
|
raise SystemExit(f"No videos found under: {input_root}")
|
|
|
|
output_format = args.output_format.lower().strip()
|
|
if output_format not in {"segments", "packed_per_video"}:
|
|
raise SystemExit("Invalid --output_format. Use: segments | packed_per_video")
|
|
|
|
total_segments = 0
|
|
with open(output_csv, "w", newline="") as f:
|
|
w = csv.writer(f)
|
|
if not args.no_header and output_format == "segments":
|
|
w.writerow(["video", "start_t", "end_t", "label", "score"])
|
|
|
|
for vp in videos:
|
|
try:
|
|
fps, sampled_idx, _flow_scores, window_scores = compute_window_scores(
|
|
vp,
|
|
window_frames=args.window_frames,
|
|
frame_stride=args.frame_stride,
|
|
resize_short_side=args.resize_short_side,
|
|
flow_method=args.flow_method,
|
|
per_flow_stat=args.per_flow_stat,
|
|
max_frames=args.max_sampled_frames,
|
|
)
|
|
thr = choose_threshold(window_scores, args.threshold, args.threshold_percentile)
|
|
segments = windows_to_segments(
|
|
sampled_idx,
|
|
window_scores,
|
|
fps=fps,
|
|
window_frames=args.window_frames,
|
|
threshold=thr,
|
|
min_duration_s=args.min_duration_s,
|
|
merge_gap_s=args.merge_gap_s,
|
|
)
|
|
|
|
vid = video_id_for_output(vp, input_root, args.path_mode)
|
|
if output_format == "segments":
|
|
for seg in segments:
|
|
w.writerow([vid, f"{seg.start_s:.3f}", f"{seg.end_s:.3f}", args.label, f"{seg.score:.6f}"])
|
|
else:
|
|
# One row per video, repeating: video,start_t,end_t,label
|
|
row: List[str] = []
|
|
for seg in segments:
|
|
row.extend([vid, f"{seg.start_s:.3f}", f"{seg.end_s:.3f}", args.label])
|
|
# Still write an empty row for videos with no segments? keep consistent:
|
|
# - if no segments, write nothing (default)
|
|
if row:
|
|
w.writerow(row)
|
|
total_segments += len(segments)
|
|
|
|
if not args.quiet:
|
|
if window_scores:
|
|
ws = np.asarray(window_scores, dtype=np.float32)
|
|
print(
|
|
f"[{vid}] fps={fps:.3f} sampled={len(sampled_idx)} "
|
|
f"windows={len(window_scores)} thr={thr:.6f} "
|
|
f"segments={len(segments)} "
|
|
f"(win_score p50={float(np.percentile(ws,50)):.6f}, p95={float(np.percentile(ws,95)):.6f})"
|
|
)
|
|
else:
|
|
print(f"[{vid}] too short for window_frames={args.window_frames}; skipped.")
|
|
except Exception as e:
|
|
if not args.quiet:
|
|
print(f"[WARN] failed on {vp}: {e}")
|
|
|
|
if not args.quiet:
|
|
print(f"\nWrote {total_segments} segments to: {output_csv}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|