275 lines
9.4 KiB
Python
275 lines
9.4 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Generate labeled preview video from SVO + weight prediction JSON.
|
||
|
|
|
||
|
|
Reads the SVO with YOLO tracking, overlays DGCNN weight/length on each
|
||
|
|
detection box, and writes ``<svo_name>_preview.mp4`` into ``--save-output``.
|
||
|
|
|
||
|
|
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
|
||
|
|
Replaces any existing preview video so the final published file has labels.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Dict, Optional, Tuple
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
try:
|
||
|
|
import pyzed.sl as sl
|
||
|
|
ZED_AVAILABLE = True
|
||
|
|
except ImportError:
|
||
|
|
ZED_AVAILABLE = False
|
||
|
|
|
||
|
|
|
||
|
|
def _extract_weight_length(weight_json: Path) -> Tuple[Optional[float], Optional[float]]:
|
||
|
|
"""Return (weight_g, length_mm) from a weight prediction JSON."""
|
||
|
|
data = json.loads(weight_json.read_text(encoding="utf-8"))
|
||
|
|
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
|
||
|
|
|
||
|
|
w_candidates = [
|
||
|
|
summary.get("pred_weight_g"),
|
||
|
|
summary.get("avg_predicted_weight_g"),
|
||
|
|
data.get("pred_weight_g"),
|
||
|
|
data.get("avg_predicted_weight_g"),
|
||
|
|
]
|
||
|
|
weight_g = None
|
||
|
|
for c in w_candidates:
|
||
|
|
if c is not None:
|
||
|
|
try:
|
||
|
|
v = float(c)
|
||
|
|
if math.isfinite(v):
|
||
|
|
weight_g = v
|
||
|
|
break
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
continue
|
||
|
|
|
||
|
|
l_candidates = [
|
||
|
|
summary.get("avg_length_input_topk"),
|
||
|
|
summary.get("avg_length_input"),
|
||
|
|
data.get("avg_length_input"),
|
||
|
|
]
|
||
|
|
length_mm = None
|
||
|
|
for c in l_candidates:
|
||
|
|
if c is not None:
|
||
|
|
try:
|
||
|
|
v = float(c)
|
||
|
|
if math.isfinite(v):
|
||
|
|
length_mm = v
|
||
|
|
break
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
continue
|
||
|
|
|
||
|
|
return weight_g, length_mm
|
||
|
|
|
||
|
|
|
||
|
|
def _draw_label_on_box(
|
||
|
|
image: np.ndarray,
|
||
|
|
box: np.ndarray,
|
||
|
|
tid: int,
|
||
|
|
class_name: str,
|
||
|
|
weight_g: Optional[float],
|
||
|
|
length_mm: Optional[float],
|
||
|
|
) -> None:
|
||
|
|
x1, y1, x2, y2 = map(int, box)
|
||
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||
|
|
|
||
|
|
w_str = f"{weight_g:.0f}g" if weight_g is not None else "--g"
|
||
|
|
l_str = f"{length_mm:.0f}mm" if length_mm is not None else "--mm"
|
||
|
|
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
|
||
|
|
|
||
|
|
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
|
||
|
|
cv2.rectangle(image, (x1, y1 - th - bl - 5), (x1 + tw, y1), (0, 255, 0), -1)
|
||
|
|
cv2.putText(image, label, (x1, y1 - bl - 2),
|
||
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
|
||
|
|
|
||
|
|
|
||
|
|
def _draw_large_summary(
|
||
|
|
image: np.ndarray,
|
||
|
|
weight_g: Optional[float],
|
||
|
|
length_mm: Optional[float],
|
||
|
|
) -> None:
|
||
|
|
lines = []
|
||
|
|
lines.append(f"Weight: {weight_g:.0f}g" if weight_g is not None else "Weight: --g")
|
||
|
|
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
|
||
|
|
|
||
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||
|
|
scale = 2.75
|
||
|
|
thickness = 2
|
||
|
|
pad = 10
|
||
|
|
h, w = image.shape[:2]
|
||
|
|
margin = 20
|
||
|
|
|
||
|
|
sizes = [cv2.getTextSize(ln, font, scale, thickness) for ln in lines]
|
||
|
|
max_tw = max(s[0][0] for s in sizes)
|
||
|
|
total_h = sum(s[0][1] + s[1] + pad for s in sizes)
|
||
|
|
|
||
|
|
x0 = w - max_tw - margin
|
||
|
|
y0 = 30
|
||
|
|
|
||
|
|
overlay = image.copy()
|
||
|
|
cv2.rectangle(overlay, (x0 - pad, y0 - pad),
|
||
|
|
(w - margin + pad, y0 + total_h + pad), (0, 0, 0), -1)
|
||
|
|
cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image)
|
||
|
|
|
||
|
|
y = y0
|
||
|
|
for i, ln in enumerate(lines):
|
||
|
|
(tw, th), bl = sizes[i]
|
||
|
|
cv2.putText(image, ln, (x0, y + th), font, scale, (0, 255, 255), thickness, cv2.LINE_AA)
|
||
|
|
y += th + bl + pad
|
||
|
|
|
||
|
|
|
||
|
|
def generate_video(
|
||
|
|
svo_path: Path,
|
||
|
|
output_dir: Path,
|
||
|
|
weight_json: Path,
|
||
|
|
yolo_model_path: str,
|
||
|
|
conf: float = 0.25,
|
||
|
|
imgsz: int = 640,
|
||
|
|
frame_stride: int = 1,
|
||
|
|
show_large: bool = False,
|
||
|
|
) -> Optional[Path]:
|
||
|
|
if not ZED_AVAILABLE:
|
||
|
|
print("ERROR: pyzed not available, cannot generate labeled video")
|
||
|
|
return None
|
||
|
|
|
||
|
|
weight_g, length_mm = _extract_weight_length(weight_json)
|
||
|
|
print(f" Labeling with weight={weight_g}g, length={length_mm}mm from {weight_json.name}")
|
||
|
|
|
||
|
|
if weight_g is None and length_mm is None:
|
||
|
|
print(" WARNING: No valid weight/length in JSON, video will show '--'")
|
||
|
|
|
||
|
|
from ultralytics import YOLO
|
||
|
|
yolo = YOLO(yolo_model_path)
|
||
|
|
class_names = yolo.names if hasattr(yolo, "names") else {}
|
||
|
|
|
||
|
|
from dataset.zed_reader import ZEDReader
|
||
|
|
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
|
||
|
|
if not zed_reader.open():
|
||
|
|
print(f" ERROR: Failed to open {svo_path.name}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
runtime = sl.RuntimeParameters()
|
||
|
|
left_mat = sl.Mat()
|
||
|
|
svo_name = svo_path.stem
|
||
|
|
|
||
|
|
output_dir = Path(output_dir)
|
||
|
|
images_dir = output_dir / "images"
|
||
|
|
images_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
frames = []
|
||
|
|
idx = 0
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
err = zed_reader.zed.grab(runtime)
|
||
|
|
if err != sl.ERROR_CODE.SUCCESS:
|
||
|
|
break
|
||
|
|
|
||
|
|
zed_reader.zed.retrieve_image(left_mat, sl.VIEW.LEFT)
|
||
|
|
left_np = left_mat.get_data()
|
||
|
|
img = left_np[:, :, :3].copy() if left_np.shape[2] > 3 else left_np.copy()
|
||
|
|
|
||
|
|
if frame_stride > 1 and (idx % frame_stride) != 0:
|
||
|
|
idx += 1
|
||
|
|
continue
|
||
|
|
|
||
|
|
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
|
||
|
|
num_dets = len(results.boxes) if results.boxes is not None else 0
|
||
|
|
|
||
|
|
left_disp = img.copy()
|
||
|
|
if num_dets > 0:
|
||
|
|
boxes = results.boxes.xyxy.cpu().numpy()
|
||
|
|
tids = (results.boxes.id.cpu().numpy().astype(int)
|
||
|
|
if hasattr(results.boxes, "id") and results.boxes.id is not None
|
||
|
|
else np.zeros(len(boxes), dtype=int))
|
||
|
|
cls_ids = (results.boxes.cls.cpu().numpy().astype(int)
|
||
|
|
if results.boxes.cls is not None
|
||
|
|
else np.zeros(len(boxes), dtype=int))
|
||
|
|
|
||
|
|
for i, box in enumerate(boxes):
|
||
|
|
tid = int(tids[i]) if i < len(tids) else 0
|
||
|
|
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
|
||
|
|
cname = class_names.get(cid, "fish")
|
||
|
|
_draw_label_on_box(left_disp, box, tid, cname, weight_g, length_mm)
|
||
|
|
|
||
|
|
if show_large:
|
||
|
|
_draw_large_summary(left_disp, weight_g, length_mm)
|
||
|
|
|
||
|
|
frame_name = f"frame_{idx + 1:06d}"
|
||
|
|
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
|
||
|
|
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
||
|
|
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
|
||
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
||
|
|
|
||
|
|
right_disp = img.copy()
|
||
|
|
cv2.putText(right_disp, "Original", (10, right_disp.shape[0] - 20),
|
||
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
||
|
|
|
||
|
|
combined = np.hstack([left_disp, right_disp])
|
||
|
|
if num_dets > 0:
|
||
|
|
frames.append(combined)
|
||
|
|
|
||
|
|
if idx % 30 == 0:
|
||
|
|
print(f" [{idx + 1}] {frame_name} dets={num_dets} frames_collected={len(frames)}")
|
||
|
|
|
||
|
|
idx += 1
|
||
|
|
finally:
|
||
|
|
zed_reader.close()
|
||
|
|
|
||
|
|
if not frames:
|
||
|
|
print(f" WARNING: No detection frames collected from {svo_name}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
video_path = images_dir / f"{svo_name}_preview.mp4"
|
||
|
|
h, w = frames[0].shape[:2]
|
||
|
|
writer = cv2.VideoWriter(str(video_path), cv2.VideoWriter_fourcc(*"mp4v"), 10.0, (w, h))
|
||
|
|
for f in frames:
|
||
|
|
writer.write(f)
|
||
|
|
writer.release()
|
||
|
|
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, weight={weight_g}g len={length_mm}mm)")
|
||
|
|
return video_path
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Generate labeled preview video from SVO + weight JSON")
|
||
|
|
parser.add_argument("--svo", required=True, help="Path to .svo2 file")
|
||
|
|
parser.add_argument("--save-output", required=True, help="Output directory for this SVO")
|
||
|
|
parser.add_argument("--weight-json", required=True, help="Path to weight_prediction.json or DGCNN output JSON")
|
||
|
|
parser.add_argument("--yolo-model",
|
||
|
|
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt")
|
||
|
|
parser.add_argument("--conf", type=float, default=0.25)
|
||
|
|
parser.add_argument("--imgsz", type=int, default=640)
|
||
|
|
parser.add_argument("--frame-stride", type=int, default=1)
|
||
|
|
parser.add_argument("--show-large-labels-at-top-right", action="store_true")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
svo = Path(args.svo).expanduser().resolve()
|
||
|
|
output_dir = Path(args.save_output).expanduser().resolve()
|
||
|
|
wjson = Path(args.weight_json).expanduser().resolve()
|
||
|
|
|
||
|
|
if not svo.exists():
|
||
|
|
raise SystemExit(f"SVO not found: {svo}")
|
||
|
|
if not wjson.exists():
|
||
|
|
raise SystemExit(f"Weight JSON not found: {wjson}")
|
||
|
|
|
||
|
|
generate_video(
|
||
|
|
svo_path=svo,
|
||
|
|
output_dir=output_dir,
|
||
|
|
weight_json=wjson,
|
||
|
|
yolo_model_path=args.yolo_model,
|
||
|
|
conf=args.conf,
|
||
|
|
imgsz=args.imgsz,
|
||
|
|
frame_stride=args.frame_stride,
|
||
|
|
show_large=args.show_large_labels_at_top_right,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|