whole process
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate labeled preview video from SVO + weight prediction JSON.
|
||||
|
||||
Reads the SVO with YOLO tracking, overlays DGCNN weight/length on each
|
||||
detection box, and writes ``<svo_name>_preview.mp4`` into ``--save-output``.
|
||||
Per-frame labeling: each detection box shows the weight/length predicted
|
||||
for that specific frame's PLY (from ``per_cloud`` / ``per_file`` in the
|
||||
DGCNN JSON). Frames without a corresponding PLY carry forward the last
|
||||
known value. The final aggregated result is shown at top-right.
|
||||
|
||||
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
|
||||
Replaces any existing preview video so the final published file has labels.
|
||||
@@ -13,9 +15,10 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@@ -27,45 +30,74 @@ except ImportError:
|
||||
ZED_AVAILABLE = False
|
||||
|
||||
|
||||
def _extract_weight_length(weight_json: Path) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Return (weight_g, length_mm) from a weight prediction JSON."""
|
||||
def _parse_weight_json(weight_json: Path) -> Tuple[
|
||||
Dict[int, Tuple[float, float]],
|
||||
Optional[float],
|
||||
Optional[float],
|
||||
bool,
|
||||
]:
|
||||
"""Parse weight JSON → per-frame map + summary + confidence.
|
||||
|
||||
Returns:
|
||||
per_frame: {frame_number: (weight_g, length_mm)} from per_cloud/per_file
|
||||
summary_weight_g: final aggregated weight
|
||||
summary_length_mm: final aggregated length
|
||||
is_confident: True when ``*`` should be shown (avg > 440g OR length band fraction >= 25%)
|
||||
"""
|
||||
data = json.loads(weight_json.read_text(encoding="utf-8"))
|
||||
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
|
||||
|
||||
w_candidates = [
|
||||
def _first_finite(*candidates):
|
||||
for c in candidates:
|
||||
if c is not None:
|
||||
try:
|
||||
v = float(c)
|
||||
if math.isfinite(v):
|
||||
return v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return None
|
||||
|
||||
summary_wg = _first_finite(
|
||||
summary.get("pred_weight_g"),
|
||||
summary.get("avg_predicted_weight_g"),
|
||||
data.get("pred_weight_g"),
|
||||
data.get("avg_predicted_weight_g"),
|
||||
]
|
||||
weight_g = None
|
||||
for c in w_candidates:
|
||||
if c is not None:
|
||||
try:
|
||||
v = float(c)
|
||||
if math.isfinite(v):
|
||||
weight_g = v
|
||||
break
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
l_candidates = [
|
||||
)
|
||||
summary_lmm = _first_finite(
|
||||
summary.get("avg_length_input_topk"),
|
||||
summary.get("avg_length_input"),
|
||||
data.get("avg_length_input"),
|
||||
]
|
||||
length_mm = None
|
||||
for c in l_candidates:
|
||||
if c is not None:
|
||||
try:
|
||||
v = float(c)
|
||||
if math.isfinite(v):
|
||||
length_mm = v
|
||||
break
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
)
|
||||
|
||||
return weight_g, length_mm
|
||||
CONFIDENT_AVG_G = 440.0
|
||||
MIN_FRAC_LARGEST_LENGTH_GROUP = 0.25
|
||||
|
||||
mean_g = _first_finite(
|
||||
summary.get("mean_all_pred_g_after_filters"),
|
||||
summary.get("avg_predicted_weight_g"),
|
||||
)
|
||||
frac = _first_finite(summary.get("fraction_in_near_max_length_band"))
|
||||
|
||||
is_confident = False
|
||||
if mean_g is not None and mean_g > CONFIDENT_AVG_G:
|
||||
is_confident = True
|
||||
elif frac is not None and frac >= MIN_FRAC_LARGEST_LENGTH_GROUP:
|
||||
is_confident = True
|
||||
|
||||
per_frame: Dict[int, Tuple[float, float]] = {}
|
||||
for item in data.get("per_cloud") or data.get("per_file") or []:
|
||||
ply = item.get("ply", "")
|
||||
m = re.search(r"frame_(\d+)", Path(str(ply)).stem)
|
||||
if not m:
|
||||
continue
|
||||
fnum = int(m.group(1))
|
||||
wg = _first_finite(item.get("predicted_weight_g"))
|
||||
lmm = _first_finite(item.get("length_input"))
|
||||
if wg is not None:
|
||||
per_frame[fnum] = (wg, lmm if lmm is not None else float("nan"))
|
||||
|
||||
return per_frame, summary_wg, summary_lmm, is_confident
|
||||
|
||||
|
||||
def _draw_label_on_box(
|
||||
@@ -79,8 +111,8 @@ def _draw_label_on_box(
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
w_str = f"{weight_g:.0f}g" if weight_g is not None else "--g"
|
||||
l_str = f"{length_mm:.0f}mm" if length_mm is not None else "--mm"
|
||||
w_str = f"{weight_g:.0f}g" if weight_g is not None and math.isfinite(weight_g) else "--g"
|
||||
l_str = f"{length_mm:.0f}mm" if length_mm is not None and math.isfinite(length_mm) else "--mm"
|
||||
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
|
||||
|
||||
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
|
||||
@@ -93,13 +125,15 @@ def _draw_large_summary(
|
||||
image: np.ndarray,
|
||||
weight_g: Optional[float],
|
||||
length_mm: Optional[float],
|
||||
is_confident: bool = False,
|
||||
) -> None:
|
||||
star = " *" if is_confident else ""
|
||||
lines = []
|
||||
lines.append(f"Weight: {weight_g:.0f}g" if weight_g is not None else "Weight: --g")
|
||||
lines.append(f"Final: {weight_g:.0f}g{star}" if weight_g is not None else f"Final: --g")
|
||||
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
|
||||
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
scale = 2.75
|
||||
scale = 2.0
|
||||
thickness = 2
|
||||
pad = 10
|
||||
h, w = image.shape[:2]
|
||||
@@ -138,11 +172,13 @@ def generate_video(
|
||||
print("ERROR: pyzed not available, cannot generate labeled video")
|
||||
return None
|
||||
|
||||
weight_g, length_mm = _extract_weight_length(weight_json)
|
||||
print(f" Labeling with weight={weight_g}g, length={length_mm}mm from {weight_json.name}")
|
||||
per_frame, summary_wg, summary_lmm, is_confident = _parse_weight_json(weight_json)
|
||||
star_s = " *" if is_confident else ""
|
||||
print(f" Per-frame predictions: {len(per_frame)} PLYs mapped")
|
||||
print(f" Summary: weight={summary_wg}g, length={summary_lmm}mm{star_s}")
|
||||
|
||||
if weight_g is None and length_mm is None:
|
||||
print(" WARNING: No valid weight/length in JSON, video will show '--'")
|
||||
if not per_frame and summary_wg is None:
|
||||
print(" WARNING: No weight data in JSON, video will show '--'")
|
||||
|
||||
from ultralytics import YOLO
|
||||
yolo = YOLO(yolo_model_path)
|
||||
@@ -162,8 +198,10 @@ def generate_video(
|
||||
images_dir = output_dir / "images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
frames = []
|
||||
frames: List[np.ndarray] = []
|
||||
idx = 0
|
||||
last_wg: Optional[float] = None
|
||||
last_lmm: Optional[float] = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
@@ -179,6 +217,15 @@ def generate_video(
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
frame_number = idx + 1
|
||||
if frame_number in per_frame:
|
||||
cur_wg, cur_lmm = per_frame[frame_number]
|
||||
last_wg = cur_wg
|
||||
last_lmm = cur_lmm if math.isfinite(cur_lmm) else last_lmm
|
||||
else:
|
||||
cur_wg = last_wg
|
||||
cur_lmm = last_lmm
|
||||
|
||||
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
|
||||
num_dets = len(results.boxes) if results.boxes is not None else 0
|
||||
|
||||
@@ -196,13 +243,13 @@ def generate_video(
|
||||
tid = int(tids[i]) if i < len(tids) else 0
|
||||
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
|
||||
cname = class_names.get(cid, "fish")
|
||||
_draw_label_on_box(left_disp, box, tid, cname, weight_g, length_mm)
|
||||
_draw_label_on_box(left_disp, box, tid, cname, cur_wg, cur_lmm)
|
||||
|
||||
if show_large:
|
||||
_draw_large_summary(left_disp, weight_g, length_mm)
|
||||
if show_large or summary_wg is not None:
|
||||
_draw_large_summary(left_disp, summary_wg, summary_lmm, is_confident)
|
||||
|
||||
frame_name = f"frame_{idx + 1:06d}"
|
||||
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
|
||||
frame_name = f"frame_{frame_number:06d}"
|
||||
info = f"[{frame_number}] {frame_name} | Detections: {num_dets}"
|
||||
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
||||
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
||||
@@ -216,7 +263,9 @@ def generate_video(
|
||||
frames.append(combined)
|
||||
|
||||
if idx % 30 == 0:
|
||||
print(f" [{idx + 1}] {frame_name} dets={num_dets} frames_collected={len(frames)}")
|
||||
w_s = f"{cur_wg:.0f}g" if cur_wg is not None else "--"
|
||||
l_s = f"{cur_lmm:.0f}mm" if cur_lmm is not None else "--"
|
||||
print(f" [{frame_number}] {frame_name} dets={num_dets} w={w_s} l={l_s} collected={len(frames)}")
|
||||
|
||||
idx += 1
|
||||
finally:
|
||||
@@ -232,7 +281,7 @@ def generate_video(
|
||||
for f in frames:
|
||||
writer.write(f)
|
||||
writer.release()
|
||||
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, weight={weight_g}g len={length_mm}mm)")
|
||||
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, {len(per_frame)} PLY labels)")
|
||||
return video_path
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
SESSION_ROOT="/home/ubuntu/data/fish/2016-1-22-last"
|
||||
FISH_NAME="fish9"
|
||||
FISH_NAME="fish1"
|
||||
fish_dir="${SESSION_ROOT}/${FISH_NAME}/"
|
||||
OUT_PARENT="output_weight_estimator"
|
||||
save_out="${OUT_PARENT}/${FISH_NAME}"
|
||||
|
||||
Reference in New Issue
Block a user