Files
FishServer/FishAction/dataset/prepare_fish_dataset.py
2026-05-06 15:59:38 +08:00

264 lines
8.8 KiB
Python

#!/usr/bin/env python3
"""
Script to prepare fish action video dataset for SlowFast training.
This script:
1. Scans the video directory structure (each subfolder = class)
2. Creates train/val/test splits
3. Generates CSV files in the format expected by SlowFast
4. Creates a label mapping file
IMPORTANT: SlowFast does NOT require pre-split video clips. It automatically
samples short clips (typically 2-4 seconds) from full-length videos during
training. You can use your full-length videos as-is.
Usage:
python prepare_fish_dataset.py --video_dir ~/data/fish/fish_action_videos --output_dir ./dataset
"""
import argparse
import os
import random
from collections import defaultdict
from pathlib import Path
def get_video_files(directory, extensions=None):
"""Get all video files from a directory."""
if extensions is None:
extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm', '.flv', '.m4v'}
video_files = []
for root, dirs, files in os.walk(directory):
for file in files:
if any(file.lower().endswith(ext) for ext in extensions):
video_files.append(os.path.join(root, file))
return sorted(video_files)
def _canonicalize_class_name(class_name: str) -> str:
"""
Map raw folder names to canonical class names.
Supports either:
- 3-folder layout: feeding/, normal/, scared/
- 5-folder layout: feeding/, normal_underwater/, normal_upperwater/,
scared_underwater/, scared_upperwater/
"""
name = class_name.strip().lower()
if name == "feeding":
return "feeding"
if name in {"normal", "normal_underwater", "normal_upperwater"}:
return "normal"
if name in {"scared", "scared_underwater", "scared_upperwater"}:
return "scared"
return class_name # unknown, handled by caller
def create_dataset_splits(video_dir, output_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
"""
Create train/val/test splits from video directory structure.
Args:
video_dir: Root directory containing class subfolders
output_dir: Directory to save CSV files
train_ratio: Ratio of videos for training
val_ratio: Ratio of videos for validation
test_ratio: Ratio of videos for testing
seed: Random seed for reproducibility
"""
# Validate ratios
assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
"Ratios must sum to 1.0"
random.seed(seed)
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Get all class directories
video_path = Path(video_dir).expanduser()
if not video_path.exists():
raise ValueError(f"Video directory does not exist: {video_dir}")
# Find all class subdirectories
class_dirs = [d for d in video_path.iterdir() if d.is_dir()]
class_dirs = sorted(class_dirs)
if len(class_dirs) == 0:
raise ValueError(f"No subdirectories found in {video_dir}")
print(f"Found {len(class_dirs)} classes:")
for class_dir in class_dirs:
print(f" - {class_dir.name}")
# Canonicalize classes into 3-class setup.
canonical_classes = []
unknown_classes = []
for d in class_dirs:
canon = _canonicalize_class_name(d.name)
if canon in {"feeding", "normal", "scared"}:
canonical_classes.append(canon)
else:
unknown_classes.append(d.name)
if unknown_classes:
raise ValueError(
"Unknown class folder(s) found that cannot be mapped to 3 classes: "
f"{unknown_classes}. Expected folders among: "
"feeding, normal, scared, normal_underwater, normal_upperwater, "
"scared_underwater, scared_upperwater."
)
canonical_classes = sorted(set(canonical_classes))
# Ensure stable label ordering.
desired_order = ["feeding", "normal", "scared"]
canonical_classes = [c for c in desired_order if c in canonical_classes]
# Create canonical class->label mapping (3 classes).
class_to_label = {name: idx for idx, name in enumerate(canonical_classes)}
label_to_class = {idx: name for name, idx in class_to_label.items()}
# Collect all videos with their classes
all_videos = []
for class_dir in class_dirs:
raw_class_name = class_dir.name
class_name = _canonicalize_class_name(raw_class_name)
label = class_to_label[class_name]
videos = get_video_files(str(class_dir))
if len(videos) == 0:
print(f"Warning: No videos found in {class_dir}")
continue
for video_path in videos:
# Absolute path so training scripts can use path_prefix="" or point CSV paths directly.
abs_path = os.path.abspath(video_path)
all_videos.append((abs_path, label, class_name))
print(f"\nTotal videos found: {len(all_videos)}")
# Group videos by class for stratified splitting
videos_by_class = defaultdict(list)
for rel_path, label, class_name in all_videos:
videos_by_class[label].append((rel_path, label, class_name))
# Create stratified splits
train_videos = []
val_videos = []
test_videos = []
for label, videos in videos_by_class.items():
random.shuffle(videos)
n_total = len(videos)
n_train = int(n_total * train_ratio)
n_val = int(n_total * val_ratio)
train_videos.extend(videos[:n_train])
val_videos.extend(videos[n_train:n_train + n_val])
test_videos.extend(videos[n_train + n_val:])
print(f"\nClass {label_to_class[label]} (label {label}):")
print(f" Total: {n_total}")
print(f" Train: {len(videos[:n_train])}")
print(f" Val: {len(videos[n_train:n_train + n_val])}")
print(f" Test: {len(videos[n_train + n_val:])}")
# Shuffle splits
random.shuffle(train_videos)
random.shuffle(val_videos)
random.shuffle(test_videos)
# Write CSV files
def write_csv(filename, videos):
csv_path = os.path.join(output_dir, filename)
with open(csv_path, 'w') as f:
for video_path, label, _ in videos:
f.write(f"{video_path} {label}\n")
print(f"\nCreated {csv_path} with {len(videos)} videos")
write_csv("train.csv", train_videos)
write_csv("val.csv", val_videos)
write_csv("test.csv", test_videos)
# Write label mapping file
label_map_path = os.path.join(output_dir, "label_map.txt")
with open(label_map_path, 'w') as f:
for idx in sorted(label_to_class.keys()):
f.write(f"{idx} {label_to_class[idx]}\n")
print(f"\nCreated label mapping: {label_map_path}")
# Print summary
print("\n" + "="*60)
print("Dataset Preparation Summary")
print("="*60)
print(f"Total classes: {len(class_to_label)}")
print(f"Total videos: {len(all_videos)}")
print(f"Train videos: {len(train_videos)}")
print(f"Val videos: {len(val_videos)}")
print(f"Test videos: {len(test_videos)}")
print(f"\nOutput directory: {output_dir}")
print(f"Video directory: {video_dir}")
print("\nNext steps:")
print(f"1. Set DATA.PATH_TO_DATA_DIR to: {os.path.abspath(output_dir)}")
print(f"2. CSVs list absolute video paths; for PyTorchVideo-style scripts, --path_prefix can be any path")
print(f" (join ignores prefix when the CSV path is already absolute).")
print("="*60)
def main():
parser = argparse.ArgumentParser(
description="Prepare fish action video dataset for training (generates 3-class CSVs: feeding/normal/scared)"
)
parser.add_argument(
"--video_dir",
type=str,
default="~/data/fish/fish_action_videos/",
help="Root directory containing class subfolders with videos"
)
parser.add_argument(
"--output_dir",
type=str,
default="~/data/fish/fish_action_videos",
help="Output directory for CSV files"
)
parser.add_argument(
"--train_ratio",
type=float,
default=0.7,
help="Ratio of videos for training (default: 0.7)"
)
parser.add_argument(
"--val_ratio",
type=float,
default=0.15,
help="Ratio of videos for validation (default: 0.15)"
)
parser.add_argument(
"--test_ratio",
type=float,
default=0.15,
help="Ratio of videos for testing (default: 0.15)"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
args = parser.parse_args()
create_dataset_splits(
video_dir=args.video_dir,
output_dir=args.output_dir,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
test_ratio=args.test_ratio,
seed=args.seed
)
if __name__ == "__main__":
main()