264 lines
8.8 KiB
Python
264 lines
8.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Script to prepare fish action video dataset for SlowFast training.
|
||
|
|
|
||
|
|
This script:
|
||
|
|
1. Scans the video directory structure (each subfolder = class)
|
||
|
|
2. Creates train/val/test splits
|
||
|
|
3. Generates CSV files in the format expected by SlowFast
|
||
|
|
4. Creates a label mapping file
|
||
|
|
|
||
|
|
IMPORTANT: SlowFast does NOT require pre-split video clips. It automatically
|
||
|
|
samples short clips (typically 2-4 seconds) from full-length videos during
|
||
|
|
training. You can use your full-length videos as-is.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python prepare_fish_dataset.py --video_dir ~/data/fish/fish_action_videos --output_dir ./dataset
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
import random
|
||
|
|
from collections import defaultdict
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
|
||
|
|
def get_video_files(directory, extensions=None):
|
||
|
|
"""Get all video files from a directory."""
|
||
|
|
if extensions is None:
|
||
|
|
extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.m4v'}
|
||
|
|
|
||
|
|
video_files = []
|
||
|
|
for root, dirs, files in os.walk(directory):
|
||
|
|
for file in files:
|
||
|
|
if any(file.lower().endswith(ext) for ext in extensions):
|
||
|
|
video_files.append(os.path.join(root, file))
|
||
|
|
return sorted(video_files)
|
||
|
|
|
||
|
|
def _canonicalize_class_name(class_name: str) -> str:
|
||
|
|
"""
|
||
|
|
Map raw folder names to canonical class names.
|
||
|
|
|
||
|
|
Supports either:
|
||
|
|
- 3-folder layout: feeding/, normal/, scared/
|
||
|
|
- 5-folder layout: feeding/, normal_underwater/, normal_upperwater/,
|
||
|
|
scared_underwater/, scared_upperwater/
|
||
|
|
"""
|
||
|
|
name = class_name.strip().lower()
|
||
|
|
if name == "feeding":
|
||
|
|
return "feeding"
|
||
|
|
if name in {"normal", "normal_underwater", "normal_upperwater"}:
|
||
|
|
return "normal"
|
||
|
|
if name in {"scared", "scared_underwater", "scared_upperwater"}:
|
||
|
|
return "scared"
|
||
|
|
return class_name # unknown, handled by caller
|
||
|
|
|
||
|
|
|
||
|
|
def create_dataset_splits(video_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
|
||
|
|
"""
|
||
|
|
Create train/val/test splits from video directory structure.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
video_dir: Root directory containing class subfolders
|
||
|
|
output_dir: Directory to save CSV files
|
||
|
|
train_ratio: Ratio of videos for training
|
||
|
|
val_ratio: Ratio of videos for validation
|
||
|
|
test_ratio: Ratio of videos for testing
|
||
|
|
seed: Random seed for reproducibility
|
||
|
|
"""
|
||
|
|
# Validate ratios
|
||
|
|
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
|
||
|
|
"Ratios must sum to 1.0"
|
||
|
|
|
||
|
|
random.seed(seed)
|
||
|
|
|
||
|
|
# Create output directory
|
||
|
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# Get all class directories
|
||
|
|
video_path = Path(video_dir).expanduser()
|
||
|
|
if not video_path.exists():
|
||
|
|
raise ValueError(f"Video directory does not exist: {video_dir}")
|
||
|
|
|
||
|
|
# Find all class subdirectories
|
||
|
|
class_dirs = [d for d in video_path.iterdir() if d.is_dir()]
|
||
|
|
class_dirs = sorted(class_dirs)
|
||
|
|
|
||
|
|
if len(class_dirs) == 0:
|
||
|
|
raise ValueError(f"No subdirectories found in {video_dir}")
|
||
|
|
|
||
|
|
print(f"Found {len(class_dirs)} classes:")
|
||
|
|
for class_dir in class_dirs:
|
||
|
|
print(f" - {class_dir.name}")
|
||
|
|
|
||
|
|
# Canonicalize classes into 3-class setup.
|
||
|
|
canonical_classes = []
|
||
|
|
unknown_classes = []
|
||
|
|
for d in class_dirs:
|
||
|
|
canon = _canonicalize_class_name(d.name)
|
||
|
|
if canon in {"feeding", "normal", "scared"}:
|
||
|
|
canonical_classes.append(canon)
|
||
|
|
else:
|
||
|
|
unknown_classes.append(d.name)
|
||
|
|
|
||
|
|
if unknown_classes:
|
||
|
|
raise ValueError(
|
||
|
|
"Unknown class folder(s) found that cannot be mapped to 3 classes: "
|
||
|
|
f"{unknown_classes}. Expected folders among: "
|
||
|
|
"feeding, normal, scared, normal_underwater, normal_upperwater, "
|
||
|
|
"scared_underwater, scared_upperwater."
|
||
|
|
)
|
||
|
|
|
||
|
|
canonical_classes = sorted(set(canonical_classes))
|
||
|
|
# Ensure stable label ordering.
|
||
|
|
desired_order = ["feeding", "normal", "scared"]
|
||
|
|
canonical_classes = [c for c in desired_order if c in canonical_classes]
|
||
|
|
|
||
|
|
# Create canonical class->label mapping (3 classes).
|
||
|
|
class_to_label = {name: idx for idx, name in enumerate(canonical_classes)}
|
||
|
|
label_to_class = {idx: name for name, idx in class_to_label.items()}
|
||
|
|
|
||
|
|
# Collect all videos with their classes
|
||
|
|
all_videos = []
|
||
|
|
for class_dir in class_dirs:
|
||
|
|
raw_class_name = class_dir.name
|
||
|
|
class_name = _canonicalize_class_name(raw_class_name)
|
||
|
|
label = class_to_label[class_name]
|
||
|
|
videos = get_video_files(str(class_dir))
|
||
|
|
|
||
|
|
if len(videos) == 0:
|
||
|
|
print(f"Warning: No videos found in {class_dir}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
for video_path in videos:
|
||
|
|
# Absolute path so training scripts can use path_prefix="" or point CSV paths directly.
|
||
|
|
abs_path = os.path.abspath(video_path)
|
||
|
|
all_videos.append((abs_path, label, class_name))
|
||
|
|
|
||
|
|
print(f"\nTotal videos found: {len(all_videos)}")
|
||
|
|
|
||
|
|
# Group videos by class for stratified splitting
|
||
|
|
videos_by_class = defaultdict(list)
|
||
|
|
for rel_path, label, class_name in all_videos:
|
||
|
|
videos_by_class[label].append((rel_path, label, class_name))
|
||
|
|
|
||
|
|
# Create stratified splits
|
||
|
|
train_videos = []
|
||
|
|
val_videos = []
|
||
|
|
test_videos = []
|
||
|
|
|
||
|
|
for label, videos in videos_by_class.items():
|
||
|
|
random.shuffle(videos)
|
||
|
|
n_total = len(videos)
|
||
|
|
n_train = int(n_total * train_ratio)
|
||
|
|
n_val = int(n_total * val_ratio)
|
||
|
|
|
||
|
|
train_videos.extend(videos[:n_train])
|
||
|
|
val_videos.extend(videos[n_train:n_train + n_val])
|
||
|
|
test_videos.extend(videos[n_train + n_val:])
|
||
|
|
|
||
|
|
print(f"\nClass {label_to_class[label]} (label {label}):")
|
||
|
|
print(f" Total: {n_total}")
|
||
|
|
print(f" Train: {len(videos[:n_train])}")
|
||
|
|
print(f" Val: {len(videos[n_train:n_train + n_val])}")
|
||
|
|
print(f" Test: {len(videos[n_train + n_val:])}")
|
||
|
|
|
||
|
|
# Shuffle splits
|
||
|
|
random.shuffle(train_videos)
|
||
|
|
random.shuffle(val_videos)
|
||
|
|
random.shuffle(test_videos)
|
||
|
|
|
||
|
|
# Write CSV files
|
||
|
|
def write_csv(filename, videos):
|
||
|
|
csv_path = os.path.join(output_dir, filename)
|
||
|
|
with open(csv_path, 'w') as f:
|
||
|
|
for video_path, label, _ in videos:
|
||
|
|
f.write(f"{video_path} {label}\n")
|
||
|
|
print(f"\nCreated {csv_path} with {len(videos)} videos")
|
||
|
|
|
||
|
|
write_csv("train.csv", train_videos)
|
||
|
|
write_csv("val.csv", val_videos)
|
||
|
|
write_csv("test.csv", test_videos)
|
||
|
|
|
||
|
|
# Write label mapping file
|
||
|
|
label_map_path = os.path.join(output_dir, "label_map.txt")
|
||
|
|
with open(label_map_path, 'w') as f:
|
||
|
|
for idx in sorted(label_to_class.keys()):
|
||
|
|
f.write(f"{idx} {label_to_class[idx]}\n")
|
||
|
|
print(f"\nCreated label mapping: {label_map_path}")
|
||
|
|
|
||
|
|
# Print summary
|
||
|
|
print("\n" + "="*60)
|
||
|
|
print("Dataset Preparation Summary")
|
||
|
|
print("="*60)
|
||
|
|
print(f"Total classes: {len(class_to_label)}")
|
||
|
|
print(f"Total videos: {len(all_videos)}")
|
||
|
|
print(f"Train videos: {len(train_videos)}")
|
||
|
|
print(f"Val videos: {len(val_videos)}")
|
||
|
|
print(f"Test videos: {len(test_videos)}")
|
||
|
|
print(f"\nOutput directory: {output_dir}")
|
||
|
|
print(f"Video directory: {video_dir}")
|
||
|
|
print("\nNext steps:")
|
||
|
|
print(f"1. Set DATA.PATH_TO_DATA_DIR to: {os.path.abspath(output_dir)}")
|
||
|
|
print(f"2. CSVs list absolute video paths; for PyTorchVideo-style scripts, --path_prefix can be any path")
|
||
|
|
print(f" (join ignores prefix when the CSV path is already absolute).")
|
||
|
|
print("="*60)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Prepare fish action video dataset for training (generates 3-class CSVs: feeding/normal/scared)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--video_dir",
|
||
|
|
type=str,
|
||
|
|
default="~/data/fish/fish_action_videos/",
|
||
|
|
help="Root directory containing class subfolders with videos"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--output_dir",
|
||
|
|
type=str,
|
||
|
|
default="~/data/fish/fish_action_videos",
|
||
|
|
help="Output directory for CSV files"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--train_ratio",
|
||
|
|
type=float,
|
||
|
|
default=0.7,
|
||
|
|
help="Ratio of videos for training (default: 0.7)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--val_ratio",
|
||
|
|
type=float,
|
||
|
|
default=0.15,
|
||
|
|
help="Ratio of videos for validation (default: 0.15)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--test_ratio",
|
||
|
|
type=float,
|
||
|
|
default=0.15,
|
||
|
|
help="Ratio of videos for testing (default: 0.15)"
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--seed",
|
||
|
|
type=int,
|
||
|
|
default=42,
|
||
|
|
help="Random seed for reproducibility (default: 42)"
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
create_dataset_splits(
|
||
|
|
video_dir=args.video_dir,
|
||
|
|
output_dir=args.output_dir,
|
||
|
|
train_ratio=args.train_ratio,
|
||
|
|
val_ratio=args.val_ratio,
|
||
|
|
test_ratio=args.test_ratio,
|
||
|
|
seed=args.seed
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|
||
|
|
|