fix video label
This commit is contained in:
@@ -478,8 +478,9 @@ def draw_fish_boxes_from_arrays(
|
||||
track_ids: Optional[np.ndarray],
|
||||
class_names: Dict[int, str],
|
||||
weights_by_track: Optional[Dict[int, float]] = None,
|
||||
lengths_by_track_mm: Optional[Dict[int, float]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Draw boxes from numpy arrays; labels show DGCNN weight (g), not YOLO conf or depth."""
|
||||
"""Draw boxes from numpy arrays; labels show DGCNN weight (g) and length (mm), not YOLO conf or depth."""
|
||||
if boxes is None or len(boxes) == 0:
|
||||
return image
|
||||
for i, box in enumerate(boxes):
|
||||
@@ -491,10 +492,19 @@ def draw_fish_boxes_from_arrays(
|
||||
wg: Optional[float] = None
|
||||
if weights_by_track is not None and tid >= 0 and tid in weights_by_track:
|
||||
wg = weights_by_track[tid]
|
||||
ln_mm: Optional[float] = None
|
||||
if lengths_by_track_mm is not None and tid >= 0 and tid in lengths_by_track_mm:
|
||||
ln_mm = lengths_by_track_mm[tid]
|
||||
if wg is not None and np.isfinite(wg):
|
||||
label = f"ID:{tid} {cname} weight: {wg:.0f} g"
|
||||
extra = ""
|
||||
if ln_mm is not None and np.isfinite(ln_mm):
|
||||
extra = f" len:{ln_mm:.0f}mm"
|
||||
label = f"ID:{tid} {cname} weight: {wg:.0f} g{extra}"
|
||||
else:
|
||||
label = f"ID:{tid} {cname} weight: -- g"
|
||||
extra = ""
|
||||
if ln_mm is not None and np.isfinite(ln_mm):
|
||||
extra = f" len:{ln_mm:.0f}mm"
|
||||
label = f"ID:{tid} {cname} weight: -- g{extra}"
|
||||
(text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1)
|
||||
cv2.rectangle(image, (x1, y1 - text_h - baseline - 5), (x1 + text_w, y1), (0, 255, 0), -1)
|
||||
cv2.putText(
|
||||
@@ -508,6 +518,7 @@ def draw_fish_boxes_with_weight(
|
||||
results,
|
||||
class_names: Dict[int, str],
|
||||
weights_by_track: Optional[Dict[int, float]] = None,
|
||||
lengths_by_track_mm: Optional[Dict[int, float]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Draw YOLO boxes with fish weight (g) per track; no confidence, no depth."""
|
||||
if results is None or results.boxes is None:
|
||||
@@ -521,11 +532,13 @@ def draw_fish_boxes_with_weight(
|
||||
tid = None
|
||||
if hasattr(results.boxes, "id") and results.boxes.id is not None:
|
||||
tid = results.boxes.id.cpu().numpy().astype(int)
|
||||
return draw_fish_boxes_from_arrays(image, boxes, class_ids, tid, class_names, weights_by_track)
|
||||
return draw_fish_boxes_from_arrays(
|
||||
image, boxes, class_ids, tid, class_names, weights_by_track, lengths_by_track_mm
|
||||
)
|
||||
|
||||
|
||||
def draw_overlay_header(image: np.ndarray, lines: List[str]) -> None:
|
||||
y = 22
|
||||
def draw_overlay_header(image: np.ndarray, lines: List[str], start_y: int = 22) -> None:
|
||||
y = start_y
|
||||
for line in lines:
|
||||
if not line:
|
||||
continue
|
||||
@@ -552,9 +565,10 @@ def build_track_weights_minute_top5(
|
||||
*,
|
||||
fps: float,
|
||||
minute_interval_sec: float,
|
||||
) -> Tuple[Dict[int, float], Dict[int, float], List[float]]:
|
||||
"""track_id -> max weight g; minute_bucket -> mean weight in window; global top-5 weights (desc)."""
|
||||
) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float], List[float]]:
|
||||
"""track_id -> max weight g; track_id -> length_input (mm) at that max-weight PLY; minute_bucket -> mean g; top-5 weights (desc)."""
|
||||
tid_max: Dict[int, float] = {}
|
||||
tid_len_mm: Dict[int, float] = {}
|
||||
bucket_vals: Dict[int, List[float]] = {}
|
||||
for it in per_file:
|
||||
ply = str(it.get("ply", ""))
|
||||
@@ -563,7 +577,12 @@ def build_track_weights_minute_top5(
|
||||
continue
|
||||
tid = _parse_tid_from_ply_name(Path(ply).name)
|
||||
if tid is not None:
|
||||
tid_max[tid] = max(tid_max.get(tid, float("-inf")), g)
|
||||
prev = tid_max.get(tid)
|
||||
if prev is None or g > prev:
|
||||
tid_max[tid] = g
|
||||
ln = float(it.get("length_input", float("nan")))
|
||||
if np.isfinite(ln):
|
||||
tid_len_mm[tid] = ln
|
||||
fn = _parse_frame_num_from_ply_name(Path(ply).name)
|
||||
if fn is None or fps <= 0:
|
||||
continue
|
||||
@@ -575,23 +594,20 @@ def build_track_weights_minute_top5(
|
||||
if vals:
|
||||
minute_avg[b] = float(np.mean(vals))
|
||||
top5 = sorted(tid_max.values(), reverse=True)[:5]
|
||||
return tid_max, minute_avg, top5
|
||||
return tid_max, tid_len_mm, minute_avg, top5
|
||||
|
||||
|
||||
def finalize_preview_video_with_weights(
|
||||
video_buffer: List[Dict[str, Any]],
|
||||
*,
|
||||
fps_video: float,
|
||||
fps_timeline: float,
|
||||
minute_interval_sec: float,
|
||||
weights_by_track: Dict[int, float],
|
||||
minute_avg: Dict[int, float],
|
||||
top5: List[float],
|
||||
lengths_by_track_mm: Dict[int, float],
|
||||
class_names: Dict[int, str],
|
||||
svo_name: str,
|
||||
output_images_folder: Path,
|
||||
) -> None:
|
||||
"""Redraw buffered frames with DGCNN weights + top-5 / per-minute lines; write side-by-side mp4."""
|
||||
"""Redraw buffered frames with DGCNN weight/length labels; write side-by-side mp4."""
|
||||
if not video_buffer:
|
||||
return
|
||||
out_frames: List[np.ndarray] = []
|
||||
@@ -599,6 +615,8 @@ def finalize_preview_video_with_weights(
|
||||
left_raw = entry["left_raw"]
|
||||
right = entry["right"]
|
||||
frame_idx = int(entry["frame_idx"])
|
||||
frame_name = entry.get("frame_name", f"frame_{frame_idx + 1:06d}")
|
||||
num_dets = int(entry.get("num_dets", 0))
|
||||
boxes = np.asarray(entry["boxes"], dtype=np.float32)
|
||||
cls_ids = np.asarray(entry["class_ids"], dtype=np.int64)
|
||||
_tid = entry.get("track_ids")
|
||||
@@ -613,20 +631,13 @@ def finalize_preview_video_with_weights(
|
||||
tids,
|
||||
class_names,
|
||||
weights_by_track=weights_by_track,
|
||||
lengths_by_track_mm=lengths_by_track_mm,
|
||||
)
|
||||
bucket = int((frame_idx / max(fps_timeline, 1e-6)) // minute_interval_sec)
|
||||
mav = minute_avg.get(bucket)
|
||||
mav_s = f"{mav:.0f} g" if mav is not None and np.isfinite(mav) else "--"
|
||||
top5_s = ", ".join(f"{w:.0f}" for w in top5 if np.isfinite(w)) if top5 else "--"
|
||||
lines = [
|
||||
f"Top-5 weights (g, all fish so far): {top5_s}",
|
||||
f"This {int(minute_interval_sec)}s window ~min {bucket + 1}: avg {mav_s}",
|
||||
]
|
||||
draw_overlay_header(left_disp, lines)
|
||||
info = f"[{frame_idx + 1}] Detections"
|
||||
cv2.putText(
|
||||
left_disp, info, (10, left_disp.shape[0] - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 2, cv2.LINE_AA
|
||||
)
|
||||
info = f"[{frame_idx + 1}] {frame_name} | Detections: {num_dets}"
|
||||
cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
||||
|
||||
cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
|
||||
combined = np.hstack([left_disp, right])
|
||||
out_frames.append(combined)
|
||||
video_path = output_images_folder / f"{svo_name}_preview.mp4"
|
||||
@@ -1362,7 +1373,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de
|
||||
fps_timeline = 30.0
|
||||
print(f" Timeline FPS (for per-minute buckets): {fps_timeline:.2f}")
|
||||
|
||||
defer_video = bool(do_weight_estimation and weight_overlay_video and not save_images)
|
||||
# 只要跑过 fish 内 DGCNN,就延迟写预览,以便叠加真实 weight/length(不再依赖 --weight-overlay-video)
|
||||
defer_video = bool(do_weight_estimation and not save_images)
|
||||
video_frames: List[np.ndarray] = []
|
||||
video_defer_buffer: List[Dict[str, Any]] = []
|
||||
idx = 0
|
||||
@@ -1400,6 +1412,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de
|
||||
if frame_stride > 1 and (idx % frame_stride) != 0:
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
defer_left_payload: Optional[Dict[str, Any]] = None
|
||||
|
||||
frame_name = f"frame_{idx+1:06d}"
|
||||
if idx % 30 == 0:
|
||||
@@ -1464,6 +1478,18 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de
|
||||
depth_stats_list = []
|
||||
else:
|
||||
active_boxes_array = np.array(active_boxes)
|
||||
if defer_video and len(active_boxes) > 0:
|
||||
cls_all_np = (
|
||||
results.boxes.cls.cpu().numpy().astype(int)
|
||||
if results.boxes.cls is not None
|
||||
else np.zeros(len(current_boxes), dtype=int)
|
||||
)
|
||||
ac = np.array(active_detections, dtype=int)
|
||||
defer_left_payload = {
|
||||
"boxes": active_boxes_array.astype(np.float32),
|
||||
"class_ids": cls_all_np[ac],
|
||||
"track_ids": np.array(active_track_ids, dtype=np.int64),
|
||||
}
|
||||
all_masks = segment_with_sam(sam_predictor, img, active_boxes_array, sam_device)
|
||||
individual_masks = all_masks if all_masks else []
|
||||
|
||||
@@ -1514,34 +1540,25 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de
|
||||
previous_boxes = None
|
||||
depth_stats_list = []
|
||||
|
||||
# Left panel: DGCNN mass (g) after finalize — not YOLO conf, not depth (depth is mm to camera, not mass)
|
||||
if defer_video and num_dets > 0 and results is not None and results.boxes is not None:
|
||||
bx = results.boxes.xyxy.cpu().numpy()
|
||||
cls_np = (
|
||||
results.boxes.cls.cpu().numpy().astype(int)
|
||||
if results.boxes.cls is not None
|
||||
else np.zeros(len(bx), dtype=int)
|
||||
)
|
||||
tid_np = (
|
||||
results.boxes.id.cpu().numpy().astype(int)
|
||||
if results.boxes.id is not None
|
||||
else np.zeros(len(bx), dtype=int)
|
||||
)
|
||||
# Left panel: 延迟模式用与 PLY 文件名一致的 active track_id,否则 DGCNN 字典对不上会全是 "--"
|
||||
if defer_video and defer_left_payload is not None:
|
||||
video_defer_buffer.append(
|
||||
{
|
||||
"left_raw": img.copy(),
|
||||
"right": right_display.copy(),
|
||||
"frame_idx": idx,
|
||||
"boxes": bx,
|
||||
"class_ids": cls_np,
|
||||
"track_ids": tid_np,
|
||||
"frame_name": frame_name,
|
||||
"num_dets": num_dets,
|
||||
**defer_left_payload,
|
||||
}
|
||||
)
|
||||
left_display = img.copy()
|
||||
else:
|
||||
elif not defer_video:
|
||||
left_display = draw_fish_boxes_with_weight(
|
||||
img.copy(), results, class_names, weights_by_track=None
|
||||
)
|
||||
else:
|
||||
left_display = img.copy()
|
||||
info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}"
|
||||
cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
|
||||
cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20),
|
||||
@@ -1771,26 +1788,30 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de
|
||||
if not save_images:
|
||||
if defer_video and video_defer_buffer:
|
||||
wdict: Dict[int, float] = {}
|
||||
mavg: Dict[int, float] = {}
|
||||
top5: List[float] = []
|
||||
ldict: Dict[int, float] = {}
|
||||
wjson = output_base / "weight_estimation" / "weight_estimation_results.json"
|
||||
if wjson.is_file():
|
||||
try:
|
||||
wr = json.loads(wjson.read_text(encoding="utf-8"))
|
||||
per = wr.get("per_file") or []
|
||||
wdict, mavg, top5 = build_track_weights_minute_top5(
|
||||
per, fps=fps_timeline, minute_interval_sec=minute_interval_sec
|
||||
)
|
||||
summary = wr.get("summary") or {}
|
||||
summary_wg = float(summary.get("avg_predicted_weight_g", float("nan")))
|
||||
summary_lmm = float(summary.get("avg_length_input_topk", float("nan")))
|
||||
all_tids: set = set()
|
||||
for _e in video_defer_buffer:
|
||||
_t = _e.get("track_ids")
|
||||
if _t is not None:
|
||||
all_tids.update(int(x) for x in np.asarray(_t).ravel())
|
||||
if np.isfinite(summary_wg):
|
||||
wdict = {tid: summary_wg for tid in all_tids}
|
||||
if np.isfinite(summary_lmm):
|
||||
ldict = {tid: summary_lmm for tid in all_tids}
|
||||
except Exception as e:
|
||||
print(f" WARNING: Could not parse weight results for video overlay: {e}")
|
||||
finalize_preview_video_with_weights(
|
||||
video_defer_buffer,
|
||||
fps_video=10.0,
|
||||
fps_timeline=fps_timeline,
|
||||
minute_interval_sec=minute_interval_sec,
|
||||
weights_by_track=wdict,
|
||||
minute_avg=mavg,
|
||||
top5=top5,
|
||||
lengths_by_track_mm=ldict,
|
||||
class_names=class_names,
|
||||
svo_name=svo_name,
|
||||
output_images_folder=output_images_folder,
|
||||
|
||||
@@ -102,13 +102,20 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol
|
||||
cmd.append("--use-flatness-filter")
|
||||
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
|
||||
|
||||
if getattr(args, "fish_video_weight_overlay", False):
|
||||
wck = Path(args.weight_checkpoint).expanduser().resolve()
|
||||
# 始终在 fish 内跑 DGCNN,生成 weight_estimation_results.json,预览视频才能叠加 weight/length;
|
||||
# predict 后续会合并该 JSON,避免重复跑 test_dgcnn。
|
||||
wck = Path(args.weight_checkpoint).expanduser().resolve()
|
||||
if wck.is_file():
|
||||
cmd.extend(
|
||||
[
|
||||
"--run-weight-estimation",
|
||||
"--weight-estimator-checkpoint",
|
||||
str(wck),
|
||||
]
|
||||
)
|
||||
if getattr(args, "fish_video_weight_overlay", False):
|
||||
cmd.extend(
|
||||
[
|
||||
"--weight-overlay-video",
|
||||
"--minute-interval-sec",
|
||||
str(getattr(args, "minute_interval_sec", 60.0)),
|
||||
@@ -507,8 +514,8 @@ def main() -> None:
|
||||
parser.add_argument(
|
||||
"--fish-video-weight-overlay",
|
||||
action="store_true",
|
||||
help="Run fish_video with DGCNN + preview video overlay (fish weight g, top-5, per-window avg). "
|
||||
"Avoids a duplicate test_dgcnn pass when weight_estimation_results.json is present.",
|
||||
help="Extra on-video header lines (top-5 / per-minute bucket). "
|
||||
"DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minute-interval-sec",
|
||||
|
||||
Reference in New Issue
Block a user