#!/usr/bin/env python3 """ Evaluate a trained PyTorchVideo X3D checkpoint as a 3-class classifier. This script loads a `checkpoint_best.pt` produced by `train_pytorchvideo_x3d.py` and runs inference on a CSV split (val.csv or test.csv), then reports 3-class accuracy and a 3x3 confusion matrix. CSV format (whitespace-separated): relative/path/to/video.mp4 Example: python test_pytorchvideo_x3d_3class.py \ --checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \ --csv_dir /home/ubuntu/data/fish/fish_action_videos \ --path_prefix /home/ubuntu/data/fish/fish_action_videos \ --split val.csv \ --batch_size 4 \ --num_workers 4 """ from __future__ import annotations import argparse import os import time import warnings from typing import List, Tuple import torch from torch.utils.data import DataLoader from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler from train_pytorchvideo_x3d import ( VideoTransform, _parse_path_label_line, build_pretrained_x3d, replace_last_linear, ) THREE_CLASS_NAMES = ["feeding", "normal", "scared"] def read_csv_list(csv_path: str, path_prefix: str) -> List[Tuple[str, Dict]]: items: List[Tuple[str, Dict]] = [] with open(csv_path, "r") as f: for line in f: parsed = _parse_path_label_line(line) if parsed is None: continue rel_path, label = parsed full = os.path.join(path_prefix, rel_path) items.append((full, {"label": int(label)})) return items def confusion_matrix_3(preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ preds/labels: int tensors of shape (N,) with values in [0,2] returns: (3,3) where rows=true, cols=pred """ cm = torch.zeros((3, 3), dtype=torch.long) for t, p in zip(labels.tolist(), preds.tolist()): cm[t][p] += 1 return cm def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint_best.pt") parser.add_argument("--csv_dir", type=str, required=True, help="Folder containing split CSVs (and label_map.txt if available)") parser.add_argument("--path_prefix", type=str, required=True, help="Absolute path prefix for videos") parser.add_argument("--split", type=str, default="val.csv", help="CSV filename inside csv_dir (e.g. val.csv or test.csv)") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument( "--log_interval", type=int, default=20, help="Print progress every N batches (default: 20). Set 0 to disable.", ) parser.add_argument( "--max_batches", type=int, default=0, help="If >0, stop after this many batches (useful for quick sanity checks).", ) parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)") args = parser.parse_args() device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu") csv_dir = os.path.abspath(os.path.expanduser(args.csv_dir)) path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix)) split_path = os.path.join(csv_dir, args.split) if not os.path.isfile(split_path): raise FileNotFoundError(f"Split CSV not found: {split_path}") ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint)) # Torch 2.5+ warns about pickle by default. This is your own checkpoint, so we silence it. with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) cfg = ckpt.get("cfg", {}) model_name = cfg.get("model", "x3d_m") num_frames = int(cfg.get("num_frames", 16)) sampling_rate = int(cfg.get("sampling_rate", 5)) target_fps = int(cfg.get("target_fps", 30)) num_classes = int(cfg.get("num_classes", 3)) clip_duration = num_frames * sampling_rate / float(target_fps) # Build dataset items = read_csv_list(split_path, path_prefix) print(f"Loaded {len(items)} samples from {split_path}") print( f"Decoding clips with pyav (clip_duration={clip_duration:.2f}s, num_frames={num_frames}). " f"If this looks stuck, try --num_workers 0." ) ds = LabeledVideoDataset( labeled_video_paths=items, clip_sampler=make_clip_sampler("uniform", clip_duration), decode_audio=False, decoder="pyav", transform=VideoTransform(num_frames=num_frames, train=False), ) loader = DataLoader( ds, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=(device == "cuda"), persistent_workers=bool(args.num_workers > 0), prefetch_factor=2 if args.num_workers > 0 else None, ) # Build model architecture and load weights. model = build_pretrained_x3d(model_name=model_name, pretrained=False) # Ensure head shape matches checkpoint state_dict (trained as 3-class). replace_last_linear(model, num_classes=num_classes) model.load_state_dict(ckpt["model"], strict=True) model = model.to(device) model.eval() total = 0 correct3 = 0 cm = torch.zeros((3, 3), dtype=torch.long) with torch.no_grad(): t_last = time.time() for batch_idx, (videos, labels3) in enumerate(loader, start=1): videos = videos.to(device, non_blocking=True) labels3 = labels3.to(device, non_blocking=True) logits3 = model(videos) preds3 = torch.argmax(logits3, dim=1) total += labels3.numel() correct3 += (preds3 == labels3).sum().item() cm += confusion_matrix_3(preds3.cpu(), labels3.cpu()) if args.log_interval > 0 and (batch_idx % args.log_interval) == 0: dt = time.time() - t_last t_last = time.time() print( f"Processed batch {batch_idx:05d} " f"({total} clips) | acc so far {(correct3/max(total,1))*100:.2f}% " f"({dt:.1f}s/{args.log_interval} batches)" ) if args.max_batches and batch_idx >= args.max_batches: break acc3 = correct3 / max(total, 1) print(f"3-class accuracy: {acc3*100:.2f}% ({correct3}/{total})") print("3-class confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:") print(cm) if __name__ == "__main__": main()