Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
463
FishAction/predict_video_x3d_3class.py
Executable file
463
FishAction/predict_video_x3d_3class.py
Executable file
@@ -0,0 +1,463 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Video-level prediction by aggregating multiple clip predictions (PyTorchVideo X3D).
|
||||
|
||||
This script supports:
|
||||
- **Native 3-class checkpoints** (no mapping/merging required)
|
||||
- **Legacy 5-class checkpoints** evaluated as 3-class via merging:
|
||||
- feeding -> feeding
|
||||
- normal_underwater + normal_upperwater -> normal
|
||||
- scared_underwater + scared_upperwater -> scared
|
||||
|
||||
Instead of printing per-clip labels, this script:
|
||||
- samples N clips per video
|
||||
- runs inference per clip
|
||||
- aggregates probabilities across clips (mean)
|
||||
- outputs one label per video
|
||||
- prints progress: clips processed, videos finished
|
||||
|
||||
Example:
|
||||
python predict_video_x3d_3class.py \
|
||||
--checkpoint /home/ubuntu/projects/FishAction/checkpoints/ptv_x3d_m/checkpoint_best.pt \
|
||||
--csv /home/ubuntu/data/fish/fish_action_videos/val.csv \
|
||||
--path_prefix /home/ubuntu/data/fish/fish_action_videos \
|
||||
--clips_per_video 10 \
|
||||
--batch_size 8 \
|
||||
--num_workers 4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorchvideo.data import LabeledVideoDataset, make_clip_sampler
|
||||
|
||||
from train_pytorchvideo_x3d import VideoTransform, _parse_path_label_line, build_pretrained_x3d, replace_last_linear
|
||||
|
||||
|
||||
THREE_CLASS_NAMES = ["feeding", "normal", "scared"]
|
||||
|
||||
|
||||
def read_label_map(label_map_path: str) -> Dict[str, int]:
|
||||
out: Dict[str, int] = {}
|
||||
with open(label_map_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
idx_str, name = line.split(maxsplit=1)
|
||||
out[name] = int(idx_str)
|
||||
return out
|
||||
|
||||
|
||||
def build_merge_indices(label_map_dir: str) -> Dict[str, List[int]]:
|
||||
"""
|
||||
5->3 merge map by indices. Uses label_map.txt if present, else defaults:
|
||||
0 feeding
|
||||
1 normal_underwater
|
||||
2 normal_upperwater
|
||||
3 scared_underwater
|
||||
4 scared_upperwater
|
||||
"""
|
||||
label_map_path = os.path.join(label_map_dir, "label_map.txt")
|
||||
if os.path.isfile(label_map_path):
|
||||
lm = read_label_map(label_map_path)
|
||||
# If label_map.txt already defines 3-class names, treat as identity merge.
|
||||
if all(k in lm for k in ["feeding", "normal", "scared"]):
|
||||
# This supports native 3-class datasets/checkpoints.
|
||||
return {"feeding": [lm["feeding"]], "normal": [lm["normal"]], "scared": [lm["scared"]]}
|
||||
return {
|
||||
"feeding": [lm["feeding"]],
|
||||
"normal": [lm["normal_underwater"], lm["normal_upperwater"]],
|
||||
"scared": [lm["scared_underwater"], lm["scared_upperwater"]],
|
||||
}
|
||||
return {"feeding": [0], "normal": [1, 2], "scared": [3, 4]}
|
||||
|
||||
|
||||
def collapse_logits_5_to_3(logits_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor:
|
||||
probs = F.softmax(logits_5, dim=1)
|
||||
p0 = probs[:, merge["feeding"]].sum(dim=1)
|
||||
p1 = probs[:, merge["normal"]].sum(dim=1)
|
||||
p2 = probs[:, merge["scared"]].sum(dim=1)
|
||||
probs3 = torch.stack([p0, p1, p2], dim=1).clamp_min(1e-12)
|
||||
return probs3
|
||||
|
||||
|
||||
def map_label_5_to_3(labels_5: torch.Tensor, merge: Dict[str, List[int]]) -> torch.Tensor:
|
||||
inv: Dict[int, int] = {}
|
||||
for new_idx, name in enumerate(THREE_CLASS_NAMES):
|
||||
for old in merge[name]:
|
||||
inv[int(old)] = new_idx
|
||||
return torch.tensor([inv[int(x)] for x in labels_5.tolist()], device=labels_5.device, dtype=torch.long)
|
||||
|
||||
|
||||
class EvalTransform:
|
||||
"""
|
||||
Wrap VideoTransform but keep metadata needed for video-level aggregation.
|
||||
Output is a dict so DataLoader collates it nicely.
|
||||
"""
|
||||
|
||||
def __init__(self, num_frames: int):
|
||||
self._vt = VideoTransform(num_frames=num_frames, train=False)
|
||||
|
||||
def __call__(self, sample: Dict):
|
||||
video, label = self._vt(sample) # (C,T,H,W), int tensor
|
||||
return {
|
||||
"video": video,
|
||||
"label": label,
|
||||
"video_name": sample.get("video_name", ""),
|
||||
"video_index": int(sample.get("video_index", -1)),
|
||||
"clip_index": int(sample.get("clip_index", -1)),
|
||||
}
|
||||
|
||||
|
||||
def confusion_matrix_3(preds: List[int], labels: List[int]) -> List[List[int]]:
|
||||
cm = [[0, 0, 0] for _ in range(3)]
|
||||
for t, p in zip(labels, preds):
|
||||
cm[t][p] += 1
|
||||
return cm
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint", type=str, required=True)
|
||||
parser.add_argument("--csv", type=str, required=True, help="CSV split file (e.g. val.csv)")
|
||||
parser.add_argument("--path_prefix", type=str, required=True)
|
||||
parser.add_argument("--clips_per_video", type=int, default=10)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--num_workers", type=int, default=4)
|
||||
parser.add_argument("--log_interval", type=int, default=20, help="Log every N batches. 0 disables.")
|
||||
parser.add_argument(
|
||||
"--aggregate",
|
||||
type=str,
|
||||
default="vote",
|
||||
choices=["mean", "max", "vote", "any_scared"],
|
||||
help=(
|
||||
"Video-level aggregation strategy over clip predictions: "
|
||||
"`mean`=mean probability (default), "
|
||||
"`max`=max probability pooling per class, "
|
||||
"`vote`=majority vote over per-clip argmax (ties broken by mean prob), "
|
||||
"`any_scared`=predict scared if ANY clip is scared (argmax==scared OR max scared prob >= threshold)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scared_threshold",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=(
|
||||
"Used only for --aggregate any_scared. If >0, predict scared when the maximum "
|
||||
"per-clip scared probability >= this threshold. If 0, only uses argmax==scared."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--any_scared_fallback",
|
||||
type=str,
|
||||
default="mean",
|
||||
choices=["mean", "max"],
|
||||
help="Used only for --aggregate any_scared when scared is not triggered.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_clip_preds",
|
||||
action="store_true",
|
||||
help="Print per-video per-clip predicted classes to stdout (can be verbose).",
|
||||
)
|
||||
parser.add_argument("--output_json", type=str, default="", help="Optional path to write per-video predictions JSON.")
|
||||
parser.add_argument(
|
||||
"--output_json_mismatches",
|
||||
type=str,
|
||||
default="wrong.json",
|
||||
help="Optional path to write a JSON containing only videos where pred != true (and true label is known).",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default="", help="cuda|cpu (auto if empty)")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = args.device.strip() or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
csv_path = os.path.abspath(os.path.expanduser(args.csv))
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
path_prefix = os.path.abspath(os.path.expanduser(args.path_prefix))
|
||||
|
||||
if not os.path.isfile(csv_path):
|
||||
raise FileNotFoundError(f"CSV not found: {csv_path}")
|
||||
|
||||
# Load checkpoint safely (this is your own file).
|
||||
ckpt_path = os.path.abspath(os.path.expanduser(args.checkpoint))
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="You are using `torch.load` with `weights_only=False`.*")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
|
||||
cfg = ckpt.get("cfg", {})
|
||||
model_name = cfg.get("model", "x3d_m")
|
||||
num_frames = int(cfg.get("num_frames", 16))
|
||||
sampling_rate = int(cfg.get("sampling_rate", 5))
|
||||
target_fps = int(cfg.get("target_fps", 30))
|
||||
num_classes = int(cfg.get("num_classes", 5))
|
||||
clip_duration = num_frames * sampling_rate / float(target_fps)
|
||||
|
||||
merge = build_merge_indices(csv_dir) if num_classes != 3 else {"feeding": [0], "normal": [1], "scared": [2]}
|
||||
|
||||
# Read CSV -> list[(path, {"label": int})]
|
||||
labeled_paths: List[Tuple[str, Dict]] = []
|
||||
with open(csv_path, "r") as f:
|
||||
for line in f:
|
||||
parsed = _parse_path_label_line(line)
|
||||
if parsed is None:
|
||||
continue
|
||||
rel_path, label = parsed
|
||||
labeled_paths.append((os.path.join(path_prefix, rel_path), {"label": int(label)}))
|
||||
|
||||
num_videos = len(labeled_paths)
|
||||
total_clips_expected = num_videos * args.clips_per_video
|
||||
print(
|
||||
f"Videos: {num_videos} | clips_per_video: {args.clips_per_video} "
|
||||
f"(expected clips: {total_clips_expected})"
|
||||
)
|
||||
|
||||
ds = LabeledVideoDataset(
|
||||
labeled_video_paths=labeled_paths,
|
||||
clip_sampler=make_clip_sampler("constant_clips_per_video", clip_duration, args.clips_per_video),
|
||||
decode_audio=False,
|
||||
decoder="pyav",
|
||||
transform=EvalTransform(num_frames=num_frames),
|
||||
)
|
||||
loader = DataLoader(
|
||||
ds,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=(device == "cuda"),
|
||||
persistent_workers=bool(args.num_workers > 0),
|
||||
prefetch_factor=2 if args.num_workers > 0 else None,
|
||||
)
|
||||
|
||||
# Build model architecture + load weights.
|
||||
model = build_pretrained_x3d(model_name=model_name, pretrained=False)
|
||||
replace_last_linear(model, num_classes=num_classes)
|
||||
model.load_state_dict(ckpt["model"], strict=True)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Aggregate per video_index.
|
||||
prob_sum = defaultdict(lambda: torch.zeros(3, dtype=torch.float32))
|
||||
prob_max = defaultdict(lambda: torch.zeros(3, dtype=torch.float32))
|
||||
clip_count = defaultdict(int)
|
||||
video_name_by_idx: Dict[int, str] = {}
|
||||
true_label3_by_idx: Dict[int, int] = {}
|
||||
clip_pred3_by_vidx: Dict[int, List[int]] = defaultdict(list)
|
||||
clip_idx_by_vidx: Dict[int, List[int]] = defaultdict(list)
|
||||
clip_probs3_by_vidx: Dict[int, List[List[float]]] = defaultdict(list)
|
||||
printed_videos = set()
|
||||
|
||||
processed_clips = 0
|
||||
finished_videos = set()
|
||||
t_last = time.time()
|
||||
|
||||
def _video_pred_for(vidx: int) -> Tuple[int, List[float]]:
|
||||
"""
|
||||
Returns (pred_idx_3class, probs_3class_list)
|
||||
"""
|
||||
mean_probs = (prob_sum[vidx] / max(clip_count[vidx], 1))
|
||||
if args.aggregate == "mean":
|
||||
probs = mean_probs
|
||||
pred = int(torch.argmax(probs).item())
|
||||
return pred, probs.numpy().tolist()
|
||||
if args.aggregate == "max":
|
||||
probs = prob_max[vidx]
|
||||
pred = int(torch.argmax(probs).item())
|
||||
return pred, probs.numpy().tolist()
|
||||
if args.aggregate == "any_scared":
|
||||
scared_idx = 2
|
||||
max_scared_prob = float(prob_max[vidx][scared_idx].item())
|
||||
any_argmax_scared = any(p == scared_idx for p in clip_pred3_by_vidx.get(vidx, []))
|
||||
if any_argmax_scared or (args.scared_threshold > 0.0 and max_scared_prob >= float(args.scared_threshold)):
|
||||
# Return max pooled probs for interpretability when we trigger on "any" behavior.
|
||||
probs = prob_max[vidx]
|
||||
return scared_idx, probs.numpy().tolist()
|
||||
# Otherwise fall back to a smoother aggregate.
|
||||
if args.any_scared_fallback == "max":
|
||||
probs = prob_max[vidx]
|
||||
pred = int(torch.argmax(probs).item())
|
||||
return pred, probs.numpy().tolist()
|
||||
probs = mean_probs
|
||||
pred = int(torch.argmax(probs).item())
|
||||
return pred, probs.numpy().tolist()
|
||||
# majority vote; tie-break by mean prob
|
||||
counts = [0, 0, 0]
|
||||
for p in clip_pred3_by_vidx.get(vidx, []):
|
||||
if 0 <= p <= 2:
|
||||
counts[p] += 1
|
||||
max_count = max(counts)
|
||||
winners = [i for i, c in enumerate(counts) if c == max_count]
|
||||
if len(winners) == 1:
|
||||
pred = winners[0]
|
||||
else:
|
||||
# tie-break using mean probability
|
||||
pred = int(torch.argmax(mean_probs[winners]).item())
|
||||
pred = winners[pred]
|
||||
# Return mean probs for reporting consistency
|
||||
return pred, mean_probs.numpy().tolist()
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(loader, start=1):
|
||||
videos = batch["video"].to(device, non_blocking=True)
|
||||
labels_raw = batch["label"].to(device, non_blocking=True)
|
||||
vinds = batch["video_index"]
|
||||
vnames = batch["video_name"]
|
||||
|
||||
logits = model(videos)
|
||||
if num_classes == 3:
|
||||
probs3 = F.softmax(logits, dim=1) # (B,3)
|
||||
labels3 = labels_raw # already 0..2
|
||||
else:
|
||||
# Legacy: 5-class -> merge to 3-class.
|
||||
probs3 = collapse_logits_5_to_3(logits, merge) # (B,3) probs
|
||||
labels3 = map_label_5_to_3(labels_raw, merge) # (B,)
|
||||
|
||||
for i in range(probs3.shape[0]):
|
||||
vidx = int(vinds[i])
|
||||
video_name_by_idx[vidx] = str(vnames[i])
|
||||
true_label3_by_idx[vidx] = int(labels3[i].item())
|
||||
p3 = probs3[i].detach().cpu()
|
||||
prob_sum[vidx] += p3
|
||||
prob_max[vidx] = torch.maximum(prob_max[vidx], p3)
|
||||
clip_count[vidx] += 1
|
||||
pred3 = int(torch.argmax(p3).item())
|
||||
clip_pred3_by_vidx[vidx].append(pred3)
|
||||
# clip_index might not be strictly sequential depending on sampler, but is useful for debugging.
|
||||
clip_idx = int(batch.get("clip_index")[i])
|
||||
clip_idx_by_vidx[vidx].append(clip_idx)
|
||||
clip_probs3_by_vidx[vidx].append([float(p3[0].item()), float(p3[1].item()), float(p3[2].item())])
|
||||
if clip_count[vidx] >= args.clips_per_video:
|
||||
finished_videos.add(vidx)
|
||||
# Print as soon as this video completes, while the run is still progressing.
|
||||
if args.print_clip_preds and vidx not in printed_videos:
|
||||
printed_videos.add(vidx)
|
||||
pred_vid, _ = _video_pred_for(vidx)
|
||||
true_vid = int(true_label3_by_idx.get(vidx, -1))
|
||||
clip_preds = clip_pred3_by_vidx.get(vidx, [])
|
||||
clip_inds = clip_idx_by_vidx.get(vidx, [])
|
||||
clip_probs = clip_probs3_by_vidx.get(vidx, [])
|
||||
# Pair up clip_index with prediction for readability.
|
||||
paired = list(zip(clip_inds, clip_preds, clip_probs))
|
||||
paired_sorted = sorted(paired, key=lambda x: x[0])
|
||||
clip_str = ", ".join([f"{ci}:{THREE_CLASS_NAMES[pi]}" for ci, pi, _ in paired_sorted])
|
||||
print(f"\n[VIDEO DONE {len(printed_videos)}/{num_videos}] {video_name_by_idx.get(vidx,'')}")
|
||||
print(f" true: {THREE_CLASS_NAMES[true_vid] if true_vid >= 0 else 'unknown'}")
|
||||
print(f" pred(video): {THREE_CLASS_NAMES[pred_vid]} (agg={args.aggregate})")
|
||||
print(f" clip preds: {clip_str}")
|
||||
print(" clip probs (feeding, normal, scared):")
|
||||
for ci, pi, p in paired_sorted:
|
||||
print(
|
||||
f" {ci:02d}:{THREE_CLASS_NAMES[pi]} "
|
||||
f"f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}"
|
||||
)
|
||||
|
||||
processed_clips += probs3.shape[0]
|
||||
|
||||
if args.log_interval > 0 and (batch_idx % args.log_interval) == 0:
|
||||
dt = time.time() - t_last
|
||||
t_last = time.time()
|
||||
print(
|
||||
f"Processed {processed_clips}/{total_clips_expected} clips | "
|
||||
f"videos done {len(finished_videos)}/{num_videos} | "
|
||||
f"({dt:.1f}s/{args.log_interval} batches)"
|
||||
)
|
||||
|
||||
# Final per-video predictions.
|
||||
preds3: List[int] = []
|
||||
labels3_list: List[int] = []
|
||||
results = []
|
||||
for vidx in sorted(video_name_by_idx.keys()):
|
||||
pred, probs3_list = _video_pred_for(vidx)
|
||||
true = int(true_label3_by_idx.get(vidx, -1))
|
||||
preds3.append(pred)
|
||||
labels3_list.append(true)
|
||||
|
||||
clip_preds = clip_pred3_by_vidx.get(vidx, [])
|
||||
clip_preds_named = [THREE_CLASS_NAMES[i] for i in clip_preds]
|
||||
clip_indices = clip_idx_by_vidx.get(vidx, [])
|
||||
clip_probs = clip_probs3_by_vidx.get(vidx, [])
|
||||
|
||||
if args.print_clip_preds:
|
||||
print(f"\nVideo {vidx}: {video_name_by_idx[vidx]}")
|
||||
print(f" true: {THREE_CLASS_NAMES[true] if true >= 0 else 'unknown'}")
|
||||
print(f" pred(video): {THREE_CLASS_NAMES[pred]} (agg={args.aggregate})")
|
||||
print(f" clip preds ({len(clip_preds_named)}): {clip_preds_named}")
|
||||
paired = list(zip(clip_indices, clip_preds, clip_probs))
|
||||
paired_sorted = sorted(paired, key=lambda x: x[0])
|
||||
print(" clip probs (feeding, normal, scared):")
|
||||
for ci, pi, p in paired_sorted:
|
||||
print(f" {ci:02d}:{THREE_CLASS_NAMES[pi]} f={p[0]:.3f} n={p[1]:.3f} s={p[2]:.3f}")
|
||||
|
||||
results.append(
|
||||
{
|
||||
"video_index": vidx,
|
||||
"video_name": video_name_by_idx[vidx],
|
||||
"pred_3class": THREE_CLASS_NAMES[pred],
|
||||
"pred_3class_idx": pred,
|
||||
"true_3class_idx": true,
|
||||
"aggregate": args.aggregate,
|
||||
"probs_3class_mean": {
|
||||
"feeding": float((prob_sum[vidx] / max(clip_count[vidx], 1))[0].item()),
|
||||
"normal": float((prob_sum[vidx] / max(clip_count[vidx], 1))[1].item()),
|
||||
"scared": float((prob_sum[vidx] / max(clip_count[vidx], 1))[2].item()),
|
||||
},
|
||||
"probs_3class_used": {
|
||||
"feeding": probs3_list[0],
|
||||
"normal": probs3_list[1],
|
||||
"scared": probs3_list[2],
|
||||
},
|
||||
"clip_preds_3class_idx": clip_preds,
|
||||
"clip_preds_3class": clip_preds_named,
|
||||
"clip_indices": clip_indices,
|
||||
"clips": [
|
||||
{
|
||||
"clip_index": int(ci),
|
||||
"pred_3class": THREE_CLASS_NAMES[int(pi)],
|
||||
"pred_3class_idx": int(pi),
|
||||
"probs_3class": {"feeding": float(p[0]), "normal": float(p[1]), "scared": float(p[2])},
|
||||
}
|
||||
for ci, pi, p in sorted(list(zip(clip_indices, clip_preds, clip_probs)), key=lambda x: x[0])
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
correct = sum(int(p == t) for p, t in zip(preds3, labels3_list) if t >= 0)
|
||||
total = sum(int(t >= 0) for t in labels3_list)
|
||||
acc = correct / max(total, 1)
|
||||
cm = confusion_matrix_3(preds3, labels3_list)
|
||||
|
||||
print(f"Video-level 3-class accuracy: {acc*100:.2f}% ({correct}/{total})")
|
||||
print("Confusion matrix (rows=true, cols=pred) for [feeding, normal, scared]:")
|
||||
print(cm)
|
||||
|
||||
if args.output_json:
|
||||
out_path = os.path.abspath(os.path.expanduser(args.output_json))
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(results, f, indent=2, ensure_ascii=False)
|
||||
print(f"Wrote per-video predictions to: {out_path}")
|
||||
|
||||
if args.output_json_mismatches:
|
||||
mismatches = [
|
||||
r
|
||||
for r in results
|
||||
if int(r.get("true_3class_idx", -1)) >= 0 and int(r.get("pred_3class_idx", -2)) != int(r.get("true_3class_idx", -1))
|
||||
]
|
||||
out_path = os.path.abspath(os.path.expanduser(args.output_json_mismatches))
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(mismatches, f, indent=2, ensure_ascii=False)
|
||||
print(f"Wrote mismatches-only predictions to: {out_path} ({len(mismatches)}/{len(results)})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user