Files
FishServer/FishMeasure/generate_video_with_labels.py

275 lines
9.4 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Generate labeled preview video from SVO + weight prediction JSON.
Reads the SVO with YOLO tracking, overlays DGCNN weight/length on each
detection box, and writes ``<svo_name>_preview.mp4`` into ``--save-output``.
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
Replaces any existing preview video so the final published file has labels.
"""
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
import cv2
import numpy as np
try:
import pyzed.sl as sl
ZED_AVAILABLE = True
except ImportError:
ZED_AVAILABLE = False
def _extract_weight_length(weight_json: Path) -> Tuple[Optional[float], Optional[float]]:
"""Return (weight_g, length_mm) from a weight prediction JSON."""
data = json.loads(weight_json.read_text(encoding="utf-8"))
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
w_candidates = [
summary.get("pred_weight_g"),
summary.get("avg_predicted_weight_g"),
data.get("pred_weight_g"),
data.get("avg_predicted_weight_g"),
]
weight_g = None
for c in w_candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
weight_g = v
break
except (TypeError, ValueError):
continue
l_candidates = [
summary.get("avg_length_input_topk"),
summary.get("avg_length_input"),
data.get("avg_length_input"),
]
length_mm = None
for c in l_candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
length_mm = v
break
except (TypeError, ValueError):
continue
return weight_g, length_mm
def _draw_label_on_box(
image: np.ndarray,
box: np.ndarray,
tid: int,
class_name: str,
weight_g: Optional[float],
length_mm: Optional[float],
) -> None:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
w_str = f"{weight_g:.0f}g" if weight_g is not None else "--g"
l_str = f"{length_mm:.0f}mm" if length_mm is not None else "--mm"
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
cv2.rectangle(image, (x1, y1 - th - bl - 5), (x1 + tw, y1), (0, 255, 0), -1)
cv2.putText(image, label, (x1, y1 - bl - 2),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
def _draw_large_summary(
image: np.ndarray,
weight_g: Optional[float],
length_mm: Optional[float],
) -> None:
lines = []
lines.append(f"Weight: {weight_g:.0f}g" if weight_g is not None else "Weight: --g")
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
font = cv2.FONT_HERSHEY_SIMPLEX
scale = 2.75
thickness = 2
pad = 10
h, w = image.shape[:2]
margin = 20
sizes = [cv2.getTextSize(ln, font, scale, thickness) for ln in lines]
max_tw = max(s[0][0] for s in sizes)
total_h = sum(s[0][1] + s[1] + pad for s in sizes)
x0 = w - max_tw - margin
y0 = 30
overlay = image.copy()
cv2.rectangle(overlay, (x0 - pad, y0 - pad),
(w - margin + pad, y0 + total_h + pad), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image)
y = y0
for i, ln in enumerate(lines):
(tw, th), bl = sizes[i]
cv2.putText(image, ln, (x0, y + th), font, scale, (0, 255, 255), thickness, cv2.LINE_AA)
y += th + bl + pad
def generate_video(
svo_path: Path,
output_dir: Path,
weight_json: Path,
yolo_model_path: str,
conf: float = 0.25,
imgsz: int = 640,
frame_stride: int = 1,
show_large: bool = False,
) -> Optional[Path]:
if not ZED_AVAILABLE:
print("ERROR: pyzed not available, cannot generate labeled video")
return None
weight_g, length_mm = _extract_weight_length(weight_json)
print(f" Labeling with weight={weight_g}g, length={length_mm}mm from {weight_json.name}")
if weight_g is None and length_mm is None:
print(" WARNING: No valid weight/length in JSON, video will show '--'")
from ultralytics import YOLO
yolo = YOLO(yolo_model_path)
class_names = yolo.names if hasattr(yolo, "names") else {}
from dataset.zed_reader import ZEDReader
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
if not zed_reader.open():
print(f" ERROR: Failed to open {svo_path.name}")
return None
runtime = sl.RuntimeParameters()
left_mat = sl.Mat()
svo_name = svo_path.stem
output_dir = Path(output_dir)
images_dir = output_dir / "images"
images_dir.mkdir(parents=True, exist_ok=True)
frames = []
idx = 0
try:
while True:
err = zed_reader.zed.grab(runtime)
if err != sl.ERROR_CODE.SUCCESS:
break
zed_reader.zed.retrieve_image(left_mat, sl.VIEW.LEFT)
left_np = left_mat.get_data()
img = left_np[:, :, :3].copy() if left_np.shape[2] > 3 else left_np.copy()
if frame_stride > 1 and (idx % frame_stride) != 0:
idx += 1
continue
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
left_disp = img.copy()
if num_dets > 0:
boxes = results.boxes.xyxy.cpu().numpy()
tids = (results.boxes.id.cpu().numpy().astype(int)
if hasattr(results.boxes, "id") and results.boxes.id is not None
else np.zeros(len(boxes), dtype=int))
cls_ids = (results.boxes.cls.cpu().numpy().astype(int)
if results.boxes.cls is not None
else np.zeros(len(boxes), dtype=int))
for i, box in enumerate(boxes):
tid = int(tids[i]) if i < len(tids) else 0
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
cname = class_names.get(cid, "fish")
_draw_label_on_box(left_disp, box, tid, cname, weight_g, length_mm)
if show_large:
_draw_large_summary(left_disp, weight_g, length_mm)
frame_name = f"frame_{idx + 1:06d}"
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
right_disp = img.copy()
cv2.putText(right_disp, "Original", (10, right_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
combined = np.hstack([left_disp, right_disp])
if num_dets > 0:
frames.append(combined)
if idx % 30 == 0:
print(f" [{idx + 1}] {frame_name} dets={num_dets} frames_collected={len(frames)}")
idx += 1
finally:
zed_reader.close()
if not frames:
print(f" WARNING: No detection frames collected from {svo_name}")
return None
video_path = images_dir / f"{svo_name}_preview.mp4"
h, w = frames[0].shape[:2]
writer = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), 10.0, (w, h))
for f in frames:
writer.write(f)
writer.release()
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, weight={weight_g}g len={length_mm}mm)")
return video_path
def main():
parser = argparse.ArgumentParser(description="Generate labeled preview video from SVO + weight JSON")
parser.add_argument("--svo", required=True, help="Path to .svo2 file")
parser.add_argument("--save-output", required=True, help="Output directory for this SVO")
parser.add_argument("--weight-json", required=True, help="Path to weight_prediction.json or DGCNN output JSON")
parser.add_argument("--yolo-model",
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt")
parser.add_argument("--conf", type=float, default=0.25)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument("--frame-stride", type=int, default=1)
parser.add_argument("--show-large-labels-at-top-right", action="store_true")
args = parser.parse_args()
svo = Path(args.svo).expanduser().resolve()
output_dir = Path(args.save_output).expanduser().resolve()
wjson = Path(args.weight_json).expanduser().resolve()
if not svo.exists():
raise SystemExit(f"SVO not found: {svo}")
if not wjson.exists():
raise SystemExit(f"Weight JSON not found: {wjson}")
generate_video(
svo_path=svo,
output_dir=output_dir,
weight_json=wjson,
yolo_model_path=args.yolo_model,
conf=args.conf,
imgsz=args.imgsz,
frame_stride=args.frame_stride,
show_large=args.show_large_labels_at_top_right,
)
if __name__ == "__main__":
main()