399 lines
13 KiB
Python
Executable File
399 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Train a video classifier using timestamp/segment labels (PyTorchVideo + X3D).
|
|
|
|
Why this script:
|
|
- If "scared" only happens in a short window, video-level single labels dilute training.
|
|
- This script trains on *timestamped segments* so clips are sampled inside labeled windows.
|
|
|
|
Segments CSV format (whitespace-separated; last token is label int):
|
|
relative/video.mp4 start_sec end_sec label
|
|
|
|
Example:
|
|
scared_underwater/foo.mp4 12.4 16.2 3
|
|
scared_underwater/foo.mp4 30.0 32.0 3
|
|
normal_underwater/bar.mp4 0.0 60.0 1
|
|
|
|
Run:
|
|
python train_pytorchvideo_x3d_segments.py \
|
|
--segments_csv /home/ubuntu/data/fish/fish_action_videos/train_segments.csv \
|
|
--val_csv /home/ubuntu/data/fish/fish_action_videos/val.csv \
|
|
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
|
|
--model x3d_m --pretrained \
|
|
--num_frames 16 --sampling_rate 5 --target_fps 30 \
|
|
--batch_size 4 --epochs 30 --num_workers 4 --amp \
|
|
--output_dir /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m_segments
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import time
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from pytorchvideo.data.encoded_video import EncodedVideo
|
|
|
|
from train_pytorchvideo_x3d import (
|
|
VideoTransform,
|
|
_parse_path_label_line,
|
|
build_pretrained_x3d,
|
|
replace_last_linear,
|
|
set_seed,
|
|
)
|
|
|
|
|
|
def _parse_segment_line(line: str) -> Optional[Tuple[str, float, float, int]]:
|
|
"""
|
|
Parse: path start end label
|
|
Supports paths with spaces (label is last token).
|
|
"""
|
|
line = line.strip()
|
|
if not line or line.startswith("#"):
|
|
return None
|
|
parts = line.split()
|
|
if len(parts) < 4:
|
|
return None
|
|
try:
|
|
label = int(parts[-1])
|
|
end_sec = float(parts[-2])
|
|
start_sec = float(parts[-3])
|
|
except ValueError:
|
|
return None
|
|
path = " ".join(parts[:-3])
|
|
return path, start_sec, end_sec, label
|
|
|
|
|
|
@dataclass
|
|
class SegmentRow:
|
|
rel_path: str
|
|
start_sec: float
|
|
end_sec: float
|
|
label: int
|
|
|
|
|
|
class SegmentClipDataset(Dataset):
|
|
"""
|
|
Each item is a labeled time segment inside a video. We sample one clip from within
|
|
that segment (random for train, center for eval).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
segments: List[SegmentRow],
|
|
path_prefix: str,
|
|
clip_duration: float,
|
|
num_frames: int,
|
|
train: bool,
|
|
decoder: str = "pyav",
|
|
):
|
|
self.segments = segments
|
|
self.path_prefix = path_prefix
|
|
self.clip_duration = float(clip_duration)
|
|
self.train = train
|
|
self.decoder = decoder
|
|
# Reuse our existing transform pipeline (expects dict with "video" + "label")
|
|
self.transform = VideoTransform(num_frames=num_frames, train=train)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.segments)
|
|
|
|
def __getitem__(self, idx: int):
|
|
seg = self.segments[idx]
|
|
full_path = os.path.join(self.path_prefix, seg.rel_path)
|
|
|
|
# Choose clip start inside [start_sec, end_sec - clip_duration]
|
|
start = float(seg.start_sec)
|
|
end = float(seg.end_sec)
|
|
if end <= start:
|
|
end = start + self.clip_duration
|
|
|
|
latest_start = max(start, end - self.clip_duration)
|
|
if self.train:
|
|
clip_start = random.uniform(start, latest_start) if latest_start > start else start
|
|
else:
|
|
clip_start = (start + latest_start) / 2.0
|
|
clip_end = clip_start + self.clip_duration
|
|
|
|
video = EncodedVideo.from_path(full_path, decode_audio=False, decoder=self.decoder)
|
|
clip = video.get_clip(clip_start, clip_end)
|
|
frames = clip.get("video", None)
|
|
if frames is None:
|
|
raise RuntimeError(f"Failed to decode clip from {full_path} [{clip_start},{clip_end}]")
|
|
|
|
sample = {"video": frames, "label": seg.label}
|
|
x, y = self.transform(sample)
|
|
return x, y
|
|
|
|
|
|
def infer_num_classes_from_segments(rows: List[SegmentRow]) -> int:
|
|
m = -1
|
|
for r in rows:
|
|
m = max(m, int(r.label))
|
|
return m + 1
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
segments_csv: str
|
|
val_csv: str
|
|
path_prefix: str
|
|
model: str
|
|
pretrained: bool
|
|
num_classes: int
|
|
num_frames: int
|
|
sampling_rate: int
|
|
target_fps: int
|
|
clip_duration: float
|
|
batch_size: int
|
|
epochs: int
|
|
lr: float
|
|
weight_decay: float
|
|
num_workers: int
|
|
seed: int
|
|
output_dir: str
|
|
amp: bool
|
|
device: str
|
|
log_interval: int
|
|
|
|
|
|
def accuracy_top1(logits: torch.Tensor, labels: torch.Tensor) -> float:
|
|
preds = torch.argmax(logits, dim=1)
|
|
return (preds == labels).float().mean().item()
|
|
|
|
|
|
def read_val_csv(csv_path: str, path_prefix: str) -> List[Tuple[str, int]]:
|
|
"""
|
|
Returns list of (full_path, label).
|
|
Uses robust path/label parser from the main script.
|
|
"""
|
|
out = []
|
|
with open(csv_path, "r") as f:
|
|
for line in f:
|
|
parsed = _parse_path_label_line(line)
|
|
if parsed is None:
|
|
continue
|
|
rel_path, label = parsed
|
|
out.append((os.path.join(path_prefix, rel_path), int(label)))
|
|
return out
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--segments_csv", type=str, required=True, help="train segments csv (path start end label)")
|
|
parser.add_argument("--val_csv", type=str, required=True, help="val csv (path label)")
|
|
parser.add_argument("--path_prefix", type=str, required=True)
|
|
parser.add_argument("--model", type=str, default="x3d_m", choices=["x3d_xs", "x3d_s", "x3d_m", "x3d_l"])
|
|
parser.add_argument("--pretrained", action="store_true")
|
|
parser.add_argument("--num_classes", type=int, default=0, help="0=auto from segments")
|
|
parser.add_argument("--num_frames", type=int, default=16)
|
|
parser.add_argument("--sampling_rate", type=int, default=5)
|
|
parser.add_argument("--target_fps", type=int, default=30)
|
|
parser.add_argument("--batch_size", type=int, default=4)
|
|
parser.add_argument("--epochs", type=int, default=30)
|
|
parser.add_argument("--lr", type=float, default=1e-3)
|
|
parser.add_argument("--weight_decay", type=float, default=1e-4)
|
|
parser.add_argument("--num_workers", type=int, default=4)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--output_dir", type=str, required=True)
|
|
parser.add_argument("--amp", action="store_true")
|
|
parser.add_argument("--log_interval", type=int, default=20)
|
|
args = parser.parse_args()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
amp = bool(args.amp and device == "cuda")
|
|
|
|
segments_csv = os.path.abspath(os.path.expanduser(args.segments_csv))
|
|
val_csv = os.path.abspath(os.path.expanduser(args.val_csv))
|
|
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
|
|
output_dir = os.path.abspath(os.path.expanduser(args.output_dir))
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
set_seed(args.seed)
|
|
|
|
clip_duration = args.num_frames * args.sampling_rate / float(args.target_fps)
|
|
|
|
# Load segments
|
|
segments: List[SegmentRow] = []
|
|
with open(segments_csv, "r") as f:
|
|
for line in f:
|
|
parsed = _parse_segment_line(line)
|
|
if parsed is None:
|
|
continue
|
|
rel, s, e, lab = parsed
|
|
segments.append(SegmentRow(rel_path=rel, start_sec=float(s), end_sec=float(e), label=int(lab)))
|
|
if not segments:
|
|
raise ValueError(f"No segments loaded from {segments_csv}")
|
|
|
|
num_classes = args.num_classes if args.num_classes > 0 else infer_num_classes_from_segments(segments)
|
|
|
|
cfg = TrainConfig(
|
|
segments_csv=segments_csv,
|
|
val_csv=val_csv,
|
|
path_prefix=path_prefix,
|
|
model=args.model,
|
|
pretrained=bool(args.pretrained),
|
|
num_classes=num_classes,
|
|
num_frames=args.num_frames,
|
|
sampling_rate=args.sampling_rate,
|
|
target_fps=args.target_fps,
|
|
clip_duration=clip_duration,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs,
|
|
lr=args.lr,
|
|
weight_decay=args.weight_decay,
|
|
num_workers=args.num_workers,
|
|
seed=args.seed,
|
|
output_dir=output_dir,
|
|
amp=amp,
|
|
device=device,
|
|
log_interval=args.log_interval,
|
|
)
|
|
with open(os.path.join(output_dir, "config.json"), "w") as f:
|
|
json.dump(asdict(cfg), f, indent=2)
|
|
|
|
train_ds = SegmentClipDataset(
|
|
segments=segments,
|
|
path_prefix=path_prefix,
|
|
clip_duration=clip_duration,
|
|
num_frames=args.num_frames,
|
|
train=True,
|
|
decoder="pyav",
|
|
)
|
|
train_loader = DataLoader(
|
|
train_ds,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
num_workers=args.num_workers,
|
|
pin_memory=(device == "cuda"),
|
|
persistent_workers=bool(args.num_workers > 0),
|
|
)
|
|
|
|
# Simple val: still uses full-video sampling from a single random clip.
|
|
# (You can also create val segments if you want segment-level validation.)
|
|
val_items = read_val_csv(val_csv, path_prefix)
|
|
|
|
# Reuse LabeledVideoDataset for val to sample a uniform clip from each video.
|
|
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
|
|
|
|
val_labeled = [(p, {"label": int(l)}) for p, l in val_items]
|
|
val_ds = LabeledVideoDataset(
|
|
labeled_video_paths=val_labeled,
|
|
clip_sampler=make_clip_sampler("uniform", clip_duration),
|
|
decode_audio=False,
|
|
decoder="pyav",
|
|
transform=VideoTransform(num_frames=args.num_frames, train=False),
|
|
)
|
|
val_loader = DataLoader(
|
|
val_ds,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
pin_memory=(device == "cuda"),
|
|
persistent_workers=bool(args.num_workers > 0),
|
|
)
|
|
|
|
model = build_pretrained_x3d(cfg.model, cfg.pretrained)
|
|
replace_last_linear(model, cfg.num_classes)
|
|
model = model.to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
|
|
scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp) if device == "cuda" else torch.amp.GradScaler(enabled=False)
|
|
|
|
best_val = -1.0
|
|
|
|
for epoch in range(1, cfg.epochs + 1):
|
|
model.train()
|
|
t0 = time.time()
|
|
tr_loss = 0.0
|
|
tr_acc = 0.0
|
|
n = 0
|
|
last_t = time.time()
|
|
|
|
for step, (x, y) in enumerate(train_loader, start=1):
|
|
x = x.to(device, non_blocking=True)
|
|
y = y.to(device, non_blocking=True)
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
if device == "cuda":
|
|
with torch.amp.autocast("cuda", enabled=cfg.amp):
|
|
logits = model(x)
|
|
loss = criterion(logits, y)
|
|
else:
|
|
logits = model(x)
|
|
loss = criterion(logits, y)
|
|
|
|
scaler.scale(loss).backward()
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
tr_loss += loss.item()
|
|
tr_acc += accuracy_top1(logits.detach(), y)
|
|
n += 1
|
|
|
|
if cfg.log_interval > 0 and (step % cfg.log_interval) == 0:
|
|
dt = time.time() - last_t
|
|
last_t = time.time()
|
|
print(
|
|
f"[epoch {epoch:03d}] step {step:05d} "
|
|
f"loss {(tr_loss/n):.4f} acc {(tr_acc/n)*100:.2f}% "
|
|
f"({dt:.1f}s/{cfg.log_interval} steps)"
|
|
)
|
|
|
|
tr_loss /= max(n, 1)
|
|
tr_acc /= max(n, 1)
|
|
|
|
# Validation
|
|
model.eval()
|
|
va_loss = 0.0
|
|
va_acc = 0.0
|
|
m = 0
|
|
with torch.no_grad():
|
|
for x, y in val_loader:
|
|
x = x.to(device, non_blocking=True)
|
|
y = y.to(device, non_blocking=True)
|
|
logits = model(x)
|
|
loss = criterion(logits, y)
|
|
va_loss += loss.item()
|
|
va_acc += accuracy_top1(logits, y)
|
|
m += 1
|
|
va_loss /= max(m, 1)
|
|
va_acc /= max(m, 1)
|
|
|
|
dt_epoch = time.time() - t0
|
|
print(
|
|
f"Epoch {epoch:03d}/{cfg.epochs} | "
|
|
f"train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | "
|
|
f"val loss {va_loss:.4f} acc {va_acc*100:.2f}% | "
|
|
f"time {dt_epoch:.1f}s"
|
|
)
|
|
|
|
# Save last/best
|
|
torch.save(
|
|
{"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "cfg": asdict(cfg)},
|
|
os.path.join(output_dir, "checkpoint_last.pt"),
|
|
)
|
|
if va_acc > best_val:
|
|
best_val = va_acc
|
|
torch.save(
|
|
{"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "cfg": asdict(cfg), "best_val_acc": best_val},
|
|
os.path.join(output_dir, "checkpoint_best.pt"),
|
|
)
|
|
|
|
print(f"Done. Best val acc: {best_val*100:.2f}%")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|