393 lines
14 KiB
Python
393 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""Generate labeled preview video from SVO + weight prediction JSON.
|
|
|
|
Per-frame labeling: each detection box shows the weight/length predicted
|
|
for that specific frame's PLY (from ``per_cloud`` / ``per_file`` in the
|
|
DGCNN JSON). Frames without a corresponding PLY carry forward the last
|
|
known value. The final aggregated result is shown at top-right.
|
|
|
|
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
|
|
Replaces any existing preview video so the final published file has labels.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import math
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
try:
|
|
import pyzed.sl as sl
|
|
ZED_AVAILABLE = True
|
|
except ImportError:
|
|
ZED_AVAILABLE = False
|
|
|
|
|
|
def _open_video_writer(path: Path, fps: float, size: Tuple[int, int]) -> cv2.VideoWriter:
|
|
"""Open a VideoWriter, preferring GStreamer NVENC on Jetson for hardware H.264."""
|
|
w, h = size
|
|
try:
|
|
if hasattr(cv2, "CAP_GSTREAMER"):
|
|
loc = str(path).replace('"', '\\"')
|
|
gst_pipe = (
|
|
f'appsrc ! videoconvert ! video/x-raw,format=BGRx ! '
|
|
f'nvvidconv ! video/x-raw(memory:NVMM) ! '
|
|
f'nvv4l2h264enc bitrate=4000000 ! h264parse ! '
|
|
f'mp4mux ! filesink location="{loc}"'
|
|
)
|
|
writer = cv2.VideoWriter(gst_pipe, cv2.CAP_GSTREAMER, 0, fps, (w, h))
|
|
if writer.isOpened():
|
|
return writer
|
|
writer.release()
|
|
except Exception:
|
|
pass
|
|
return cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
|
|
|
|
|
|
def _parse_weight_json(weight_json: Path) -> Tuple[
|
|
Dict[int, Tuple[float, float]],
|
|
Optional[float],
|
|
Optional[float],
|
|
bool,
|
|
]:
|
|
"""Parse weight JSON → per-frame map + summary + confidence.
|
|
|
|
Returns:
|
|
per_frame: {frame_number: (weight_g, length_mm)} from per_cloud/per_file
|
|
summary_weight_g: final aggregated weight
|
|
summary_length_mm: final aggregated length
|
|
is_confident: True when ``*`` should be shown (avg > 440g OR length band fraction >= 25%)
|
|
"""
|
|
data = json.loads(weight_json.read_text(encoding="utf-8"))
|
|
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
|
|
|
|
def _first_finite(*candidates):
|
|
for c in candidates:
|
|
if c is not None:
|
|
try:
|
|
v = float(c)
|
|
if math.isfinite(v):
|
|
return v
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return None
|
|
|
|
summary_wg = _first_finite(
|
|
summary.get("pred_weight_g"),
|
|
summary.get("avg_predicted_weight_g"),
|
|
data.get("pred_weight_g"),
|
|
data.get("avg_predicted_weight_g"),
|
|
)
|
|
summary_lmm = _first_finite(
|
|
summary.get("avg_length_input_topk"),
|
|
summary.get("avg_length_input"),
|
|
data.get("avg_length_input"),
|
|
)
|
|
|
|
CONFIDENT_AVG_G = 440.0
|
|
MIN_FRAC_LARGEST_LENGTH_GROUP = 0.25
|
|
|
|
mean_g = _first_finite(
|
|
summary.get("mean_all_pred_g_after_filters"),
|
|
summary.get("avg_predicted_weight_g"),
|
|
)
|
|
frac = _first_finite(summary.get("fraction_in_near_max_length_band"))
|
|
|
|
is_confident = False
|
|
if mean_g is not None and mean_g > CONFIDENT_AVG_G:
|
|
is_confident = True
|
|
elif frac is not None and frac >= MIN_FRAC_LARGEST_LENGTH_GROUP:
|
|
is_confident = True
|
|
|
|
per_frame: Dict[int, Tuple[float, float]] = {}
|
|
for item in data.get("per_cloud") or data.get("per_file") or []:
|
|
ply = item.get("ply", "")
|
|
m = re.search(r"frame_(\d+)", Path(str(ply)).stem)
|
|
if not m:
|
|
continue
|
|
fnum = int(m.group(1))
|
|
wg = _first_finite(item.get("predicted_weight_g"))
|
|
lmm = _first_finite(item.get("length_input"))
|
|
if wg is not None:
|
|
per_frame[fnum] = (wg, lmm if lmm is not None else float("nan"))
|
|
|
|
return per_frame, summary_wg, summary_lmm, is_confident
|
|
|
|
|
|
def _draw_label_on_box(
|
|
image: np.ndarray,
|
|
box: np.ndarray,
|
|
tid: int,
|
|
class_name: str,
|
|
weight_g: Optional[float],
|
|
length_mm: Optional[float],
|
|
) -> None:
|
|
x1, y1, x2, y2 = map(int, box)
|
|
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|
|
|
w_str = f"{weight_g:.0f}g" if weight_g is not None and math.isfinite(weight_g) else "--g"
|
|
l_str = f"{length_mm:.0f}mm" if length_mm is not None and math.isfinite(length_mm) else "--mm"
|
|
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
|
|
|
|
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
|
|
cv2.rectangle(image, (x1, y1 - th - bl - 5), (x1 + tw, y1), (0, 255, 0), -1)
|
|
cv2.putText(image, label, (x1, y1 - bl - 2),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA)
|
|
|
|
|
|
def _draw_large_summary(
|
|
image: np.ndarray,
|
|
weight_g: Optional[float],
|
|
length_mm: Optional[float],
|
|
is_confident: bool = False,
|
|
) -> None:
|
|
star = " *" if is_confident else ""
|
|
lines = []
|
|
lines.append(f"Final: {weight_g:.0f}g{star}" if weight_g is not None else f"Final: --g")
|
|
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
|
|
|
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
scale = 2.0
|
|
thickness = 2
|
|
pad = 10
|
|
h, w = image.shape[:2]
|
|
margin = 20
|
|
|
|
sizes = [cv2.getTextSize(ln, font, scale, thickness) for ln in lines]
|
|
max_tw = max(s[0][0] for s in sizes)
|
|
total_h = sum(s[0][1] + s[1] + pad for s in sizes)
|
|
|
|
x0 = w - max_tw - margin
|
|
y0 = 30
|
|
|
|
overlay = image.copy()
|
|
cv2.rectangle(overlay, (x0 - pad, y0 - pad),
|
|
(w - margin + pad, y0 + total_h + pad), (0, 0, 0), -1)
|
|
cv2.addWeighted(overlay, 0.7, image, 0.3, 0, image)
|
|
|
|
y = y0
|
|
for i, ln in enumerate(lines):
|
|
(tw, th), bl = sizes[i]
|
|
cv2.putText(image, ln, (x0, y + th), font, scale, (0, 255, 255), thickness, cv2.LINE_AA)
|
|
y += th + bl + pad
|
|
|
|
|
|
def generate_video(
|
|
svo_path: Path,
|
|
output_dir: Path,
|
|
weight_json: Path,
|
|
yolo_model_path: str,
|
|
conf: float = 0.25,
|
|
imgsz: int = 640,
|
|
frame_stride: int = 1,
|
|
show_large: bool = False,
|
|
summary_weight_g: Optional[float] = None,
|
|
summary_length_mm: Optional[float] = None,
|
|
summary_star: bool = False,
|
|
output_video_name: Optional[str] = None,
|
|
sam_device: str = "cuda",
|
|
) -> Optional[Path]:
|
|
if not ZED_AVAILABLE:
|
|
print("ERROR: pyzed not available, cannot generate labeled video")
|
|
return None
|
|
|
|
per_frame, parsed_summary_wg, parsed_summary_lmm, raw_confident = _parse_weight_json(weight_json)
|
|
if summary_weight_g is None:
|
|
summary_weight_g = parsed_summary_wg
|
|
if summary_length_mm is None:
|
|
summary_length_mm = parsed_summary_lmm
|
|
star_s = " *" if summary_star else ""
|
|
print(f" Per-frame predictions: {len(per_frame)} PLYs mapped")
|
|
print(
|
|
f" Summary: weight={summary_weight_g}g, length={summary_length_mm}mm{star_s} "
|
|
f"(raw_confident={raw_confident})"
|
|
)
|
|
|
|
if not per_frame and summary_weight_g is None:
|
|
print(" WARNING: No weight data in JSON, video will show '--'")
|
|
|
|
from ultralytics import YOLO
|
|
yolo = YOLO(yolo_model_path)
|
|
class_names = yolo.names if hasattr(yolo, "names") else {}
|
|
from fish_video_weight_evaluation import (
|
|
create_segmentation_overlay,
|
|
load_sam_predictor_with_fallback,
|
|
segment_with_sam,
|
|
)
|
|
sam_predictor, eff_sam_device = load_sam_predictor_with_fallback(sam_device)
|
|
sam_torch_device = eff_sam_device
|
|
|
|
from dataset.zed_reader import ZEDReader
|
|
zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False)
|
|
if not zed_reader.open():
|
|
print(f" ERROR: Failed to open {svo_path.name}")
|
|
return None
|
|
|
|
runtime = sl.RuntimeParameters()
|
|
left_mat = sl.Mat()
|
|
svo_name = svo_path.stem
|
|
|
|
output_dir = Path(output_dir)
|
|
images_dir = output_dir / "images"
|
|
images_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
frames: List[np.ndarray] = []
|
|
idx = 0
|
|
last_wg: Optional[float] = None
|
|
last_lmm: Optional[float] = None
|
|
|
|
try:
|
|
while True:
|
|
err = zed_reader.zed.grab(runtime)
|
|
if err != sl.ERROR_CODE.SUCCESS:
|
|
break
|
|
|
|
zed_reader.zed.retrieve_image(left_mat, sl.VIEW.LEFT)
|
|
left_np = left_mat.get_data()
|
|
img = left_np[:, :, :3].copy() if left_np.shape[2] > 3 else left_np.copy()
|
|
|
|
if frame_stride > 1 and (idx % frame_stride) != 0:
|
|
idx += 1
|
|
continue
|
|
|
|
frame_number = idx + 1
|
|
frame_name = f"frame_{frame_number:06d}"
|
|
if frame_number in per_frame:
|
|
cur_wg, cur_lmm = per_frame[frame_number]
|
|
last_wg = cur_wg
|
|
last_lmm = cur_lmm if math.isfinite(cur_lmm) else last_lmm
|
|
else:
|
|
cur_wg = last_wg
|
|
cur_lmm = last_lmm
|
|
|
|
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
|
|
num_dets = len(results.boxes) if results.boxes is not None else 0
|
|
|
|
left_disp = img.copy()
|
|
right_disp = img.copy()
|
|
if num_dets > 0:
|
|
boxes = results.boxes.xyxy.cpu().numpy()
|
|
tids = (results.boxes.id.cpu().numpy().astype(int)
|
|
if hasattr(results.boxes, "id") and results.boxes.id is not None
|
|
else np.zeros(len(boxes), dtype=int))
|
|
cls_ids = (results.boxes.cls.cpu().numpy().astype(int)
|
|
if results.boxes.cls is not None
|
|
else np.zeros(len(boxes), dtype=int))
|
|
|
|
for i, box in enumerate(boxes):
|
|
tid = int(tids[i]) if i < len(tids) else 0
|
|
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
|
|
cname = class_names.get(cid, "fish")
|
|
_draw_label_on_box(left_disp, box, tid, cname, cur_wg, cur_lmm)
|
|
|
|
try:
|
|
masks = segment_with_sam(sam_predictor, img, boxes, sam_torch_device)
|
|
except Exception as e:
|
|
print(f" WARNING: SAM segmentation failed on {frame_name}: {e}")
|
|
masks = []
|
|
|
|
if masks:
|
|
right_disp = create_segmentation_overlay(img.copy(), masks)
|
|
cv2.putText(right_disp, "Segmentation", (10, right_disp.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
|
else:
|
|
cv2.putText(right_disp, "Segmentation (failed)", (10, right_disp.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2, cv2.LINE_AA)
|
|
else:
|
|
cv2.putText(right_disp, "No detections", (10, right_disp.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (128, 128, 128), 2, cv2.LINE_AA)
|
|
|
|
if show_large or summary_weight_g is not None:
|
|
_draw_large_summary(left_disp, summary_weight_g, summary_length_mm, summary_star)
|
|
|
|
info = f"[{frame_number}] {frame_name} | Detections: {num_dets}"
|
|
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
|
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
|
|
|
combined = np.hstack([left_disp, right_disp])
|
|
if num_dets > 0:
|
|
frames.append(combined)
|
|
|
|
if idx % 30 == 0:
|
|
w_s = f"{cur_wg:.0f}g" if cur_wg is not None else "--"
|
|
l_s = f"{cur_lmm:.0f}mm" if cur_lmm is not None else "--"
|
|
print(f" [{frame_number}] {frame_name} dets={num_dets} w={w_s} l={l_s} collected={len(frames)}")
|
|
|
|
idx += 1
|
|
finally:
|
|
zed_reader.close()
|
|
|
|
if not frames:
|
|
print(f" WARNING: No detection frames collected from {svo_name}")
|
|
return None
|
|
|
|
video_path = images_dir / (output_video_name or f"{svo_name}_preview.mp4")
|
|
h, w = frames[0].shape[:2]
|
|
writer = _open_video_writer(video_path, 10.0, (w, h))
|
|
for f in frames:
|
|
writer.write(f)
|
|
writer.release()
|
|
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, {len(per_frame)} PLY labels)")
|
|
return video_path
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Generate labeled preview video from SVO + weight JSON")
|
|
parser.add_argument("--svo", required=True, help="Path to .svo2 file")
|
|
parser.add_argument("--save-output", required=True, help="Output directory for this SVO")
|
|
parser.add_argument("--weight-json", required=True, help="Path to weight_prediction.json or DGCNN output JSON")
|
|
parser.add_argument("--yolo-model",
|
|
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt")
|
|
parser.add_argument("--conf", type=float, default=0.25)
|
|
parser.add_argument("--imgsz", type=int, default=640)
|
|
parser.add_argument("--frame-stride", type=int, default=1)
|
|
parser.add_argument("--sam-device", type=str, default="cuda")
|
|
parser.add_argument("--show-large-labels-at-top-right", action="store_true")
|
|
parser.add_argument(
|
|
"--summary-star",
|
|
action="store_true",
|
|
default=False,
|
|
help="Whether to draw * on the Final summary line; caller/DB is the source of truth.",
|
|
)
|
|
parser.add_argument("--summary-weight-g", type=float, default=None)
|
|
parser.add_argument("--summary-length-mm", type=float, default=None)
|
|
parser.add_argument("--output-video-name", type=str, default=None)
|
|
args = parser.parse_args()
|
|
|
|
svo = Path(args.svo).expanduser().resolve()
|
|
output_dir = Path(args.save_output).expanduser().resolve()
|
|
wjson = Path(args.weight_json).expanduser().resolve()
|
|
|
|
if not svo.exists():
|
|
raise SystemExit(f"SVO not found: {svo}")
|
|
if not wjson.exists():
|
|
raise SystemExit(f"Weight JSON not found: {wjson}")
|
|
|
|
generate_video(
|
|
svo_path=svo,
|
|
output_dir=output_dir,
|
|
weight_json=wjson,
|
|
yolo_model_path=args.yolo_model,
|
|
conf=args.conf,
|
|
imgsz=args.imgsz,
|
|
frame_stride=args.frame_stride,
|
|
sam_device=args.sam_device,
|
|
show_large=args.show_large_labels_at_top_right,
|
|
summary_weight_g=args.summary_weight_g,
|
|
summary_length_mm=args.summary_length_mm,
|
|
summary_star=bool(args.summary_star),
|
|
output_video_name=args.output_video_name,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|