#!/usr/bin/env python3 """ Train a video classifier using PyTorchVideo (X3D) on a CSV dataset. This script is intentionally independent from SlowFast, so it works even if SlowFast's optional deps/imports are broken. Expected CSV format (space-separated, like SlowFast/Kinetics): relative/path/to/video.mp4 Example: python train_pytorchvideo_x3d.py \ --csv_dir /home/ubuntu/projects/FishAction/data/fish/fish_action_training_dataset \ --path_prefix /home/ubuntu/data/fish/fish_action_videos \ --model x3d_m \ --pretrained \ --batch_size 4 \ --epochs 30 \ --num_workers 4 \ --output_dir /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m """ from __future__ import annotations import argparse import contextlib import json import os import random import time from math import cos, pi from dataclasses import asdict, dataclass from pathlib import Path from typing import Dict, List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler KINETICS_MEAN = (0.45, 0.45, 0.45) KINETICS_STD = (0.225, 0.225, 0.225) def set_seed(seed: int) -> None: random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def _parse_path_label_line(line: str) -> Tuple[str, int] | None: """ Robust parser for CSV lines. Expected (whitespace-separated): path label If the path contains spaces, we take the last token as label: /some/path/with spaces/video.mp4 3 """ line = line.strip() if not line or line.startswith("#"): return None parts = line.split() if len(parts) < 2: return None label_str = parts[-1] try: label = int(label_str) except ValueError: # likely a header row, e.g. "path label" return None path = " ".join(parts[:-1]) return path, label def read_csv_list(csv_path: str, path_prefix: str) -> List[Tuple[str, Dict]]: items: List[Tuple[str, Dict]] = [] with open(csv_path, "r") as f: for line in f: parsed = _parse_path_label_line(line) if parsed is None: continue rel_path, label = parsed full = os.path.join(path_prefix, rel_path) items.append((full, {"label": int(label)})) return items def infer_num_classes(csv_path: str) -> int: max_label = -1 with open(csv_path, "r") as f: for line in f: parsed = _parse_path_label_line(line) if parsed is None: continue _, label = parsed max_label = max(max_label, int(label)) return max_label + 1 def uniform_temporal_subsample_cthw(video: torch.Tensor, num_frames: int) -> torch.Tensor: # video: C T H W c, t, h, w = video.shape if t == num_frames: return video if t < 1: raise ValueError(f"Invalid temporal dimension t={t}") idx = torch.linspace(0, t - 1, steps=num_frames, device=video.device).long() return torch.index_select(video, dim=1, index=idx) def to_cthw(video: torch.Tensor) -> torch.Tensor: """ Convert common video tensor layouts to C T H W. Supported: - T H W C - T C H W - C T H W """ if video.ndim != 4: raise ValueError(f"Expected 4D video tensor, got shape {tuple(video.shape)}") # Heuristics for channel location (3 channels RGB). if video.shape[0] == 3: # C T H W return video if video.shape[-1] == 3: # T H W C -> C T H W return video.permute(3, 0, 1, 2) if video.shape[1] == 3: # T C H W -> C T H W return video.permute(1, 0, 2, 3) raise ValueError(f"Could not infer channel dim for video shape {tuple(video.shape)}") def resize_short_side_tchw(frames_tchw: torch.Tensor, short_side: int) -> torch.Tensor: # frames_tchw: T C H W t, c, h, w = frames_tchw.shape if min(h, w) == short_side: return frames_tchw if h < w: new_h = short_side new_w = int(round(w * short_side / h)) else: new_w = short_side new_h = int(round(h * short_side / w)) return F.interpolate(frames_tchw, size=(new_h, new_w), mode="bilinear", align_corners=False) def random_crop_tchw(frames_tchw: torch.Tensor, crop_size: int) -> torch.Tensor: t, c, h, w = frames_tchw.shape if h < crop_size or w < crop_size: # Fallback: center crop after resizing to crop_size. frames_tchw = F.interpolate( frames_tchw, size=(max(h, crop_size), max(w, crop_size)), mode="bilinear", align_corners=False ) t, c, h, w = frames_tchw.shape top = random.randint(0, h - crop_size) if h > crop_size else 0 left = random.randint(0, w - crop_size) if w > crop_size else 0 return frames_tchw[:, :, top : top + crop_size, left : left + crop_size] def center_crop_tchw(frames_tchw: torch.Tensor, crop_size: int) -> torch.Tensor: t, c, h, w = frames_tchw.shape if h < crop_size or w < crop_size: frames_tchw = F.interpolate( frames_tchw, size=(max(h, crop_size), max(w, crop_size)), mode="bilinear", align_corners=False ) t, c, h, w = frames_tchw.shape top = max((h - crop_size) // 2, 0) left = max((w - crop_size) // 2, 0) return frames_tchw[:, :, top : top + crop_size, left : left + crop_size] def normalize_cthw(video_cthw: torch.Tensor, mean=KINETICS_MEAN, std=KINETICS_STD) -> torch.Tensor: m = torch.tensor(mean, device=video_cthw.device, dtype=video_cthw.dtype).view(3, 1, 1, 1) s = torch.tensor(std, device=video_cthw.device, dtype=video_cthw.dtype).view(3, 1, 1, 1) return (video_cthw - m) / s class VideoTransform: def __init__( self, num_frames: int, train: bool, short_side: int = 256, crop_size: int = 224, random_flip: bool = True, ): self.num_frames = num_frames self.train = train self.short_side = short_side self.crop_size = crop_size self.random_flip = random_flip def __call__(self, sample: Dict): video = sample["video"] label_obj = sample.get("label", 0) if isinstance(label_obj, dict): label = int(label_obj.get("label", 0)) else: label = int(label_obj) video = to_cthw(video) if video.dtype == torch.uint8: video = video.float() / 255.0 else: video = video.float() video = uniform_temporal_subsample_cthw(video, self.num_frames) # C T H W # Spatial ops on T C H W. frames_tchw = video.permute(1, 0, 2, 3) frames_tchw = resize_short_side_tchw(frames_tchw, self.short_side) if self.train: frames_tchw = random_crop_tchw(frames_tchw, self.crop_size) if self.random_flip and random.random() < 0.5: frames_tchw = torch.flip(frames_tchw, dims=[3]) else: frames_tchw = center_crop_tchw(frames_tchw, self.crop_size) video = frames_tchw.permute(1, 0, 2, 3) # back to C T H W video = normalize_cthw(video) return video, torch.tensor(label, dtype=torch.long) def build_pretrained_x3d(model_name: str, pretrained: bool) -> nn.Module: """ model_name: x3d_xs | x3d_s | x3d_m | x3d_l """ if pretrained: # Uses torch hub; downloads pytorchvideo main repo + weights once into ~/.cache/torch/hub/ return torch.hub.load("facebookresearch/pytorchvideo", model_name, pretrained=True) # No pretrained: build from pytorchvideo package API. from pytorchvideo.models.x3d import create_x3d return create_x3d(model_num_class=400) def replace_last_linear(model: nn.Module, num_classes: int) -> None: # X3D hub model ends with blocks.5.proj (Linear out_features=400). # We replace the last Linear module found. last_name = None last_linear = None for name, mod in model.named_modules(): if isinstance(mod, nn.Linear): last_name = name last_linear = mod if last_name is None or last_linear is None: raise RuntimeError("Could not find a Linear layer to replace for classification head.") in_features = last_linear.in_features new_linear = nn.Linear(in_features, num_classes) # Set it by walking the module path. parts = last_name.split(".") parent = model for p in parts[:-1]: parent = getattr(parent, p) setattr(parent, parts[-1], new_linear) @dataclass class TrainConfig: csv_dir: str path_prefix: str train_csv: str val_csv: str model: str pretrained: bool num_classes: int num_frames: int sampling_rate: int target_fps: int batch_size: int epochs: int lr: float weight_decay: float scheduler: str warmup_epochs: int min_lr: float step_size: int gamma: float plateau_factor: float plateau_patience: int num_workers: int prefetch_factor: int persistent_workers: bool seed: int output_dir: str amp: bool device: str decoder: str train_sampling: str train_clips_per_video: int val_sampling: str val_clips_per_video: int tf32: bool def accuracy_top1(logits: torch.Tensor, labels: torch.Tensor) -> float: preds = torch.argmax(logits, dim=1) correct = (preds == labels).sum().item() return correct / labels.numel() def _get_lr(optimizer: torch.optim.Optimizer) -> float: return float(optimizer.param_groups[0]["lr"]) def build_scheduler( optimizer: torch.optim.Optimizer, cfg: TrainConfig ) -> tuple[torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.ReduceLROnPlateau | None, bool]: """ Returns: (scheduler, needs_val_metric) """ if cfg.scheduler == "none": return None, False if cfg.scheduler == "plateau": sched = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=cfg.plateau_factor, patience=cfg.plateau_patience, ) return sched, True # For cosine/step we implement warmup + schedule as a single LambdaLR, stepped once per epoch. warmup = max(int(cfg.warmup_epochs), 0) total = max(int(cfg.epochs), 1) min_lr_factor = float(cfg.min_lr / cfg.lr) if cfg.lr > 0 else 0.0 min_lr_factor = max(0.0, min(1.0, min_lr_factor)) def lr_lambda(epoch_idx0: int) -> float: # epoch_idx0 is 0-based. e = int(epoch_idx0) if warmup > 0 and e < warmup: return float(e + 1) / float(warmup) if cfg.scheduler == "step": e2 = e - warmup if cfg.step_size <= 0: return 1.0 n_steps = max(e2 // int(cfg.step_size), 0) return float(cfg.gamma) ** float(n_steps) if cfg.scheduler == "cosine": denom = max(total - warmup, 1) t = (e - warmup) / float(denom) # 0..1-ish t = min(max(t, 0.0), 1.0) cosine_factor = 0.5 * (1.0 + cos(pi * t)) return min_lr_factor + (1.0 - min_lr_factor) * cosine_factor return 1.0 sched = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) return sched, False def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--csv_dir", type=str, required=True, help="Folder containing train.csv/val.csv") parser.add_argument("--path_prefix", type=str, required=True, help="Absolute path prefix for videos") parser.add_argument("--train_csv", type=str, default="train.csv") parser.add_argument("--val_csv", type=str, default="val.csv") parser.add_argument("--model", type=str, default="x3d_m", choices=["x3d_xs", "x3d_s", "x3d_m", "x3d_l"]) parser.add_argument("--pretrained", action="store_true", help="Use torch.hub pretrained weights") parser.add_argument("--num_classes", type=int, default=0, help="If 0, infer from train.csv max label + 1") parser.add_argument("--num_frames", type=int, default=16) parser.add_argument("--sampling_rate", type=int, default=5, help="Used only for clip_duration calculation") parser.add_argument("--target_fps", type=int, default=30, help="Used only for clip_duration calculation") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--epochs", type=int, default=30) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--weight_decay", type=float, default=1e-4) parser.add_argument( "--scheduler", type=str, default="none", choices=["none", "cosine", "step", "plateau"], help=( "LR scheduler. cosine/step are stepped once per epoch. " "plateau monitors val loss and reduces LR when it stops improving." ), ) parser.add_argument( "--warmup_epochs", type=int, default=0, help="Warm up LR linearly for N epochs (cosine/step).", ) parser.add_argument("--min_lr", type=float, default=1e-5, help="Cosine schedule minimum LR (absolute).") parser.add_argument("--step_size", type=int, default=10, help="Step schedule: decay every N epochs (after warmup).") parser.add_argument("--gamma", type=float, default=0.1, help="Step schedule: multiplicative decay factor.") parser.add_argument( "--plateau_factor", type=float, default=0.1, help="Plateau schedule: LR *= factor when val loss plateaus.", ) parser.add_argument( "--plateau_patience", type=int, default=2, help="Plateau schedule: epochs to wait for val loss improvement before decaying.", ) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument( "--prefetch_factor", type=int, default=2, help="DataLoader prefetch factor per worker (only used when num_workers > 0).", ) parser.add_argument( "--persistent_workers", action="store_true", help="Keep DataLoader workers alive across epochs (faster, uses more RAM; only if num_workers > 0).", ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--amp", action="store_true", help="Use mixed precision (CUDA only)") parser.add_argument( "--tf32", action="store_true", help="Allow TF32 matmul/conv on Ampere+ (CUDA only). Can speed up training a bit.", ) parser.add_argument( "--decoder", type=str, default="pyav", help="Video decoder backend for PyTorchVideo LabeledVideoDataset (common: pyav, decord if installed).", ) parser.add_argument( "--train_sampling", type=str, default="random", choices=["random", "constant_clips_per_video", "uniform"], help=( "Training clip sampling. " "'random' uses 1 random clip per video per epoch. " "'constant_clips_per_video' uses N evenly spaced clips per video per epoch. " "'uniform' splits videos into sequential clips (long videos yield more clips)." ), ) parser.add_argument( "--train_clips_per_video", type=int, default=1, help="Used only when --train_sampling=constant_clips_per_video.", ) parser.add_argument( "--val_sampling", type=str, default="uniform", choices=["uniform", "random", "constant_clips_per_video"], help=( "Validation clip sampling. 'uniform' may produce many clips for long videos. " "'random' uses 1 random clip per video. " "'constant_clips_per_video' uses N evenly spaced clips per video." ), ) parser.add_argument( "--val_clips_per_video", type=int, default=10, help="Used only when --val_sampling=constant_clips_per_video.", ) parser.add_argument( "--log_interval", type=int, default=20, help="Print training/val progress every N batches (default: 20). Use 1 for every batch, or --log_every_batch.", ) parser.add_argument( "--log_every_batch", action="store_true", help="Print one line after every training and validation batch (includes batch/total when known).", ) args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" amp = bool(args.amp and device == "cuda") tf32 = bool(args.tf32 and device == "cuda") if tf32: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True try: torch.set_float32_matmul_precision("high") except Exception: pass csv_dir = os.path.abspath(os.path.expanduser(args.csv_dir)) path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix)) train_csv_path = os.path.join(csv_dir, args.train_csv) val_csv_path = os.path.join(csv_dir, args.val_csv) if not os.path.isfile(train_csv_path): raise FileNotFoundError( f"train csv not found: {train_csv_path}\n" f"Tip: --csv_dir should be the folder that contains train.csv/val.csv, " f"NOT the raw video folder." ) if not os.path.isfile(val_csv_path): raise FileNotFoundError(f"val csv not found: {val_csv_path}") if args.num_classes and args.num_classes > 0: num_classes = args.num_classes else: num_classes = infer_num_classes(train_csv_path) cfg = TrainConfig( csv_dir=csv_dir, path_prefix=path_prefix, train_csv=args.train_csv, val_csv=args.val_csv, model=args.model, pretrained=bool(args.pretrained), num_classes=num_classes, num_frames=args.num_frames, sampling_rate=args.sampling_rate, target_fps=args.target_fps, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, scheduler=str(args.scheduler), warmup_epochs=int(args.warmup_epochs), min_lr=float(args.min_lr), step_size=int(args.step_size), gamma=float(args.gamma), plateau_factor=float(args.plateau_factor), plateau_patience=int(args.plateau_patience), num_workers=args.num_workers, prefetch_factor=args.prefetch_factor, persistent_workers=bool(args.persistent_workers and args.num_workers > 0), seed=args.seed, output_dir=os.path.abspath(os.path.expanduser(args.output_dir)), amp=amp, device=device, decoder=str(args.decoder), train_sampling=str(args.train_sampling), train_clips_per_video=int(args.train_clips_per_video), val_sampling=str(args.val_sampling), val_clips_per_video=int(args.val_clips_per_video), tf32=tf32, ) os.makedirs(cfg.output_dir, exist_ok=True) with open(os.path.join(cfg.output_dir, "config.json"), "w") as f: json.dump(asdict(cfg), f, indent=2) set_seed(cfg.seed) torch.backends.cudnn.benchmark = True # Clip duration in seconds, for the clip sampler. clip_duration = cfg.num_frames * cfg.sampling_rate / float(cfg.target_fps) train_items = read_csv_list(train_csv_path, cfg.path_prefix) val_items = read_csv_list(val_csv_path, cfg.path_prefix) if cfg.train_sampling == "random": train_sampler = make_clip_sampler("random", clip_duration) elif cfg.train_sampling == "uniform": train_sampler = make_clip_sampler("uniform", clip_duration) elif cfg.train_sampling == "constant_clips_per_video": train_sampler = make_clip_sampler("constant_clips_per_video", clip_duration, cfg.train_clips_per_video) else: raise ValueError(f"Unknown train_sampling: {cfg.train_sampling}") train_ds = LabeledVideoDataset( labeled_video_paths=train_items, clip_sampler=train_sampler, decode_audio=False, decoder=cfg.decoder, transform=VideoTransform(num_frames=cfg.num_frames, train=True), ) if cfg.val_sampling == "uniform": val_sampler = make_clip_sampler("uniform", clip_duration) elif cfg.val_sampling == "random": val_sampler = make_clip_sampler("random", clip_duration) elif cfg.val_sampling == "constant_clips_per_video": val_sampler = make_clip_sampler("constant_clips_per_video", clip_duration, cfg.val_clips_per_video) else: raise ValueError(f"Unknown val_sampling: {cfg.val_sampling}") val_ds = LabeledVideoDataset( labeled_video_paths=val_items, clip_sampler=val_sampler, decode_audio=False, decoder=cfg.decoder, transform=VideoTransform(num_frames=cfg.num_frames, train=False), ) train_loader = DataLoader( train_ds, batch_size=cfg.batch_size, num_workers=cfg.num_workers, pin_memory=(device == "cuda"), drop_last=True, persistent_workers=cfg.persistent_workers, prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, ) val_loader = DataLoader( val_ds, batch_size=cfg.batch_size, num_workers=cfg.num_workers, pin_memory=(device == "cuda"), drop_last=False, persistent_workers=cfg.persistent_workers, prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, ) model = build_pretrained_x3d(cfg.model, cfg.pretrained) replace_last_linear(model, cfg.num_classes) model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) scheduler, scheduler_needs_val = build_scheduler(optimizer, cfg) # torch.cuda.amp is deprecated in newer torch versions. scaler = torch.amp.GradScaler("cuda", enabled=cfg.amp) if device == "cuda" else torch.amp.GradScaler(enabled=False) best_val = -1.0 try: n_train_batches_total = len(train_loader) except TypeError: n_train_batches_total = None try: n_val_batches_total = len(val_loader) except TypeError: n_val_batches_total = None def _batch_suffix(cur: int, total: int | None) -> str: return f"{cur}/{total}" if total is not None else str(cur) for epoch in range(1, cfg.epochs + 1): model.train() t0 = time.time() train_loss = 0.0 train_acc = 0.0 n_batches = 0 last_log_t = time.time() last_batch_end_t = time.time() avg_data_s = 0.0 avg_iter_s = 0.0 for videos, labels in train_loader: data_s = time.time() - last_batch_end_t iter_t0 = time.time() videos = videos.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) autocast_ctx = ( torch.amp.autocast("cuda", enabled=cfg.amp) if device == "cuda" else contextlib.nullcontext() ) with autocast_ctx: logits = model(videos) loss = criterion(logits, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() train_loss += loss.item() train_acc += accuracy_top1(logits.detach(), labels) n_batches += 1 iter_s = time.time() - iter_t0 last_batch_end_t = time.time() # EMA-ish moving average for stable logs. avg_data_s = 0.9 * avg_data_s + 0.1 * data_s if n_batches > 1 else data_s avg_iter_s = 0.9 * avg_iter_s + 0.1 * iter_s if n_batches > 1 else iter_s do_log = False if args.log_every_batch: do_log = True elif args.log_interval > 0 and (n_batches % args.log_interval) == 0: do_log = True if do_log: dt_log = time.time() - last_log_t last_log_t = time.time() bsuf = _batch_suffix(n_batches, n_train_batches_total) if args.log_every_batch: print( f" [epoch {epoch:03d}/{cfg.epochs}] train batch {bsuf} | " f"loss {loss.item():.4f} | " f"run_avg_loss {(train_loss/n_batches):.4f} run_avg_acc {(train_acc/n_batches)*100:.2f}% | " f"data {avg_data_s*1000:.0f}ms iter {avg_iter_s*1000:.0f}ms" ) else: print( f" [epoch {epoch:03d}] step {n_batches:05d} " f"loss {(train_loss/n_batches):.4f} " f"acc {(train_acc/n_batches)*100:.2f}% " f"data {avg_data_s*1000:.0f}ms iter {avg_iter_s*1000:.0f}ms " f"({dt_log:.1f}s/{args.log_interval} steps)" ) train_loss /= max(n_batches, 1) train_acc /= max(n_batches, 1) model.eval() val_loss = 0.0 val_acc = 0.0 n_val = 0 last_val_log_t = time.time() with torch.no_grad(): for videos, labels in val_loader: videos = videos.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) logits = model(videos) loss = criterion(logits, labels) val_loss += loss.item() val_acc += accuracy_top1(logits, labels) n_val += 1 do_val_log = False if args.log_every_batch: do_val_log = True elif args.log_interval > 0 and (n_val % args.log_interval) == 0: do_val_log = True if do_val_log: dt_log = time.time() - last_val_log_t last_val_log_t = time.time() vsuf = _batch_suffix(n_val, n_val_batches_total) if args.log_every_batch: print( f" [epoch {epoch:03d}/{cfg.epochs}] val batch {vsuf} | " f"loss {loss.item():.4f} | " f"run_avg_loss {(val_loss/n_val):.4f} run_avg_acc {(val_acc/n_val)*100:.2f}%" ) else: print( f" [epoch {epoch:03d}] val step {n_val:05d} " f"loss {(val_loss/n_val):.4f} " f"acc {(val_acc/n_val)*100:.2f}% " f"({dt_log:.1f}s/{args.log_interval} steps)" ) val_loss /= max(n_val, 1) val_acc /= max(n_val, 1) dt = time.time() - t0 print( f"Epoch {epoch:03d}/{cfg.epochs} | " f"train loss {train_loss:.4f} acc {train_acc*100:.2f}% | " f"val loss {val_loss:.4f} acc {val_acc*100:.2f}% | " f"lr {_get_lr(optimizer):.6g} | " f"time {dt:.1f}s" ) if scheduler is not None: if scheduler_needs_val: scheduler.step(val_loss) else: scheduler.step() # Save last last_path = os.path.join(cfg.output_dir, "checkpoint_last.pt") torch.save( { "epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if scheduler is not None else None, "scaler": scaler.state_dict(), "cfg": asdict(cfg), }, last_path, ) # Save best if val_acc > best_val: best_val = val_acc best_path = os.path.join(cfg.output_dir, "checkpoint_best.pt") torch.save( { "epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if scheduler is not None else None, "scaler": scaler.state_dict(), "cfg": asdict(cfg), "best_val_acc": best_val, }, best_path, ) print(f"Done. Best val acc: {best_val*100:.2f}%") if __name__ == "__main__": main()