Files
FishServer/FishAction/test_pytorchvideo_x3d_3class.py

184 lines
6.5 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Evaluate a trained PyTorchVideo X3D checkpoint as a 3-class classifier.
This script loads a `checkpoint_best.pt` produced by `train_pytorchvideo_x3d.py` and runs
inference on a CSV split (val.csv or test.csv), then reports 3-class accuracy and a
3x3 confusion matrix.
CSV format (whitespace-separated):
relative/path/to/video.mp4 <label_int>
Example:
python test_pytorchvideo_x3d_3class.py \
--checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \
--csv_dir /home/ubuntu/data/fish/fish_action_videos \
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
--split val.csv \
--batch_size 4 \
--num_workers 4
"""
from __future__ import annotations
import argparse
import os
import time
import warnings
from typing import List, Tuple
import torch
from torch.utils.data import DataLoader
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
from train_pytorchvideo_x3d import (
VideoTransform,
_parse_path_label_line,
build_pretrained_x3d,
replace_last_linear,
)
THREE_CLASS_NAMES = ["feeding", "normal", "scared"]
def read_csv_list(csv_path: str, path_prefix: str) -> List[Tuple[str, Dict]]:
items: List[Tuple[str, Dict]] = []
with open(csv_path, "r") as f:
for line in f:
parsed = _parse_path_label_line(line)
if parsed is None:
continue
rel_path, label = parsed
full = os.path.join(path_prefix, rel_path)
items.append((full, {"label": int(label)}))
return items
def confusion_matrix_3(preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
preds/labels: int tensors of shape (N,) with values in [0,2]
returns: (3,3) where rows=true, cols=pred
"""
cm = torch.zeros((3, 3), dtype=torch.long)
for t, p in zip(labels.tolist(), preds.tolist()):
cm[t][p] += 1
return cm
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint_best.pt")
parser.add_argument("--csv_dir", type=str, required=True, help="Folder containing split CSVs (and label_map.txt if available)")
parser.add_argument("--path_prefix", type=str, required=True, help="Absolute path prefix for videos")
parser.add_argument("--split", type=str, default="val.csv", help="CSV filename inside csv_dir (e.g. val.csv or test.csv)")
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument(
"--log_interval",
type=int,
default=20,
help="Print progress every N batches (default: 20). Set 0 to disable.",
)
parser.add_argument(
"--max_batches",
type=int,
default=0,
help="If >0, stop after this many batches (useful for quick sanity checks).",
)
parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)")
args = parser.parse_args()
device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu")
csv_dir = os.path.abspath(os.path.expanduser(args.csv_dir))
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
split_path = os.path.join(csv_dir, args.split)
if not os.path.isfile(split_path):
raise FileNotFoundError(f"Split CSV not found: {split_path}")
ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint))
# Torch 2.5+ warns about pickle by default. This is your own checkpoint, so we silence it.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = ckpt.get("cfg", {})
model_name = cfg.get("model", "x3d_m")
num_frames = int(cfg.get("num_frames", 16))
sampling_rate = int(cfg.get("sampling_rate", 5))
target_fps = int(cfg.get("target_fps", 30))
num_classes = int(cfg.get("num_classes", 3))
clip_duration = num_frames * sampling_rate / float(target_fps)
# Build dataset
items = read_csv_list(split_path, path_prefix)
print(f"Loaded {len(items)} samples from {split_path}")
print(
f"Decoding clips with pyav (clip_duration={clip_duration:.2f}s, num_frames={num_frames}). "
f"If this looks stuck, try --num_workers 0."
)
ds = LabeledVideoDataset(
labeled_video_paths=items,
clip_sampler=make_clip_sampler("uniform", clip_duration),
decode_audio=False,
decoder="pyav",
transform=VideoTransform(num_frames=num_frames, train=False),
)
loader = DataLoader(
ds,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=(device == "cuda"),
persistent_workers=bool(args.num_workers > 0),
prefetch_factor=2 if args.num_workers > 0 else None,
)
# Build model architecture and load weights.
model = build_pretrained_x3d(model_name=model_name, pretrained=False)
# Ensure head shape matches checkpoint state_dict (trained as 3-class).
replace_last_linear(model, num_classes=num_classes)
model.load_state_dict(ckpt["model"], strict=True)
model = model.to(device)
model.eval()
total = 0
correct3 = 0
cm = torch.zeros((3, 3), dtype=torch.long)
with torch.no_grad():
t_last = time.time()
for batch_idx, (videos, labels3) in enumerate(loader, start=1):
videos = videos.to(device, non_blocking=True)
labels3 = labels3.to(device, non_blocking=True)
logits3 = model(videos)
preds3 = torch.argmax(logits3, dim=1)
total += labels3.numel()
correct3 += (preds3 == labels3).sum().item()
cm += confusion_matrix_3(preds3.cpu(), labels3.cpu())
if args.log_interval > 0 and (batch_idx % args.log_interval) == 0:
dt = time.time() - t_last
t_last = time.time()
print(
f"Processed batch {batch_idx:05d} "
f"({total} clips) | acc so far {(correct3/max(total,1))*100:.2f}% "
f"({dt:.1f}s/{args.log_interval} batches)"
)
if args.max_batches and batch_idx >= args.max_batches:
break
acc3 = correct3 / max(total, 1)
print(f"3-class accuracy: {acc3*100:.2f}% ({correct3}/{total})")
print("3-class confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:")
print(cm)
if __name__ == "__main__":
main()