797 lines
28 KiB
Python
Executable File
797 lines
28 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Train a video classifier using PyTorchVideo (X3D) on a CSV dataset.
|
|
|
|
This script is intentionally independent from SlowFast, so it works even if
|
|
SlowFast's optional deps/imports are broken.
|
|
|
|
Expected CSV format (space-separated, like SlowFast/Kinetics):
|
|
relative/path/to/video.mp4 <label_int>
|
|
|
|
Example:
|
|
python train_pytorchvideo_x3d.py \
|
|
--csv_dir /home/ubuntu/projects/FishAction/data/fish/fish_action_training_dataset \
|
|
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
|
|
--model x3d_m \
|
|
--pretrained \
|
|
--batch_size 4 \
|
|
--epochs 30 \
|
|
--num_workers 4 \
|
|
--output_dir /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import contextlib
|
|
import json
|
|
import os
|
|
import random
|
|
import time
|
|
from math import cos, pi
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
|
|
|
|
|
|
KINETICS_MEAN = (0.45, 0.45, 0.45)
|
|
KINETICS_STD = (0.225, 0.225, 0.225)
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def _parse_path_label_line(line: str) -> Tuple[str, int] | None:
|
|
"""
|
|
Robust parser for CSV lines.
|
|
|
|
Expected (whitespace-separated):
|
|
path label
|
|
|
|
If the path contains spaces, we take the last token as label:
|
|
/some/path/with spaces/video.mp4 3
|
|
"""
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
return None
|
|
parts = line.split()
|
|
if len(parts) < 2:
|
|
return None
|
|
label_str = parts[-1]
|
|
try:
|
|
label = int(label_str)
|
|
except ValueError:
|
|
# likely a header row, e.g. "path label"
|
|
return None
|
|
path = " ".join(parts[:-1])
|
|
return path, label
|
|
|
|
|
|
def read_csv_list(csv_path: str, path_prefix: str) -> List[Tuple[str, Dict]]:
|
|
items: List[Tuple[str, Dict]] = []
|
|
with open(csv_path, "r") as f:
|
|
for line in f:
|
|
parsed = _parse_path_label_line(line)
|
|
if parsed is None:
|
|
continue
|
|
rel_path, label = parsed
|
|
full = os.path.join(path_prefix, rel_path)
|
|
items.append((full, {"label": int(label)}))
|
|
return items
|
|
|
|
|
|
def infer_num_classes(csv_path: str) -> int:
|
|
max_label = -1
|
|
with open(csv_path, "r") as f:
|
|
for line in f:
|
|
parsed = _parse_path_label_line(line)
|
|
if parsed is None:
|
|
continue
|
|
_, label = parsed
|
|
max_label = max(max_label, int(label))
|
|
return max_label + 1
|
|
|
|
|
|
def uniform_temporal_subsample_cthw(video: torch.Tensor, num_frames: int) -> torch.Tensor:
|
|
# video: C T H W
|
|
c, t, h, w = video.shape
|
|
if t == num_frames:
|
|
return video
|
|
if t < 1:
|
|
raise ValueError(f"Invalid temporal dimension t={t}")
|
|
idx = torch.linspace(0, t - 1, steps=num_frames, device=video.device).long()
|
|
return torch.index_select(video, dim=1, index=idx)
|
|
|
|
|
|
def to_cthw(video: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Convert common video tensor layouts to C T H W.
|
|
Supported:
|
|
- T H W C
|
|
- T C H W
|
|
- C T H W
|
|
"""
|
|
if video.ndim != 4:
|
|
raise ValueError(f"Expected 4D video tensor, got shape {tuple(video.shape)}")
|
|
|
|
# Heuristics for channel location (3 channels RGB).
|
|
if video.shape[0] == 3:
|
|
# C T H W
|
|
return video
|
|
if video.shape[-1] == 3:
|
|
# T H W C -> C T H W
|
|
return video.permute(3, 0, 1, 2)
|
|
if video.shape[1] == 3:
|
|
# T C H W -> C T H W
|
|
return video.permute(1, 0, 2, 3)
|
|
raise ValueError(f"Could not infer channel dim for video shape {tuple(video.shape)}")
|
|
|
|
|
|
def resize_short_side_tchw(frames_tchw: torch.Tensor, short_side: int) -> torch.Tensor:
|
|
# frames_tchw: T C H W
|
|
t, c, h, w = frames_tchw.shape
|
|
if min(h, w) == short_side:
|
|
return frames_tchw
|
|
if h < w:
|
|
new_h = short_side
|
|
new_w = int(round(w * short_side / h))
|
|
else:
|
|
new_w = short_side
|
|
new_h = int(round(h * short_side / w))
|
|
return F.interpolate(frames_tchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
|
|
|
|
|
|
def random_crop_tchw(frames_tchw: torch.Tensor, crop_size: int) -> torch.Tensor:
|
|
t, c, h, w = frames_tchw.shape
|
|
if h < crop_size or w < crop_size:
|
|
# Fallback: center crop after resizing to crop_size.
|
|
frames_tchw = F.interpolate(
|
|
frames_tchw, size=(max(h, crop_size), max(w, crop_size)), mode="bilinear", align_corners=False
|
|
)
|
|
t, c, h, w = frames_tchw.shape
|
|
top = random.randint(0, h - crop_size) if h > crop_size else 0
|
|
left = random.randint(0, w - crop_size) if w > crop_size else 0
|
|
return frames_tchw[:, :, top : top + crop_size, left : left + crop_size]
|
|
|
|
|
|
def center_crop_tchw(frames_tchw: torch.Tensor, crop_size: int) -> torch.Tensor:
|
|
t, c, h, w = frames_tchw.shape
|
|
if h < crop_size or w < crop_size:
|
|
frames_tchw = F.interpolate(
|
|
frames_tchw, size=(max(h, crop_size), max(w, crop_size)), mode="bilinear", align_corners=False
|
|
)
|
|
t, c, h, w = frames_tchw.shape
|
|
top = max((h - crop_size) // 2, 0)
|
|
left = max((w - crop_size) // 2, 0)
|
|
return frames_tchw[:, :, top : top + crop_size, left : left + crop_size]
|
|
|
|
|
|
def normalize_cthw(video_cthw: torch.Tensor, mean=KINETICS_MEAN, std=KINETICS_STD) -> torch.Tensor:
|
|
m = torch.tensor(mean, device=video_cthw.device, dtype=video_cthw.dtype).view(3, 1, 1, 1)
|
|
s = torch.tensor(std, device=video_cthw.device, dtype=video_cthw.dtype).view(3, 1, 1, 1)
|
|
return (video_cthw - m) / s
|
|
|
|
|
|
class VideoTransform:
|
|
def __init__(
|
|
self,
|
|
num_frames: int,
|
|
train: bool,
|
|
short_side: int = 256,
|
|
crop_size: int = 224,
|
|
random_flip: bool = True,
|
|
):
|
|
self.num_frames = num_frames
|
|
self.train = train
|
|
self.short_side = short_side
|
|
self.crop_size = crop_size
|
|
self.random_flip = random_flip
|
|
|
|
def __call__(self, sample: Dict):
|
|
video = sample["video"]
|
|
label_obj = sample.get("label", 0)
|
|
if isinstance(label_obj, dict):
|
|
label = int(label_obj.get("label", 0))
|
|
else:
|
|
label = int(label_obj)
|
|
|
|
video = to_cthw(video)
|
|
if video.dtype == torch.uint8:
|
|
video = video.float() / 255.0
|
|
else:
|
|
video = video.float()
|
|
|
|
video = uniform_temporal_subsample_cthw(video, self.num_frames) # C T H W
|
|
|
|
# Spatial ops on T C H W.
|
|
frames_tchw = video.permute(1, 0, 2, 3)
|
|
frames_tchw = resize_short_side_tchw(frames_tchw, self.short_side)
|
|
if self.train:
|
|
frames_tchw = random_crop_tchw(frames_tchw, self.crop_size)
|
|
if self.random_flip and random.random() < 0.5:
|
|
frames_tchw = torch.flip(frames_tchw, dims=[3])
|
|
else:
|
|
frames_tchw = center_crop_tchw(frames_tchw, self.crop_size)
|
|
|
|
video = frames_tchw.permute(1, 0, 2, 3) # back to C T H W
|
|
video = normalize_cthw(video)
|
|
|
|
return video, torch.tensor(label, dtype=torch.long)
|
|
|
|
|
|
def build_pretrained_x3d(model_name: str, pretrained: bool) -> nn.Module:
|
|
"""
|
|
model_name: x3d_xs | x3d_s | x3d_m | x3d_l
|
|
"""
|
|
if pretrained:
|
|
# Uses torch hub; downloads pytorchvideo main repo + weights once into ~/.cache/torch/hub/
|
|
return torch.hub.load("facebookresearch/pytorchvideo", model_name, pretrained=True)
|
|
# No pretrained: build from pytorchvideo package API.
|
|
from pytorchvideo.models.x3d import create_x3d
|
|
|
|
return create_x3d(model_num_class=400)
|
|
|
|
|
|
def replace_last_linear(model: nn.Module, num_classes: int) -> None:
|
|
# X3D hub model ends with blocks.5.proj (Linear out_features=400).
|
|
# We replace the last Linear module found.
|
|
last_name = None
|
|
last_linear = None
|
|
for name, mod in model.named_modules():
|
|
if isinstance(mod, nn.Linear):
|
|
last_name = name
|
|
last_linear = mod
|
|
if last_name is None or last_linear is None:
|
|
raise RuntimeError("Could not find a Linear layer to replace for classification head.")
|
|
|
|
in_features = last_linear.in_features
|
|
new_linear = nn.Linear(in_features, num_classes)
|
|
|
|
# Set it by walking the module path.
|
|
parts = last_name.split(".")
|
|
parent = model
|
|
for p in parts[:-1]:
|
|
parent = getattr(parent, p)
|
|
setattr(parent, parts[-1], new_linear)
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
csv_dir: str
|
|
path_prefix: str
|
|
train_csv: str
|
|
val_csv: str
|
|
model: str
|
|
pretrained: bool
|
|
num_classes: int
|
|
num_frames: int
|
|
sampling_rate: int
|
|
target_fps: int
|
|
batch_size: int
|
|
epochs: int
|
|
lr: float
|
|
weight_decay: float
|
|
scheduler: str
|
|
warmup_epochs: int
|
|
min_lr: float
|
|
step_size: int
|
|
gamma: float
|
|
plateau_factor: float
|
|
plateau_patience: int
|
|
num_workers: int
|
|
prefetch_factor: int
|
|
persistent_workers: bool
|
|
seed: int
|
|
output_dir: str
|
|
amp: bool
|
|
device: str
|
|
decoder: str
|
|
train_sampling: str
|
|
train_clips_per_video: int
|
|
val_sampling: str
|
|
val_clips_per_video: int
|
|
tf32: bool
|
|
|
|
|
|
def accuracy_top1(logits: torch.Tensor, labels: torch.Tensor) -> float:
|
|
preds = torch.argmax(logits, dim=1)
|
|
correct = (preds == labels).sum().item()
|
|
return correct / labels.numel()
|
|
|
|
|
|
def _get_lr(optimizer: torch.optim.Optimizer) -> float:
|
|
return float(optimizer.param_groups[0]["lr"])
|
|
|
|
|
|
def build_scheduler(
|
|
optimizer: torch.optim.Optimizer, cfg: TrainConfig
|
|
) -> tuple[torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.ReduceLROnPlateau | None, bool]:
|
|
"""
|
|
Returns: (scheduler, needs_val_metric)
|
|
"""
|
|
if cfg.scheduler == "none":
|
|
return None, False
|
|
|
|
if cfg.scheduler == "plateau":
|
|
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer,
|
|
mode="min",
|
|
factor=cfg.plateau_factor,
|
|
patience=cfg.plateau_patience,
|
|
)
|
|
return sched, True
|
|
|
|
# For cosine/step we implement warmup + schedule as a single LambdaLR, stepped once per epoch.
|
|
warmup = max(int(cfg.warmup_epochs), 0)
|
|
total = max(int(cfg.epochs), 1)
|
|
min_lr_factor = float(cfg.min_lr / cfg.lr) if cfg.lr > 0 else 0.0
|
|
min_lr_factor = max(0.0, min(1.0, min_lr_factor))
|
|
|
|
def lr_lambda(epoch_idx0: int) -> float:
|
|
# epoch_idx0 is 0-based.
|
|
e = int(epoch_idx0)
|
|
if warmup > 0 and e < warmup:
|
|
return float(e + 1) / float(warmup)
|
|
if cfg.scheduler == "step":
|
|
e2 = e - warmup
|
|
if cfg.step_size <= 0:
|
|
return 1.0
|
|
n_steps = max(e2 // int(cfg.step_size), 0)
|
|
return float(cfg.gamma) ** float(n_steps)
|
|
if cfg.scheduler == "cosine":
|
|
denom = max(total - warmup, 1)
|
|
t = (e - warmup) / float(denom) # 0..1-ish
|
|
t = min(max(t, 0.0), 1.0)
|
|
cosine_factor = 0.5 * (1.0 + cos(pi * t))
|
|
return min_lr_factor + (1.0 - min_lr_factor) * cosine_factor
|
|
return 1.0
|
|
|
|
sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
|
return sched, False
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--csv_dir", type=str, required=True, help="Folder containing train.csv/val.csv")
|
|
parser.add_argument("--path_prefix", type=str, required=True, help="Absolute path prefix for videos")
|
|
parser.add_argument("--train_csv", type=str, default="train.csv")
|
|
parser.add_argument("--val_csv", type=str, default="val.csv")
|
|
parser.add_argument("--model", type=str, default="x3d_m", choices=["x3d_xs", "x3d_s", "x3d_m", "x3d_l"])
|
|
parser.add_argument("--pretrained", action="store_true", help="Use torch.hub pretrained weights")
|
|
parser.add_argument("--num_classes", type=int, default=0, help="If 0, infer from train.csv max label + 1")
|
|
parser.add_argument("--num_frames", type=int, default=16)
|
|
parser.add_argument("--sampling_rate", type=int, default=5, help="Used only for clip_duration calculation")
|
|
parser.add_argument("--target_fps", type=int, default=30, help="Used only for clip_duration calculation")
|
|
parser.add_argument("--batch_size", type=int, default=4)
|
|
parser.add_argument("--epochs", type=int, default=30)
|
|
parser.add_argument("--lr", type=float, default=1e-3)
|
|
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
|
parser.add_argument(
|
|
"--scheduler",
|
|
type=str,
|
|
default="none",
|
|
choices=["none", "cosine", "step", "plateau"],
|
|
help=(
|
|
"LR scheduler. cosine/step are stepped once per epoch. "
|
|
"plateau monitors val loss and reduces LR when it stops improving."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--warmup_epochs",
|
|
type=int,
|
|
default=0,
|
|
help="Warm up LR linearly for N epochs (cosine/step).",
|
|
)
|
|
parser.add_argument("--min_lr", type=float, default=1e-5, help="Cosine schedule minimum LR (absolute).")
|
|
parser.add_argument("--step_size", type=int, default=10, help="Step schedule: decay every N epochs (after warmup).")
|
|
parser.add_argument("--gamma", type=float, default=0.1, help="Step schedule: multiplicative decay factor.")
|
|
parser.add_argument(
|
|
"--plateau_factor",
|
|
type=float,
|
|
default=0.1,
|
|
help="Plateau schedule: LR *= factor when val loss plateaus.",
|
|
)
|
|
parser.add_argument(
|
|
"--plateau_patience",
|
|
type=int,
|
|
default=2,
|
|
help="Plateau schedule: epochs to wait for val loss improvement before decaying.",
|
|
)
|
|
parser.add_argument("--num_workers", type=int, default=4)
|
|
parser.add_argument(
|
|
"--prefetch_factor",
|
|
type=int,
|
|
default=2,
|
|
help="DataLoader prefetch factor per worker (only used when num_workers > 0).",
|
|
)
|
|
parser.add_argument(
|
|
"--persistent_workers",
|
|
action="store_true",
|
|
help="Keep DataLoader workers alive across epochs (faster, uses more RAM; only if num_workers > 0).",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--output_dir", type=str, required=True)
|
|
parser.add_argument("--amp", action="store_true", help="Use mixed precision (CUDA only)")
|
|
parser.add_argument(
|
|
"--tf32",
|
|
action="store_true",
|
|
help="Allow TF32 matmul/conv on Ampere+ (CUDA only). Can speed up training a bit.",
|
|
)
|
|
parser.add_argument(
|
|
"--decoder",
|
|
type=str,
|
|
default="pyav",
|
|
help="Video decoder backend for PyTorchVideo LabeledVideoDataset (common: pyav, decord if installed).",
|
|
)
|
|
parser.add_argument(
|
|
"--train_sampling",
|
|
type=str,
|
|
default="random",
|
|
choices=["random", "constant_clips_per_video", "uniform"],
|
|
help=(
|
|
"Training clip sampling. "
|
|
"'random' uses 1 random clip per video per epoch. "
|
|
"'constant_clips_per_video' uses N evenly spaced clips per video per epoch. "
|
|
"'uniform' splits videos into sequential clips (long videos yield more clips)."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--train_clips_per_video",
|
|
type=int,
|
|
default=1,
|
|
help="Used only when --train_sampling=constant_clips_per_video.",
|
|
)
|
|
parser.add_argument(
|
|
"--val_sampling",
|
|
type=str,
|
|
default="uniform",
|
|
choices=["uniform", "random", "constant_clips_per_video"],
|
|
help=(
|
|
"Validation clip sampling. 'uniform' may produce many clips for long videos. "
|
|
"'random' uses 1 random clip per video. "
|
|
"'constant_clips_per_video' uses N evenly spaced clips per video."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--val_clips_per_video",
|
|
type=int,
|
|
default=10,
|
|
help="Used only when --val_sampling=constant_clips_per_video.",
|
|
)
|
|
parser.add_argument(
|
|
"--log_interval",
|
|
type=int,
|
|
default=20,
|
|
help="Print training/val progress every N batches (default: 20). Use 1 for every batch, or --log_every_batch.",
|
|
)
|
|
parser.add_argument(
|
|
"--log_every_batch",
|
|
action="store_true",
|
|
help="Print one line after every training and validation batch (includes batch/total when known).",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
amp = bool(args.amp and device == "cuda")
|
|
tf32 = bool(args.tf32 and device == "cuda")
|
|
if tf32:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
try:
|
|
torch.set_float32_matmul_precision("high")
|
|
except Exception:
|
|
pass
|
|
|
|
csv_dir = os.path.abspath(os.path.expanduser(args.csv_dir))
|
|
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
|
|
train_csv_path = os.path.join(csv_dir, args.train_csv)
|
|
val_csv_path = os.path.join(csv_dir, args.val_csv)
|
|
|
|
if not os.path.isfile(train_csv_path):
|
|
raise FileNotFoundError(
|
|
f"train csv not found: {train_csv_path}\n"
|
|
f"Tip: --csv_dir should be the folder that contains train.csv/val.csv, "
|
|
f"NOT the raw video folder."
|
|
)
|
|
if not os.path.isfile(val_csv_path):
|
|
raise FileNotFoundError(f"val csv not found: {val_csv_path}")
|
|
|
|
if args.num_classes and args.num_classes > 0:
|
|
num_classes = args.num_classes
|
|
else:
|
|
num_classes = infer_num_classes(train_csv_path)
|
|
|
|
cfg = TrainConfig(
|
|
csv_dir=csv_dir,
|
|
path_prefix=path_prefix,
|
|
train_csv=args.train_csv,
|
|
val_csv=args.val_csv,
|
|
model=args.model,
|
|
pretrained=bool(args.pretrained),
|
|
num_classes=num_classes,
|
|
num_frames=args.num_frames,
|
|
sampling_rate=args.sampling_rate,
|
|
target_fps=args.target_fps,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs,
|
|
lr=args.lr,
|
|
weight_decay=args.weight_decay,
|
|
scheduler=str(args.scheduler),
|
|
warmup_epochs=int(args.warmup_epochs),
|
|
min_lr=float(args.min_lr),
|
|
step_size=int(args.step_size),
|
|
gamma=float(args.gamma),
|
|
plateau_factor=float(args.plateau_factor),
|
|
plateau_patience=int(args.plateau_patience),
|
|
num_workers=args.num_workers,
|
|
prefetch_factor=args.prefetch_factor,
|
|
persistent_workers=bool(args.persistent_workers and args.num_workers > 0),
|
|
seed=args.seed,
|
|
output_dir=os.path.abspath(os.path.expanduser(args.output_dir)),
|
|
amp=amp,
|
|
device=device,
|
|
decoder=str(args.decoder),
|
|
train_sampling=str(args.train_sampling),
|
|
train_clips_per_video=int(args.train_clips_per_video),
|
|
val_sampling=str(args.val_sampling),
|
|
val_clips_per_video=int(args.val_clips_per_video),
|
|
tf32=tf32,
|
|
)
|
|
|
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
with open(os.path.join(cfg.output_dir, "config.json"), "w") as f:
|
|
json.dump(asdict(cfg), f, indent=2)
|
|
|
|
set_seed(cfg.seed)
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# Clip duration in seconds, for the clip sampler.
|
|
clip_duration = cfg.num_frames * cfg.sampling_rate / float(cfg.target_fps)
|
|
|
|
train_items = read_csv_list(train_csv_path, cfg.path_prefix)
|
|
val_items = read_csv_list(val_csv_path, cfg.path_prefix)
|
|
|
|
if cfg.train_sampling == "random":
|
|
train_sampler = make_clip_sampler("random", clip_duration)
|
|
elif cfg.train_sampling == "uniform":
|
|
train_sampler = make_clip_sampler("uniform", clip_duration)
|
|
elif cfg.train_sampling == "constant_clips_per_video":
|
|
train_sampler = make_clip_sampler("constant_clips_per_video", clip_duration, cfg.train_clips_per_video)
|
|
else:
|
|
raise ValueError(f"Unknown train_sampling: {cfg.train_sampling}")
|
|
|
|
train_ds = LabeledVideoDataset(
|
|
labeled_video_paths=train_items,
|
|
clip_sampler=train_sampler,
|
|
decode_audio=False,
|
|
decoder=cfg.decoder,
|
|
transform=VideoTransform(num_frames=cfg.num_frames, train=True),
|
|
)
|
|
if cfg.val_sampling == "uniform":
|
|
val_sampler = make_clip_sampler("uniform", clip_duration)
|
|
elif cfg.val_sampling == "random":
|
|
val_sampler = make_clip_sampler("random", clip_duration)
|
|
elif cfg.val_sampling == "constant_clips_per_video":
|
|
val_sampler = make_clip_sampler("constant_clips_per_video", clip_duration, cfg.val_clips_per_video)
|
|
else:
|
|
raise ValueError(f"Unknown val_sampling: {cfg.val_sampling}")
|
|
val_ds = LabeledVideoDataset(
|
|
labeled_video_paths=val_items,
|
|
clip_sampler=val_sampler,
|
|
decode_audio=False,
|
|
decoder=cfg.decoder,
|
|
transform=VideoTransform(num_frames=cfg.num_frames, train=False),
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_ds,
|
|
batch_size=cfg.batch_size,
|
|
num_workers=cfg.num_workers,
|
|
pin_memory=(device == "cuda"),
|
|
drop_last=True,
|
|
persistent_workers=cfg.persistent_workers,
|
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
|
)
|
|
val_loader = DataLoader(
|
|
val_ds,
|
|
batch_size=cfg.batch_size,
|
|
num_workers=cfg.num_workers,
|
|
pin_memory=(device == "cuda"),
|
|
drop_last=False,
|
|
persistent_workers=cfg.persistent_workers,
|
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
|
)
|
|
|
|
model = build_pretrained_x3d(cfg.model, cfg.pretrained)
|
|
replace_last_linear(model, cfg.num_classes)
|
|
model = model.to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
|
scheduler, scheduler_needs_val = build_scheduler(optimizer, cfg)
|
|
# torch.cuda.amp is deprecated in newer torch versions.
|
|
scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp) if device == "cuda" else torch.amp.GradScaler(enabled=False)
|
|
|
|
best_val = -1.0
|
|
|
|
try:
|
|
n_train_batches_total = len(train_loader)
|
|
except TypeError:
|
|
n_train_batches_total = None
|
|
try:
|
|
n_val_batches_total = len(val_loader)
|
|
except TypeError:
|
|
n_val_batches_total = None
|
|
|
|
def _batch_suffix(cur: int, total: int | None) -> str:
|
|
return f"{cur}/{total}" if total is not None else str(cur)
|
|
|
|
for epoch in range(1, cfg.epochs + 1):
|
|
model.train()
|
|
t0 = time.time()
|
|
train_loss = 0.0
|
|
train_acc = 0.0
|
|
n_batches = 0
|
|
last_log_t = time.time()
|
|
last_batch_end_t = time.time()
|
|
avg_data_s = 0.0
|
|
avg_iter_s = 0.0
|
|
|
|
for videos, labels in train_loader:
|
|
data_s = time.time() - last_batch_end_t
|
|
iter_t0 = time.time()
|
|
videos = videos.to(device, non_blocking=True)
|
|
labels = labels.to(device, non_blocking=True)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
autocast_ctx = (
|
|
torch.amp.autocast("cuda", enabled=cfg.amp) if device == "cuda" else contextlib.nullcontext()
|
|
)
|
|
with autocast_ctx:
|
|
logits = model(videos)
|
|
loss = criterion(logits, labels)
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
train_loss += loss.item()
|
|
train_acc += accuracy_top1(logits.detach(), labels)
|
|
n_batches += 1
|
|
iter_s = time.time() - iter_t0
|
|
last_batch_end_t = time.time()
|
|
# EMA-ish moving average for stable logs.
|
|
avg_data_s = 0.9 * avg_data_s + 0.1 * data_s if n_batches > 1 else data_s
|
|
avg_iter_s = 0.9 * avg_iter_s + 0.1 * iter_s if n_batches > 1 else iter_s
|
|
|
|
do_log = False
|
|
if args.log_every_batch:
|
|
do_log = True
|
|
elif args.log_interval > 0 and (n_batches % args.log_interval) == 0:
|
|
do_log = True
|
|
if do_log:
|
|
dt_log = time.time() - last_log_t
|
|
last_log_t = time.time()
|
|
bsuf = _batch_suffix(n_batches, n_train_batches_total)
|
|
if args.log_every_batch:
|
|
print(
|
|
f" [epoch {epoch:03d}/{cfg.epochs}] train batch {bsuf} | "
|
|
f"loss {loss.item():.4f} | "
|
|
f"run_avg_loss {(train_loss/n_batches):.4f} run_avg_acc {(train_acc/n_batches)*100:.2f}% | "
|
|
f"data {avg_data_s*1000:.0f}ms iter {avg_iter_s*1000:.0f}ms"
|
|
)
|
|
else:
|
|
print(
|
|
f" [epoch {epoch:03d}] step {n_batches:05d} "
|
|
f"loss {(train_loss/n_batches):.4f} "
|
|
f"acc {(train_acc/n_batches)*100:.2f}% "
|
|
f"data {avg_data_s*1000:.0f}ms iter {avg_iter_s*1000:.0f}ms "
|
|
f"({dt_log:.1f}s/{args.log_interval} steps)"
|
|
)
|
|
|
|
train_loss /= max(n_batches, 1)
|
|
train_acc /= max(n_batches, 1)
|
|
|
|
model.eval()
|
|
val_loss = 0.0
|
|
val_acc = 0.0
|
|
n_val = 0
|
|
last_val_log_t = time.time()
|
|
with torch.no_grad():
|
|
for videos, labels in val_loader:
|
|
videos = videos.to(device, non_blocking=True)
|
|
labels = labels.to(device, non_blocking=True)
|
|
logits = model(videos)
|
|
loss = criterion(logits, labels)
|
|
val_loss += loss.item()
|
|
val_acc += accuracy_top1(logits, labels)
|
|
n_val += 1
|
|
|
|
do_val_log = False
|
|
if args.log_every_batch:
|
|
do_val_log = True
|
|
elif args.log_interval > 0 and (n_val % args.log_interval) == 0:
|
|
do_val_log = True
|
|
if do_val_log:
|
|
dt_log = time.time() - last_val_log_t
|
|
last_val_log_t = time.time()
|
|
vsuf = _batch_suffix(n_val, n_val_batches_total)
|
|
if args.log_every_batch:
|
|
print(
|
|
f" [epoch {epoch:03d}/{cfg.epochs}] val batch {vsuf} | "
|
|
f"loss {loss.item():.4f} | "
|
|
f"run_avg_loss {(val_loss/n_val):.4f} run_avg_acc {(val_acc/n_val)*100:.2f}%"
|
|
)
|
|
else:
|
|
print(
|
|
f" [epoch {epoch:03d}] val step {n_val:05d} "
|
|
f"loss {(val_loss/n_val):.4f} "
|
|
f"acc {(val_acc/n_val)*100:.2f}% "
|
|
f"({dt_log:.1f}s/{args.log_interval} steps)"
|
|
)
|
|
val_loss /= max(n_val, 1)
|
|
val_acc /= max(n_val, 1)
|
|
|
|
dt = time.time() - t0
|
|
print(
|
|
f"Epoch {epoch:03d}/{cfg.epochs} | "
|
|
f"train loss {train_loss:.4f} acc {train_acc*100:.2f}% | "
|
|
f"val loss {val_loss:.4f} acc {val_acc*100:.2f}% | "
|
|
f"lr {_get_lr(optimizer):.6g} | "
|
|
f"time {dt:.1f}s"
|
|
)
|
|
|
|
if scheduler is not None:
|
|
if scheduler_needs_val:
|
|
scheduler.step(val_loss)
|
|
else:
|
|
scheduler.step()
|
|
|
|
# Save last
|
|
last_path = os.path.join(cfg.output_dir, "checkpoint_last.pt")
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model": model.state_dict(),
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
|
"scaler": scaler.state_dict(),
|
|
"cfg": asdict(cfg),
|
|
},
|
|
last_path,
|
|
)
|
|
|
|
# Save best
|
|
if val_acc > best_val:
|
|
best_val = val_acc
|
|
best_path = os.path.join(cfg.output_dir, "checkpoint_best.pt")
|
|
torch.save(
|
|
{
|
|
"epoch": epoch,
|
|
"model": model.state_dict(),
|
|
"optimizer": optimizer.state_dict(),
|
|
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
|
"scaler": scaler.state_dict(),
|
|
"cfg": asdict(cfg),
|
|
"best_val_acc": best_val,
|
|
},
|
|
best_path,
|
|
)
|
|
|
|
print(f"Done. Best val acc: {best_val*100:.2f}%")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|