#!/usr/bin/env python3 """ Train a video classifier using timestamp/segment labels (PyTorchVideo + X3D). Why this script: - If "scared" only happens in a short window, video-level single labels dilute training. - This script trains on *timestamped segments* so clips are sampled inside labeled windows. Segments CSV format (whitespace-separated; last token is label int): relative/video.mp4 start_sec end_sec label Example: scared_underwater/foo.mp4 12.4 16.2 3 scared_underwater/foo.mp4 30.0 32.0 3 normal_underwater/bar.mp4 0.0 60.0 1 Run: python train_pytorchvideo_x3d_segments.py \ --segments_csv /home/ubuntu/data/fish/fish_action_videos/train_segments.csv \ --val_csv /home/ubuntu/data/fish/fish_action_videos/val.csv \ --path_prefix /home/ubuntu/data/fish/fish_action_videos \ --model x3d_m --pretrained \ --num_frames 16 --sampling_rate 5 --target_fps 30 \ --batch_size 4 --epochs 30 --num_workers 4 --amp \ --output_dir /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m_segments """ from __future__ import annotations import argparse import json import math import os import random import time from dataclasses import asdict, dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from pytorchvideo.data.encoded_video import EncodedVideo from train_pytorchvideo_x3d import ( VideoTransform, _parse_path_label_line, build_pretrained_x3d, replace_last_linear, set_seed, ) def _parse_segment_line(line: str) -> Optional[Tuple[str, float, float, int]]: """ Parse: path start end label Supports paths with spaces (label is last token). """ line = line.strip() if not line or line.startswith("#"): return None parts = line.split() if len(parts) < 4: return None try: label = int(parts[-1]) end_sec = float(parts[-2]) start_sec = float(parts[-3]) except ValueError: return None path = " ".join(parts[:-3]) return path, start_sec, end_sec, label @dataclass class SegmentRow: rel_path: str start_sec: float end_sec: float label: int class SegmentClipDataset(Dataset): """ Each item is a labeled time segment inside a video. We sample one clip from within that segment (random for train, center for eval). """ def __init__( self, segments: List[SegmentRow], path_prefix: str, clip_duration: float, num_frames: int, train: bool, decoder: str = "pyav", ): self.segments = segments self.path_prefix = path_prefix self.clip_duration = float(clip_duration) self.train = train self.decoder = decoder # Reuse our existing transform pipeline (expects dict with "video" + "label") self.transform = VideoTransform(num_frames=num_frames, train=train) def __len__(self) -> int: return len(self.segments) def __getitem__(self, idx: int): seg = self.segments[idx] full_path = os.path.join(self.path_prefix, seg.rel_path) # Choose clip start inside [start_sec, end_sec - clip_duration] start = float(seg.start_sec) end = float(seg.end_sec) if end <= start: end = start + self.clip_duration latest_start = max(start, end - self.clip_duration) if self.train: clip_start = random.uniform(start, latest_start) if latest_start > start else start else: clip_start = (start + latest_start) / 2.0 clip_end = clip_start + self.clip_duration video = EncodedVideo.from_path(full_path, decode_audio=False, decoder=self.decoder) clip = video.get_clip(clip_start, clip_end) frames = clip.get("video", None) if frames is None: raise RuntimeError(f"Failed to decode clip from {full_path} [{clip_start},{clip_end}]") sample = {"video": frames, "label": seg.label} x, y = self.transform(sample) return x, y def infer_num_classes_from_segments(rows: List[SegmentRow]) -> int: m = -1 for r in rows: m = max(m, int(r.label)) return m + 1 @dataclass class TrainConfig: segments_csv: str val_csv: str path_prefix: str model: str pretrained: bool num_classes: int num_frames: int sampling_rate: int target_fps: int clip_duration: float batch_size: int epochs: int lr: float weight_decay: float num_workers: int seed: int output_dir: str amp: bool device: str log_interval: int def accuracy_top1(logits: torch.Tensor, labels: torch.Tensor) -> float: preds = torch.argmax(logits, dim=1) return (preds == labels).float().mean().item() def read_val_csv(csv_path: str, path_prefix: str) -> List[Tuple[str, int]]: """ Returns list of (full_path, label). Uses robust path/label parser from the main script. """ out = [] with open(csv_path, "r") as f: for line in f: parsed = _parse_path_label_line(line) if parsed is None: continue rel_path, label = parsed out.append((os.path.join(path_prefix, rel_path), int(label))) return out def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--segments_csv", type=str, required=True, help="train segments csv (path start end label)") parser.add_argument("--val_csv", type=str, required=True, help="val csv (path label)") parser.add_argument("--path_prefix", type=str, required=True) parser.add_argument("--model", type=str, default="x3d_m", choices=["x3d_xs", "x3d_s", "x3d_m", "x3d_l"]) parser.add_argument("--pretrained", action="store_true") parser.add_argument("--num_classes", type=int, default=0, help="0=auto from segments") parser.add_argument("--num_frames", type=int, default=16) parser.add_argument("--sampling_rate", type=int, default=5) parser.add_argument("--target_fps", type=int, default=30) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--epochs", type=int, default=30) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--weight_decay", type=float, default=1e-4) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--amp", action="store_true") parser.add_argument("--log_interval", type=int, default=20) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" amp = bool(args.amp and device == "cuda") segments_csv = os.path.abspath(os.path.expanduser(args.segments_csv)) val_csv = os.path.abspath(os.path.expanduser(args.val_csv)) path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix)) output_dir = os.path.abspath(os.path.expanduser(args.output_dir)) os.makedirs(output_dir, exist_ok=True) set_seed(args.seed) clip_duration = args.num_frames * args.sampling_rate / float(args.target_fps) # Load segments segments: List[SegmentRow] = [] with open(segments_csv, "r") as f: for line in f: parsed = _parse_segment_line(line) if parsed is None: continue rel, s, e, lab = parsed segments.append(SegmentRow(rel_path=rel, start_sec=float(s), end_sec=float(e), label=int(lab))) if not segments: raise ValueError(f"No segments loaded from {segments_csv}") num_classes = args.num_classes if args.num_classes > 0 else infer_num_classes_from_segments(segments) cfg = TrainConfig( segments_csv=segments_csv, val_csv=val_csv, path_prefix=path_prefix, model=args.model, pretrained=bool(args.pretrained), num_classes=num_classes, num_frames=args.num_frames, sampling_rate=args.sampling_rate, target_fps=args.target_fps, clip_duration=clip_duration, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, num_workers=args.num_workers, seed=args.seed, output_dir=output_dir, amp=amp, device=device, log_interval=args.log_interval, ) with open(os.path.join(output_dir, "config.json"), "w") as f: json.dump(asdict(cfg), f, indent=2) train_ds = SegmentClipDataset( segments=segments, path_prefix=path_prefix, clip_duration=clip_duration, num_frames=args.num_frames, train=True, decoder="pyav", ) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=(device == "cuda"), persistent_workers=bool(args.num_workers > 0), ) # Simple val: still uses full-video sampling from a single random clip. # (You can also create val segments if you want segment-level validation.) val_items = read_val_csv(val_csv, path_prefix) # Reuse LabeledVideoDataset for val to sample a uniform clip from each video. from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler val_labeled = [(p, {"label": int(l)}) for p, l in val_items] val_ds = LabeledVideoDataset( labeled_video_paths=val_labeled, clip_sampler=make_clip_sampler("uniform", clip_duration), decode_audio=False, decoder="pyav", transform=VideoTransform(num_frames=args.num_frames, train=False), ) val_loader = DataLoader( val_ds, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=(device == "cuda"), persistent_workers=bool(args.num_workers > 0), ) model = build_pretrained_x3d(cfg.model, cfg.pretrained) replace_last_linear(model, cfg.num_classes) model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp) if device == "cuda" else torch.amp.GradScaler(enabled=False) best_val = -1.0 for epoch in range(1, cfg.epochs + 1): model.train() t0 = time.time() tr_loss = 0.0 tr_acc = 0.0 n = 0 last_t = time.time() for step, (x, y) in enumerate(train_loader, start=1): x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) if device == "cuda": with torch.amp.autocast("cuda", enabled=cfg.amp): logits = model(x) loss = criterion(logits, y) else: logits = model(x) loss = criterion(logits, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() tr_loss += loss.item() tr_acc += accuracy_top1(logits.detach(), y) n += 1 if cfg.log_interval > 0 and (step % cfg.log_interval) == 0: dt = time.time() - last_t last_t = time.time() print( f"[epoch {epoch:03d}] step {step:05d} " f"loss {(tr_loss/n):.4f} acc {(tr_acc/n)*100:.2f}% " f"({dt:.1f}s/{cfg.log_interval} steps)" ) tr_loss /= max(n, 1) tr_acc /= max(n, 1) # Validation model.eval() va_loss = 0.0 va_acc = 0.0 m = 0 with torch.no_grad(): for x, y in val_loader: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) logits = model(x) loss = criterion(logits, y) va_loss += loss.item() va_acc += accuracy_top1(logits, y) m += 1 va_loss /= max(m, 1) va_acc /= max(m, 1) dt_epoch = time.time() - t0 print( f"Epoch {epoch:03d}/{cfg.epochs} | " f"train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | " f"val loss {va_loss:.4f} acc {va_acc*100:.2f}% | " f"time {dt_epoch:.1f}s" ) # Save last/best torch.save( {"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "cfg": asdict(cfg)}, os.path.join(output_dir, "checkpoint_last.pt"), ) if va_acc > best_val: best_val = va_acc torch.save( {"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "cfg": asdict(cfg), "best_val_acc": best_val}, os.path.join(output_dir, "checkpoint_best.pt"), ) print(f"Done. Best val acc: {best_val*100:.2f}%") if __name__ == "__main__": main()