#!/usr/bin/env python3 """ Video-level prediction by aggregating multiple clip predictions (PyTorchVideo X3D). This script supports: - **Native 3-class checkpoints** (no mapping/merging required) - **Legacy 5-class checkpoints** evaluated as 3-class via merging: - feeding -> feeding - normal_underwater + normal_upperwater -> normal - scared_underwater + scared_upperwater -> scared Instead of printing per-clip labels, this script: - samples N clips per video - runs inference per clip - aggregates probabilities across clips (mean) - outputs one label per video - prints progress: clips processed, videos finished Example: python predict_video_x3d_3class.py \ --checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \ --csv /home/ubuntu/data/fish/fish_action_videos/val.csv \ --path_prefix /home/ubuntu/data/fish/fish_action_videos \ --clips_per_video 10 \ --batch_size 8 \ --num_workers 4 """ from __future__ import annotations import argparse import json import os import time import warnings from collections import defaultdict from typing import Dict, List, Tuple import torch import torch.nn.functional as F from torch.utils.data import DataLoader from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler from train_pytorchvideo_x3d import VideoTransform, _parse_path_label_line, build_pretrained_x3d, replace_last_linear THREE_CLASS_NAMES = ["feeding", "normal", "scared"] def read_label_map(label_map_path: str) -> Dict[str, int]: out: Dict[str, int] = {} with open(label_map_path, "r") as f: for line in f: line = line.strip() if not line: continue idx_str, name = line.split(maxsplit=1) out[name] = int(idx_str) return out def build_merge_indices(label_map_dir: str) -> Dict[str, List[int]]: """ 5->3 merge map by indices. Uses label_map.txt if present, else defaults: 0 feeding 1 normal_underwater 2 normal_upperwater 3 scared_underwater 4 scared_upperwater """ label_map_path = os.path.join(label_map_dir, "label_map.txt") if os.path.isfile(label_map_path): lm = read_label_map(label_map_path) # If label_map.txt already defines 3-class names, treat as identity merge. if all(k in lm for k in ["feeding", "normal", "scared"]): # This supports native 3-class datasets/checkpoints. return {"feeding": [lm["feeding"]], "normal": [lm["normal"]], "scared": [lm["scared"]]} return { "feeding": [lm["feeding"]], "normal": [lm["normal_underwater"], lm["normal_upperwater"]], "scared": [lm["scared_underwater"], lm["scared_upperwater"]], } return {"feeding": [0], "normal": [1, 2], "scared": [3, 4]} def collapse_logits_5_to_3(logits_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor: probs = F.softmax(logits_5, dim=1) p0 = probs[:, merge["feeding"]].sum(dim=1) p1 = probs[:, merge["normal"]].sum(dim=1) p2 = probs[:, merge["scared"]].sum(dim=1) probs3 = torch.stack([p0, p1, p2], dim=1).clamp_min(1e-12) return probs3 def map_label_5_to_3(labels_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor: inv: Dict[int, int] = {} for new_idx, name in enumerate(THREE_CLASS_NAMES): for old in merge[name]: inv[int(old)] = new_idx return torch.tensor([inv[int(x)] for x in labels_5.tolist()], device=labels_5.device, dtype=torch.long) class EvalTransform: """ Wrap VideoTransform but keep metadata needed for video-level aggregation. Output is a dict so DataLoader collates it nicely. """ def __init__(self, num_frames: int): self._vt = VideoTransform(num_frames=num_frames, train=False) def __call__(self, sample: Dict): video, label = self._vt(sample) # (C,T,H,W), int tensor return { "video": video, "label": label, "video_name": sample.get("video_name", ""), "video_index": int(sample.get("video_index", -1)), "clip_index": int(sample.get("clip_index", -1)), } def confusion_matrix_3(preds: List[int], labels: List[int]) -> List[List[int]]: cm = [[0, 0, 0] for _ in range(3)] for t, p in zip(labels, preds): cm[t][p] += 1 return cm def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument("--csv", type=str, required=True, help="CSV split file (e.g. val.csv)") parser.add_argument("--path_prefix", type=str, required=True) parser.add_argument("--clips_per_video", type=int, default=10) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--log_interval", type=int, default=20, help="Log every N batches. 0 disables.") parser.add_argument( "--aggregate", type=str, default="vote", choices=["mean", "max", "vote", "any_scared"], help=( "Video-level aggregation strategy over clip predictions: " "`mean`=mean probability (default), " "`max`=max probability pooling per class, " "`vote`=majority vote over per-clip argmax (ties broken by mean prob), " "`any_scared`=predict scared if ANY clip is scared (argmax==scared OR max scared prob >= threshold)." ), ) parser.add_argument( "--scared_threshold", type=float, default=0.0, help=( "Used only for --aggregate any_scared. If >0, predict scared when the maximum " "per-clip scared probability >= this threshold. If 0, only uses argmax==scared." ), ) parser.add_argument( "--any_scared_fallback", type=str, default="mean", choices=["mean", "max"], help="Used only for --aggregate any_scared when scared is not triggered.", ) parser.add_argument( "--print_clip_preds", action="store_true", help="Print per-video per-clip predicted classes to stdout (can be verbose).", ) parser.add_argument("--output_json", type=str, default="", help="Optional path to write per-video predictions JSON.") parser.add_argument( "--output_json_mismatches", type=str, default="wrong.json", help="Optional path to write a JSON containing only videos where pred != true (and true label is known).", ) parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)") args = parser.parse_args() device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu") csv_path = os.path.abspath(os.path.expanduser(args.csv)) csv_dir = os.path.dirname(csv_path) path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix)) if not os.path.isfile(csv_path): raise FileNotFoundError(f"CSV not found: {csv_path}") # Load checkpoint safely (this is your own file). ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint)) with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) cfg = ckpt.get("cfg", {}) model_name = cfg.get("model", "x3d_m") num_frames = int(cfg.get("num_frames", 16)) sampling_rate = int(cfg.get("sampling_rate", 5)) target_fps = int(cfg.get("target_fps", 30)) num_classes = int(cfg.get("num_classes", 5)) clip_duration = num_frames * sampling_rate / float(target_fps) merge = build_merge_indices(csv_dir) if num_classes != 3 else {"feeding": [0], "normal": [1], "scared": [2]} # Read CSV -> list[(path, {"label": int})] labeled_paths: List[Tuple[str, Dict]] = [] with open(csv_path, "r") as f: for line in f: parsed = _parse_path_label_line(line) if parsed is None: continue rel_path, label = parsed labeled_paths.append((os.path.join(path_prefix, rel_path), {"label": int(label)})) num_videos = len(labeled_paths) total_clips_expected = num_videos * args.clips_per_video print( f"Videos: {num_videos} | clips_per_video: {args.clips_per_video} " f"(expected clips: {total_clips_expected})" ) ds = LabeledVideoDataset( labeled_video_paths=labeled_paths, clip_sampler=make_clip_sampler("constant_clips_per_video", clip_duration, args.clips_per_video), decode_audio=False, decoder="pyav", transform=EvalTransform(num_frames=num_frames), ) loader = DataLoader( ds, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=(device == "cuda"), persistent_workers=bool(args.num_workers > 0), prefetch_factor=2 if args.num_workers > 0 else None, ) # Build model architecture + load weights. model = build_pretrained_x3d(model_name=model_name, pretrained=False) replace_last_linear(model, num_classes=num_classes) model.load_state_dict(ckpt["model"], strict=True) model = model.to(device) model.eval() # Aggregate per video_index. prob_sum = defaultdict(lambda: torch.zeros(3, dtype=torch.float32)) prob_max = defaultdict(lambda: torch.zeros(3, dtype=torch.float32)) clip_count = defaultdict(int) video_name_by_idx: Dict[int, str] = {} true_label3_by_idx: Dict[int, int] = {} clip_pred3_by_vidx: Dict[int, List[int]] = defaultdict(list) clip_idx_by_vidx: Dict[int, List[int]] = defaultdict(list) clip_probs3_by_vidx: Dict[int, List[List[float]]] = defaultdict(list) printed_videos = set() processed_clips = 0 finished_videos = set() t_last = time.time() def _video_pred_for(vidx: int) -> Tuple[int, List[float]]: """ Returns (pred_idx_3class, probs_3class_list) """ mean_probs = (prob_sum[vidx] / max(clip_count[vidx], 1)) if args.aggregate == "mean": probs = mean_probs pred = int(torch.argmax(probs).item()) return pred, probs.numpy().tolist() if args.aggregate == "max": probs = prob_max[vidx] pred = int(torch.argmax(probs).item()) return pred, probs.numpy().tolist() if args.aggregate == "any_scared": scared_idx = 2 max_scared_prob = float(prob_max[vidx][scared_idx].item()) any_argmax_scared = any(p == scared_idx for p in clip_pred3_by_vidx.get(vidx, [])) if any_argmax_scared or (args.scared_threshold > 0.0 and max_scared_prob >= float(args.scared_threshold)): # Return max pooled probs for interpretability when we trigger on "any" behavior. probs = prob_max[vidx] return scared_idx, probs.numpy().tolist() # Otherwise fall back to a smoother aggregate. if args.any_scared_fallback == "max": probs = prob_max[vidx] pred = int(torch.argmax(probs).item()) return pred, probs.numpy().tolist() probs = mean_probs pred = int(torch.argmax(probs).item()) return pred, probs.numpy().tolist() # majority vote; tie-break by mean prob counts = [0, 0, 0] for p in clip_pred3_by_vidx.get(vidx, []): if 0 <= p <= 2: counts[p] += 1 max_count = max(counts) winners = [i for i, c in enumerate(counts) if c == max_count] if len(winners) == 1: pred = winners[0] else: # tie-break using mean probability pred = int(torch.argmax(mean_probs[winners]).item()) pred = winners[pred] # Return mean probs for reporting consistency return pred, mean_probs.numpy().tolist() with torch.no_grad(): for batch_idx, batch in enumerate(loader, start=1): videos = batch["video"].to(device, non_blocking=True) labels_raw = batch["label"].to(device, non_blocking=True) vinds = batch["video_index"] vnames = batch["video_name"] logits = model(videos) if num_classes == 3: probs3 = F.softmax(logits, dim=1) # (B,3) labels3 = labels_raw # already 0..2 else: # Legacy: 5-class -> merge to 3-class. probs3 = collapse_logits_5_to_3(logits, merge) # (B,3) probs labels3 = map_label_5_to_3(labels_raw, merge) # (B,) for i in range(probs3.shape[0]): vidx = int(vinds[i]) video_name_by_idx[vidx] = str(vnames[i]) true_label3_by_idx[vidx] = int(labels3[i].item()) p3 = probs3[i].detach().cpu() prob_sum[vidx] += p3 prob_max[vidx] = torch.maximum(prob_max[vidx], p3) clip_count[vidx] += 1 pred3 = int(torch.argmax(p3).item()) clip_pred3_by_vidx[vidx].append(pred3) # clip_index might not be strictly sequential depending on sampler, but is useful for debugging. clip_idx = int(batch.get("clip_index")[i]) clip_idx_by_vidx[vidx].append(clip_idx) clip_probs3_by_vidx[vidx].append([float(p3[0].item()), float(p3[1].item()), float(p3[2].item())]) if clip_count[vidx] >= args.clips_per_video: finished_videos.add(vidx) # Print as soon as this video completes, while the run is still progressing. if args.print_clip_preds and vidx not in printed_videos: printed_videos.add(vidx) pred_vid, _ = _video_pred_for(vidx) true_vid = int(true_label3_by_idx.get(vidx, -1)) clip_preds = clip_pred3_by_vidx.get(vidx, []) clip_inds = clip_idx_by_vidx.get(vidx, []) clip_probs = clip_probs3_by_vidx.get(vidx, []) # Pair up clip_index with prediction for readability. paired = list(zip(clip_inds, clip_preds, clip_probs)) paired_sorted = sorted(paired, key=lambda x: x[0]) clip_str = ", ".join([f"{ci}:{THREE_CLASS_NAMES[pi]}" for ci, pi, _ in paired_sorted]) print(f"\n[VIDEO DONE {len(printed_videos)}/{num_videos}] {video_name_by_idx.get(vidx,'')}") print(f" true: {THREE_CLASS_NAMES[true_vid] if true_vid >= 0 else 'unknown'}") print(f" pred(video): {THREE_CLASS_NAMES[pred_vid]} (agg={args.aggregate})") print(f" clip preds: {clip_str}") print(" clip probs (feeding, normal, scared):") for ci, pi, p in paired_sorted: print( f" {ci:02d}:{THREE_CLASS_NAMES[pi]} " f"f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}" ) processed_clips += probs3.shape[0] if args.log_interval > 0 and (batch_idx % args.log_interval) == 0: dt = time.time() - t_last t_last = time.time() print( f"Processed {processed_clips}/{total_clips_expected} clips | " f"videos done {len(finished_videos)}/{num_videos} | " f"({dt:.1f}s/{args.log_interval} batches)" ) # Final per-video predictions. preds3: List[int] = [] labels3_list: List[int] = [] results = [] for vidx in sorted(video_name_by_idx.keys()): pred, probs3_list = _video_pred_for(vidx) true = int(true_label3_by_idx.get(vidx, -1)) preds3.append(pred) labels3_list.append(true) clip_preds = clip_pred3_by_vidx.get(vidx, []) clip_preds_named = [THREE_CLASS_NAMES[i] for i in clip_preds] clip_indices = clip_idx_by_vidx.get(vidx, []) clip_probs = clip_probs3_by_vidx.get(vidx, []) if args.print_clip_preds: print(f"\nVideo {vidx}: {video_name_by_idx[vidx]}") print(f" true: {THREE_CLASS_NAMES[true] if true >= 0 else 'unknown'}") print(f" pred(video): {THREE_CLASS_NAMES[pred]} (agg={args.aggregate})") print(f" clip preds ({len(clip_preds_named)}): {clip_preds_named}") paired = list(zip(clip_indices, clip_preds, clip_probs)) paired_sorted = sorted(paired, key=lambda x: x[0]) print(" clip probs (feeding, normal, scared):") for ci, pi, p in paired_sorted: print(f" {ci:02d}:{THREE_CLASS_NAMES[pi]} f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}") results.append( { "video_index": vidx, "video_name": video_name_by_idx[vidx], "pred_3class": THREE_CLASS_NAMES[pred], "pred_3class_idx": pred, "true_3class_idx": true, "aggregate": args.aggregate, "probs_3class_mean": { "feeding": float((prob_sum[vidx] / max(clip_count[vidx], 1))[0].item()), "normal": float((prob_sum[vidx] / max(clip_count[vidx], 1))[1].item()), "scared": float((prob_sum[vidx] / max(clip_count[vidx], 1))[2].item()), }, "probs_3class_used": { "feeding": probs3_list[0], "normal": probs3_list[1], "scared": probs3_list[2], }, "clip_preds_3class_idx": clip_preds, "clip_preds_3class": clip_preds_named, "clip_indices": clip_indices, "clips": [ { "clip_index": int(ci), "pred_3class": THREE_CLASS_NAMES[int(pi)], "pred_3class_idx": int(pi), "probs_3class": {"feeding": float(p[0]), "normal": float(p[1]), "scared": float(p[2])}, } for ci, pi, p in sorted(list(zip(clip_indices, clip_preds, clip_probs)), key=lambda x: x[0]) ], } ) correct = sum(int(p == t) for p, t in zip(preds3, labels3_list) if t >= 0) total = sum(int(t >= 0) for t in labels3_list) acc = correct / max(total, 1) cm = confusion_matrix_3(preds3, labels3_list) print(f"Video-level 3-class accuracy: {acc*100:.2f}% ({correct}/{total})") print("Confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:") print(cm) if args.output_json: out_path = os.path.abspath(os.path.expanduser(args.output_json)) with open(out_path, "w") as f: json.dump(results, f, indent=2, ensure_ascii=False) print(f"Wrote per-video predictions to: {out_path}") if args.output_json_mismatches: mismatches = [ r for r in results if int(r.get("true_3class_idx", -1)) >= 0 and int(r.get("pred_3class_idx", -2)) != int(r.get("true_3class_idx", -1)) ] out_path = os.path.abspath(os.path.expanduser(args.output_json_mismatches)) with open(out_path, "w") as f: json.dump(mismatches, f, indent=2, ensure_ascii=False) print(f"Wrote mismatches-only predictions to: {out_path} ({len(mismatches)}/{len(results)})") if __name__ == "__main__": main()