#!/usr/bin/env python3 """Generate labeled preview video from SVO + weight prediction JSON. Per-frame labeling: each detection box shows the weight/length predicted for that specific frame's PLY (from ``per_cloud`` / ``per_file`` in the DGCNN JSON). Frames without a corresponding PLY carry forward the last known value. The final aggregated result is shown at top-right. Called by ``predict_weigth_from_svo2.py`` after DGCNN completes. Replaces any existing preview video so the final published file has labels. """ from __future__ import annotations import argparse import json import math import re import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import cv2 import numpy as np try: import pyzed.sl as sl ZED_AVAILABLE = True except ImportError: ZED_AVAILABLE = False def _open_video_writer(path: Path, fps: float, size: Tuple[int, int]) -> cv2.VideoWriter: """Open a VideoWriter, preferring GStreamer NVENC on Jetson for hardware H.264.""" w, h = size try: if hasattr(cv2, "CAP_GSTREAMER"): loc = str(path).replace('"', '\\"') gst_pipe = ( f'appsrc ! videoconvert ! video/x-raw,format=BGRx ! ' f'nvvidconv ! video/x-raw(memory:NVMM) ! ' f'nvv4l2h264enc bitrate=4000000 ! h264parse ! ' f'mp4mux ! filesink location="{loc}"' ) writer = cv2.VideoWriter(gst_pipe, cv2.CAP_GSTREAMER, 0, fps, (w, h)) if writer.isOpened(): return writer writer.release() except Exception: pass return cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) def _parse_weight_json(weight_json: Path) -> Tuple[ Dict[int, Tuple[float, float]], Optional[float], Optional[float], bool, ]: """Parse weight JSON → per-frame map + summary + confidence. Returns: per_frame: {frame_number: (weight_g, length_mm)} from per_cloud/per_file summary_weight_g: final aggregated weight summary_length_mm: final aggregated length is_confident: True when ``*`` should be shown (avg > 440g OR length band fraction >= 25%) """ data = json.loads(weight_json.read_text(encoding="utf-8")) summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {} def _first_finite(*candidates): for c in candidates: if c is not None: try: v = float(c) if math.isfinite(v): return v except (TypeError, ValueError): pass return None summary_wg = _first_finite( summary.get("pred_weight_g"), summary.get("avg_predicted_weight_g"), data.get("pred_weight_g"), data.get("avg_predicted_weight_g"), ) summary_lmm = _first_finite( summary.get("avg_length_input_topk"), summary.get("avg_length_input"), data.get("avg_length_input"), ) CONFIDENT_AVG_G = 440.0 MIN_FRAC_LARGEST_LENGTH_GROUP = 0.25 mean_g = _first_finite( summary.get("mean_all_pred_g_after_filters"), summary.get("avg_predicted_weight_g"), ) frac = _first_finite(summary.get("fraction_in_near_max_length_band")) is_confident = False if mean_g is not None and mean_g > CONFIDENT_AVG_G: is_confident = True elif frac is not None and frac >= MIN_FRAC_LARGEST_LENGTH_GROUP: is_confident = True per_frame: Dict[int, Tuple[float, float]] = {} for item in data.get("per_cloud") or data.get("per_file") or []: ply = item.get("ply", "") m = re.search(r"frame_(\d+)", Path(str(ply)).stem) if not m: continue fnum = int(m.group(1)) wg = _first_finite(item.get("predicted_weight_g")) lmm = _first_finite(item.get("length_input")) if wg is not None: per_frame[fnum] = (wg, lmm if lmm is not None else float("nan")) return per_frame, summary_wg, summary_lmm, is_confident def _draw_label_on_box( image: np.ndarray, box: np.ndarray, tid: int, class_name: str, weight_g: Optional[float], length_mm: Optional[float], ) -> None: x1, y1, x2, y2 = map(int, box) cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) w_str = f"{weight_g:.0f}g" if weight_g is not None and math.isfinite(weight_g) else "--g" l_str = f"{length_mm:.0f}mm" if length_mm is not None and math.isfinite(length_mm) else "--mm" label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}" (tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1) cv2.rectangle(image, (x1, y1 - th - bl - 5), (x1 + tw, y1), (0, 255, 0), -1) cv2.putText(image, label, (x1, y1 - bl - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA) def _draw_large_summary( image: np.ndarray, weight_g: Optional[float], length_mm: Optional[float], is_confident: bool = False, ) -> None: star = " *" if is_confident else "" lines = [] lines.append(f"Final: {weight_g:.0f}g{star}" if weight_g is not None else f"Final: --g") lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm") font = cv2.FONT_HERSHEY_SIMPLEX scale = 2.0 thickness = 2 pad = 10 h, w = image.shape[:2] margin = 20 sizes = [cv2.getTextSize(ln, font, scale, thickness) for ln in lines] max_tw = max(s[0][0] for s in sizes) total_h = sum(s[0][1] + s[1] + pad for s in sizes) x0 = w - max_tw - margin y0 = 30 overlay = image.copy() cv2.rectangle(overlay, (x0 - pad, y0 - pad), (w - margin + pad, y0 + total_h + pad), (0, 0, 0), -1) cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image) y = y0 for i, ln in enumerate(lines): (tw, th), bl = sizes[i] cv2.putText(image, ln, (x0, y + th), font, scale, (0, 255, 255), thickness, cv2.LINE_AA) y += th + bl + pad def generate_video( svo_path: Path, output_dir: Path, weight_json: Path, yolo_model_path: str, conf: float = 0.25, imgsz: int = 640, frame_stride: int = 1, show_large: bool = False, summary_weight_g: Optional[float] = None, summary_length_mm: Optional[float] = None, summary_star: bool = False, output_video_name: Optional[str] = None, sam_device: str = "cuda", ) -> Optional[Path]: if not ZED_AVAILABLE: print("ERROR: pyzed not available, cannot generate labeled video") return None per_frame, parsed_summary_wg, parsed_summary_lmm, raw_confident = _parse_weight_json(weight_json) if summary_weight_g is None: summary_weight_g = parsed_summary_wg if summary_length_mm is None: summary_length_mm = parsed_summary_lmm star_s = " *" if summary_star else "" print(f" Per-frame predictions: {len(per_frame)} PLYs mapped") print( f" Summary: weight={summary_weight_g}g, length={summary_length_mm}mm{star_s} " f"(raw_confident={raw_confident})" ) if not per_frame and summary_weight_g is None: print(" WARNING: No weight data in JSON, video will show '--'") from ultralytics import YOLO yolo = YOLO(yolo_model_path) class_names = yolo.names if hasattr(yolo, "names") else {} from fish_video_weight_evaluation import ( create_segmentation_overlay, load_sam_predictor_with_fallback, segment_with_sam, ) sam_predictor, eff_sam_device = load_sam_predictor_with_fallback(sam_device) sam_torch_device = eff_sam_device from dataset.zed_reader import ZEDReader zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False) if not zed_reader.open(): print(f" ERROR: Failed to open {svo_path.name}") return None runtime = sl.RuntimeParameters() left_mat = sl.Mat() svo_name = svo_path.stem output_dir = Path(output_dir) images_dir = output_dir / "images" images_dir.mkdir(parents=True, exist_ok=True) frames: List[np.ndarray] = [] idx = 0 last_wg: Optional[float] = None last_lmm: Optional[float] = None try: while True: err = zed_reader.zed.grab(runtime) if err != sl.ERROR_CODE.SUCCESS: break zed_reader.zed.retrieve_image(left_mat, sl.VIEW.LEFT) left_np = left_mat.get_data() img = left_np[:, :, :3].copy() if left_np.shape[2] > 3 else left_np.copy() if frame_stride > 1 and (idx % frame_stride) != 0: idx += 1 continue frame_number = idx + 1 frame_name = f"frame_{frame_number:06d}" if frame_number in per_frame: cur_wg, cur_lmm = per_frame[frame_number] last_wg = cur_wg last_lmm = cur_lmm if math.isfinite(cur_lmm) else last_lmm else: cur_wg = last_wg cur_lmm = last_lmm results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0] num_dets = len(results.boxes) if results.boxes is not None else 0 left_disp = img.copy() right_disp = img.copy() if num_dets > 0: boxes = results.boxes.xyxy.cpu().numpy() tids = (results.boxes.id.cpu().numpy().astype(int) if hasattr(results.boxes, "id") and results.boxes.id is not None else np.zeros(len(boxes), dtype=int)) cls_ids = (results.boxes.cls.cpu().numpy().astype(int) if results.boxes.cls is not None else np.zeros(len(boxes), dtype=int)) for i, box in enumerate(boxes): tid = int(tids[i]) if i < len(tids) else 0 cid = int(cls_ids[i]) if i < len(cls_ids) else 0 cname = class_names.get(cid, "fish") _draw_label_on_box(left_disp, box, tid, cname, cur_wg, cur_lmm) try: masks = segment_with_sam(sam_predictor, img, boxes, sam_torch_device) except Exception as e: print(f" WARNING: SAM segmentation failed on {frame_name}: {e}") masks = [] if masks: right_disp = create_segmentation_overlay(img.copy(), masks) cv2.putText(right_disp, "Segmentation", (10, right_disp.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) else: cv2.putText(right_disp, "Segmentation (failed)", (10, right_disp.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA) else: cv2.putText(right_disp, "No detections", (10, right_disp.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA) if show_large or summary_weight_g is not None: _draw_large_summary(left_disp, summary_weight_g, summary_length_mm, summary_star) info = f"[{frame_number}] {frame_name} | Detections: {num_dets}" cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA) cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) combined = np.hstack([left_disp, right_disp]) if num_dets > 0: frames.append(combined) if idx % 30 == 0: w_s = f"{cur_wg:.0f}g" if cur_wg is not None else "--" l_s = f"{cur_lmm:.0f}mm" if cur_lmm is not None else "--" print(f" [{frame_number}] {frame_name} dets={num_dets} w={w_s} l={l_s} collected={len(frames)}") idx += 1 finally: zed_reader.close() if not frames: print(f" WARNING: No detection frames collected from {svo_name}") return None video_path = images_dir / (output_video_name or f"{svo_name}_preview.mp4") h, w = frames[0].shape[:2] writer = _open_video_writer(video_path, 10.0, (w, h)) for f in frames: writer.write(f) writer.release() print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, {len(per_frame)} PLY labels)") return video_path def main(): parser = argparse.ArgumentParser(description="Generate labeled preview video from SVO + weight JSON") parser.add_argument("--svo", required=True, help="Path to .svo2 file") parser.add_argument("--save-output", required=True, help="Output directory for this SVO") parser.add_argument("--weight-json", required=True, help="Path to weight_prediction.json or DGCNN output JSON") parser.add_argument("--yolo-model", default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt") parser.add_argument("--conf", type=float, default=0.25) parser.add_argument("--imgsz", type=int, default=640) parser.add_argument("--frame-stride", type=int, default=1) parser.add_argument("--sam-device", type=str, default="cuda") parser.add_argument("--show-large-labels-at-top-right", action="store_true") parser.add_argument( "--summary-star", action="store_true", default=False, help="Whether to draw * on the Final summary line; caller/DB is the source of truth.", ) parser.add_argument("--summary-weight-g", type=float, default=None) parser.add_argument("--summary-length-mm", type=float, default=None) parser.add_argument("--output-video-name", type=str, default=None) args = parser.parse_args() svo = Path(args.svo).expanduser().resolve() output_dir = Path(args.save_output).expanduser().resolve() wjson = Path(args.weight_json).expanduser().resolve() if not svo.exists(): raise SystemExit(f"SVO not found: {svo}") if not wjson.exists(): raise SystemExit(f"Weight JSON not found: {wjson}") generate_video( svo_path=svo, output_dir=output_dir, weight_json=wjson, yolo_model_path=args.yolo_model, conf=args.conf, imgsz=args.imgsz, frame_stride=args.frame_stride, sam_device=args.sam_device, show_large=args.show_large_labels_at_top_right, summary_weight_g=args.summary_weight_g, summary_length_mm=args.summary_length_mm, summary_star=bool(args.summary_star), output_video_name=args.output_video_name, ) if __name__ == "__main__": main()