Files
FishServer/FishAction/predict_video_x3d_3class.py
2026-05-06 15:59:38 +08:00

464 lines
19 KiB
Python

#!/usr/bin/env python3
"""
Video-level prediction by aggregating multiple clip predictions (PyTorchVideo X3D).
This script supports:
- **Native 3-class checkpoints** (no mapping/merging required)
- **Legacy 5-class checkpoints** evaluated as 3-class via merging:
- feeding -> feeding
- normal_underwater + normal_upperwater -> normal
- scared_underwater + scared_upperwater -> scared
Instead of printing per-clip labels, this script:
- samples N clips per video
- runs inference per clip
- aggregates probabilities across clips (mean)
- outputs one label per video
- prints progress: clips processed, videos finished
Example:
python predict_video_x3d_3class.py \
--checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \
--csv /home/ubuntu/data/fish/fish_action_videos/val.csv \
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
--clips_per_video 10 \
--batch_size 8 \
--num_workers 4
"""
from __future__ import annotations
import argparse
import json
import os
import time
import warnings
from collections import defaultdict
from typing import Dict, List, Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
from train_pytorchvideo_x3d import VideoTransform, _parse_path_label_line, build_pretrained_x3d, replace_last_linear
THREE_CLASS_NAMES = ["feeding", "normal", "scared"]
def read_label_map(label_map_path: str) -> Dict[str, int]:
out: Dict[str, int] = {}
with open(label_map_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
idx_str, name = line.split(maxsplit=1)
out[name] = int(idx_str)
return out
def build_merge_indices(label_map_dir: str) -> Dict[str, List[int]]:
"""
5->3 merge map by indices. Uses label_map.txt if present, else defaults:
0 feeding
1 normal_underwater
2 normal_upperwater
3 scared_underwater
4 scared_upperwater
"""
label_map_path = os.path.join(label_map_dir, "label_map.txt")
if os.path.isfile(label_map_path):
lm = read_label_map(label_map_path)
# If label_map.txt already defines 3-class names, treat as identity merge.
if all(k in lm for k in ["feeding", "normal", "scared"]):
# This supports native 3-class datasets/checkpoints.
return {"feeding": [lm["feeding"]], "normal": [lm["normal"]], "scared": [lm["scared"]]}
return {
"feeding": [lm["feeding"]],
"normal": [lm["normal_underwater"], lm["normal_upperwater"]],
"scared": [lm["scared_underwater"], lm["scared_upperwater"]],
}
return {"feeding": [0], "normal": [1, 2], "scared": [3, 4]}
def collapse_logits_5_to_3(logits_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor:
probs = F.softmax(logits_5, dim=1)
p0 = probs[:, merge["feeding"]].sum(dim=1)
p1 = probs[:, merge["normal"]].sum(dim=1)
p2 = probs[:, merge["scared"]].sum(dim=1)
probs3 = torch.stack([p0, p1, p2], dim=1).clamp_min(1e-12)
return probs3
def map_label_5_to_3(labels_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor:
inv: Dict[int, int] = {}
for new_idx, name in enumerate(THREE_CLASS_NAMES):
for old in merge[name]:
inv[int(old)] = new_idx
return torch.tensor([inv[int(x)] for x in labels_5.tolist()], device=labels_5.device, dtype=torch.long)
class EvalTransform:
"""
Wrap VideoTransform but keep metadata needed for video-level aggregation.
Output is a dict so DataLoader collates it nicely.
"""
def __init__(self, num_frames: int):
self._vt = VideoTransform(num_frames=num_frames, train=False)
def __call__(self, sample: Dict):
video, label = self._vt(sample) # (C,T,H,W), int tensor
return {
"video": video,
"label": label,
"video_name": sample.get("video_name", ""),
"video_index": int(sample.get("video_index", -1)),
"clip_index": int(sample.get("clip_index", -1)),
}
def confusion_matrix_3(preds: List[int], labels: List[int]) -> List[List[int]]:
cm = [[0, 0, 0] for _ in range(3)]
for t, p in zip(labels, preds):
cm[t][p] += 1
return cm
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--csv", type=str, required=True, help="CSV split file (e.g. val.csv)")
parser.add_argument("--path_prefix", type=str, required=True)
parser.add_argument("--clips_per_video", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--log_interval", type=int, default=20, help="Log every N batches. 0 disables.")
parser.add_argument(
"--aggregate",
type=str,
default="vote",
choices=["mean", "max", "vote", "any_scared"],
help=(
"Video-level aggregation strategy over clip predictions: "
"`mean`=mean probability (default), "
"`max`=max probability pooling per class, "
"`vote`=majority vote over per-clip argmax (ties broken by mean prob), "
"`any_scared`=predict scared if ANY clip is scared (argmax==scared OR max scared prob >= threshold)."
),
)
parser.add_argument(
"--scared_threshold",
type=float,
default=0.0,
help=(
"Used only for --aggregate any_scared. If >0, predict scared when the maximum "
"per-clip scared probability >= this threshold. If 0, only uses argmax==scared."
),
)
parser.add_argument(
"--any_scared_fallback",
type=str,
default="mean",
choices=["mean", "max"],
help="Used only for --aggregate any_scared when scared is not triggered.",
)
parser.add_argument(
"--print_clip_preds",
action="store_true",
help="Print per-video per-clip predicted classes to stdout (can be verbose).",
)
parser.add_argument("--output_json", type=str, default="", help="Optional path to write per-video predictions JSON.")
parser.add_argument(
"--output_json_mismatches",
type=str,
default="wrong.json",
help="Optional path to write a JSON containing only videos where pred != true (and true label is known).",
)
parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)")
args = parser.parse_args()
device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu")
csv_path = os.path.abspath(os.path.expanduser(args.csv))
csv_dir = os.path.dirname(csv_path)
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
if not os.path.isfile(csv_path):
raise FileNotFoundError(f"CSV not found: {csv_path}")
# Load checkpoint safely (this is your own file).
ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
cfg = ckpt.get("cfg", {})
model_name = cfg.get("model", "x3d_m")
num_frames = int(cfg.get("num_frames", 16))
sampling_rate = int(cfg.get("sampling_rate", 5))
target_fps = int(cfg.get("target_fps", 30))
num_classes = int(cfg.get("num_classes", 5))
clip_duration = num_frames * sampling_rate / float(target_fps)
merge = build_merge_indices(csv_dir) if num_classes != 3 else {"feeding": [0], "normal": [1], "scared": [2]}
# Read CSV -> list[(path, {"label": int})]
labeled_paths: List[Tuple[str, Dict]] = []
with open(csv_path, "r") as f:
for line in f:
parsed = _parse_path_label_line(line)
if parsed is None:
continue
rel_path, label = parsed
labeled_paths.append((os.path.join(path_prefix, rel_path), {"label": int(label)}))
num_videos = len(labeled_paths)
total_clips_expected = num_videos * args.clips_per_video
print(
f"Videos: {num_videos} | clips_per_video: {args.clips_per_video} "
f"(expected clips: {total_clips_expected})"
)
ds = LabeledVideoDataset(
labeled_video_paths=labeled_paths,
clip_sampler=make_clip_sampler("constant_clips_per_video", clip_duration, args.clips_per_video),
decode_audio=False,
decoder="pyav",
transform=EvalTransform(num_frames=num_frames),
)
loader = DataLoader(
ds,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=(device == "cuda"),
persistent_workers=bool(args.num_workers > 0),
prefetch_factor=2 if args.num_workers > 0 else None,
)
# Build model architecture + load weights.
model = build_pretrained_x3d(model_name=model_name, pretrained=False)
replace_last_linear(model, num_classes=num_classes)
model.load_state_dict(ckpt["model"], strict=True)
model = model.to(device)
model.eval()
# Aggregate per video_index.
prob_sum = defaultdict(lambda: torch.zeros(3, dtype=torch.float32))
prob_max = defaultdict(lambda: torch.zeros(3, dtype=torch.float32))
clip_count = defaultdict(int)
video_name_by_idx: Dict[int, str] = {}
true_label3_by_idx: Dict[int, int] = {}
clip_pred3_by_vidx: Dict[int, List[int]] = defaultdict(list)
clip_idx_by_vidx: Dict[int, List[int]] = defaultdict(list)
clip_probs3_by_vidx: Dict[int, List[List[float]]] = defaultdict(list)
printed_videos = set()
processed_clips = 0
finished_videos = set()
t_last = time.time()
def _video_pred_for(vidx: int) -> Tuple[int, List[float]]:
"""
Returns (pred_idx_3class, probs_3class_list)
"""
mean_probs = (prob_sum[vidx] / max(clip_count[vidx], 1))
if args.aggregate == "mean":
probs = mean_probs
pred = int(torch.argmax(probs).item())
return pred, probs.numpy().tolist()
if args.aggregate == "max":
probs = prob_max[vidx]
pred = int(torch.argmax(probs).item())
return pred, probs.numpy().tolist()
if args.aggregate == "any_scared":
scared_idx = 2
max_scared_prob = float(prob_max[vidx][scared_idx].item())
any_argmax_scared = any(p == scared_idx for p in clip_pred3_by_vidx.get(vidx, []))
if any_argmax_scared or (args.scared_threshold > 0.0 and max_scared_prob >= float(args.scared_threshold)):
# Return max pooled probs for interpretability when we trigger on "any" behavior.
probs = prob_max[vidx]
return scared_idx, probs.numpy().tolist()
# Otherwise fall back to a smoother aggregate.
if args.any_scared_fallback == "max":
probs = prob_max[vidx]
pred = int(torch.argmax(probs).item())
return pred, probs.numpy().tolist()
probs = mean_probs
pred = int(torch.argmax(probs).item())
return pred, probs.numpy().tolist()
# majority vote; tie-break by mean prob
counts = [0, 0, 0]
for p in clip_pred3_by_vidx.get(vidx, []):
if 0 <= p <= 2:
counts[p] += 1
max_count = max(counts)
winners = [i for i, c in enumerate(counts) if c == max_count]
if len(winners) == 1:
pred = winners[0]
else:
# tie-break using mean probability
pred = int(torch.argmax(mean_probs[winners]).item())
pred = winners[pred]
# Return mean probs for reporting consistency
return pred, mean_probs.numpy().tolist()
with torch.no_grad():
for batch_idx, batch in enumerate(loader, start=1):
videos = batch["video"].to(device, non_blocking=True)
labels_raw = batch["label"].to(device, non_blocking=True)
vinds = batch["video_index"]
vnames = batch["video_name"]
logits = model(videos)
if num_classes == 3:
probs3 = F.softmax(logits, dim=1) # (B,3)
labels3 = labels_raw # already 0..2
else:
# Legacy: 5-class -> merge to 3-class.
probs3 = collapse_logits_5_to_3(logits, merge) # (B,3) probs
labels3 = map_label_5_to_3(labels_raw, merge) # (B,)
for i in range(probs3.shape[0]):
vidx = int(vinds[i])
video_name_by_idx[vidx] = str(vnames[i])
true_label3_by_idx[vidx] = int(labels3[i].item())
p3 = probs3[i].detach().cpu()
prob_sum[vidx] += p3
prob_max[vidx] = torch.maximum(prob_max[vidx], p3)
clip_count[vidx] += 1
pred3 = int(torch.argmax(p3).item())
clip_pred3_by_vidx[vidx].append(pred3)
# clip_index might not be strictly sequential depending on sampler, but is useful for debugging.
clip_idx = int(batch.get("clip_index")[i])
clip_idx_by_vidx[vidx].append(clip_idx)
clip_probs3_by_vidx[vidx].append([float(p3[0].item()), float(p3[1].item()), float(p3[2].item())])
if clip_count[vidx] >= args.clips_per_video:
finished_videos.add(vidx)
# Print as soon as this video completes, while the run is still progressing.
if args.print_clip_preds and vidx not in printed_videos:
printed_videos.add(vidx)
pred_vid, _ = _video_pred_for(vidx)
true_vid = int(true_label3_by_idx.get(vidx, -1))
clip_preds = clip_pred3_by_vidx.get(vidx, [])
clip_inds = clip_idx_by_vidx.get(vidx, [])
clip_probs = clip_probs3_by_vidx.get(vidx, [])
# Pair up clip_index with prediction for readability.
paired = list(zip(clip_inds, clip_preds, clip_probs))
paired_sorted = sorted(paired, key=lambda x: x[0])
clip_str = ", ".join([f"{ci}:{THREE_CLASS_NAMES[pi]}" for ci, pi, _ in paired_sorted])
print(f"\n[VIDEO DONE {len(printed_videos)}/{num_videos}] {video_name_by_idx.get(vidx,'')}")
print(f" true: {THREE_CLASS_NAMES[true_vid] if true_vid >= 0 else 'unknown'}")
print(f" pred(video): {THREE_CLASS_NAMES[pred_vid]} (agg={args.aggregate})")
print(f" clip preds: {clip_str}")
print(" clip probs (feeding, normal, scared):")
for ci, pi, p in paired_sorted:
print(
f" {ci:02d}:{THREE_CLASS_NAMES[pi]} "
f"f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}"
)
processed_clips += probs3.shape[0]
if args.log_interval > 0 and (batch_idx % args.log_interval) == 0:
dt = time.time() - t_last
t_last = time.time()
print(
f"Processed {processed_clips}/{total_clips_expected} clips | "
f"videos done {len(finished_videos)}/{num_videos} | "
f"({dt:.1f}s/{args.log_interval} batches)"
)
# Final per-video predictions.
preds3: List[int] = []
labels3_list: List[int] = []
results = []
for vidx in sorted(video_name_by_idx.keys()):
pred, probs3_list = _video_pred_for(vidx)
true = int(true_label3_by_idx.get(vidx, -1))
preds3.append(pred)
labels3_list.append(true)
clip_preds = clip_pred3_by_vidx.get(vidx, [])
clip_preds_named = [THREE_CLASS_NAMES[i] for i in clip_preds]
clip_indices = clip_idx_by_vidx.get(vidx, [])
clip_probs = clip_probs3_by_vidx.get(vidx, [])
if args.print_clip_preds:
print(f"\nVideo {vidx}: {video_name_by_idx[vidx]}")
print(f" true: {THREE_CLASS_NAMES[true] if true >= 0 else 'unknown'}")
print(f" pred(video): {THREE_CLASS_NAMES[pred]} (agg={args.aggregate})")
print(f" clip preds ({len(clip_preds_named)}): {clip_preds_named}")
paired = list(zip(clip_indices, clip_preds, clip_probs))
paired_sorted = sorted(paired, key=lambda x: x[0])
print(" clip probs (feeding, normal, scared):")
for ci, pi, p in paired_sorted:
print(f" {ci:02d}:{THREE_CLASS_NAMES[pi]} f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}")
results.append(
{
"video_index": vidx,
"video_name": video_name_by_idx[vidx],
"pred_3class": THREE_CLASS_NAMES[pred],
"pred_3class_idx": pred,
"true_3class_idx": true,
"aggregate": args.aggregate,
"probs_3class_mean": {
"feeding": float((prob_sum[vidx] / max(clip_count[vidx], 1))[0].item()),
"normal": float((prob_sum[vidx] / max(clip_count[vidx], 1))[1].item()),
"scared": float((prob_sum[vidx] / max(clip_count[vidx], 1))[2].item()),
},
"probs_3class_used": {
"feeding": probs3_list[0],
"normal": probs3_list[1],
"scared": probs3_list[2],
},
"clip_preds_3class_idx": clip_preds,
"clip_preds_3class": clip_preds_named,
"clip_indices": clip_indices,
"clips": [
{
"clip_index": int(ci),
"pred_3class": THREE_CLASS_NAMES[int(pi)],
"pred_3class_idx": int(pi),
"probs_3class": {"feeding": float(p[0]), "normal": float(p[1]), "scared": float(p[2])},
}
for ci, pi, p in sorted(list(zip(clip_indices, clip_preds, clip_probs)), key=lambda x: x[0])
],
}
)
correct = sum(int(p == t) for p, t in zip(preds3, labels3_list) if t >= 0)
total = sum(int(t >= 0) for t in labels3_list)
acc = correct / max(total, 1)
cm = confusion_matrix_3(preds3, labels3_list)
print(f"Video-level 3-class accuracy: {acc*100:.2f}% ({correct}/{total})")
print("Confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:")
print(cm)
if args.output_json:
out_path = os.path.abspath(os.path.expanduser(args.output_json))
with open(out_path, "w") as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print(f"Wrote per-video predictions to: {out_path}")
if args.output_json_mismatches:
mismatches = [
r
for r in results
if int(r.get("true_3class_idx", -1)) >= 0 and int(r.get("pred_3class_idx", -2)) != int(r.get("true_3class_idx", -1))
]
out_path = os.path.abspath(os.path.expanduser(args.output_json_mismatches))
with open(out_path, "w") as f:
json.dump(mismatches, f, indent=2, ensure_ascii=False)
print(f"Wrote mismatches-only predictions to: {out_path} ({len(mismatches)}/{len(results)})")
if __name__ == "__main__":
main()