184 lines
6.5 KiB
Python
Executable File
184 lines
6.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Evaluate a trained PyTorchVideo X3D checkpoint as a 3-class classifier.
|
|
|
|
This script loads a `checkpoint_best.pt` produced by `train_pytorchvideo_x3d.py` and runs
|
|
inference on a CSV split (val.csv or test.csv), then reports 3-class accuracy and a
|
|
3x3 confusion matrix.
|
|
|
|
CSV format (whitespace-separated):
|
|
relative/path/to/video.mp4 <label_int>
|
|
|
|
Example:
|
|
python test_pytorchvideo_x3d_3class.py \
|
|
--checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \
|
|
--csv_dir /home/ubuntu/data/fish/fish_action_videos \
|
|
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
|
|
--split val.csv \
|
|
--batch_size 4 \
|
|
--num_workers 4
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
import warnings
|
|
from typing import List, Tuple
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
|
|
|
|
from train_pytorchvideo_x3d import (
|
|
VideoTransform,
|
|
_parse_path_label_line,
|
|
build_pretrained_x3d,
|
|
replace_last_linear,
|
|
)
|
|
|
|
|
|
THREE_CLASS_NAMES = ["feeding", "normal", "scared"]
|
|
|
|
|
|
def read_csv_list(csv_path: str, path_prefix: str) -> List[Tuple[str, Dict]]:
|
|
items: List[Tuple[str, Dict]] = []
|
|
with open(csv_path, "r") as f:
|
|
for line in f:
|
|
parsed = _parse_path_label_line(line)
|
|
if parsed is None:
|
|
continue
|
|
rel_path, label = parsed
|
|
full = os.path.join(path_prefix, rel_path)
|
|
items.append((full, {"label": int(label)}))
|
|
return items
|
|
|
|
|
|
def confusion_matrix_3(preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
preds/labels: int tensors of shape (N,) with values in [0,2]
|
|
returns: (3,3) where rows=true, cols=pred
|
|
"""
|
|
cm = torch.zeros((3, 3), dtype=torch.long)
|
|
for t, p in zip(labels.tolist(), preds.tolist()):
|
|
cm[t][p] += 1
|
|
return cm
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint_best.pt")
|
|
parser.add_argument("--csv_dir", type=str, required=True, help="Folder containing split CSVs (and label_map.txt if available)")
|
|
parser.add_argument("--path_prefix", type=str, required=True, help="Absolute path prefix for videos")
|
|
parser.add_argument("--split", type=str, default="val.csv", help="CSV filename inside csv_dir (e.g. val.csv or test.csv)")
|
|
parser.add_argument("--batch_size", type=int, default=4)
|
|
parser.add_argument("--num_workers", type=int, default=4)
|
|
parser.add_argument(
|
|
"--log_interval",
|
|
type=int,
|
|
default=20,
|
|
help="Print progress every N batches (default: 20). Set 0 to disable.",
|
|
)
|
|
parser.add_argument(
|
|
"--max_batches",
|
|
type=int,
|
|
default=0,
|
|
help="If >0, stop after this many batches (useful for quick sanity checks).",
|
|
)
|
|
parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)")
|
|
args = parser.parse_args()
|
|
|
|
device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
csv_dir = os.path.abspath(os.path.expanduser(args.csv_dir))
|
|
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
|
|
split_path = os.path.join(csv_dir, args.split)
|
|
|
|
if not os.path.isfile(split_path):
|
|
raise FileNotFoundError(f"Split CSV not found: {split_path}")
|
|
|
|
ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint))
|
|
# Torch 2.5+ warns about pickle by default. This is your own checkpoint, so we silence it.
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*")
|
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
cfg = ckpt.get("cfg", {})
|
|
model_name = cfg.get("model", "x3d_m")
|
|
num_frames = int(cfg.get("num_frames", 16))
|
|
sampling_rate = int(cfg.get("sampling_rate", 5))
|
|
target_fps = int(cfg.get("target_fps", 30))
|
|
num_classes = int(cfg.get("num_classes", 3))
|
|
clip_duration = num_frames * sampling_rate / float(target_fps)
|
|
|
|
# Build dataset
|
|
items = read_csv_list(split_path, path_prefix)
|
|
print(f"Loaded {len(items)} samples from {split_path}")
|
|
print(
|
|
f"Decoding clips with pyav (clip_duration={clip_duration:.2f}s, num_frames={num_frames}). "
|
|
f"If this looks stuck, try --num_workers 0."
|
|
)
|
|
ds = LabeledVideoDataset(
|
|
labeled_video_paths=items,
|
|
clip_sampler=make_clip_sampler("uniform", clip_duration),
|
|
decode_audio=False,
|
|
decoder="pyav",
|
|
transform=VideoTransform(num_frames=num_frames, train=False),
|
|
)
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
pin_memory=(device == "cuda"),
|
|
persistent_workers=bool(args.num_workers > 0),
|
|
prefetch_factor=2 if args.num_workers > 0 else None,
|
|
)
|
|
|
|
# Build model architecture and load weights.
|
|
model = build_pretrained_x3d(model_name=model_name, pretrained=False)
|
|
# Ensure head shape matches checkpoint state_dict (trained as 3-class).
|
|
replace_last_linear(model, num_classes=num_classes)
|
|
model.load_state_dict(ckpt["model"], strict=True)
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
total = 0
|
|
correct3 = 0
|
|
cm = torch.zeros((3, 3), dtype=torch.long)
|
|
|
|
with torch.no_grad():
|
|
t_last = time.time()
|
|
for batch_idx, (videos, labels3) in enumerate(loader, start=1):
|
|
videos = videos.to(device, non_blocking=True)
|
|
labels3 = labels3.to(device, non_blocking=True)
|
|
|
|
logits3 = model(videos)
|
|
preds3 = torch.argmax(logits3, dim=1)
|
|
|
|
total += labels3.numel()
|
|
correct3 += (preds3 == labels3).sum().item()
|
|
cm += confusion_matrix_3(preds3.cpu(), labels3.cpu())
|
|
|
|
if args.log_interval > 0 and (batch_idx % args.log_interval) == 0:
|
|
dt = time.time() - t_last
|
|
t_last = time.time()
|
|
print(
|
|
f"Processed batch {batch_idx:05d} "
|
|
f"({total} clips) | acc so far {(correct3/max(total,1))*100:.2f}% "
|
|
f"({dt:.1f}s/{args.log_interval} batches)"
|
|
)
|
|
|
|
if args.max_batches and batch_idx >= args.max_batches:
|
|
break
|
|
|
|
acc3 = correct3 / max(total, 1)
|
|
print(f"3-class accuracy: {acc3*100:.2f}% ({correct3}/{total})")
|
|
print("3-class confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:")
|
|
print(cm)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|