Files
FishServer/FishAction/train_pytorchvideo_x3d.py
2026-04-08 19:32:23 +08:00

797 lines
28 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Train a video classifier using PyTorchVideo (X3D) on a CSV dataset.
This script is intentionally independent from SlowFast, so it works even if
SlowFast's optional deps/imports are broken.
Expected CSV format (space-separated, like SlowFast/Kinetics):
relative/path/to/video.mp4 <label_int>
Example:
python train_pytorchvideo_x3d.py \
--csv_dir /home/ubuntu/projects/FishAction/data/fish/fish_action_training_dataset \
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
--model x3d_m \
--pretrained \
--batch_size 4 \
--epochs 30 \
--num_workers 4 \
--output_dir /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m
"""
from __future__ import annotations
import argparse
import contextlib
import json
import os
import random
import time
from math import cos, pi
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
KINETICS_MEAN = (0.45, 0.45, 0.45)
KINETICS_STD = (0.225, 0.225, 0.225)
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def _parse_path_label_line(line: str) -> Tuple[str, int] | None:
"""
Robust parser for CSV lines.
Expected (whitespace-separated):
path label
If the path contains spaces, we take the last token as label:
/some/path/with spaces/video.mp4 3
"""
line = line.strip()
if not line or line.startswith("#"):
return None
parts = line.split()
if len(parts) < 2:
return None
label_str = parts[-1]
try:
label = int(label_str)
except ValueError:
# likely a header row, e.g. "path label"
return None
path = " ".join(parts[:-1])
return path, label
def read_csv_list(csv_path: str, path_prefix: str) -> List[Tuple[str, Dict]]:
items: List[Tuple[str, Dict]] = []
with open(csv_path, "r") as f:
for line in f:
parsed = _parse_path_label_line(line)
if parsed is None:
continue
rel_path, label = parsed
full = os.path.join(path_prefix, rel_path)
items.append((full, {"label": int(label)}))
return items
def infer_num_classes(csv_path: str) -> int:
max_label = -1
with open(csv_path, "r") as f:
for line in f:
parsed = _parse_path_label_line(line)
if parsed is None:
continue
_, label = parsed
max_label = max(max_label, int(label))
return max_label + 1
def uniform_temporal_subsample_cthw(video: torch.Tensor, num_frames: int) -> torch.Tensor:
# video: C T H W
c, t, h, w = video.shape
if t == num_frames:
return video
if t < 1:
raise ValueError(f"Invalid temporal dimension t={t}")
idx = torch.linspace(0, t - 1, steps=num_frames, device=video.device).long()
return torch.index_select(video, dim=1, index=idx)
def to_cthw(video: torch.Tensor) -> torch.Tensor:
"""
Convert common video tensor layouts to C T H W.
Supported:
- T H W C
- T C H W
- C T H W
"""
if video.ndim != 4:
raise ValueError(f"Expected 4D video tensor, got shape {tuple(video.shape)}")
# Heuristics for channel location (3 channels RGB).
if video.shape[0] == 3:
# C T H W
return video
if video.shape[-1] == 3:
# T H W C -> C T H W
return video.permute(3, 0, 1, 2)
if video.shape[1] == 3:
# T C H W -> C T H W
return video.permute(1, 0, 2, 3)
raise ValueError(f"Could not infer channel dim for video shape {tuple(video.shape)}")
def resize_short_side_tchw(frames_tchw: torch.Tensor, short_side: int) -> torch.Tensor:
# frames_tchw: T C H W
t, c, h, w = frames_tchw.shape
if min(h, w) == short_side:
return frames_tchw
if h < w:
new_h = short_side
new_w = int(round(w * short_side / h))
else:
new_w = short_side
new_h = int(round(h * short_side / w))
return F.interpolate(frames_tchw, size=(new_h, new_w), mode="bilinear", align_corners=False)
def random_crop_tchw(frames_tchw: torch.Tensor, crop_size: int) -> torch.Tensor:
t, c, h, w = frames_tchw.shape
if h < crop_size or w < crop_size:
# Fallback: center crop after resizing to crop_size.
frames_tchw = F.interpolate(
frames_tchw, size=(max(h, crop_size), max(w, crop_size)), mode="bilinear", align_corners=False
)
t, c, h, w = frames_tchw.shape
top = random.randint(0, h - crop_size) if h > crop_size else 0
left = random.randint(0, w - crop_size) if w > crop_size else 0
return frames_tchw[:, :, top : top + crop_size, left : left + crop_size]
def center_crop_tchw(frames_tchw: torch.Tensor, crop_size: int) -> torch.Tensor:
t, c, h, w = frames_tchw.shape
if h < crop_size or w < crop_size:
frames_tchw = F.interpolate(
frames_tchw, size=(max(h, crop_size), max(w, crop_size)), mode="bilinear", align_corners=False
)
t, c, h, w = frames_tchw.shape
top = max((h - crop_size) // 2, 0)
left = max((w - crop_size) // 2, 0)
return frames_tchw[:, :, top : top + crop_size, left : left + crop_size]
def normalize_cthw(video_cthw: torch.Tensor, mean=KINETICS_MEAN, std=KINETICS_STD) -> torch.Tensor:
m = torch.tensor(mean, device=video_cthw.device, dtype=video_cthw.dtype).view(3, 1, 1, 1)
s = torch.tensor(std, device=video_cthw.device, dtype=video_cthw.dtype).view(3, 1, 1, 1)
return (video_cthw - m) / s
class VideoTransform:
def __init__(
self,
num_frames: int,
train: bool,
short_side: int = 256,
crop_size: int = 224,
random_flip: bool = True,
):
self.num_frames = num_frames
self.train = train
self.short_side = short_side
self.crop_size = crop_size
self.random_flip = random_flip
def __call__(self, sample: Dict):
video = sample["video"]
label_obj = sample.get("label", 0)
if isinstance(label_obj, dict):
label = int(label_obj.get("label", 0))
else:
label = int(label_obj)
video = to_cthw(video)
if video.dtype == torch.uint8:
video = video.float() / 255.0
else:
video = video.float()
video = uniform_temporal_subsample_cthw(video, self.num_frames) # C T H W
# Spatial ops on T C H W.
frames_tchw = video.permute(1, 0, 2, 3)
frames_tchw = resize_short_side_tchw(frames_tchw, self.short_side)
if self.train:
frames_tchw = random_crop_tchw(frames_tchw, self.crop_size)
if self.random_flip and random.random() < 0.5:
frames_tchw = torch.flip(frames_tchw, dims=[3])
else:
frames_tchw = center_crop_tchw(frames_tchw, self.crop_size)
video = frames_tchw.permute(1, 0, 2, 3) # back to C T H W
video = normalize_cthw(video)
return video, torch.tensor(label, dtype=torch.long)
def build_pretrained_x3d(model_name: str, pretrained: bool) -> nn.Module:
"""
model_name: x3d_xs | x3d_s | x3d_m | x3d_l
"""
if pretrained:
# Uses torch hub; downloads pytorchvideo main repo + weights once into ~/.cache/torch/hub/
return torch.hub.load("facebookresearch/pytorchvideo", model_name, pretrained=True)
# No pretrained: build from pytorchvideo package API.
from pytorchvideo.models.x3d import create_x3d
return create_x3d(model_num_class=400)
def replace_last_linear(model: nn.Module, num_classes: int) -> None:
# X3D hub model ends with blocks.5.proj (Linear out_features=400).
# We replace the last Linear module found.
last_name = None
last_linear = None
for name, mod in model.named_modules():
if isinstance(mod, nn.Linear):
last_name = name
last_linear = mod
if last_name is None or last_linear is None:
raise RuntimeError("Could not find a Linear layer to replace for classification head.")
in_features = last_linear.in_features
new_linear = nn.Linear(in_features, num_classes)
# Set it by walking the module path.
parts = last_name.split(".")
parent = model
for p in parts[:-1]:
parent = getattr(parent, p)
setattr(parent, parts[-1], new_linear)
@dataclass
class TrainConfig:
csv_dir: str
path_prefix: str
train_csv: str
val_csv: str
model: str
pretrained: bool
num_classes: int
num_frames: int
sampling_rate: int
target_fps: int
batch_size: int
epochs: int
lr: float
weight_decay: float
scheduler: str
warmup_epochs: int
min_lr: float
step_size: int
gamma: float
plateau_factor: float
plateau_patience: int
num_workers: int
prefetch_factor: int
persistent_workers: bool
seed: int
output_dir: str
amp: bool
device: str
decoder: str
train_sampling: str
train_clips_per_video: int
val_sampling: str
val_clips_per_video: int
tf32: bool
def accuracy_top1(logits: torch.Tensor, labels: torch.Tensor) -> float:
preds = torch.argmax(logits, dim=1)
correct = (preds == labels).sum().item()
return correct / labels.numel()
def _get_lr(optimizer: torch.optim.Optimizer) -> float:
return float(optimizer.param_groups[0]["lr"])
def build_scheduler(
optimizer: torch.optim.Optimizer, cfg: TrainConfig
) -> tuple[torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.ReduceLROnPlateau | None, bool]:
"""
Returns: (scheduler, needs_val_metric)
"""
if cfg.scheduler == "none":
return None, False
if cfg.scheduler == "plateau":
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=cfg.plateau_factor,
patience=cfg.plateau_patience,
)
return sched, True
# For cosine/step we implement warmup + schedule as a single LambdaLR, stepped once per epoch.
warmup = max(int(cfg.warmup_epochs), 0)
total = max(int(cfg.epochs), 1)
min_lr_factor = float(cfg.min_lr / cfg.lr) if cfg.lr > 0 else 0.0
min_lr_factor = max(0.0, min(1.0, min_lr_factor))
def lr_lambda(epoch_idx0: int) -> float:
# epoch_idx0 is 0-based.
e = int(epoch_idx0)
if warmup > 0 and e < warmup:
return float(e + 1) / float(warmup)
if cfg.scheduler == "step":
e2 = e - warmup
if cfg.step_size <= 0:
return 1.0
n_steps = max(e2 // int(cfg.step_size), 0)
return float(cfg.gamma) ** float(n_steps)
if cfg.scheduler == "cosine":
denom = max(total - warmup, 1)
t = (e - warmup) / float(denom) # 0..1-ish
t = min(max(t, 0.0), 1.0)
cosine_factor = 0.5 * (1.0 + cos(pi * t))
return min_lr_factor + (1.0 - min_lr_factor) * cosine_factor
return 1.0
sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
return sched, False
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--csv_dir", type=str, required=True, help="Folder containing train.csv/val.csv")
parser.add_argument("--path_prefix", type=str, required=True, help="Absolute path prefix for videos")
parser.add_argument("--train_csv", type=str, default="train.csv")
parser.add_argument("--val_csv", type=str, default="val.csv")
parser.add_argument("--model", type=str, default="x3d_m", choices=["x3d_xs", "x3d_s", "x3d_m", "x3d_l"])
parser.add_argument("--pretrained", action="store_true", help="Use torch.hub pretrained weights")
parser.add_argument("--num_classes", type=int, default=0, help="If 0, infer from train.csv max label + 1")
parser.add_argument("--num_frames", type=int, default=16)
parser.add_argument("--sampling_rate", type=int, default=5, help="Used only for clip_duration calculation")
parser.add_argument("--target_fps", type=int, default=30, help="Used only for clip_duration calculation")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--epochs", type=int, default=30)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--weight_decay", type=float, default=1e-4)
parser.add_argument(
"--scheduler",
type=str,
default="none",
choices=["none", "cosine", "step", "plateau"],
help=(
"LR scheduler. cosine/step are stepped once per epoch. "
"plateau monitors val loss and reduces LR when it stops improving."
),
)
parser.add_argument(
"--warmup_epochs",
type=int,
default=0,
help="Warm up LR linearly for N epochs (cosine/step).",
)
parser.add_argument("--min_lr", type=float, default=1e-5, help="Cosine schedule minimum LR (absolute).")
parser.add_argument("--step_size", type=int, default=10, help="Step schedule: decay every N epochs (after warmup).")
parser.add_argument("--gamma", type=float, default=0.1, help="Step schedule: multiplicative decay factor.")
parser.add_argument(
"--plateau_factor",
type=float,
default=0.1,
help="Plateau schedule: LR *= factor when val loss plateaus.",
)
parser.add_argument(
"--plateau_patience",
type=int,
default=2,
help="Plateau schedule: epochs to wait for val loss improvement before decaying.",
)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument(
"--prefetch_factor",
type=int,
default=2,
help="DataLoader prefetch factor per worker (only used when num_workers > 0).",
)
parser.add_argument(
"--persistent_workers",
action="store_true",
help="Keep DataLoader workers alive across epochs (faster, uses more RAM; only if num_workers > 0).",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--amp", action="store_true", help="Use mixed precision (CUDA only)")
parser.add_argument(
"--tf32",
action="store_true",
help="Allow TF32 matmul/conv on Ampere+ (CUDA only). Can speed up training a bit.",
)
parser.add_argument(
"--decoder",
type=str,
default="pyav",
help="Video decoder backend for PyTorchVideo LabeledVideoDataset (common: pyav, decord if installed).",
)
parser.add_argument(
"--train_sampling",
type=str,
default="random",
choices=["random", "constant_clips_per_video", "uniform"],
help=(
"Training clip sampling. "
"'random' uses 1 random clip per video per epoch. "
"'constant_clips_per_video' uses N evenly spaced clips per video per epoch. "
"'uniform' splits videos into sequential clips (long videos yield more clips)."
),
)
parser.add_argument(
"--train_clips_per_video",
type=int,
default=1,
help="Used only when --train_sampling=constant_clips_per_video.",
)
parser.add_argument(
"--val_sampling",
type=str,
default="uniform",
choices=["uniform", "random", "constant_clips_per_video"],
help=(
"Validation clip sampling. 'uniform' may produce many clips for long videos. "
"'random' uses 1 random clip per video. "
"'constant_clips_per_video' uses N evenly spaced clips per video."
),
)
parser.add_argument(
"--val_clips_per_video",
type=int,
default=10,
help="Used only when --val_sampling=constant_clips_per_video.",
)
parser.add_argument(
"--log_interval",
type=int,
default=20,
help="Print training/val progress every N batches (default: 20). Use 1 for every batch, or --log_every_batch.",
)
parser.add_argument(
"--log_every_batch",
action="store_true",
help="Print one line after every training and validation batch (includes batch/total when known).",
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
amp = bool(args.amp and device == "cuda")
tf32 = bool(args.tf32 and device == "cuda")
if tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
csv_dir = os.path.abspath(os.path.expanduser(args.csv_dir))
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
train_csv_path = os.path.join(csv_dir, args.train_csv)
val_csv_path = os.path.join(csv_dir, args.val_csv)
if not os.path.isfile(train_csv_path):
raise FileNotFoundError(
f"train csv not found: {train_csv_path}\n"
f"Tip: --csv_dir should be the folder that contains train.csv/val.csv, "
f"NOT the raw video folder."
)
if not os.path.isfile(val_csv_path):
raise FileNotFoundError(f"val csv not found: {val_csv_path}")
if args.num_classes and args.num_classes > 0:
num_classes = args.num_classes
else:
num_classes = infer_num_classes(train_csv_path)
cfg = TrainConfig(
csv_dir=csv_dir,
path_prefix=path_prefix,
train_csv=args.train_csv,
val_csv=args.val_csv,
model=args.model,
pretrained=bool(args.pretrained),
num_classes=num_classes,
num_frames=args.num_frames,
sampling_rate=args.sampling_rate,
target_fps=args.target_fps,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
scheduler=str(args.scheduler),
warmup_epochs=int(args.warmup_epochs),
min_lr=float(args.min_lr),
step_size=int(args.step_size),
gamma=float(args.gamma),
plateau_factor=float(args.plateau_factor),
plateau_patience=int(args.plateau_patience),
num_workers=args.num_workers,
prefetch_factor=args.prefetch_factor,
persistent_workers=bool(args.persistent_workers and args.num_workers > 0),
seed=args.seed,
output_dir=os.path.abspath(os.path.expanduser(args.output_dir)),
amp=amp,
device=device,
decoder=str(args.decoder),
train_sampling=str(args.train_sampling),
train_clips_per_video=int(args.train_clips_per_video),
val_sampling=str(args.val_sampling),
val_clips_per_video=int(args.val_clips_per_video),
tf32=tf32,
)
os.makedirs(cfg.output_dir, exist_ok=True)
with open(os.path.join(cfg.output_dir, "config.json"), "w") as f:
json.dump(asdict(cfg), f, indent=2)
set_seed(cfg.seed)
torch.backends.cudnn.benchmark = True
# Clip duration in seconds, for the clip sampler.
clip_duration = cfg.num_frames * cfg.sampling_rate / float(cfg.target_fps)
train_items = read_csv_list(train_csv_path, cfg.path_prefix)
val_items = read_csv_list(val_csv_path, cfg.path_prefix)
if cfg.train_sampling == "random":
train_sampler = make_clip_sampler("random", clip_duration)
elif cfg.train_sampling == "uniform":
train_sampler = make_clip_sampler("uniform", clip_duration)
elif cfg.train_sampling == "constant_clips_per_video":
train_sampler = make_clip_sampler("constant_clips_per_video", clip_duration, cfg.train_clips_per_video)
else:
raise ValueError(f"Unknown train_sampling: {cfg.train_sampling}")
train_ds = LabeledVideoDataset(
labeled_video_paths=train_items,
clip_sampler=train_sampler,
decode_audio=False,
decoder=cfg.decoder,
transform=VideoTransform(num_frames=cfg.num_frames, train=True),
)
if cfg.val_sampling == "uniform":
val_sampler = make_clip_sampler("uniform", clip_duration)
elif cfg.val_sampling == "random":
val_sampler = make_clip_sampler("random", clip_duration)
elif cfg.val_sampling == "constant_clips_per_video":
val_sampler = make_clip_sampler("constant_clips_per_video", clip_duration, cfg.val_clips_per_video)
else:
raise ValueError(f"Unknown val_sampling: {cfg.val_sampling}")
val_ds = LabeledVideoDataset(
labeled_video_paths=val_items,
clip_sampler=val_sampler,
decode_audio=False,
decoder=cfg.decoder,
transform=VideoTransform(num_frames=cfg.num_frames, train=False),
)
train_loader = DataLoader(
train_ds,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
pin_memory=(device == "cuda"),
drop_last=True,
persistent_workers=cfg.persistent_workers,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
)
val_loader = DataLoader(
val_ds,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
pin_memory=(device == "cuda"),
drop_last=False,
persistent_workers=cfg.persistent_workers,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
)
model = build_pretrained_x3d(cfg.model, cfg.pretrained)
replace_last_linear(model, cfg.num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler, scheduler_needs_val = build_scheduler(optimizer, cfg)
# torch.cuda.amp is deprecated in newer torch versions.
scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp) if device == "cuda" else torch.amp.GradScaler(enabled=False)
best_val = -1.0
try:
n_train_batches_total = len(train_loader)
except TypeError:
n_train_batches_total = None
try:
n_val_batches_total = len(val_loader)
except TypeError:
n_val_batches_total = None
def _batch_suffix(cur: int, total: int | None) -> str:
return f"{cur}/{total}" if total is not None else str(cur)
for epoch in range(1, cfg.epochs + 1):
model.train()
t0 = time.time()
train_loss = 0.0
train_acc = 0.0
n_batches = 0
last_log_t = time.time()
last_batch_end_t = time.time()
avg_data_s = 0.0
avg_iter_s = 0.0
for videos, labels in train_loader:
data_s = time.time() - last_batch_end_t
iter_t0 = time.time()
videos = videos.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
autocast_ctx = (
torch.amp.autocast("cuda", enabled=cfg.amp) if device == "cuda" else contextlib.nullcontext()
)
with autocast_ctx:
logits = model(videos)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item()
train_acc += accuracy_top1(logits.detach(), labels)
n_batches += 1
iter_s = time.time() - iter_t0
last_batch_end_t = time.time()
# EMA-ish moving average for stable logs.
avg_data_s = 0.9 * avg_data_s + 0.1 * data_s if n_batches > 1 else data_s
avg_iter_s = 0.9 * avg_iter_s + 0.1 * iter_s if n_batches > 1 else iter_s
do_log = False
if args.log_every_batch:
do_log = True
elif args.log_interval > 0 and (n_batches % args.log_interval) == 0:
do_log = True
if do_log:
dt_log = time.time() - last_log_t
last_log_t = time.time()
bsuf = _batch_suffix(n_batches, n_train_batches_total)
if args.log_every_batch:
print(
f" [epoch {epoch:03d}/{cfg.epochs}] train batch {bsuf} | "
f"loss {loss.item():.4f} | "
f"run_avg_loss {(train_loss/n_batches):.4f} run_avg_acc {(train_acc/n_batches)*100:.2f}% | "
f"data {avg_data_s*1000:.0f}ms iter {avg_iter_s*1000:.0f}ms"
)
else:
print(
f" [epoch {epoch:03d}] step {n_batches:05d} "
f"loss {(train_loss/n_batches):.4f} "
f"acc {(train_acc/n_batches)*100:.2f}% "
f"data {avg_data_s*1000:.0f}ms iter {avg_iter_s*1000:.0f}ms "
f"({dt_log:.1f}s/{args.log_interval} steps)"
)
train_loss /= max(n_batches, 1)
train_acc /= max(n_batches, 1)
model.eval()
val_loss = 0.0
val_acc = 0.0
n_val = 0
last_val_log_t = time.time()
with torch.no_grad():
for videos, labels in val_loader:
videos = videos.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
logits = model(videos)
loss = criterion(logits, labels)
val_loss += loss.item()
val_acc += accuracy_top1(logits, labels)
n_val += 1
do_val_log = False
if args.log_every_batch:
do_val_log = True
elif args.log_interval > 0 and (n_val % args.log_interval) == 0:
do_val_log = True
if do_val_log:
dt_log = time.time() - last_val_log_t
last_val_log_t = time.time()
vsuf = _batch_suffix(n_val, n_val_batches_total)
if args.log_every_batch:
print(
f" [epoch {epoch:03d}/{cfg.epochs}] val batch {vsuf} | "
f"loss {loss.item():.4f} | "
f"run_avg_loss {(val_loss/n_val):.4f} run_avg_acc {(val_acc/n_val)*100:.2f}%"
)
else:
print(
f" [epoch {epoch:03d}] val step {n_val:05d} "
f"loss {(val_loss/n_val):.4f} "
f"acc {(val_acc/n_val)*100:.2f}% "
f"({dt_log:.1f}s/{args.log_interval} steps)"
)
val_loss /= max(n_val, 1)
val_acc /= max(n_val, 1)
dt = time.time() - t0
print(
f"Epoch {epoch:03d}/{cfg.epochs} | "
f"train loss {train_loss:.4f} acc {train_acc*100:.2f}% | "
f"val loss {val_loss:.4f} acc {val_acc*100:.2f}% | "
f"lr {_get_lr(optimizer):.6g} | "
f"time {dt:.1f}s"
)
if scheduler is not None:
if scheduler_needs_val:
scheduler.step(val_loss)
else:
scheduler.step()
# Save last
last_path = os.path.join(cfg.output_dir, "checkpoint_last.pt")
torch.save(
{
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"scaler": scaler.state_dict(),
"cfg": asdict(cfg),
},
last_path,
)
# Save best
if val_acc > best_val:
best_val = val_acc
best_path = os.path.join(cfg.output_dir, "checkpoint_best.pt")
torch.save(
{
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"scaler": scaler.state_dict(),
"cfg": asdict(cfg),
"best_val_acc": best_val,
},
best_path,
)
print(f"Done. Best val acc: {best_val*100:.2f}%")
if __name__ == "__main__":
main()