Files
FishServer/FishAction/train_pytorchvideo_x3d_segments.py
2026-04-08 19:32:23 +08:00

399 lines
13 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Train a video classifier using timestamp/segment labels (PyTorchVideo + X3D).
Why this script:
- If "scared" only happens in a short window, video-level single labels dilute training.
- This script trains on *timestamped segments* so clips are sampled inside labeled windows.
Segments CSV format (whitespace-separated; last token is label int):
relative/video.mp4 start_sec end_sec label
Example:
scared_underwater/foo.mp4 12.4 16.2 3
scared_underwater/foo.mp4 30.0 32.0 3
normal_underwater/bar.mp4 0.0 60.0 1
Run:
python train_pytorchvideo_x3d_segments.py \
--segments_csv /home/ubuntu/data/fish/fish_action_videos/train_segments.csv \
--val_csv /home/ubuntu/data/fish/fish_action_videos/val.csv \
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
--model x3d_m --pretrained \
--num_frames 16 --sampling_rate 5 --target_fps 30 \
--batch_size 4 --epochs 30 --num_workers 4 --amp \
--output_dir /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m_segments
"""
from __future__ import annotations
import argparse
import json
import math
import os
import random
import time
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pytorchvideo.data.encoded_video import EncodedVideo
from train_pytorchvideo_x3d import (
VideoTransform,
_parse_path_label_line,
build_pretrained_x3d,
replace_last_linear,
set_seed,
)
def _parse_segment_line(line: str) -> Optional[Tuple[str, float, float, int]]:
"""
Parse: path start end label
Supports paths with spaces (label is last token).
"""
line = line.strip()
if not line or line.startswith("#"):
return None
parts = line.split()
if len(parts) < 4:
return None
try:
label = int(parts[-1])
end_sec = float(parts[-2])
start_sec = float(parts[-3])
except ValueError:
return None
path = " ".join(parts[:-3])
return path, start_sec, end_sec, label
@dataclass
class SegmentRow:
rel_path: str
start_sec: float
end_sec: float
label: int
class SegmentClipDataset(Dataset):
"""
Each item is a labeled time segment inside a video. We sample one clip from within
that segment (random for train, center for eval).
"""
def __init__(
self,
segments: List[SegmentRow],
path_prefix: str,
clip_duration: float,
num_frames: int,
train: bool,
decoder: str = "pyav",
):
self.segments = segments
self.path_prefix = path_prefix
self.clip_duration = float(clip_duration)
self.train = train
self.decoder = decoder
# Reuse our existing transform pipeline (expects dict with "video" + "label")
self.transform = VideoTransform(num_frames=num_frames, train=train)
def __len__(self) -> int:
return len(self.segments)
def __getitem__(self, idx: int):
seg = self.segments[idx]
full_path = os.path.join(self.path_prefix, seg.rel_path)
# Choose clip start inside [start_sec, end_sec - clip_duration]
start = float(seg.start_sec)
end = float(seg.end_sec)
if end <= start:
end = start + self.clip_duration
latest_start = max(start, end - self.clip_duration)
if self.train:
clip_start = random.uniform(start, latest_start) if latest_start > start else start
else:
clip_start = (start + latest_start) / 2.0
clip_end = clip_start + self.clip_duration
video = EncodedVideo.from_path(full_path, decode_audio=False, decoder=self.decoder)
clip = video.get_clip(clip_start, clip_end)
frames = clip.get("video", None)
if frames is None:
raise RuntimeError(f"Failed to decode clip from {full_path} [{clip_start},{clip_end}]")
sample = {"video": frames, "label": seg.label}
x, y = self.transform(sample)
return x, y
def infer_num_classes_from_segments(rows: List[SegmentRow]) -> int:
m = -1
for r in rows:
m = max(m, int(r.label))
return m + 1
@dataclass
class TrainConfig:
segments_csv: str
val_csv: str
path_prefix: str
model: str
pretrained: bool
num_classes: int
num_frames: int
sampling_rate: int
target_fps: int
clip_duration: float
batch_size: int
epochs: int
lr: float
weight_decay: float
num_workers: int
seed: int
output_dir: str
amp: bool
device: str
log_interval: int
def accuracy_top1(logits: torch.Tensor, labels: torch.Tensor) -> float:
preds = torch.argmax(logits, dim=1)
return (preds == labels).float().mean().item()
def read_val_csv(csv_path: str, path_prefix: str) -> List[Tuple[str, int]]:
"""
Returns list of (full_path, label).
Uses robust path/label parser from the main script.
"""
out = []
with open(csv_path, "r") as f:
for line in f:
parsed = _parse_path_label_line(line)
if parsed is None:
continue
rel_path, label = parsed
out.append((os.path.join(path_prefix, rel_path), int(label)))
return out
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--segments_csv", type=str, required=True, help="train segments csv (path start end label)")
parser.add_argument("--val_csv", type=str, required=True, help="val csv (path label)")
parser.add_argument("--path_prefix", type=str, required=True)
parser.add_argument("--model", type=str, default="x3d_m", choices=["x3d_xs", "x3d_s", "x3d_m", "x3d_l"])
parser.add_argument("--pretrained", action="store_true")
parser.add_argument("--num_classes", type=int, default=0, help="0=auto from segments")
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--sampling_rate", type=int, default=5)
parser.add_argument("--target_fps", type=int, default=30)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--epochs", type=int, default=30)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--log_interval", type=int, default=20)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
amp = bool(args.amp and device == "cuda")
segments_csv = os.path.abspath(os.path.expanduser(args.segments_csv))
val_csv = os.path.abspath(os.path.expanduser(args.val_csv))
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
output_dir = os.path.abspath(os.path.expanduser(args.output_dir))
os.makedirs(output_dir, exist_ok=True)
set_seed(args.seed)
clip_duration = args.num_frames * args.sampling_rate / float(args.target_fps)
# Load segments
segments: List[SegmentRow] = []
with open(segments_csv, "r") as f:
for line in f:
parsed = _parse_segment_line(line)
if parsed is None:
continue
rel, s, e, lab = parsed
segments.append(SegmentRow(rel_path=rel, start_sec=float(s), end_sec=float(e), label=int(lab)))
if not segments:
raise ValueError(f"No segments loaded from {segments_csv}")
num_classes = args.num_classes if args.num_classes > 0 else infer_num_classes_from_segments(segments)
cfg = TrainConfig(
segments_csv=segments_csv,
val_csv=val_csv,
path_prefix=path_prefix,
model=args.model,
pretrained=bool(args.pretrained),
num_classes=num_classes,
num_frames=args.num_frames,
sampling_rate=args.sampling_rate,
target_fps=args.target_fps,
clip_duration=clip_duration,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
num_workers=args.num_workers,
seed=args.seed,
output_dir=output_dir,
amp=amp,
device=device,
log_interval=args.log_interval,
)
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(asdict(cfg), f, indent=2)
train_ds = SegmentClipDataset(
segments=segments,
path_prefix=path_prefix,
clip_duration=clip_duration,
num_frames=args.num_frames,
train=True,
decoder="pyav",
)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=(device == "cuda"),
persistent_workers=bool(args.num_workers > 0),
)
# Simple val: still uses full-video sampling from a single random clip.
# (You can also create val segments if you want segment-level validation.)
val_items = read_val_csv(val_csv, path_prefix)
# Reuse LabeledVideoDataset for val to sample a uniform clip from each video.
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
val_labeled = [(p, {"label": int(l)}) for p, l in val_items]
val_ds = LabeledVideoDataset(
labeled_video_paths=val_labeled,
clip_sampler=make_clip_sampler("uniform", clip_duration),
decode_audio=False,
decoder="pyav",
transform=VideoTransform(num_frames=args.num_frames, train=False),
)
val_loader = DataLoader(
val_ds,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=(device == "cuda"),
persistent_workers=bool(args.num_workers > 0),
)
model = build_pretrained_x3d(cfg.model, cfg.pretrained)
replace_last_linear(model, cfg.num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp) if device == "cuda" else torch.amp.GradScaler(enabled=False)
best_val = -1.0
for epoch in range(1, cfg.epochs + 1):
model.train()
t0 = time.time()
tr_loss = 0.0
tr_acc = 0.0
n = 0
last_t = time.time()
for step, (x, y) in enumerate(train_loader, start=1):
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
if device == "cuda":
with torch.amp.autocast("cuda", enabled=cfg.amp):
logits = model(x)
loss = criterion(logits, y)
else:
logits = model(x)
loss = criterion(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
tr_loss += loss.item()
tr_acc += accuracy_top1(logits.detach(), y)
n += 1
if cfg.log_interval > 0 and (step % cfg.log_interval) == 0:
dt = time.time() - last_t
last_t = time.time()
print(
f"[epoch {epoch:03d}] step {step:05d} "
f"loss {(tr_loss/n):.4f} acc {(tr_acc/n)*100:.2f}% "
f"({dt:.1f}s/{cfg.log_interval} steps)"
)
tr_loss /= max(n, 1)
tr_acc /= max(n, 1)
# Validation
model.eval()
va_loss = 0.0
va_acc = 0.0
m = 0
with torch.no_grad():
for x, y in val_loader:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
logits = model(x)
loss = criterion(logits, y)
va_loss += loss.item()
va_acc += accuracy_top1(logits, y)
m += 1
va_loss /= max(m, 1)
va_acc /= max(m, 1)
dt_epoch = time.time() - t0
print(
f"Epoch {epoch:03d}/{cfg.epochs} | "
f"train loss {tr_loss:.4f} acc {tr_acc*100:.2f}% | "
f"val loss {va_loss:.4f} acc {va_acc*100:.2f}% | "
f"time {dt_epoch:.1f}s"
)
# Save last/best
torch.save(
{"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "cfg": asdict(cfg)},
os.path.join(output_dir, "checkpoint_last.pt"),
)
if va_acc > best_val:
best_val = va_acc
torch.save(
{"epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scaler": scaler.state_dict(), "cfg": asdict(cfg), "best_val_acc": best_val},
os.path.join(output_dir, "checkpoint_best.pt"),
)
print(f"Done. Best val acc: {best_val*100:.2f}%")
if __name__ == "__main__":
main()