Files
FishServer/FishAction/dataset/auto_label_optical_flow.py
2026-05-06 15:59:38 +08:00

428 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Auto pre-label fish "scared" events using optical flow.
Goal
----
Scan videos and find time segments where motion is high, which can be used as a
proxy for "scared" events. By default, this script outputs a CSV with one segment per line:
video,start_t,end_t,label,score
Where:
- start_t / end_t are in seconds
- label defaults to "scared" for detected segments
- score is the aggregated optical-flow magnitude score for the segment
If you prefer a "packed" format (one CSV row per video, with repeated triplets),
use: --output_format packed_per_video
Method (high level)
-------------------
1) Sample frames (optionally with stride) from each video.
2) Compute dense optical flow (Farneback by default) between consecutive sampled frames.
3) For each window of W frames (default 30), compute a window motion score.
4) Mark windows above a threshold as positive.
5) Merge overlapping/adjacent positive windows into longer segments.
Notes
-----
- Default optical flow method is Farneback (works with opencv-python).
- Optional TV-L1 requires opencv-contrib-python.
Example
-------
python dataset/auto_label_optical_flow.py \
--input_dir ~/data/fish/fish_action_videos \
--output_csv ./dataset/optical_flow_scared_segments.csv \
--window_frames 30 \
--threshold_percentile 95 \
--label scared
"""
from __future__ import annotations
import argparse
import csv
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Iterator, List, Optional, Sequence, Tuple
import numpy as np
try:
import cv2 # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError(
"Failed to import cv2. Please install OpenCV first:\n"
" pip install opencv-python\n"
"Optional (for TV-L1):\n"
" pip install opencv-contrib-python\n"
) from e
VIDEO_EXTS_DEFAULT = (".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".m4v")
@dataclass(frozen=True)
class Segment:
start_s: float
end_s: float
score: float
def iter_videos(root: Path, recursive: bool, exts: Sequence[str]) -> List[Path]:
exts_norm = {e.lower() for e in exts}
it: Iterable[Path] = root.rglob("*") if recursive else root.glob("*")
out: List[Path] = []
for p in it:
if not p.is_file():
continue
if p.suffix.lower() in exts_norm:
out.append(p)
return sorted(out)
def _resize_gray(frame_bgr: np.ndarray, short_side: int) -> np.ndarray:
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
if short_side <= 0:
return gray
h, w = gray.shape[:2]
if h == 0 or w == 0:
return gray
if min(h, w) == short_side:
return gray
scale = short_side / float(min(h, w))
nh = max(1, int(round(h * scale)))
nw = max(1, int(round(w * scale)))
return cv2.resize(gray, (nw, nh), interpolation=cv2.INTER_AREA)
def _make_flow_fn(method: str):
method = method.lower().strip()
if method == "farneback":
def _flow(prev: np.ndarray, cur: np.ndarray) -> np.ndarray:
return cv2.calcOpticalFlowFarneback(
prev, cur, None,
0.5, 3, 15, 3, 5, 1.2, 0
)
return _flow
if method in {"tvl1", "tv-l1", "dual_tvl1"}:
# Requires opencv-contrib-python.
if not hasattr(cv2, "optflow") or not hasattr(cv2.optflow, "DualTVL1OpticalFlow_create"):
raise RuntimeError(
"TV-L1 optical flow requested but OpenCV contrib is missing. Install:\n"
" pip install opencv-contrib-python"
)
tvl1 = cv2.optflow.DualTVL1OpticalFlow_create()
def _flow(prev: np.ndarray, cur: np.ndarray) -> np.ndarray:
return tvl1.calc(prev, cur, None)
return _flow
raise ValueError(f"Unknown optical flow method: {method}. Use 'farneback' or 'tvl1'.")
def compute_window_scores(
video_path: Path,
*,
window_frames: int,
frame_stride: int,
resize_short_side: int,
flow_method: str,
per_flow_stat: str,
max_frames: Optional[int],
) -> Tuple[float, List[int], List[float], List[float]]:
"""
Returns:
fps,
sampled_frame_indices (actual frame indices in the original video, 0-based),
per_flow_scores (len = len(sampled_frames)-1),
window_scores (len = len(sampled_frames)-window_frames+1, may be empty)
"""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Could not open video: {video_path}")
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
if not (fps > 0 and math.isfinite(fps)):
# Fallback to 30 if metadata missing, but keep warning-ish behavior in caller.
fps = 30.0
flow_fn = _make_flow_fn(flow_method)
per_flow_stat = per_flow_stat.lower().strip()
if per_flow_stat not in {"mean", "median", "p90", "p95"}:
raise ValueError("per_flow_stat must be one of: mean, median, p90, p95")
sampled_idx: List[int] = []
flow_scores: List[float] = []
actual_idx = -1
prev_gray: Optional[np.ndarray] = None
prev_sampled_actual_idx: Optional[int] = None
sampled_count = 0
while True:
# Advance by frame_stride frames; keep the last frame read as the sampled frame.
frame = None
for _ in range(frame_stride):
ok, f = cap.read()
if not ok:
frame = None
break
frame = f
actual_idx += 1
if frame is None:
break
gray = _resize_gray(frame, resize_short_side)
sampled_idx.append(actual_idx)
sampled_count += 1
if prev_gray is not None:
flow = flow_fn(prev_gray, gray)
mag = cv2.magnitude(flow[..., 0], flow[..., 1])
if per_flow_stat == "mean":
score = float(np.mean(mag))
elif per_flow_stat == "median":
score = float(np.median(mag))
elif per_flow_stat == "p90":
score = float(np.percentile(mag, 90))
else:
score = float(np.percentile(mag, 95))
flow_scores.append(score)
prev_gray = gray
prev_sampled_actual_idx = actual_idx
if max_frames is not None and sampled_count >= max_frames:
break
cap.release()
if window_frames < 2:
raise ValueError("window_frames must be >= 2")
if frame_stride < 1:
raise ValueError("frame_stride must be >= 1")
# Map W frames => (W-1) flow scores inside the window.
W = window_frames
window_scores: List[float] = []
if len(sampled_idx) >= W:
for i in range(0, len(sampled_idx) - W + 1):
# flows correspond to transitions between sampled frames:
# window frames are [i .. i+W-1], so use flow_scores indices [i .. i+W-2]
# BUT flow_scores is aligned to transitions after the first sampled frame:
# flow_scores[j-1] is between sampled_idx[j-1] and sampled_idx[j]
# For window starting at i, we need transitions for j in [i+1 .. i+W-1]
start = i
end = i + W - 1
# flow_scores slice: (i+1)-1 .. (i+W-1)-1 => i .. i+W-2
s = flow_scores[start : end]
window_scores.append(float(np.mean(s)) if s else 0.0)
return fps, sampled_idx, flow_scores, window_scores
def windows_to_segments(
sampled_idx: List[int],
window_scores: List[float],
*,
fps: float,
window_frames: int,
threshold: float,
min_duration_s: float,
merge_gap_s: float,
) -> List[Segment]:
if not window_scores:
return []
W = window_frames
positives: List[Tuple[float, float, float]] = []
for i, s in enumerate(window_scores):
if s < threshold:
continue
start_frame = sampled_idx[i]
end_frame = sampled_idx[i + W - 1]
start_s = start_frame / fps
end_s = end_frame / fps
positives.append((start_s, end_s, float(s)))
if not positives:
return []
# Merge overlapping/adjacent windows into longer segments. Score = max(window_score) in merged region.
positives.sort(key=lambda x: x[0])
merged: List[Segment] = []
cur_s, cur_e, cur_score = positives[0]
for s, e, sc in positives[1:]:
if s <= cur_e + merge_gap_s:
cur_e = max(cur_e, e)
cur_score = max(cur_score, sc)
else:
if (cur_e - cur_s) >= min_duration_s:
merged.append(Segment(cur_s, cur_e, cur_score))
cur_s, cur_e, cur_score = s, e, sc
if (cur_e - cur_s) >= min_duration_s:
merged.append(Segment(cur_s, cur_e, cur_score))
return merged
def choose_threshold(window_scores: List[float], threshold: Optional[float], threshold_percentile: float) -> float:
if threshold is not None:
return float(threshold)
if not window_scores:
return float("inf")
p = float(threshold_percentile)
if not (0.0 < p < 100.0):
raise ValueError("threshold_percentile must be in (0, 100)")
return float(np.percentile(window_scores, p))
def video_id_for_output(video_path: Path, input_root: Path, mode: str) -> str:
mode = mode.lower().strip()
if mode == "relative":
try:
return str(video_path.relative_to(input_root))
except Exception:
return str(video_path)
if mode == "absolute":
return str(video_path)
if mode == "basename":
return video_path.name
raise ValueError("path_mode must be one of: relative, absolute, basename")
def main() -> None:
ap = argparse.ArgumentParser(description="Auto pre-label high-motion segments using optical flow.")
ap.add_argument("--input_dir", type=str, required=True, help="Directory containing videos (can be nested).")
ap.add_argument("--output_csv", type=str, required=True, help="Output CSV path.")
ap.add_argument("--recursive", action="store_true", help="Recursively scan input_dir.")
ap.add_argument("--exts", type=str, default=",".join(VIDEO_EXTS_DEFAULT), help="Comma-separated video extensions.")
ap.add_argument("--label", type=str, default="scared", help="Label name for detected segments.")
ap.add_argument("--window_frames", type=int, default=30, help="Sliding window size in frames (default: 30).")
ap.add_argument("--frame_stride", type=int, default=1, help="Sample every N frames (default: 1).")
ap.add_argument("--resize_short_side", type=int, default=224, help="Resize grayscale frames so short side == N (0 disables).")
ap.add_argument("--flow_method", type=str, default="farneback", help="Optical flow method: farneback | tvl1")
ap.add_argument(
"--per_flow_stat",
type=str,
default="mean",
help="Per-flow magnitude aggregation: mean | median | p90 | p95",
)
ap.add_argument(
"--threshold",
type=float,
default=None,
help="Absolute threshold on window score. If omitted, uses --threshold_percentile.",
)
ap.add_argument(
"--threshold_percentile",
type=float,
default=95.0,
help="Adaptive threshold percentile computed per-video from window scores (default: 95).",
)
ap.add_argument("--min_duration_s", type=float, default=0.0, help="Drop segments shorter than this duration (seconds).")
ap.add_argument("--merge_gap_s", type=float, default=0.25, help="Merge segments whose gap <= this (seconds).")
ap.add_argument("--max_sampled_frames", type=int, default=None, help="For debugging: stop after N sampled frames per video.")
ap.add_argument(
"--path_mode",
type=str,
default="relative",
help="How to write the video field: relative | absolute | basename (default: relative)",
)
ap.add_argument("--no_header", action="store_true", help="Do not write CSV header.")
ap.add_argument(
"--output_format",
type=str,
default="segments",
help="CSV format: segments (one segment per line, includes score) | packed_per_video (one line per video)",
)
ap.add_argument("--quiet", action="store_true", help="Reduce logging.")
args = ap.parse_args()
input_root = Path(args.input_dir).expanduser().resolve()
output_csv = Path(args.output_csv).expanduser().resolve()
output_csv.parent.mkdir(parents=True, exist_ok=True)
exts = [e.strip().lower() for e in args.exts.split(",") if e.strip()]
exts = [e if e.startswith(".") else f".{e}" for e in exts]
videos = iter_videos(input_root, args.recursive, exts)
if not videos:
raise SystemExit(f"No videos found under: {input_root}")
output_format = args.output_format.lower().strip()
if output_format not in {"segments", "packed_per_video"}:
raise SystemExit("Invalid --output_format. Use: segments | packed_per_video")
total_segments = 0
with open(output_csv, "w", newline="") as f:
w = csv.writer(f)
if not args.no_header and output_format == "segments":
w.writerow(["video", "start_t", "end_t", "label", "score"])
for vp in videos:
try:
fps, sampled_idx, _flow_scores, window_scores = compute_window_scores(
vp,
window_frames=args.window_frames,
frame_stride=args.frame_stride,
resize_short_side=args.resize_short_side,
flow_method=args.flow_method,
per_flow_stat=args.per_flow_stat,
max_frames=args.max_sampled_frames,
)
thr = choose_threshold(window_scores, args.threshold, args.threshold_percentile)
segments = windows_to_segments(
sampled_idx,
window_scores,
fps=fps,
window_frames=args.window_frames,
threshold=thr,
min_duration_s=args.min_duration_s,
merge_gap_s=args.merge_gap_s,
)
vid = video_id_for_output(vp, input_root, args.path_mode)
if output_format == "segments":
for seg in segments:
w.writerow([vid, f"{seg.start_s:.3f}", f"{seg.end_s:.3f}", args.label, f"{seg.score:.6f}"])
else:
# One row per video, repeating: video,start_t,end_t,label
row: List[str] = []
for seg in segments:
row.extend([vid, f"{seg.start_s:.3f}", f"{seg.end_s:.3f}", args.label])
# Still write an empty row for videos with no segments? keep consistent:
# - if no segments, write nothing (default)
if row:
w.writerow(row)
total_segments += len(segments)
if not args.quiet:
if window_scores:
ws = np.asarray(window_scores, dtype=np.float32)
print(
f"[{vid}] fps={fps:.3f} sampled={len(sampled_idx)} "
f"windows={len(window_scores)} thr={thr:.6f} "
f"segments={len(segments)} "
f"(win_score p50={float(np.percentile(ws,50)):.6f}, p95={float(np.percentile(ws,95)):.6f})"
)
else:
print(f"[{vid}] too short for window_frames={args.window_frames}; skipped.")
except Exception as e:
if not args.quiet:
print(f"[WARN] failed on {vp}: {e}")
if not args.quiet:
print(f"\nWrote {total_segments} segments to: {output_csv}")
if __name__ == "__main__":
main()