Files
FishServer/FishMeasure/generate_video_with_labels.py

372 lines
14 KiB
Python

#!/usr/bin/env python3
"""Generate labeled preview video from SVO + weight prediction JSON.
Per-frame labeling: each detection box shows the weight/length predicted
for that specific frame's PLY (from ``per_cloud`` / ``per_file`` in the
DGCNN JSON). Frames without a corresponding PLY carry forward the last
known value. The final aggregated result is shown at top-right.
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
Replaces any existing preview video so the final published file has labels.
"""
from __future__ import annotations
import argparse
import json
import math
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
try:
import pyzed.sl as sl
ZED_AVAILABLE = True
except ImportError:
ZED_AVAILABLE = False
def _parse_weight_json(weight_json: Path) -> Tuple[
Dict[int, Tuple[float, float]],
Optional[float],
Optional[float],
bool,
]:
"""Parse weight JSON → per-frame map + summary + confidence.
Returns:
per_frame: {frame_number: (weight_g, length_mm)} from per_cloud/per_file
summary_weight_g: final aggregated weight
summary_length_mm: final aggregated length
is_confident: True when ``*`` should be shown (avg > 440g OR length band fraction >= 25%)
"""
data = json.loads(weight_json.read_text(encoding="utf-8"))
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
def _first_finite(*candidates):
for c in candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
return v
except (TypeError, ValueError):
pass
return None
summary_wg = _first_finite(
summary.get("pred_weight_g"),
summary.get("avg_predicted_weight_g"),
data.get("pred_weight_g"),
data.get("avg_predicted_weight_g"),
)
summary_lmm = _first_finite(
summary.get("avg_length_input_topk"),
summary.get("avg_length_input"),
data.get("avg_length_input"),
)
CONFIDENT_AVG_G = 440.0
MIN_FRAC_LARGEST_LENGTH_GROUP = 0.25
mean_g = _first_finite(
summary.get("mean_all_pred_g_after_filters"),
summary.get("avg_predicted_weight_g"),
)
frac = _first_finite(summary.get("fraction_in_near_max_length_band"))
is_confident = False
if mean_g is not None and mean_g > CONFIDENT_AVG_G:
is_confident = True
elif frac is not None and frac >= MIN_FRAC_LARGEST_LENGTH_GROUP:
is_confident = True
per_frame: Dict[int, Tuple[float, float]] = {}
for item in data.get("per_cloud") or data.get("per_file") or []:
ply = item.get("ply", "")
m = re.search(r"frame_(\d+)", Path(str(ply)).stem)
if not m:
continue
fnum = int(m.group(1))
wg = _first_finite(item.get("predicted_weight_g"))
lmm = _first_finite(item.get("length_input"))
if wg is not None:
per_frame[fnum] = (wg, lmm if lmm is not None else float("nan"))
return per_frame, summary_wg, summary_lmm, is_confident
def _draw_label_on_box(
image: np.ndarray,
box: np.ndarray,
tid: int,
class_name: str,
weight_g: Optional[float],
length_mm: Optional[float],
) -> None:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
w_str = f"{weight_g:.0f}g" if weight_g is not None and math.isfinite(weight_g) else "--g"
l_str = f"{length_mm:.0f}mm" if length_mm is not None and math.isfinite(length_mm) else "--mm"
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
cv2.rectangle(image, (x1, y1 - th - bl - 5), (x1 + tw, y1), (0, 255, 0), -1)
cv2.putText(image, label, (x1, y1 - bl - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
def _draw_large_summary(
image: np.ndarray,
weight_g: Optional[float],
length_mm: Optional[float],
is_confident: bool = False,
) -> None:
star = " *" if is_confident else ""
lines = []
lines.append(f"Final: {weight_g:.0f}g{star}" if weight_g is not None else f"Final: --g")
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
font = cv2.FONT_HERSHEY_SIMPLEX
scale = 2.0
thickness = 2
pad = 10
h, w = image.shape[:2]
margin = 20
sizes = [cv2.getTextSize(ln, font, scale, thickness) for ln in lines]
max_tw = max(s[0][0] for s in sizes)
total_h = sum(s[0][1] + s[1] + pad for s in sizes)
x0 = w - max_tw - margin
y0 = 30
overlay = image.copy()
cv2.rectangle(overlay, (x0 - pad, y0 - pad),
(w - margin + pad, y0 + total_h + pad), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image)
y = y0
for i, ln in enumerate(lines):
(tw, th), bl = sizes[i]
cv2.putText(image, ln, (x0, y + th), font, scale, (0, 255, 255), thickness, cv2.LINE_AA)
y += th + bl + pad
def generate_video(
svo_path: Path,
output_dir: Path,
weight_json: Path,
yolo_model_path: str,
conf: float = 0.25,
imgsz: int = 640,
frame_stride: int = 1,
show_large: bool = False,
summary_weight_g: Optional[float] = None,
summary_length_mm: Optional[float] = None,
summary_star: bool = False,
output_video_name: Optional[str] = None,
sam_device: str = "cuda",
) -> Optional[Path]:
if not ZED_AVAILABLE:
print("ERROR: pyzed not available, cannot generate labeled video")
return None
per_frame, parsed_summary_wg, parsed_summary_lmm, raw_confident = _parse_weight_json(weight_json)
if summary_weight_g is None:
summary_weight_g = parsed_summary_wg
if summary_length_mm is None:
summary_length_mm = parsed_summary_lmm
star_s = " *" if summary_star else ""
print(f" Per-frame predictions: {len(per_frame)} PLYs mapped")
print(
f" Summary: weight={summary_weight_g}g, length={summary_length_mm}mm{star_s} "
f"(raw_confident={raw_confident})"
)
if not per_frame and summary_weight_g is None:
print(" WARNING: No weight data in JSON, video will show '--'")
from ultralytics import YOLO
yolo = YOLO(yolo_model_path)
class_names = yolo.names if hasattr(yolo, "names") else {}
from fish_video_weight_evaluation import (
create_segmentation_overlay,
load_sam_predictor_with_fallback,
segment_with_sam,
)
sam_predictor, eff_sam_device = load_sam_predictor_with_fallback(sam_device)
sam_torch_device = eff_sam_device
from dataset.zed_reader import ZEDReader
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
if not zed_reader.open():
print(f" ERROR: Failed to open {svo_path.name}")
return None
runtime = sl.RuntimeParameters()
left_mat = sl.Mat()
svo_name = svo_path.stem
output_dir = Path(output_dir)
images_dir = output_dir / "images"
images_dir.mkdir(parents=True, exist_ok=True)
frames: List[np.ndarray] = []
idx = 0
last_wg: Optional[float] = None
last_lmm: Optional[float] = None
try:
while True:
err = zed_reader.zed.grab(runtime)
if err != sl.ERROR_CODE.SUCCESS:
break
zed_reader.zed.retrieve_image(left_mat, sl.VIEW.LEFT)
left_np = left_mat.get_data()
img = left_np[:, :, :3].copy() if left_np.shape[2] > 3 else left_np.copy()
if frame_stride > 1 and (idx % frame_stride) != 0:
idx += 1
continue
frame_number = idx + 1
frame_name = f"frame_{frame_number:06d}"
if frame_number in per_frame:
cur_wg, cur_lmm = per_frame[frame_number]
last_wg = cur_wg
last_lmm = cur_lmm if math.isfinite(cur_lmm) else last_lmm
else:
cur_wg = last_wg
cur_lmm = last_lmm
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
left_disp = img.copy()
right_disp = img.copy()
if num_dets > 0:
boxes = results.boxes.xyxy.cpu().numpy()
tids = (results.boxes.id.cpu().numpy().astype(int)
if hasattr(results.boxes, "id") and results.boxes.id is not None
else np.zeros(len(boxes), dtype=int))
cls_ids = (results.boxes.cls.cpu().numpy().astype(int)
if results.boxes.cls is not None
else np.zeros(len(boxes), dtype=int))
for i, box in enumerate(boxes):
tid = int(tids[i]) if i < len(tids) else 0
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
cname = class_names.get(cid, "fish")
_draw_label_on_box(left_disp, box, tid, cname, cur_wg, cur_lmm)
try:
masks = segment_with_sam(sam_predictor, img, boxes, sam_torch_device)
except Exception as e:
print(f" WARNING: SAM segmentation failed on {frame_name}: {e}")
masks = []
if masks:
right_disp = create_segmentation_overlay(img.copy(), masks)
cv2.putText(right_disp, "Segmentation", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
else:
cv2.putText(right_disp, "Segmentation (failed)", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
else:
cv2.putText(right_disp, "No detections", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
if show_large or summary_weight_g is not None:
_draw_large_summary(left_disp, summary_weight_g, summary_length_mm, summary_star)
info = f"[{frame_number}] {frame_name} | Detections: {num_dets}"
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
combined = np.hstack([left_disp, right_disp])
if num_dets > 0:
frames.append(combined)
if idx % 30 == 0:
w_s = f"{cur_wg:.0f}g" if cur_wg is not None else "--"
l_s = f"{cur_lmm:.0f}mm" if cur_lmm is not None else "--"
print(f" [{frame_number}] {frame_name} dets={num_dets} w={w_s} l={l_s} collected={len(frames)}")
idx += 1
finally:
zed_reader.close()
if not frames:
print(f" WARNING: No detection frames collected from {svo_name}")
return None
video_path = images_dir / (output_video_name or f"{svo_name}_preview.mp4")
h, w = frames[0].shape[:2]
writer = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), 10.0, (w, h))
for f in frames:
writer.write(f)
writer.release()
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, {len(per_frame)} PLY labels)")
return video_path
def main():
parser = argparse.ArgumentParser(description="Generate labeled preview video from SVO + weight JSON")
parser.add_argument("--svo", required=True, help="Path to .svo2 file")
parser.add_argument("--save-output", required=True, help="Output directory for this SVO")
parser.add_argument("--weight-json", required=True, help="Path to weight_prediction.json or DGCNN output JSON")
parser.add_argument("--yolo-model",
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt")
parser.add_argument("--conf", type=float, default=0.25)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--frame-stride", type=int, default=1)
parser.add_argument("--sam-device", type=str, default="cuda")
parser.add_argument("--show-large-labels-at-top-right", action="store_true")
parser.add_argument(
"--summary-star",
action=argparse.BooleanOptionalAction,
default=False,
help="Whether to draw * on the Final summary line; caller/DB is the source of truth.",
)
parser.add_argument("--summary-weight-g", type=float, default=None)
parser.add_argument("--summary-length-mm", type=float, default=None)
parser.add_argument("--output-video-name", type=str, default=None)
args = parser.parse_args()
svo = Path(args.svo).expanduser().resolve()
output_dir = Path(args.save_output).expanduser().resolve()
wjson = Path(args.weight_json).expanduser().resolve()
if not svo.exists():
raise SystemExit(f"SVO not found: {svo}")
if not wjson.exists():
raise SystemExit(f"Weight JSON not found: {wjson}")
generate_video(
svo_path=svo,
output_dir=output_dir,
weight_json=wjson,
yolo_model_path=args.yolo_model,
conf=args.conf,
imgsz=args.imgsz,
frame_stride=args.frame_stride,
sam_device=args.sam_device,
show_large=args.show_large_labels_at_top_right,
summary_weight_g=args.summary_weight_g,
summary_length_mm=args.summary_length_mm,
summary_star=bool(args.summary_star),
output_video_name=args.output_video_name,
)
if __name__ == "__main__":
main()