Files
FishServer/FishMeasure/generate_video_with_labels.py

393 lines
14 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""Generate labeled preview video from SVO + weight prediction JSON.
2026-04-14 22:05:52 +08:00
Per-frame labeling: each detection box shows the weight/length predicted
for that specific frame's PLY (from ``per_cloud`` / ``per_file`` in the
DGCNN JSON). Frames without a corresponding PLY carry forward the last
known value. The final aggregated result is shown at top-right.
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
Replaces any existing preview video so the final published file has labels.
"""
from __future__ import annotations
import argparse
import json
import math
2026-04-14 22:05:52 +08:00
import re
import sys
from pathlib import Path
2026-04-14 22:05:52 +08:00
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
try:
import pyzed.sl as sl
ZED_AVAILABLE = True
except ImportError:
ZED_AVAILABLE = False
def _open_video_writer(path: Path, fps: float, size: Tuple[int, int]) -> cv2.VideoWriter:
"""Open a VideoWriter, preferring GStreamer NVENC on Jetson for hardware H.264."""
w, h = size
try:
if hasattr(cv2, "CAP_GSTREAMER"):
loc = str(path).replace('"', '\\"')
gst_pipe = (
f'appsrc ! videoconvert ! video/x-raw,format=BGRx ! '
f'nvvidconv ! video/x-raw(memory:NVMM) ! '
f'nvv4l2h264enc bitrate=4000000 ! h264parse ! '
f'mp4mux ! filesink location="{loc}"'
)
writer = cv2.VideoWriter(gst_pipe, cv2.CAP_GSTREAMER, 0, fps, (w, h))
if writer.isOpened():
return writer
writer.release()
except Exception:
pass
return cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
2026-04-14 22:05:52 +08:00
def _parse_weight_json(weight_json: Path) -> Tuple[
Dict[int, Tuple[float, float]],
Optional[float],
Optional[float],
bool,
]:
"""Parse weight JSON → per-frame map + summary + confidence.
Returns:
per_frame: {frame_number: (weight_g, length_mm)} from per_cloud/per_file
summary_weight_g: final aggregated weight
summary_length_mm: final aggregated length
is_confident: True when ``*`` should be shown (avg > 440g OR length band fraction >= 25%)
"""
data = json.loads(weight_json.read_text(encoding="utf-8"))
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
2026-04-14 22:05:52 +08:00
def _first_finite(*candidates):
for c in candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
return v
except (TypeError, ValueError):
pass
return None
summary_wg = _first_finite(
summary.get("pred_weight_g"),
summary.get("avg_predicted_weight_g"),
data.get("pred_weight_g"),
data.get("avg_predicted_weight_g"),
2026-04-14 22:05:52 +08:00
)
summary_lmm = _first_finite(
summary.get("avg_length_input_topk"),
summary.get("avg_length_input"),
data.get("avg_length_input"),
2026-04-14 22:05:52 +08:00
)
CONFIDENT_AVG_G = 440.0
MIN_FRAC_LARGEST_LENGTH_GROUP = 0.25
2026-04-14 22:05:52 +08:00
mean_g = _first_finite(
summary.get("mean_all_pred_g_after_filters"),
summary.get("avg_predicted_weight_g"),
)
frac = _first_finite(summary.get("fraction_in_near_max_length_band"))
is_confident = False
if mean_g is not None and mean_g > CONFIDENT_AVG_G:
is_confident = True
elif frac is not None and frac >= MIN_FRAC_LARGEST_LENGTH_GROUP:
is_confident = True
per_frame: Dict[int, Tuple[float, float]] = {}
for item in data.get("per_cloud") or data.get("per_file") or []:
ply = item.get("ply", "")
m = re.search(r"frame_(\d+)", Path(str(ply)).stem)
if not m:
continue
fnum = int(m.group(1))
wg = _first_finite(item.get("predicted_weight_g"))
lmm = _first_finite(item.get("length_input"))
if wg is not None:
per_frame[fnum] = (wg, lmm if lmm is not None else float("nan"))
return per_frame, summary_wg, summary_lmm, is_confident
def _draw_label_on_box(
image: np.ndarray,
box: np.ndarray,
tid: int,
class_name: str,
weight_g: Optional[float],
length_mm: Optional[float],
) -> None:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
2026-04-14 22:05:52 +08:00
w_str = f"{weight_g:.0f}g" if weight_g is not None and math.isfinite(weight_g) else "--g"
l_str = f"{length_mm:.0f}mm" if length_mm is not None and math.isfinite(length_mm) else "--mm"
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
cv2.rectangle(image, (x1, y1 - th - bl - 5), (x1 + tw, y1), (0, 255, 0), -1)
cv2.putText(image, label, (x1, y1 - bl - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
def _draw_large_summary(
image: np.ndarray,
weight_g: Optional[float],
length_mm: Optional[float],
2026-04-14 22:05:52 +08:00
is_confident: bool = False,
) -> None:
2026-04-14 22:05:52 +08:00
star = " *" if is_confident else ""
lines = []
2026-04-14 22:05:52 +08:00
lines.append(f"Final: {weight_g:.0f}g{star}" if weight_g is not None else f"Final: --g")
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
font = cv2.FONT_HERSHEY_SIMPLEX
2026-04-14 22:05:52 +08:00
scale = 2.0
thickness = 2
pad = 10
h, w = image.shape[:2]
margin = 20
sizes = [cv2.getTextSize(ln, font, scale, thickness) for ln in lines]
max_tw = max(s[0][0] for s in sizes)
total_h = sum(s[0][1] + s[1] + pad for s in sizes)
x0 = w - max_tw - margin
y0 = 30
overlay = image.copy()
cv2.rectangle(overlay, (x0 - pad, y0 - pad),
(w - margin + pad, y0 + total_h + pad), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image)
y = y0
for i, ln in enumerate(lines):
(tw, th), bl = sizes[i]
cv2.putText(image, ln, (x0, y + th), font, scale, (0, 255, 255), thickness, cv2.LINE_AA)
y += th + bl + pad
def generate_video(
svo_path: Path,
output_dir: Path,
weight_json: Path,
yolo_model_path: str,
conf: float = 0.25,
imgsz: int = 640,
frame_stride: int = 1,
show_large: bool = False,
summary_weight_g: Optional[float] = None,
summary_length_mm: Optional[float] = None,
summary_star: bool = False,
output_video_name: Optional[str] = None,
sam_device: str = "cuda",
) -> Optional[Path]:
if not ZED_AVAILABLE:
print("ERROR: pyzed not available, cannot generate labeled video")
return None
per_frame, parsed_summary_wg, parsed_summary_lmm, raw_confident = _parse_weight_json(weight_json)
if summary_weight_g is None:
summary_weight_g = parsed_summary_wg
if summary_length_mm is None:
summary_length_mm = parsed_summary_lmm
star_s = " *" if summary_star else ""
2026-04-14 22:05:52 +08:00
print(f" Per-frame predictions: {len(per_frame)} PLYs mapped")
print(
f" Summary: weight={summary_weight_g}g, length={summary_length_mm}mm{star_s} "
f"(raw_confident={raw_confident})"
)
if not per_frame and summary_weight_g is None:
2026-04-14 22:05:52 +08:00
print(" WARNING: No weight data in JSON, video will show '--'")
from ultralytics import YOLO
yolo = YOLO(yolo_model_path)
class_names = yolo.names if hasattr(yolo, "names") else {}
from fish_video_weight_evaluation import (
create_segmentation_overlay,
load_sam_predictor_with_fallback,
segment_with_sam,
)
sam_predictor, eff_sam_device = load_sam_predictor_with_fallback(sam_device)
sam_torch_device = eff_sam_device
from dataset.zed_reader import ZEDReader
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
if not zed_reader.open():
print(f" ERROR: Failed to open {svo_path.name}")
return None
runtime = sl.RuntimeParameters()
left_mat = sl.Mat()
svo_name = svo_path.stem
output_dir = Path(output_dir)
images_dir = output_dir / "images"
images_dir.mkdir(parents=True, exist_ok=True)
2026-04-14 22:05:52 +08:00
frames: List[np.ndarray] = []
idx = 0
2026-04-14 22:05:52 +08:00
last_wg: Optional[float] = None
last_lmm: Optional[float] = None
try:
while True:
err = zed_reader.zed.grab(runtime)
if err != sl.ERROR_CODE.SUCCESS:
break
zed_reader.zed.retrieve_image(left_mat, sl.VIEW.LEFT)
left_np = left_mat.get_data()
img = left_np[:, :, :3].copy() if left_np.shape[2] > 3 else left_np.copy()
if frame_stride > 1 and (idx % frame_stride) != 0:
idx += 1
continue
2026-04-14 22:05:52 +08:00
frame_number = idx + 1
frame_name = f"frame_{frame_number:06d}"
2026-04-14 22:05:52 +08:00
if frame_number in per_frame:
cur_wg, cur_lmm = per_frame[frame_number]
last_wg = cur_wg
last_lmm = cur_lmm if math.isfinite(cur_lmm) else last_lmm
else:
cur_wg = last_wg
cur_lmm = last_lmm
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
left_disp = img.copy()
right_disp = img.copy()
if num_dets > 0:
boxes = results.boxes.xyxy.cpu().numpy()
tids = (results.boxes.id.cpu().numpy().astype(int)
if hasattr(results.boxes, "id") and results.boxes.id is not None
else np.zeros(len(boxes), dtype=int))
cls_ids = (results.boxes.cls.cpu().numpy().astype(int)
if results.boxes.cls is not None
else np.zeros(len(boxes), dtype=int))
for i, box in enumerate(boxes):
tid = int(tids[i]) if i < len(tids) else 0
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
cname = class_names.get(cid, "fish")
2026-04-14 22:05:52 +08:00
_draw_label_on_box(left_disp, box, tid, cname, cur_wg, cur_lmm)
try:
masks = segment_with_sam(sam_predictor, img, boxes, sam_torch_device)
except Exception as e:
print(f" WARNING: SAM segmentation failed on {frame_name}: {e}")
masks = []
if masks:
right_disp = create_segmentation_overlay(img.copy(), masks)
cv2.putText(right_disp, "Segmentation", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
else:
cv2.putText(right_disp, "Segmentation (failed)", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
else:
cv2.putText(right_disp, "No detections", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
if show_large or summary_weight_g is not None:
_draw_large_summary(left_disp, summary_weight_g, summary_length_mm, summary_star)
2026-04-14 22:05:52 +08:00
info = f"[{frame_number}] {frame_name} | Detections: {num_dets}"
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
combined = np.hstack([left_disp, right_disp])
if num_dets > 0:
frames.append(combined)
if idx % 30 == 0:
2026-04-14 22:05:52 +08:00
w_s = f"{cur_wg:.0f}g" if cur_wg is not None else "--"
l_s = f"{cur_lmm:.0f}mm" if cur_lmm is not None else "--"
print(f" [{frame_number}] {frame_name} dets={num_dets} w={w_s} l={l_s} collected={len(frames)}")
idx += 1
finally:
zed_reader.close()
if not frames:
print(f" WARNING: No detection frames collected from {svo_name}")
return None
video_path = images_dir / (output_video_name or f"{svo_name}_preview.mp4")
h, w = frames[0].shape[:2]
writer = _open_video_writer(video_path, 10.0, (w, h))
for f in frames:
writer.write(f)
writer.release()
2026-04-14 22:05:52 +08:00
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, {len(per_frame)} PLY labels)")
return video_path
def main():
parser = argparse.ArgumentParser(description="Generate labeled preview video from SVO + weight JSON")
parser.add_argument("--svo", required=True, help="Path to .svo2 file")
parser.add_argument("--save-output", required=True, help="Output directory for this SVO")
parser.add_argument("--weight-json", required=True, help="Path to weight_prediction.json or DGCNN output JSON")
parser.add_argument("--yolo-model",
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt")
parser.add_argument("--conf", type=float, default=0.25)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--frame-stride", type=int, default=1)
parser.add_argument("--sam-device", type=str, default="cuda")
parser.add_argument("--show-large-labels-at-top-right", action="store_true")
parser.add_argument(
"--summary-star",
action="store_true",
default=False,
help="Whether to draw * on the Final summary line; caller/DB is the source of truth.",
)
parser.add_argument("--summary-weight-g", type=float, default=None)
parser.add_argument("--summary-length-mm", type=float, default=None)
parser.add_argument("--output-video-name", type=str, default=None)
args = parser.parse_args()
svo = Path(args.svo).expanduser().resolve()
output_dir = Path(args.save_output).expanduser().resolve()
wjson = Path(args.weight_json).expanduser().resolve()
if not svo.exists():
raise SystemExit(f"SVO not found: {svo}")
if not wjson.exists():
raise SystemExit(f"Weight JSON not found: {wjson}")
generate_video(
svo_path=svo,
output_dir=output_dir,
weight_json=wjson,
yolo_model_path=args.yolo_model,
conf=args.conf,
imgsz=args.imgsz,
frame_stride=args.frame_stride,
sam_device=args.sam_device,
show_large=args.show_large_labels_at_top_right,
summary_weight_g=args.summary_weight_g,
summary_length_mm=args.summary_length_mm,
summary_star=bool(args.summary_star),
output_video_name=args.output_video_name,
)
if __name__ == "__main__":
main()