464 lines
19 KiB
Python
464 lines
19 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Video-level prediction by aggregating multiple clip predictions (PyTorchVideo X3D).
|
|
|
|
This script supports:
|
|
- **Native 3-class checkpoints** (no mapping/merging required)
|
|
- **Legacy 5-class checkpoints** evaluated as 3-class via merging:
|
|
- feeding -> feeding
|
|
- normal_underwater + normal_upperwater -> normal
|
|
- scared_underwater + scared_upperwater -> scared
|
|
|
|
Instead of printing per-clip labels, this script:
|
|
- samples N clips per video
|
|
- runs inference per clip
|
|
- aggregates probabilities across clips (mean)
|
|
- outputs one label per video
|
|
- prints progress: clips processed, videos finished
|
|
|
|
Example:
|
|
python predict_video_x3d_3class.py \
|
|
--checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \
|
|
--csv /home/ubuntu/data/fish/fish_action_videos/val.csv \
|
|
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
|
|
--clips_per_video 10 \
|
|
--batch_size 8 \
|
|
--num_workers 4
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
import warnings
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
|
|
|
|
from train_pytorchvideo_x3d import VideoTransform, _parse_path_label_line, build_pretrained_x3d, replace_last_linear
|
|
|
|
|
|
THREE_CLASS_NAMES = ["feeding", "normal", "scared"]
|
|
|
|
|
|
def read_label_map(label_map_path: str) -> Dict[str, int]:
|
|
out: Dict[str, int] = {}
|
|
with open(label_map_path, "r") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
idx_str, name = line.split(maxsplit=1)
|
|
out[name] = int(idx_str)
|
|
return out
|
|
|
|
|
|
def build_merge_indices(label_map_dir: str) -> Dict[str, List[int]]:
|
|
"""
|
|
5->3 merge map by indices. Uses label_map.txt if present, else defaults:
|
|
0 feeding
|
|
1 normal_underwater
|
|
2 normal_upperwater
|
|
3 scared_underwater
|
|
4 scared_upperwater
|
|
"""
|
|
label_map_path = os.path.join(label_map_dir, "label_map.txt")
|
|
if os.path.isfile(label_map_path):
|
|
lm = read_label_map(label_map_path)
|
|
# If label_map.txt already defines 3-class names, treat as identity merge.
|
|
if all(k in lm for k in ["feeding", "normal", "scared"]):
|
|
# This supports native 3-class datasets/checkpoints.
|
|
return {"feeding": [lm["feeding"]], "normal": [lm["normal"]], "scared": [lm["scared"]]}
|
|
return {
|
|
"feeding": [lm["feeding"]],
|
|
"normal": [lm["normal_underwater"], lm["normal_upperwater"]],
|
|
"scared": [lm["scared_underwater"], lm["scared_upperwater"]],
|
|
}
|
|
return {"feeding": [0], "normal": [1, 2], "scared": [3, 4]}
|
|
|
|
|
|
def collapse_logits_5_to_3(logits_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor:
|
|
probs = F.softmax(logits_5, dim=1)
|
|
p0 = probs[:, merge["feeding"]].sum(dim=1)
|
|
p1 = probs[:, merge["normal"]].sum(dim=1)
|
|
p2 = probs[:, merge["scared"]].sum(dim=1)
|
|
probs3 = torch.stack([p0, p1, p2], dim=1).clamp_min(1e-12)
|
|
return probs3
|
|
|
|
|
|
def map_label_5_to_3(labels_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor:
|
|
inv: Dict[int, int] = {}
|
|
for new_idx, name in enumerate(THREE_CLASS_NAMES):
|
|
for old in merge[name]:
|
|
inv[int(old)] = new_idx
|
|
return torch.tensor([inv[int(x)] for x in labels_5.tolist()], device=labels_5.device, dtype=torch.long)
|
|
|
|
|
|
class EvalTransform:
|
|
"""
|
|
Wrap VideoTransform but keep metadata needed for video-level aggregation.
|
|
Output is a dict so DataLoader collates it nicely.
|
|
"""
|
|
|
|
def __init__(self, num_frames: int):
|
|
self._vt = VideoTransform(num_frames=num_frames, train=False)
|
|
|
|
def __call__(self, sample: Dict):
|
|
video, label = self._vt(sample) # (C,T,H,W), int tensor
|
|
return {
|
|
"video": video,
|
|
"label": label,
|
|
"video_name": sample.get("video_name", ""),
|
|
"video_index": int(sample.get("video_index", -1)),
|
|
"clip_index": int(sample.get("clip_index", -1)),
|
|
}
|
|
|
|
|
|
def confusion_matrix_3(preds: List[int], labels: List[int]) -> List[List[int]]:
|
|
cm = [[0, 0, 0] for _ in range(3)]
|
|
for t, p in zip(labels, preds):
|
|
cm[t][p] += 1
|
|
return cm
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--checkpoint", type=str, required=True)
|
|
parser.add_argument("--csv", type=str, required=True, help="CSV split file (e.g. val.csv)")
|
|
parser.add_argument("--path_prefix", type=str, required=True)
|
|
parser.add_argument("--clips_per_video", type=int, default=10)
|
|
parser.add_argument("--batch_size", type=int, default=8)
|
|
parser.add_argument("--num_workers", type=int, default=4)
|
|
parser.add_argument("--log_interval", type=int, default=20, help="Log every N batches. 0 disables.")
|
|
parser.add_argument(
|
|
"--aggregate",
|
|
type=str,
|
|
default="vote",
|
|
choices=["mean", "max", "vote", "any_scared"],
|
|
help=(
|
|
"Video-level aggregation strategy over clip predictions: "
|
|
"`mean`=mean probability (default), "
|
|
"`max`=max probability pooling per class, "
|
|
"`vote`=majority vote over per-clip argmax (ties broken by mean prob), "
|
|
"`any_scared`=predict scared if ANY clip is scared (argmax==scared OR max scared prob >= threshold)."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--scared_threshold",
|
|
type=float,
|
|
default=0.0,
|
|
help=(
|
|
"Used only for --aggregate any_scared. If >0, predict scared when the maximum "
|
|
"per-clip scared probability >= this threshold. If 0, only uses argmax==scared."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--any_scared_fallback",
|
|
type=str,
|
|
default="mean",
|
|
choices=["mean", "max"],
|
|
help="Used only for --aggregate any_scared when scared is not triggered.",
|
|
)
|
|
parser.add_argument(
|
|
"--print_clip_preds",
|
|
action="store_true",
|
|
help="Print per-video per-clip predicted classes to stdout (can be verbose).",
|
|
)
|
|
parser.add_argument("--output_json", type=str, default="", help="Optional path to write per-video predictions JSON.")
|
|
parser.add_argument(
|
|
"--output_json_mismatches",
|
|
type=str,
|
|
default="wrong.json",
|
|
help="Optional path to write a JSON containing only videos where pred != true (and true label is known).",
|
|
)
|
|
parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)")
|
|
args = parser.parse_args()
|
|
|
|
device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
csv_path = os.path.abspath(os.path.expanduser(args.csv))
|
|
csv_dir = os.path.dirname(csv_path)
|
|
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
|
|
|
|
if not os.path.isfile(csv_path):
|
|
raise FileNotFoundError(f"CSV not found: {csv_path}")
|
|
|
|
# Load checkpoint safely (this is your own file).
|
|
ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint))
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*")
|
|
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
|
|
|
cfg = ckpt.get("cfg", {})
|
|
model_name = cfg.get("model", "x3d_m")
|
|
num_frames = int(cfg.get("num_frames", 16))
|
|
sampling_rate = int(cfg.get("sampling_rate", 5))
|
|
target_fps = int(cfg.get("target_fps", 30))
|
|
num_classes = int(cfg.get("num_classes", 5))
|
|
clip_duration = num_frames * sampling_rate / float(target_fps)
|
|
|
|
merge = build_merge_indices(csv_dir) if num_classes != 3 else {"feeding": [0], "normal": [1], "scared": [2]}
|
|
|
|
# Read CSV -> list[(path, {"label": int})]
|
|
labeled_paths: List[Tuple[str, Dict]] = []
|
|
with open(csv_path, "r") as f:
|
|
for line in f:
|
|
parsed = _parse_path_label_line(line)
|
|
if parsed is None:
|
|
continue
|
|
rel_path, label = parsed
|
|
labeled_paths.append((os.path.join(path_prefix, rel_path), {"label": int(label)}))
|
|
|
|
num_videos = len(labeled_paths)
|
|
total_clips_expected = num_videos * args.clips_per_video
|
|
print(
|
|
f"Videos: {num_videos} | clips_per_video: {args.clips_per_video} "
|
|
f"(expected clips: {total_clips_expected})"
|
|
)
|
|
|
|
ds = LabeledVideoDataset(
|
|
labeled_video_paths=labeled_paths,
|
|
clip_sampler=make_clip_sampler("constant_clips_per_video", clip_duration, args.clips_per_video),
|
|
decode_audio=False,
|
|
decoder="pyav",
|
|
transform=EvalTransform(num_frames=num_frames),
|
|
)
|
|
loader = DataLoader(
|
|
ds,
|
|
batch_size=args.batch_size,
|
|
num_workers=args.num_workers,
|
|
pin_memory=(device == "cuda"),
|
|
persistent_workers=bool(args.num_workers > 0),
|
|
prefetch_factor=2 if args.num_workers > 0 else None,
|
|
)
|
|
|
|
# Build model architecture + load weights.
|
|
model = build_pretrained_x3d(model_name=model_name, pretrained=False)
|
|
replace_last_linear(model, num_classes=num_classes)
|
|
model.load_state_dict(ckpt["model"], strict=True)
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
# Aggregate per video_index.
|
|
prob_sum = defaultdict(lambda: torch.zeros(3, dtype=torch.float32))
|
|
prob_max = defaultdict(lambda: torch.zeros(3, dtype=torch.float32))
|
|
clip_count = defaultdict(int)
|
|
video_name_by_idx: Dict[int, str] = {}
|
|
true_label3_by_idx: Dict[int, int] = {}
|
|
clip_pred3_by_vidx: Dict[int, List[int]] = defaultdict(list)
|
|
clip_idx_by_vidx: Dict[int, List[int]] = defaultdict(list)
|
|
clip_probs3_by_vidx: Dict[int, List[List[float]]] = defaultdict(list)
|
|
printed_videos = set()
|
|
|
|
processed_clips = 0
|
|
finished_videos = set()
|
|
t_last = time.time()
|
|
|
|
def _video_pred_for(vidx: int) -> Tuple[int, List[float]]:
|
|
"""
|
|
Returns (pred_idx_3class, probs_3class_list)
|
|
"""
|
|
mean_probs = (prob_sum[vidx] / max(clip_count[vidx], 1))
|
|
if args.aggregate == "mean":
|
|
probs = mean_probs
|
|
pred = int(torch.argmax(probs).item())
|
|
return pred, probs.numpy().tolist()
|
|
if args.aggregate == "max":
|
|
probs = prob_max[vidx]
|
|
pred = int(torch.argmax(probs).item())
|
|
return pred, probs.numpy().tolist()
|
|
if args.aggregate == "any_scared":
|
|
scared_idx = 2
|
|
max_scared_prob = float(prob_max[vidx][scared_idx].item())
|
|
any_argmax_scared = any(p == scared_idx for p in clip_pred3_by_vidx.get(vidx, []))
|
|
if any_argmax_scared or (args.scared_threshold > 0.0 and max_scared_prob >= float(args.scared_threshold)):
|
|
# Return max pooled probs for interpretability when we trigger on "any" behavior.
|
|
probs = prob_max[vidx]
|
|
return scared_idx, probs.numpy().tolist()
|
|
# Otherwise fall back to a smoother aggregate.
|
|
if args.any_scared_fallback == "max":
|
|
probs = prob_max[vidx]
|
|
pred = int(torch.argmax(probs).item())
|
|
return pred, probs.numpy().tolist()
|
|
probs = mean_probs
|
|
pred = int(torch.argmax(probs).item())
|
|
return pred, probs.numpy().tolist()
|
|
# majority vote; tie-break by mean prob
|
|
counts = [0, 0, 0]
|
|
for p in clip_pred3_by_vidx.get(vidx, []):
|
|
if 0 <= p <= 2:
|
|
counts[p] += 1
|
|
max_count = max(counts)
|
|
winners = [i for i, c in enumerate(counts) if c == max_count]
|
|
if len(winners) == 1:
|
|
pred = winners[0]
|
|
else:
|
|
# tie-break using mean probability
|
|
pred = int(torch.argmax(mean_probs[winners]).item())
|
|
pred = winners[pred]
|
|
# Return mean probs for reporting consistency
|
|
return pred, mean_probs.numpy().tolist()
|
|
|
|
with torch.no_grad():
|
|
for batch_idx, batch in enumerate(loader, start=1):
|
|
videos = batch["video"].to(device, non_blocking=True)
|
|
labels_raw = batch["label"].to(device, non_blocking=True)
|
|
vinds = batch["video_index"]
|
|
vnames = batch["video_name"]
|
|
|
|
logits = model(videos)
|
|
if num_classes == 3:
|
|
probs3 = F.softmax(logits, dim=1) # (B,3)
|
|
labels3 = labels_raw # already 0..2
|
|
else:
|
|
# Legacy: 5-class -> merge to 3-class.
|
|
probs3 = collapse_logits_5_to_3(logits, merge) # (B,3) probs
|
|
labels3 = map_label_5_to_3(labels_raw, merge) # (B,)
|
|
|
|
for i in range(probs3.shape[0]):
|
|
vidx = int(vinds[i])
|
|
video_name_by_idx[vidx] = str(vnames[i])
|
|
true_label3_by_idx[vidx] = int(labels3[i].item())
|
|
p3 = probs3[i].detach().cpu()
|
|
prob_sum[vidx] += p3
|
|
prob_max[vidx] = torch.maximum(prob_max[vidx], p3)
|
|
clip_count[vidx] += 1
|
|
pred3 = int(torch.argmax(p3).item())
|
|
clip_pred3_by_vidx[vidx].append(pred3)
|
|
# clip_index might not be strictly sequential depending on sampler, but is useful for debugging.
|
|
clip_idx = int(batch.get("clip_index")[i])
|
|
clip_idx_by_vidx[vidx].append(clip_idx)
|
|
clip_probs3_by_vidx[vidx].append([float(p3[0].item()), float(p3[1].item()), float(p3[2].item())])
|
|
if clip_count[vidx] >= args.clips_per_video:
|
|
finished_videos.add(vidx)
|
|
# Print as soon as this video completes, while the run is still progressing.
|
|
if args.print_clip_preds and vidx not in printed_videos:
|
|
printed_videos.add(vidx)
|
|
pred_vid, _ = _video_pred_for(vidx)
|
|
true_vid = int(true_label3_by_idx.get(vidx, -1))
|
|
clip_preds = clip_pred3_by_vidx.get(vidx, [])
|
|
clip_inds = clip_idx_by_vidx.get(vidx, [])
|
|
clip_probs = clip_probs3_by_vidx.get(vidx, [])
|
|
# Pair up clip_index with prediction for readability.
|
|
paired = list(zip(clip_inds, clip_preds, clip_probs))
|
|
paired_sorted = sorted(paired, key=lambda x: x[0])
|
|
clip_str = ", ".join([f"{ci}:{THREE_CLASS_NAMES[pi]}" for ci, pi, _ in paired_sorted])
|
|
print(f"\n[VIDEO DONE {len(printed_videos)}/{num_videos}] {video_name_by_idx.get(vidx,'')}")
|
|
print(f" true: {THREE_CLASS_NAMES[true_vid] if true_vid >= 0 else 'unknown'}")
|
|
print(f" pred(video): {THREE_CLASS_NAMES[pred_vid]} (agg={args.aggregate})")
|
|
print(f" clip preds: {clip_str}")
|
|
print(" clip probs (feeding, normal, scared):")
|
|
for ci, pi, p in paired_sorted:
|
|
print(
|
|
f" {ci:02d}:{THREE_CLASS_NAMES[pi]} "
|
|
f"f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}"
|
|
)
|
|
|
|
processed_clips += probs3.shape[0]
|
|
|
|
if args.log_interval > 0 and (batch_idx % args.log_interval) == 0:
|
|
dt = time.time() - t_last
|
|
t_last = time.time()
|
|
print(
|
|
f"Processed {processed_clips}/{total_clips_expected} clips | "
|
|
f"videos done {len(finished_videos)}/{num_videos} | "
|
|
f"({dt:.1f}s/{args.log_interval} batches)"
|
|
)
|
|
|
|
# Final per-video predictions.
|
|
preds3: List[int] = []
|
|
labels3_list: List[int] = []
|
|
results = []
|
|
for vidx in sorted(video_name_by_idx.keys()):
|
|
pred, probs3_list = _video_pred_for(vidx)
|
|
true = int(true_label3_by_idx.get(vidx, -1))
|
|
preds3.append(pred)
|
|
labels3_list.append(true)
|
|
|
|
clip_preds = clip_pred3_by_vidx.get(vidx, [])
|
|
clip_preds_named = [THREE_CLASS_NAMES[i] for i in clip_preds]
|
|
clip_indices = clip_idx_by_vidx.get(vidx, [])
|
|
clip_probs = clip_probs3_by_vidx.get(vidx, [])
|
|
|
|
if args.print_clip_preds:
|
|
print(f"\nVideo {vidx}: {video_name_by_idx[vidx]}")
|
|
print(f" true: {THREE_CLASS_NAMES[true] if true >= 0 else 'unknown'}")
|
|
print(f" pred(video): {THREE_CLASS_NAMES[pred]} (agg={args.aggregate})")
|
|
print(f" clip preds ({len(clip_preds_named)}): {clip_preds_named}")
|
|
paired = list(zip(clip_indices, clip_preds, clip_probs))
|
|
paired_sorted = sorted(paired, key=lambda x: x[0])
|
|
print(" clip probs (feeding, normal, scared):")
|
|
for ci, pi, p in paired_sorted:
|
|
print(f" {ci:02d}:{THREE_CLASS_NAMES[pi]} f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}")
|
|
|
|
results.append(
|
|
{
|
|
"video_index": vidx,
|
|
"video_name": video_name_by_idx[vidx],
|
|
"pred_3class": THREE_CLASS_NAMES[pred],
|
|
"pred_3class_idx": pred,
|
|
"true_3class_idx": true,
|
|
"aggregate": args.aggregate,
|
|
"probs_3class_mean": {
|
|
"feeding": float((prob_sum[vidx] / max(clip_count[vidx], 1))[0].item()),
|
|
"normal": float((prob_sum[vidx] / max(clip_count[vidx], 1))[1].item()),
|
|
"scared": float((prob_sum[vidx] / max(clip_count[vidx], 1))[2].item()),
|
|
},
|
|
"probs_3class_used": {
|
|
"feeding": probs3_list[0],
|
|
"normal": probs3_list[1],
|
|
"scared": probs3_list[2],
|
|
},
|
|
"clip_preds_3class_idx": clip_preds,
|
|
"clip_preds_3class": clip_preds_named,
|
|
"clip_indices": clip_indices,
|
|
"clips": [
|
|
{
|
|
"clip_index": int(ci),
|
|
"pred_3class": THREE_CLASS_NAMES[int(pi)],
|
|
"pred_3class_idx": int(pi),
|
|
"probs_3class": {"feeding": float(p[0]), "normal": float(p[1]), "scared": float(p[2])},
|
|
}
|
|
for ci, pi, p in sorted(list(zip(clip_indices, clip_preds, clip_probs)), key=lambda x: x[0])
|
|
],
|
|
}
|
|
)
|
|
|
|
correct = sum(int(p == t) for p, t in zip(preds3, labels3_list) if t >= 0)
|
|
total = sum(int(t >= 0) for t in labels3_list)
|
|
acc = correct / max(total, 1)
|
|
cm = confusion_matrix_3(preds3, labels3_list)
|
|
|
|
print(f"Video-level 3-class accuracy: {acc*100:.2f}% ({correct}/{total})")
|
|
print("Confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:")
|
|
print(cm)
|
|
|
|
if args.output_json:
|
|
out_path = os.path.abspath(os.path.expanduser(args.output_json))
|
|
with open(out_path, "w") as f:
|
|
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
print(f"Wrote per-video predictions to: {out_path}")
|
|
|
|
if args.output_json_mismatches:
|
|
mismatches = [
|
|
r
|
|
for r in results
|
|
if int(r.get("true_3class_idx", -1)) >= 0 and int(r.get("pred_3class_idx", -2)) != int(r.get("true_3class_idx", -1))
|
|
]
|
|
out_path = os.path.abspath(os.path.expanduser(args.output_json_mismatches))
|
|
with open(out_path, "w") as f:
|
|
json.dump(mismatches, f, indent=2, ensure_ascii=False)
|
|
print(f"Wrote mismatches-only predictions to: {out_path} ({len(mismatches)}/{len(results)})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|