whole process

This commit is contained in:
zaiun xu
2026-04-14 22:05:52 +08:00
parent af67f61b63
commit 940d426a37
7 changed files with 161 additions and 123 deletions

View File

@@ -1,8 +1,10 @@
#!/usr/bin/env python3
"""Generate labeled preview video from SVO + weight prediction JSON.
Reads the SVO with YOLO tracking, overlays DGCNN weight/length on each
detection box, and writes ``<svo_name>_preview.mp4`` into ``--save-output``.
Per-frame labeling: each detection box shows the weight/length predicted
for that specific frame's PLY (from ``per_cloud`` / ``per_file`` in the
DGCNN JSON). Frames without a corresponding PLY carry forward the last
known value. The final aggregated result is shown at top-right.
Called by ``predict_weigth_from_svo2.py`` after DGCNN completes.
Replaces any existing preview video so the final published file has labels.
@@ -13,9 +15,10 @@ from __future__ import annotations
import argparse
import json
import math
import re
import sys
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
@@ -27,45 +30,74 @@ except ImportError:
ZED_AVAILABLE = False
def _extract_weight_length(weight_json: Path) -> Tuple[Optional[float], Optional[float]]:
"""Return (weight_g, length_mm) from a weight prediction JSON."""
def _parse_weight_json(weight_json: Path) -> Tuple[
Dict[int, Tuple[float, float]],
Optional[float],
Optional[float],
bool,
]:
"""Parse weight JSON → per-frame map + summary + confidence.
Returns:
per_frame: {frame_number: (weight_g, length_mm)} from per_cloud/per_file
summary_weight_g: final aggregated weight
summary_length_mm: final aggregated length
is_confident: True when ``*`` should be shown (avg > 440g OR length band fraction >= 25%)
"""
data = json.loads(weight_json.read_text(encoding="utf-8"))
summary = data.get("dgcnn_summary") or data.get("weight_summary") or data.get("summary") or {}
w_candidates = [
def _first_finite(*candidates):
for c in candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
return v
except (TypeError, ValueError):
pass
return None
summary_wg = _first_finite(
summary.get("pred_weight_g"),
summary.get("avg_predicted_weight_g"),
data.get("pred_weight_g"),
data.get("avg_predicted_weight_g"),
]
weight_g = None
for c in w_candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
weight_g = v
break
except (TypeError, ValueError):
continue
l_candidates = [
)
summary_lmm = _first_finite(
summary.get("avg_length_input_topk"),
summary.get("avg_length_input"),
data.get("avg_length_input"),
]
length_mm = None
for c in l_candidates:
if c is not None:
try:
v = float(c)
if math.isfinite(v):
length_mm = v
break
except (TypeError, ValueError):
continue
)
return weight_g, length_mm
CONFIDENT_AVG_G = 440.0
MIN_FRAC_LARGEST_LENGTH_GROUP = 0.25
mean_g = _first_finite(
summary.get("mean_all_pred_g_after_filters"),
summary.get("avg_predicted_weight_g"),
)
frac = _first_finite(summary.get("fraction_in_near_max_length_band"))
is_confident = False
if mean_g is not None and mean_g > CONFIDENT_AVG_G:
is_confident = True
elif frac is not None and frac >= MIN_FRAC_LARGEST_LENGTH_GROUP:
is_confident = True
per_frame: Dict[int, Tuple[float, float]] = {}
for item in data.get("per_cloud") or data.get("per_file") or []:
ply = item.get("ply", "")
m = re.search(r"frame_(\d+)", Path(str(ply)).stem)
if not m:
continue
fnum = int(m.group(1))
wg = _first_finite(item.get("predicted_weight_g"))
lmm = _first_finite(item.get("length_input"))
if wg is not None:
per_frame[fnum] = (wg, lmm if lmm is not None else float("nan"))
return per_frame, summary_wg, summary_lmm, is_confident
def _draw_label_on_box(
@@ -79,8 +111,8 @@ def _draw_label_on_box(
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
w_str = f"{weight_g:.0f}g" if weight_g is not None else "--g"
l_str = f"{length_mm:.0f}mm" if length_mm is not None else "--mm"
w_str = f"{weight_g:.0f}g" if weight_g is not None and math.isfinite(weight_g) else "--g"
l_str = f"{length_mm:.0f}mm" if length_mm is not None and math.isfinite(length_mm) else "--mm"
label = f"ID:{tid} {class_name} weight: {w_str} len: {l_str}"
(tw, th), bl = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
@@ -93,13 +125,15 @@ def _draw_large_summary(
image: np.ndarray,
weight_g: Optional[float],
length_mm: Optional[float],
is_confident: bool = False,
) -> None:
star = " *" if is_confident else ""
lines = []
lines.append(f"Weight: {weight_g:.0f}g" if weight_g is not None else "Weight: --g")
lines.append(f"Final: {weight_g:.0f}g{star}" if weight_g is not None else f"Final: --g")
lines.append(f"Length: {length_mm:.0f}mm" if length_mm is not None else "Length: --mm")
font = cv2.FONT_HERSHEY_SIMPLEX
scale = 2.75
scale = 2.0
thickness = 2
pad = 10
h, w = image.shape[:2]
@@ -138,11 +172,13 @@ def generate_video(
print("ERROR: pyzed not available, cannot generate labeled video")
return None
weight_g, length_mm = _extract_weight_length(weight_json)
print(f" Labeling with weight={weight_g}g, length={length_mm}mm from {weight_json.name}")
per_frame, summary_wg, summary_lmm, is_confident = _parse_weight_json(weight_json)
star_s = " *" if is_confident else ""
print(f" Per-frame predictions: {len(per_frame)} PLYs mapped")
print(f" Summary: weight={summary_wg}g, length={summary_lmm}mm{star_s}")
if weight_g is None and length_mm is None:
print(" WARNING: No valid weight/length in JSON, video will show '--'")
if not per_frame and summary_wg is None:
print(" WARNING: No weight data in JSON, video will show '--'")
from ultralytics import YOLO
yolo = YOLO(yolo_model_path)
@@ -162,8 +198,10 @@ def generate_video(
images_dir = output_dir / "images"
images_dir.mkdir(parents=True, exist_ok=True)
frames = []
frames: List[np.ndarray] = []
idx = 0
last_wg: Optional[float] = None
last_lmm: Optional[float] = None
try:
while True:
@@ -179,6 +217,15 @@ def generate_video(
idx += 1
continue
frame_number = idx + 1
if frame_number in per_frame:
cur_wg, cur_lmm = per_frame[frame_number]
last_wg = cur_wg
last_lmm = cur_lmm if math.isfinite(cur_lmm) else last_lmm
else:
cur_wg = last_wg
cur_lmm = last_lmm
results = yolo.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0]
num_dets = len(results.boxes) if results.boxes is not None else 0
@@ -196,13 +243,13 @@ def generate_video(
tid = int(tids[i]) if i < len(tids) else 0
cid = int(cls_ids[i]) if i < len(cls_ids) else 0
cname = class_names.get(cid, "fish")
_draw_label_on_box(left_disp, box, tid, cname, weight_g, length_mm)
_draw_label_on_box(left_disp, box, tid, cname, cur_wg, cur_lmm)
if show_large:
_draw_large_summary(left_disp, weight_g, length_mm)
if show_large or summary_wg is not None:
_draw_large_summary(left_disp, summary_wg, summary_lmm, is_confident)
frame_name = f"frame_{idx + 1:06d}"
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
frame_name = f"frame_{frame_number:06d}"
info = f"[{frame_number}] {frame_name} | Detections: {num_dets}"
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
@@ -216,7 +263,9 @@ def generate_video(
frames.append(combined)
if idx % 30 == 0:
print(f" [{idx + 1}] {frame_name} dets={num_dets} frames_collected={len(frames)}")
w_s = f"{cur_wg:.0f}g" if cur_wg is not None else "--"
l_s = f"{cur_lmm:.0f}mm" if cur_lmm is not None else "--"
print(f" [{frame_number}] {frame_name} dets={num_dets} w={w_s} l={l_s} collected={len(frames)}")
idx += 1
finally:
@@ -232,7 +281,7 @@ def generate_video(
for f in frames:
writer.write(f)
writer.release()
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, weight={weight_g}g len={length_mm}mm)")
print(f" ✓ Labeled video: {video_path.name} ({len(frames)} frames, {len(per_frame)} PLY labels)")
return video_path

View File

@@ -10,7 +10,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
SESSION_ROOT="/home/ubuntu/data/fish/2016-1-22-last"
FISH_NAME="fish9"
FISH_NAME="fish1"
fish_dir="${SESSION_ROOT}/${FISH_NAME}/"
OUT_PARENT="output_weight_estimator"
save_out="${OUT_PARENT}/${FISH_NAME}"