diff --git a/FishMeasure/fish_video_weight_evaluation.py b/FishMeasure/fish_video_weight_evaluation.py old mode 100755 new mode 100644 index e83a6c2..e749a27 --- a/FishMeasure/fish_video_weight_evaluation.py +++ b/FishMeasure/fish_video_weight_evaluation.py @@ -5,7 +5,6 @@ Pure OpenCV + YOLO + SAM for viewing images from a folder. """ import argparse -import re import cv2 import json import numpy as np @@ -41,37 +40,6 @@ from utils.keep_largest_cluster import keep_largest_cluster_with_colors from utils.correct_tail_rotation import correct_tail_rotation_array -def get_h264_fourcc(): - """返回最佳的 H.264 FourCC 编码器代码。 - - 尝试顺序:avc1 (最兼容), X264, H264 - 如果都不可用,回退到 mp4v (MPEG-4,兼容性较差但通用) - """ - # 尝试的 H.264 FourCC 代码(按浏览器兼容性排序) - h264_candidates = ["avc1", "X264", "H264"] - for codec in h264_candidates: - fourcc = cv2.VideoWriter_fourcc(*codec) - # 测试是否可用(创建临时视频) - try: - import tempfile - import os - with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: - tmp_path = f.name - test_writer = cv2.VideoWriter(tmp_path, fourcc, 10.0, (64, 64)) - if test_writer.isOpened(): - test_writer.release() - os.unlink(tmp_path) - print(f"[VideoEncoder] Using H.264 codec: {codec}") - return fourcc - test_writer.release() - os.unlink(tmp_path) - except Exception: - pass - # 回退到 mp4v(MPEG-4 Part 2) - print("[VideoEncoder] H.264 not available, falling back to mp4v (MPEG-4)") - return cv2.VideoWriter_fourcc(*"mp4v") - - def estimate_pointcloud_length_pca(points: np.ndarray) -> float: """ Estimate point cloud "length" as the extent along the 1st PCA axis. @@ -300,7 +268,13 @@ def run_weight_estimation( top_k: int = 5, top_by_length: bool = True, length_switch_to_weight_mm: float = 319.0, - ply_files: Optional[List[Path]] = None, + max_length_mm: float = 400.0, + length_quality_cv_threshold_pct: float = 15.0, + length_quality_max_span_mm: float = 130.0, + min_length_width_ratio: float = 1.5, + average_all_after_filter: bool = False, + average_all_fallback_to_max_if_mean_over_g: float = 400.0, + mean_pool_fallback_to_max_if_over_g: float = 440.0, ) -> Optional[Dict]: """ Run weight estimation on point clouds in a folder using DGCNN (test_dgcnn_weight_estimator). @@ -320,6 +294,13 @@ def run_weight_estimation( num_points: Number of points to sample per PLY xyz_scale: XYZ scaling factor (0.001 = mm to m) verbose: Print detailed output + max_length_mm: Exclude PLYs with length_input > this (mm) from aggregation; 0 disables. + length_quality_cv_threshold_pct / length_quality_max_span_mm: Heuristic good/bad hint from + frame-to-frame length spread (CV and max−min span); 0 for max_span disables the span rule. + min_length_width_ratio: Exclude PLYs with PCA length/width below this; 0 disables (default 1.5). + average_all_after_filter: If True, final weight = mean over all PLYs after filters (no top-K). + average_all_fallback_to_max_if_mean_over_g: If mean (g) exceeds this with average-all, use max pred; 0 disables. + mean_pool_fallback_to_max_if_over_g: If mean over all filtered candidates (g) exceeds this, use max pred; 0 disables. Returns: Results dict or None if failed @@ -334,19 +315,43 @@ def run_weight_estimation( output_dir = Path(output_dir).expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) - # Find PLY files (optional explicit list for subsets / per-minute windows) - if ply_files is not None: - ply_list = sorted({Path(p).expanduser().resolve() for p in ply_files}) - ply_list = [p for p in ply_list if p.is_file()] - else: - ply_list = sorted(cloud_folder.glob("*.ply")) - if not ply_list: + # Find PLY files + ply_files = sorted(cloud_folder.glob("*.ply")) + if not ply_files: print(f" No PLY files found in: {cloud_folder}") return None + max_len_eff = None if max_length_mm <= 0 else float(max_length_mm) + span_lq_eff = None if length_quality_max_span_mm <= 0 else float(length_quality_max_span_mm) + min_lw_eff = None if min_length_width_ratio <= 0 else float(min_length_width_ratio) + avg_all_fb_eff = ( + None + if average_all_fallback_to_max_if_mean_over_g <= 0 + else float(average_all_fallback_to_max_if_mean_over_g) + ) + mean_pool_fb_eff = ( + None + if mean_pool_fallback_to_max_if_over_g <= 0 + else float(mean_pool_fallback_to_max_if_over_g) + ) + if verbose: print(f"\n{'='*60}") - print(f"Running DGCNN weight estimation on {len(ply_list)} point clouds...") + print(f"Running DGCNN weight estimation on {len(ply_files)} point clouds...") + if max_len_eff is not None: + print(f" Length cap: exclude length > {max_len_eff:.0f} mm from aggregation") + if min_lw_eff is not None: + print(f" Length/width: exclude PCA length/width < {min_lw_eff:.2f}") + if average_all_after_filter: + print(" Aggregation: mean weight over all PLYs after filters (no top-K)") + if avg_all_fb_eff is not None: + print( + f" If mean > {avg_all_fb_eff:.0f} g → use max pred after filter (see pred_weight_* in JSON)" + ) + if mean_pool_fb_eff is not None: + print( + f" If mean over all filtered PLYs > {mean_pool_fb_eff:.0f} g → use max pred after filter" + ) print(f"{'='*60}") try: @@ -354,7 +359,7 @@ def run_weight_estimation( if str(weight_estimator_dir) not in sys.path: sys.path.insert(0, str(weight_estimator_dir)) - from test_dgcnn_weight_estimator import predict_cloud_folder + from test_dgcnn_weight_estimator import predict_cloud_folder, _print_max_weight_after_filter per_file, summary = predict_cloud_folder( model=_weight_estimator_model, @@ -373,7 +378,13 @@ def run_weight_estimation( outlier_field="length_input", iqr_factor=1.5, zscore_threshold=2.5, - ply_files=ply_list, + max_length_mm=max_len_eff, + min_length_width_ratio=min_lw_eff, + length_quality_cv_threshold_pct=float(length_quality_cv_threshold_pct), + length_quality_max_span_mm=span_lq_eff, + average_all_after_filter=bool(average_all_after_filter), + average_all_fallback_to_max_if_mean_over_g=avg_all_fb_eff, + mean_pool_fallback_to_max_if_over_g=mean_pool_fb_eff, ) # Check CV threshold @@ -383,21 +394,54 @@ def run_weight_estimation( print(f" WARNING: High CV detected ({cv_pct:.1f}% > {max_cv_length}%) - results may be unreliable") summary["cv_warning"] = True - # Print summary (skip per-file details to reduce log noise) + # Print results if verbose: - print(f"\n Files processed: {summary.get('num_files_predicted', len(ply_list))}") + for it in per_file: + ply_name = Path(it["ply"]).name + g = float(it["predicted_weight_g"]) + length = float(it.get("length_input", float("nan"))) + is_outlier = it.get("is_outlier", False) + if it.get("filtered_by_max_length"): + cap_s = f"{max_len_eff:.0f}" if max_len_eff is not None else "?" + outlier_marker = f" [FILTERED length>{cap_s}mm]" + elif it.get("filtered_by_length_width_ratio"): + thr_s = f"{min_lw_eff:.2f}" if min_lw_eff is not None else "?" + outlier_marker = f" [FILTERED L/W<{thr_s}]" + else: + outlier_marker = " [OUTLIER]" if is_outlier else "" + length_str = f"{length:.1f}mm" if np.isfinite(length) else "nan" + print(f" {ply_name}: len={length_str} | {g:.2f}g{outlier_marker}") + + print(f"\n Files processed: {summary.get('num_files_predicted', len(ply_files))}") + n_cap = summary.get("num_files_filtered_by_max_length", 0) or 0 + if n_cap > 0 and max_len_eff is not None: + print( + f" Filtered by length cap (>{max_len_eff:.0f} mm): {n_cap} " + f"(excluded from aggregation)" + ) + n_lw = summary.get("num_files_filtered_by_length_width_ratio", 0) or 0 + if n_lw > 0 and min_lw_eff is not None: + print( + f" Filtered by length/width ratio (<{min_lw_eff:.2f}): {n_lw} " + f"(excluded from aggregation)" + ) if summary.get('num_outliers_removed', 0) > 0: print(f" Outliers removed: {summary['num_outliers_removed']}") - if cv_pct is not None and np.isfinite(cv_pct): - print(f" Length CV: {cv_pct:.1f}%") - - avg_g = summary.get("avg_predicted_weight_g", 0) - eff_len = summary.get("effective_top_by_length", top_by_length) - mode = f"top-{top_k} by length" if eff_len else f"top-{top_k} by pred" + pw = summary.get("pred_weight_g") + avg_g = float(pw if pw is not None else summary.get("avg_predicted_weight_g", 0)) + if summary.get("average_all_after_filter"): + if summary.get("used_max_instead_of_mean_all_high_mean"): + mode = "final pred — mean-all exceeded threshold, using max" + else: + mode = "mean of all after filters" + else: + eff_len = summary.get("effective_top_by_length", top_by_length) + mode = f"top-{top_k} by length" if eff_len else f"top-{top_k} by pred" if topk_length is not None: print(f" Estimated weight ({mode}, topk_length={topk_length}): {avg_g:.2f}g") else: print(f" Estimated weight ({mode}): {avg_g:.2f}g") + _print_max_weight_after_filter(summary, prefix=" ") # Save results results = { @@ -415,6 +459,13 @@ def run_weight_estimation( "max_cv_length": max_cv_length, "num_points": num_points, "xyz_scale": xyz_scale, + "max_length_mm": max_len_eff, + "min_length_width_ratio": min_lw_eff, + "length_quality_cv_threshold_pct": float(length_quality_cv_threshold_pct), + "length_quality_max_span_mm": span_lq_eff, + "average_all_after_filter": bool(average_all_after_filter), + "average_all_fallback_to_max_if_mean_over_g": avg_all_fb_eff, + "mean_pool_fallback_to_max_if_over_g": mean_pool_fb_eff, } } @@ -493,222 +544,6 @@ def draw_detections(image, results, class_names, depth_stats_list=None): return image -def draw_fish_boxes_from_arrays( - image: np.ndarray, - boxes: np.ndarray, - class_ids: np.ndarray, - track_ids: Optional[np.ndarray], - class_names: Dict[int, str], - weights_by_track: Optional[Dict[int, float]] = None, - lengths_by_track_mm: Optional[Dict[int, float]] = None, -) -> np.ndarray: - """Draw boxes from numpy arrays; labels show DGCNN weight (g) and length (mm), not YOLO conf or depth.""" - if boxes is None or len(boxes) == 0: - return image - for i, box in enumerate(boxes): - x1, y1, x2, y2 = map(int, box) - cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) - cls_id = int(class_ids[i]) if i < len(class_ids) else 0 - cname = class_names.get(cls_id, "fish") if class_names else "fish" - tid = int(track_ids[i]) if track_ids is not None and i < len(track_ids) else -1 - wg: Optional[float] = None - if weights_by_track is not None and tid >= 0 and tid in weights_by_track: - wg = weights_by_track[tid] - ln_mm: Optional[float] = None - if lengths_by_track_mm is not None and tid >= 0 and tid in lengths_by_track_mm: - ln_mm = lengths_by_track_mm[tid] - if wg is not None and np.isfinite(wg): - extra = "" - if ln_mm is not None and np.isfinite(ln_mm): - extra = f" len:{ln_mm:.0f}mm" - label = f"ID:{tid} {cname} weight: {wg:.0f} g{extra}" - else: - extra = "" - if ln_mm is not None and np.isfinite(ln_mm): - extra = f" len:{ln_mm:.0f}mm" - label = f"ID:{tid} {cname} weight: -- g{extra}" - (text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1) - cv2.rectangle(image, (x1, y1 - text_h - baseline - 5), (x1 + text_w, y1), (0, 255, 0), -1) - cv2.putText( - image, label, (x1, y1 - baseline - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA - ) - return image - - -def draw_fish_boxes_with_weight( - image: np.ndarray, - results, - class_names: Dict[int, str], - weights_by_track: Optional[Dict[int, float]] = None, - lengths_by_track_mm: Optional[Dict[int, float]] = None, -) -> np.ndarray: - """Draw YOLO boxes with fish weight (g) per track; no confidence, no depth.""" - if results is None or results.boxes is None: - return image - boxes = results.boxes.xyxy.cpu().numpy() - class_ids = ( - results.boxes.cls.cpu().numpy().astype(int) - if results.boxes.cls is not None - else np.zeros(len(boxes), dtype=int) - ) - tid = None - if hasattr(results.boxes, "id") and results.boxes.id is not None: - tid = results.boxes.id.cpu().numpy().astype(int) - return draw_fish_boxes_from_arrays( - image, boxes, class_ids, tid, class_names, weights_by_track, lengths_by_track_mm - ) - - -def draw_overlay_header(image: np.ndarray, lines: List[str], start_y: int = 22) -> None: - y = start_y - for line in lines: - if not line: - continue - cv2.putText(image, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 255), 2, cv2.LINE_AA) - y += 22 - - -def _parse_tid_from_ply_name(name: str) -> Optional[int]: - m = re.search(r"_tid(\d+)", name) - if not m: - return None - return int(m.group(1)) - - -def _parse_frame_num_from_ply_name(name: str) -> Optional[int]: - m = re.search(r"frame_(\d+)", name) - if not m: - return None - return int(m.group(1)) - - -def build_track_weights_minute_top5( - per_file: List[Dict[str, Any]], - *, - fps: float, - minute_interval_sec: float, -) -> Tuple[Dict[int, float], Dict[int, float], Dict[int, float], List[float]]: - """track_id -> max weight g; track_id -> length_input (mm) at that max-weight PLY; minute_bucket -> mean g; top-5 weights (desc).""" - tid_max: Dict[int, float] = {} - tid_len_mm: Dict[int, float] = {} - bucket_vals: Dict[int, List[float]] = {} - for it in per_file: - ply = str(it.get("ply", "")) - g = float(it.get("predicted_weight_g", float("nan"))) - if not np.isfinite(g): - continue - tid = _parse_tid_from_ply_name(Path(ply).name) - if tid is not None: - prev = tid_max.get(tid) - if prev is None or g > prev: - tid_max[tid] = g - ln = float(it.get("length_input", float("nan"))) - if np.isfinite(ln): - tid_len_mm[tid] = ln - fn = _parse_frame_num_from_ply_name(Path(ply).name) - if fn is None or fps <= 0: - continue - t_sec = (fn - 1) / fps - b = int(t_sec // minute_interval_sec) - bucket_vals.setdefault(b, []).append(g) - minute_avg: Dict[int, float] = {} - for b, vals in bucket_vals.items(): - if vals: - minute_avg[b] = float(np.mean(vals)) - top5 = sorted(tid_max.values(), reverse=True)[:5] - return tid_max, tid_len_mm, minute_avg, top5 - - -def _build_per_frame_weight_lookup( - per_file: List[Dict[str, Any]], -) -> Dict[Tuple[int, int], Tuple[float, float]]: - """Build (frame_idx, track_id) -> (weight_g, length_mm) from DGCNN per_file results. - - frame_idx is 0-based (matches the idx used in process_single_svo2). - """ - lookup: Dict[Tuple[int, int], Tuple[float, float]] = {} - for it in per_file: - ply = str(it.get("ply", "")) - name = Path(ply).name - tid = _parse_tid_from_ply_name(name) - fn = _parse_frame_num_from_ply_name(name) - if tid is None or fn is None: - continue - frame_idx = fn - 1 - wg = float(it.get("predicted_weight_g", float("nan"))) - ln = float(it.get("length_input", float("nan"))) - if np.isfinite(wg): - lookup[(frame_idx, tid)] = (wg, ln if np.isfinite(ln) else float("nan")) - return lookup - - -def finalize_preview_video_with_weights( - video_buffer: List[Dict[str, Any]], - *, - fps_video: float, - per_frame_lookup: Dict[Tuple[int, int], Tuple[float, float]], - class_names: Dict[int, str], - svo_name: str, - output_images_folder: Path, -) -> None: - """Redraw buffered frames with per-frame, per-fish DGCNN weight/length labels. - - Only (frame_idx, track_id) keys present in per_frame_lookup get numeric labels; - missing keys show weight as ``--`` (see draw_fish_boxes_from_arrays). - """ - if not video_buffer: - return - out_frames: List[np.ndarray] = [] - for entry in video_buffer: - left_raw = entry["left_raw"] - right = entry["right"] - frame_idx = int(entry["frame_idx"]) - frame_name = entry.get("frame_name", f"frame_{frame_idx + 1:06d}") - num_dets = int(entry.get("num_dets", 0)) - boxes = np.asarray(entry["boxes"], dtype=np.float32) - cls_ids = np.asarray(entry["class_ids"], dtype=np.int64) - _tid = entry.get("track_ids") - if _tid is None: - tids = np.zeros(len(boxes), dtype=np.int64) - else: - tids = np.asarray(_tid, dtype=np.int64) - - frame_wdict: Dict[int, float] = {} - frame_ldict: Dict[int, float] = {} - for t in tids: - t = int(t) - key = (frame_idx, t) - if key in per_frame_lookup: - wg, ln = per_frame_lookup[key] - frame_wdict[t] = wg - if np.isfinite(ln): - frame_ldict[t] = ln - - left_disp = draw_fish_boxes_from_arrays( - left_raw.copy(), - boxes, - cls_ids, - tids, - class_names, - weights_by_track=frame_wdict, - lengths_by_track_mm=frame_ldict, - ) - info = f"[{frame_idx + 1}] {frame_name} | Detections: {num_dets}" - cv2.putText(left_disp, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA) - - cv2.putText(left_disp, "Detection", (10, left_disp.shape[0] - 20), - cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA) - combined = np.hstack([left_disp, right]) - out_frames.append(combined) - video_path = output_images_folder / f"{svo_name}_preview.mp4" - h, w = out_frames[0].shape[:2] - fourcc = get_h264_fourcc() - vw = cv2.VideoWriter(str(video_path), fourcc, float(fps_video), (w, h)) - for fr in out_frames: - vw.write(fr) - vw.release() - - def segment_with_sam(sam_predictor, image_bgr, boxes_xyxy, device): """Segment fish using SAM with YOLO detection boxes. Returns list of individual masks for each detection.""" @@ -1299,12 +1134,17 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de weight_top_k: int = 5, weight_top_by_length: bool = True, weight_length_switch_mm: float = 319.0, weight_remove_outliers=True, weight_outlier_method="iqr", + weight_max_length_mm: float = 400.0, + weight_min_length_width_ratio: float = 1.5, + weight_length_quality_cv_threshold_pct: float = 15.0, + weight_length_quality_max_span_mm: float = 130.0, + weight_average_all_after_filter: bool = False, + weight_average_all_fallback_max_if_mean_over_g: float = 400.0, + weight_mean_pool_fallback_max_if_over_g: float = 440.0, save_raw_pointclouds=False, correct_tail_rotation=False, tail_rotation_distance_threshold=5.0, tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0, - max_cv_length=None, output_dir_stem=None, - weight_overlay_video: bool = False, - minute_interval_sec: float = 60.0): + max_cv_length=None, output_dir_stem=None): """Process a single SVO2 file with pre-loaded YOLO and SAM models. Args: @@ -1337,10 +1177,6 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de svo_name = svo_path.stem class_names = yolo_model.names if hasattr(yolo_model, 'names') else {} - try: - class_names = {int(k): v for k, v in class_names.items()} - except (TypeError, ValueError): - class_names = {} out_stem = output_dir_stem if output_dir_stem else svo_name # Setup output folders @@ -1353,6 +1189,51 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de if save_raw_pointclouds and output_raw_pc_folder: output_raw_pc_folder.mkdir(parents=True, exist_ok=True) + print(f"Reading from SVO2 file: {svo_path.name}") + print(f"Output folder: {output_base.resolve()}") + + # Check if output folder already exists and contains point clouds + # If so, skip data generation and directly run weight estimation + if output_base.exists() and output_cloud_folder.exists(): + # Check if there are point cloud files + point_cloud_files = list(output_cloud_folder.glob("*.ply")) + if point_cloud_files and do_weight_estimation: + print(f"\n{'='*60}") + print(f"Output folder already exists with {len(point_cloud_files)} point cloud files") + print(f"Skipping data generation, directly running weight estimation...") + print(f"{'='*60}") + + # Run weight estimation directly + weight_output_dir = output_base / "weight_estimation" + results = run_weight_estimation( + cloud_folder=output_cloud_folder, + output_dir=weight_output_dir, + topk_length=weight_topk_length, + remove_outliers=weight_remove_outliers, + outlier_method=weight_outlier_method, + max_cv_length=max_cv_length, + verbose=True, + top_k=weight_top_k, + top_by_length=weight_top_by_length, + length_switch_to_weight_mm=weight_length_switch_mm, + max_length_mm=weight_max_length_mm, + min_length_width_ratio=weight_min_length_width_ratio, + length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, + length_quality_max_span_mm=weight_length_quality_max_span_mm, + average_all_after_filter=weight_average_all_after_filter, + average_all_fallback_to_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, + mean_pool_fallback_to_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, + ) + return results is not None + elif point_cloud_files and not do_weight_estimation: + print(f"\n{'='*60}") + print(f"Output folder already exists with {len(point_cloud_files)} point cloud files") + print(f"Weight estimation not requested (--run-weight-estimation not set)") + print(f"Skipping processing...") + print(f"{'='*60}") + return True + # If folder exists but no point clouds, continue with normal processing + # Initialize ZED reader zed_reader = ZEDReader(svo_path=str(svo_path), camera_mode=False, use_yolo_detector=False) if not zed_reader.open(): @@ -1380,11 +1261,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de next_track_id = 0 STATIONARY_THRESHOLD = 10 MOVEMENT_THRESHOLD = 5.0 - - # 只要跑过 fish 内 DGCNN,就延迟写预览,以便叠加真实 weight/length(不再依赖 --weight-overlay-video) - defer_video = bool(do_weight_estimation and not save_images) - video_frames: List[np.ndarray] = [] - video_defer_buffer: List[Dict[str, Any]] = [] + + video_frames = [] idx = 0 # List to track point clouds that passed PointNet++ classifier (if enabled) @@ -1420,10 +1298,10 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de if frame_stride > 1 and (idx % frame_stride) != 0: idx += 1 continue - - defer_left_payload: Optional[Dict[str, Any]] = None frame_name = f"frame_{idx+1:06d}" + if idx % 30 == 0: + print(f"[{idx + 1}] {frame_name}") # Run YOLO tracking (fish) results = yolo_model.track(img, conf=conf, imgsz=imgsz, verbose=False, persist=True)[0] @@ -1484,18 +1362,6 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de depth_stats_list = [] else: active_boxes_array = np.array(active_boxes) - if defer_video and len(active_boxes) > 0: - cls_all_np = ( - results.boxes.cls.cpu().numpy().astype(int) - if results.boxes.cls is not None - else np.zeros(len(current_boxes), dtype=int) - ) - ac = np.array(active_detections, dtype=int) - defer_left_payload = { - "boxes": active_boxes_array.astype(np.float32), - "class_ids": cls_all_np[ac], - "track_ids": np.array(active_track_ids, dtype=np.int64), - } all_masks = segment_with_sam(sam_predictor, img, active_boxes_array, sam_device) individual_masks = all_masks if all_masks else [] @@ -1546,25 +1412,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de previous_boxes = None depth_stats_list = [] - # Left panel: 延迟模式用与 PLY 文件名一致的 active track_id,否则 DGCNN 字典对不上会全是 "--" - if defer_video and defer_left_payload is not None: - video_defer_buffer.append( - { - "left_raw": img.copy(), - "right": right_display.copy(), - "frame_idx": idx, - "frame_name": frame_name, - "num_dets": num_dets, - **defer_left_payload, - } - ) - left_display = img.copy() - elif not defer_video: - left_display = draw_fish_boxes_with_weight( - img.copy(), results, class_names, weights_by_track=None - ) - else: - left_display = img.copy() + # Draw detections with depth info + left_display = draw_detections(img.copy(), results, class_names, depth_stats_list) info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}" cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA) cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20), @@ -1581,8 +1430,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de cv2.imwrite(str(image_path), combined_display) if idx % 30 == 0: print(f" Saved image: {image_path.name}") - elif not defer_video: - # Collect for video (immediate composite; no deferred weight overlay) + else: + # Collect for video video_frames.append(combined_display.copy()) # Save point clouds @@ -1687,7 +1536,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SKIPPED (BAD quality)") continue else: - pass + # Point cloud is good - proceed to save + print(f" Point cloud {fish_idx + 1}: PREDICTED={prediction.upper()} (class_id={class_id}, confidence={confidence:.3f}) - SAVING (GOOD quality, confidence >= {pointcloud_classifier_threshold:.3f})") else: # Classifier requested but not available print(f" Point cloud {fish_idx + 1}: WARNING - Classifier requested but not loaded, saving without classification") @@ -1697,7 +1547,7 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de if use_pointcloud_classifier: print(f" Point cloud {fish_idx + 1}: ERROR - Classifier flag was set but classifier is None. Check startup logs for loading errors.") else: - pass + print(f" Point cloud {fish_idx + 1}: Classifier not enabled") # Evaluate flatness if enabled if use_flatness_filter: @@ -1713,7 +1563,7 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% < threshold {flatness_threshold:.2f}% - SKIPPED (not flat enough)") continue else: - pass + print(f" Point cloud {fish_idx + 1}: Flatness score {flatness_score:.2f}% >= threshold {flatness_threshold:.2f}% - PASSED") except Exception as e: print(f" Point cloud {fish_idx + 1}: WARNING - Flatness evaluation failed: {e}") # If flatness check is required and fails, skip saving @@ -1738,14 +1588,10 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de print(f" Point cloud {fish_idx + 1}: WARNING - Tail rotation correction failed: {e}") # Continue with original points if correction fails - # Save point cloud (passed all checks); include track id for DGCNN→video weight mapping + # Save point cloud (passed all checks) filtered_count = len(points) postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else "" - track_id = int(active_track_ids[fish_idx]) - ply_path = ( - output_cloud_folder - / f"cloud_{idx+1:04d}_{frame_name}_tid{track_id}{postfix}.ply" - ) + ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{frame_name}{postfix}.ply" write_ply_file(ply_path, points, colors) # Track point clouds that passed PointNet++ classifier (if enabled) @@ -1758,56 +1604,26 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de # Classifier not enabled, track all saved point clouds kept_pointclouds.append(str(ply_path)) - pass + if filter_pointcloud: + print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({original_count} -> {filtered_count} points)") + else: + print(f" Saved point cloud {fish_idx + 1}: {ply_path.name} ({filtered_count} points)") idx += 1 - # DGCNN on saved clouds first (needed for deferred weight overlay on video) - if do_weight_estimation and kept_pointclouds: - weight_output_dir = output_base / "weight_estimation" - weight_output_dir.mkdir(parents=True, exist_ok=True) - wres = run_weight_estimation( - cloud_folder=output_cloud_folder, - output_dir=weight_output_dir, - topk_length=weight_topk_length, - remove_outliers=weight_remove_outliers, - outlier_method=weight_outlier_method, - max_cv_length=max_cv_length, - verbose=False, - top_k=weight_top_k, - top_by_length=weight_top_by_length, - length_switch_to_weight_mm=weight_length_switch_mm, - ) - - # Preview video after weights (so labels show mass in g, not depth mm) - if not save_images: - if defer_video and video_defer_buffer: - per_frame_lookup: Dict[Tuple[int, int], Tuple[float, float]] = {} - wjson = output_base / "weight_estimation" / "weight_estimation_results.json" - if wjson.is_file(): - try: - wr = json.loads(wjson.read_text(encoding="utf-8")) - per_file_list = wr.get("per_file") or [] - per_frame_lookup = _build_per_frame_weight_lookup(per_file_list) - except Exception: - per_frame_lookup = {} - finalize_preview_video_with_weights( - video_defer_buffer, - fps_video=10.0, - per_frame_lookup=per_frame_lookup, - class_names=class_names, - svo_name=svo_name, - output_images_folder=output_images_folder, - ) - elif video_frames: - video_path = output_images_folder / f"{svo_name}_preview.mp4" - h, w = video_frames[0].shape[:2] - fps_v = 10.0 - fourcc = get_h264_fourcc() - video_writer = cv2.VideoWriter(str(video_path), fourcc, fps_v, (w, h)) - for frame in video_frames: - video_writer.write(frame) - video_writer.release() + # Create video (only if not saving individual images) + if not save_images and video_frames: + video_path = output_images_folder / f"{svo_name}_preview.mp4" + h, w = video_frames[0].shape[:2] + fps = 10.0 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter(str(video_path), fourcc, fps, (w, h)) + for frame in video_frames: + video_writer.write(frame) + video_writer.release() + print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames)") + elif save_images: + print(f"✓ Saved {idx} frames as individual images") # Save tracking stats if fish_tracks: @@ -1832,6 +1648,7 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de } with open(stats_path, 'w', encoding='utf-8') as f: json.dump(tracking_data, f, indent=2) + print(f"✓ Saved tracking stats: {stats_path.name}") # Save list of point clouds kept for template matching if kept_pointclouds: @@ -1839,7 +1656,52 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de with open(pointcloud_list_path, 'w', encoding='utf-8') as f: for ply_path in kept_pointclouds: f.write(f"{ply_path}\n") + if use_pointcloud_classifier: + print(f"✓ Saved list of {len(kept_pointclouds)} point clouds that passed PointNet++ classifier: {pointcloud_list_path.name}") + else: + print(f"✓ Saved list of {len(kept_pointclouds)} point clouds for template matching: {pointcloud_list_path.name}") + else: + if use_pointcloud_classifier: + print(f" Note: PointNet++ classifier was enabled but no point clouds passed the filter") + else: + print(f" Note: No point clouds were saved") + # Run weight estimation if requested + if do_weight_estimation and kept_pointclouds: + # Create output directory for weight estimation + weight_output_dir = output_base / "weight_estimation" + weight_output_dir.mkdir(parents=True, exist_ok=True) + + # Run weight estimation + results = run_weight_estimation( + cloud_folder=output_cloud_folder, + output_dir=weight_output_dir, + topk_length=weight_topk_length, + remove_outliers=weight_remove_outliers, + outlier_method=weight_outlier_method, + max_cv_length=max_cv_length, + verbose=True, + top_k=weight_top_k, + top_by_length=weight_top_by_length, + length_switch_to_weight_mm=weight_length_switch_mm, + max_length_mm=weight_max_length_mm, + min_length_width_ratio=weight_min_length_width_ratio, + length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, + length_quality_max_span_mm=weight_length_quality_max_span_mm, + average_all_after_filter=weight_average_all_after_filter, + average_all_fallback_to_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, + mean_pool_fallback_to_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, + ) + + if results is None: + print(f" WARNING: Weight estimation failed") + elif do_weight_estimation and not kept_pointclouds: + if use_pointcloud_classifier: + print(f" Warning: Weight estimation requested but no point clouds passed PointNet++ classifier") + else: + print(f" Warning: Weight estimation requested but no point clouds were saved") + + print(f"✓ Processed {idx} frames from {svo_path.name}") return True except Exception as e: @@ -1862,12 +1724,17 @@ def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor weight_top_k: int = 5, weight_top_by_length: bool = True, weight_length_switch_mm: float = 319.0, weight_remove_outliers=True, weight_outlier_method="iqr", + weight_max_length_mm: float = 400.0, + weight_min_length_width_ratio: float = 1.5, + weight_length_quality_cv_threshold_pct: float = 15.0, + weight_length_quality_max_span_mm: float = 130.0, + weight_average_all_after_filter: bool = False, + weight_average_all_fallback_max_if_mean_over_g: float = 400.0, + weight_mean_pool_fallback_max_if_over_g: float = 440.0, save_raw_pointclouds=False, correct_tail_rotation=False, tail_rotation_distance_threshold=5.0, tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0, - max_cv_length=None, batch_svo_recursive=False, - weight_overlay_video: bool = False, - minute_interval_sec: float = 60.0): + max_cv_length=None, batch_svo_recursive=False): """Process all SVO2 files in a folder with pre-loaded YOLO and SAM models. Args: @@ -1973,6 +1840,13 @@ def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor weight_length_switch_mm=weight_length_switch_mm, weight_remove_outliers=weight_remove_outliers, weight_outlier_method=weight_outlier_method, + weight_max_length_mm=weight_max_length_mm, + weight_min_length_width_ratio=weight_min_length_width_ratio, + weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, + weight_length_quality_max_span_mm=weight_length_quality_max_span_mm, + weight_average_all_after_filter=weight_average_all_after_filter, + weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, + weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, save_raw_pointclouds=save_raw_pointclouds, correct_tail_rotation=correct_tail_rotation, tail_rotation_distance_threshold=tail_rotation_distance_threshold, @@ -1980,8 +1854,6 @@ def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor tail_rotation_min_angle=tail_rotation_min_angle, max_cv_length=max_cv_length, output_dir_stem=output_dir_stem, - weight_overlay_video=weight_overlay_video, - minute_interval_sec=minute_interval_sec, ) if success: @@ -2109,16 +1981,49 @@ def main(): parser.add_argument("--weight-outlier-method", type=str, default="iqr", choices=["iqr", "zscore"], help="Outlier detection method for weight estimation (default: iqr)") parser.add_argument( - "--weight-overlay-video", - action="store_true", - help="With --run-weight-estimation: defer preview mp4 until DGCNN runs, then burn in fish weight (g) per " - "track, top-5 masses, and per-window average (no YOLO confidence, no depth median on the preview).", + "--weight-max-length-mm", + type=float, + default=400.0, + help="Exclude PLYs with estimated length > this (mm) from final weight aggregation; " + "marked filtered_by_max_length in JSON. Use 0 to disable (default: 400).", ) parser.add_argument( - "--minute-interval-sec", + "--weight-min-length-width-ratio", type=float, - default=60.0, - help="Length of each time bucket (seconds) for the on-video per-window average line (default: 60).", + default=1.5, + help="Exclude PLYs with PCA length/width below this (non-fish blobs). 0 disables (default: 1.5).", + ) + parser.add_argument( + "--weight-length-quality-cv-threshold-pct", + type=float, + default=15.0, + help="Length variance hint: mark BAD if CV%% exceeds this (default: 15).", + ) + parser.add_argument( + "--weight-length-quality-max-span-mm", + type=float, + default=130.0, + help="Length variance hint: mark BAD if max(length)−min(length) exceeds this (mm). 0 disables (default: 130).", + ) + parser.add_argument( + "--weight-average-all-after-filter", + action=argparse.BooleanOptionalAction, + default=False, + help="Mean predicted weight over all PLYs after geometry filters and outlier removal (no top-K).", + ) + parser.add_argument( + "--weight-average-all-fallback-max-if-mean-over-g", + type=float, + default=400.0, + help="With average-all: if mean predicted weight (g) exceeds this, use max pred after filters instead; " + "0 disables (default: 400). Matches test_dgcnn --average-all-fallback-max-if-mean-over-g.", + ) + parser.add_argument( + "--weight-mean-pool-fallback-max-if-over-g", + type=float, + default=440.0, + help="If mean predicted weight over all filtered PLYs (g) exceeds this, use max pred after filters; " + "0 disables (default: 440). Matches test_dgcnn --mean-pool-fallback-max-if-over-g.", ) parser.add_argument("--save-raw-pointclouds", action="store_true", help="Save point clouds to raw_pc folder before passing to PointNet++ classifier. Useful for debugging why some point clouds fail classification.") @@ -2230,6 +2135,13 @@ def main(): weight_length_switch_mm=args.weight_length_switch_mm, weight_remove_outliers=args.weight_remove_outliers, weight_outlier_method=args.weight_outlier_method, + weight_max_length_mm=args.weight_max_length_mm, + weight_min_length_width_ratio=args.weight_min_length_width_ratio, + weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, + weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm, + weight_average_all_after_filter=args.weight_average_all_after_filter, + weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, + weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, save_raw_pointclouds=args.save_raw_pointclouds, correct_tail_rotation=args.correct_tail_rotation, tail_rotation_distance_threshold=args.tail_rotation_distance_threshold, @@ -2237,8 +2149,6 @@ def main(): tail_rotation_min_angle=args.tail_rotation_min_angle, max_cv_length=args.max_cv_length, batch_svo_recursive=args.batch_svo_recursive, - weight_overlay_video=args.weight_overlay_video, - minute_interval_sec=args.minute_interval_sec, ) return @@ -2354,44 +2264,54 @@ def main(): if args.save_raw_pointclouds: print(f"✓ Output raw point cloud folder (before classifier): {output_raw_pc_folder.resolve()}") - if use_svo: - process_single_svo2( - str(svo_path), - str(Path(args.save_output).expanduser().resolve()), - yolo_model, - sam_predictor, - sam_device, - conf=args.conf, - imgsz=args.imgsz, - max_frames=args.max_frames, - frame_stride=args.frame_stride, - save_images=args.save_images, - filter_pointcloud=args.filter_pointcloud, - use_clustering_filter=args.use_clustering_filter, - use_density_filter=args.use_density_filter, - pointcloud_classifier=pointcloud_classifier, - use_pointcloud_classifier=args.use_pointcloud_classifier, - pointcloud_classifier_threshold=args.pointcloud_classifier_threshold, - flatness_threshold=args.flatness_threshold, - use_flatness_filter=args.use_flatness_filter, - do_weight_estimation=args.run_weight_estimation, - weight_topk_length=args.weight_topk_length, - weight_top_k=args.weight_top_k, - weight_top_by_length=args.weight_top_by_length, - weight_length_switch_mm=args.weight_length_switch_mm, - weight_remove_outliers=args.weight_remove_outliers, - weight_outlier_method=args.weight_outlier_method, - save_raw_pointclouds=args.save_raw_pointclouds, - correct_tail_rotation=args.correct_tail_rotation, - tail_rotation_distance_threshold=args.tail_rotation_distance_threshold, - tail_rotation_min_tail_ratio=args.tail_rotation_min_tail_ratio, - tail_rotation_min_angle=args.tail_rotation_min_angle, - max_cv_length=args.max_cv_length, - output_dir_stem=None, - weight_overlay_video=args.weight_overlay_video, - minute_interval_sec=args.minute_interval_sec, - ) - return + # Check if output folder already exists and contains point clouds + # If so, skip data generation and directly run weight estimation + if output_base.exists() and output_cloud_folder.exists(): + # Check if there are point cloud files + point_cloud_files = list(output_cloud_folder.glob("*.ply")) + if point_cloud_files and args.run_weight_estimation: + print(f"\n{'='*60}") + print(f"Output folder already exists with {len(point_cloud_files)} point cloud files") + print(f"Skipping data generation, directly running weight estimation...") + print(f"{'='*60}") + + # Load weight estimator if not already loaded + if _weight_estimator_model is None: + ckpt_path = Path(args.weight_estimator_checkpoint).expanduser().resolve() + if not load_weight_estimator(str(ckpt_path), device=args.sam_device): + print("ERROR: Failed to load weight estimator.") + return + + # Run weight estimation directly + weight_output_dir = output_base / "weight_estimation" + results = run_weight_estimation( + cloud_folder=output_cloud_folder, + output_dir=weight_output_dir, + topk_length=args.weight_topk_length, + remove_outliers=args.weight_remove_outliers, + outlier_method=args.weight_outlier_method, + max_cv_length=args.max_cv_length, + verbose=True, + top_k=args.weight_top_k, + top_by_length=args.weight_top_by_length, + length_switch_to_weight_mm=args.weight_length_switch_mm, + max_length_mm=args.weight_max_length_mm, + min_length_width_ratio=args.weight_min_length_width_ratio, + length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, + length_quality_max_span_mm=args.weight_length_quality_max_span_mm, + average_all_after_filter=args.weight_average_all_after_filter, + average_all_fallback_to_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, + mean_pool_fallback_to_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, + ) + return + elif point_cloud_files and not args.run_weight_estimation: + print(f"\n{'='*60}") + print(f"Output folder already exists with {len(point_cloud_files)} point cloud files") + print(f"Weight estimation not requested (--run-weight-estimation not set)") + print(f"Skipping processing...") + print(f"{'='*60}") + return + # If folder exists but no point clouds, continue with normal processing else: window_name = "Fish Detection & Segmentation Preview" cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) @@ -2944,7 +2864,7 @@ def main(): fps = 10.0 # Frames per second # Create video writer - fourcc = get_h264_fourcc() + fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(str(video_path), fourcc, fps, (w, h)) for frame in video_frames: @@ -3016,6 +2936,13 @@ def main(): top_k=args.weight_top_k, top_by_length=args.weight_top_by_length, length_switch_to_weight_mm=args.weight_length_switch_mm, + max_length_mm=args.weight_max_length_mm, + min_length_width_ratio=args.weight_min_length_width_ratio, + length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, + length_quality_max_span_mm=args.weight_length_quality_max_span_mm, + average_all_after_filter=args.weight_average_all_after_filter, + average_all_fallback_to_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, + mean_pool_fallback_to_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, ) if results is None: diff --git a/FishMeasure/predict_weigth_from_svo2.py b/FishMeasure/predict_weigth_from_svo2.py old mode 100755 new mode 100644 index 4cbf8c5..1f53aa1 --- a/FishMeasure/predict_weigth_from_svo2.py +++ b/FishMeasure/predict_weigth_from_svo2.py @@ -102,26 +102,6 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol cmd.append("--use-flatness-filter") cmd.extend(["--flatness-threshold", str(args.flatness_threshold)]) - # 始终在 fish 内跑 DGCNN,生成 weight_estimation_results.json,预览视频才能叠加 weight/length; - # predict 后续会合并该 JSON,避免重复跑 test_dgcnn。 - wck = Path(args.weight_checkpoint).expanduser().resolve() - if wck.is_file(): - cmd.extend( - [ - "--run-weight-estimation", - "--weight-estimator-checkpoint", - str(wck), - ] - ) - if getattr(args, "fish_video_weight_overlay", False): - cmd.extend( - [ - "--weight-overlay-video", - "--minute-interval-sec", - str(getattr(args, "minute_interval_sec", 60.0)), - ] - ) - print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}") proc = subprocess.run(cmd, cwd=str(REPO_ROOT)) if proc.returncode != 0: @@ -144,6 +124,13 @@ def _run_test_dgcnn_weight_estimator_subprocess( remove_outliers: bool, outlier_method: str, labels_json: Optional[str], + weight_max_length_mm: float = 400.0, + weight_min_length_width_ratio: float = 1.5, + weight_length_quality_cv_threshold_pct: float = 15.0, + weight_length_quality_max_span_mm: float = 130.0, + weight_average_all_after_filter: bool = False, + weight_average_all_fallback_max_if_mean_over_g: float = 400.0, + weight_mean_pool_fallback_max_if_over_g: float = 440.0, ) -> Tuple[Path, Dict[str, Any]]: if (cloud_dir is None) == (ply_list_file is None): raise ValueError("Exactly one of cloud_dir or ply_list_file must be set") @@ -187,6 +174,22 @@ def _run_test_dgcnn_weight_estimator_subprocess( if remove_outliers: cmd.append("--remove-outliers") cmd.extend(["--outlier-method", outlier_method]) + cmd.extend(["--max-length-mm", str(weight_max_length_mm)]) + cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)]) + cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)]) + cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)]) + if weight_average_all_after_filter: + cmd.append("--average-all-after-filter") + fb_thr = weight_average_all_fallback_max_if_mean_over_g + if fb_thr is not None and float(fb_thr) > 0: + cmd.extend( + ["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))] + ) + mp_fb = weight_mean_pool_fallback_max_if_over_g + if mp_fb is not None and float(mp_fb) > 0: + cmd.extend( + ["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))] + ) print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}") proc = subprocess.run(cmd, cwd=str(REPO_ROOT)) @@ -215,6 +218,9 @@ def _merge_weight_prediction_json( skipped = bool(summary.get("skipped")) avg_kg = summary.get("avg_predicted_weight_kg") avg_g = summary.get("avg_predicted_weight_g") + pred_kg = summary.get("pred_weight_kg") + pred_g = summary.get("pred_weight_g") + pred_rule = summary.get("pred_weight_rule") return { "svo": str(svo_path), "svo_name": svo_name, @@ -228,6 +234,9 @@ def _merge_weight_prediction_json( "num_ply_predicted": summary.get("num_files_predicted"), "avg_predicted_weight_kg": None if skipped else avg_kg, "avg_predicted_weight_g": None if skipped else avg_g, + "pred_weight_g": None if skipped else pred_g, + "pred_weight_kg": None if skipped else pred_kg, + "pred_weight_rule": None if skipped else pred_rule, "dgcnn_meta": dgcnn_data.get("meta"), "dgcnn_summary": summary, "weight_summary": summary, @@ -248,6 +257,9 @@ def _merge_weight_prediction_json_combined( skipped = bool(summary.get("skipped")) avg_kg = summary.get("avg_predicted_weight_kg") avg_g = summary.get("avg_predicted_weight_g") + pred_kg = summary.get("pred_weight_kg") + pred_g = summary.get("pred_weight_g") + pred_rule = summary.get("pred_weight_rule") cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths] n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs) return { @@ -265,6 +277,9 @@ def _merge_weight_prediction_json_combined( "num_ply_predicted": summary.get("num_files_predicted"), "avg_predicted_weight_kg": None if skipped else avg_kg, "avg_predicted_weight_g": None if skipped else avg_g, + "pred_weight_g": None if skipped else pred_g, + "pred_weight_kg": None if skipped else pred_kg, + "pred_weight_rule": None if skipped else pred_rule, "dgcnn_meta": dgcnn_data.get("meta"), "dgcnn_summary": summary, "weight_summary": summary, @@ -297,7 +312,13 @@ def run_weight_prediction_for_svo( weight_outlier_method: str, weight_xyz_scale: float, weight_labels_json: Optional[str], - force_dgcnn_subprocess: bool = False, + weight_max_length_mm: float = 400.0, + weight_min_length_width_ratio: float = 1.5, + weight_length_quality_cv_threshold_pct: float = 15.0, + weight_length_quality_max_span_mm: float = 130.0, + weight_average_all_after_filter: bool = False, + weight_average_all_fallback_max_if_mean_over_g: float = 400.0, + weight_mean_pool_fallback_max_if_over_g: float = 440.0, ) -> Dict[str, Any]: svo_path = svo_path.expanduser().resolve() if not svo_path.exists(): @@ -314,28 +335,6 @@ def run_weight_prediction_for_svo( f"fish_video_weight_evaluation.py first." ) - fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json" - if not force_dgcnn_subprocess and fish_wj.is_file(): - try: - dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8")) - if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"): - print(f"Using existing DGCNN results from fish_video: {fish_wj}") - result = _merge_weight_prediction_json( - svo_path=svo_path, - svo_name=svo_name, - out_dir=out_dir, - cloud_dir=cloud_dir, - dgcnn_json_path=fish_wj, - dgcnn_data=dgcnn_data, - ) - (out_dir / "weight_prediction.json").write_text( - json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False), - encoding="utf-8", - ) - return result - except Exception as e: - print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}") - dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess( cloud_dir=cloud_dir, ply_list_file=None, @@ -351,6 +350,13 @@ def run_weight_prediction_for_svo( remove_outliers=weight_remove_outliers, outlier_method=weight_outlier_method, labels_json=weight_labels_json, + weight_max_length_mm=weight_max_length_mm, + weight_min_length_width_ratio=weight_min_length_width_ratio, + weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, + weight_length_quality_max_span_mm=weight_length_quality_max_span_mm, + weight_average_all_after_filter=weight_average_all_after_filter, + weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, + weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, ) result = _merge_weight_prediction_json( svo_path=svo_path, @@ -381,6 +387,13 @@ def run_weight_prediction_combined_svos( weight_outlier_method: str, weight_xyz_scale: float, weight_labels_json: Optional[str], + weight_max_length_mm: float = 400.0, + weight_min_length_width_ratio: float = 1.5, + weight_length_quality_cv_threshold_pct: float = 15.0, + weight_length_quality_max_span_mm: float = 130.0, + weight_average_all_after_filter: bool = False, + weight_average_all_fallback_max_if_mean_over_g: float = 400.0, + weight_mean_pool_fallback_max_if_over_g: float = 440.0, ) -> Dict[str, Any]: """One DGCNN run over all ``//cloud/*.ply`` (top-K / by-length applies to the union).""" output_base = output_base.expanduser().resolve() @@ -411,6 +424,13 @@ def run_weight_prediction_combined_svos( remove_outliers=weight_remove_outliers, outlier_method=weight_outlier_method, labels_json=weight_labels_json, + weight_max_length_mm=weight_max_length_mm, + weight_min_length_width_ratio=weight_min_length_width_ratio, + weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, + weight_length_quality_max_span_mm=weight_length_quality_max_span_mm, + weight_average_all_after_filter=weight_average_all_after_filter, + weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, + weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, ) result = _merge_weight_prediction_json_combined( svo_paths=svo_paths, @@ -502,6 +522,48 @@ def main() -> None: ) parser.add_argument("--weight-xyz-scale", type=float, default=0.001) parser.add_argument("--weight-labels-json", type=str, default=None) + parser.add_argument( + "--weight-max-length-mm", + type=float, + default=400.0, + help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).", + ) + parser.add_argument( + "--weight-min-length-width-ratio", + type=float, + default=1.5, + help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.", + ) + parser.add_argument( + "--weight-average-all-after-filter", + action=argparse.BooleanOptionalAction, + default=False, + help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).", + ) + parser.add_argument( + "--weight-average-all-fallback-max-if-mean-over-g", + type=float, + default=400.0, + help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).", + ) + parser.add_argument( + "--weight-mean-pool-fallback-max-if-over-g", + type=float, + default=440.0, + help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).", + ) + parser.add_argument( + "--weight-length-quality-cv-threshold-pct", + type=float, + default=15.0, + help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).", + ) + parser.add_argument( + "--weight-length-quality-max-span-mm", + type=float, + default=130.0, + help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.", + ) parser.add_argument( "--reuse-existing-clouds", @@ -511,24 +573,6 @@ def main() -> None: ) parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds") - parser.add_argument( - "--fish-video-weight-overlay", - action="store_true", - help="Extra on-video header lines (top-5 / per-minute bucket). " - "DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.", - ) - parser.add_argument( - "--minute-interval-sec", - type=float, - default=60.0, - help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).", - ) - parser.add_argument( - "--force-dgcnn-subprocess", - action="store_true", - help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.", - ) - args = parser.parse_args() if args.frame_stride < 1: raise SystemExit("--frame-stride must be >= 1") @@ -586,6 +630,13 @@ def main() -> None: weight_outlier_method=args.weight_outlier_method, weight_xyz_scale=args.weight_xyz_scale, weight_labels_json=args.weight_labels_json, + weight_max_length_mm=args.weight_max_length_mm, + weight_min_length_width_ratio=args.weight_min_length_width_ratio, + weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, + weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm, + weight_average_all_after_filter=args.weight_average_all_after_filter, + weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, + weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, ) ) except Exception as e: @@ -612,7 +663,13 @@ def main() -> None: weight_outlier_method=args.weight_outlier_method, weight_xyz_scale=args.weight_xyz_scale, weight_labels_json=args.weight_labels_json, - force_dgcnn_subprocess=args.force_dgcnn_subprocess, + weight_max_length_mm=args.weight_max_length_mm, + weight_min_length_width_ratio=args.weight_min_length_width_ratio, + weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, + weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm, + weight_average_all_after_filter=args.weight_average_all_after_filter, + weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, + weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, ) ) except Exception as e: @@ -628,8 +685,11 @@ def main() -> None: if len(results) == 1 and "error" not in results[0]: r0 = results[0] avg_g = r0.get("avg_predicted_weight_g") - if avg_g is not None: - print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g") + pred_g = r0.get("pred_weight_g") + out_g = pred_g if pred_g is not None else avg_g + if out_g is not None: + label = "pred_weight" if pred_g is not None else "avg_predicted_weight" + print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g") else: print("Final predicted weight: N/A (skipped or no valid point clouds)") diff --git a/FishMeasure/run_predict_from_svo2_fish9.sh b/FishMeasure/run_predict_from_svo2_fish9.sh index 89ab6b2..ed11636 100755 --- a/FishMeasure/run_predict_from_svo2_fish9.sh +++ b/FishMeasure/run_predict_from_svo2_fish9.sh @@ -10,7 +10,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "$SCRIPT_DIR" SESSION_ROOT="/home/ubuntu/data/fish/2016-1-22-last" -FISH_NAME="fish17" +FISH_NAME="fish9" fish_dir="${SESSION_ROOT}/${FISH_NAME}/" OUT_PARENT="output_weight_estimator" save_out="${OUT_PARENT}/${FISH_NAME}" @@ -37,7 +37,7 @@ python3 predict_weigth_from_svo2.py \ --weight-checkpoint weight_estimator/runs/dgcnn_20260312_171043/best.pt \ --save-output "$save_out" \ --yolo-model "/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt" \ - --conf 0.5 \ + --conf 0.8 \ --imgsz 640 \ --sam-device cuda \ --max-frames 0 \ @@ -50,4 +50,5 @@ python3 predict_weigth_from_svo2.py \ --flatness-threshold 55.0 \ --frame-stride 1 \ --weight-top-k 5 \ - --weight-top-by-length + --weight-top-by-length +# Optional: append --no-weight-top-by-length if you want top-K by predicted weight only. diff --git a/FishMeasure/weight_estimator/test_dgcc.sh b/FishMeasure/weight_estimator/test_dgcc.sh old mode 100755 new mode 100644 index 644c8c2..c3c87f9 --- a/FishMeasure/weight_estimator/test_dgcc.sh +++ b/FishMeasure/weight_estimator/test_dgcc.sh @@ -1,5 +1,6 @@ -python test_dgcnn_weight_estimator.py --checkpoint runs/dgcnn_20260312_171043/best.pt\ - --batch-root '/home/ubuntu/projects/FishMeasure/output_weight_estimator' --top-k=5 +python test_dgcnn_weight_estimator.py --checkpoint runs/dgcnn_20260312_171043/best.pt \ + --batch-root '/home/ubuntu/projects/FishMeasure/output_weight_estimator' --top-k=5 --top-by-length #--average-all-after-filter + diff --git a/FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py b/FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py old mode 100755 new mode 100644 index 448beb3..437ff07 --- a/FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py +++ b/FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py @@ -43,6 +43,90 @@ for _p in (str(REPO_ROOT), str(WEIGHT_EST_DIR)): sys.path.insert(0, _p) +def _mean_length_top_k_longest(candidates: List[Dict], top_k: int) -> float: + """Mean length_input over the K longest clouds (by length) among candidates; independent of weight aggregation.""" + if not candidates: + return float("nan") + sorted_by_len = sorted( + candidates, + key=lambda d: float(d.get("length_input", float("-inf"))), + reverse=True, + ) + take = min(int(top_k), len(sorted_by_len)) + sub = sorted_by_len[:take] + lengths = [ + float(it["length_input"]) + for it in sub + if np.isfinite(float(it.get("length_input", float("nan")))) + ] + return float(np.mean(lengths)) if lengths else float("nan") + + +def _final_pred_g_for_compare(summary: Dict[str, Any]) -> float: + """Adopted final prediction (fallback rules); aligns with summary pred_weight_g.""" + pg = summary.get("pred_weight_g") + if pg is not None and np.isfinite(float(pg)): + return float(pg) + return float(summary["avg_predicted_weight_g"]) + + +def _mean_column_g_for_log(summary: Dict[str, Any]) -> float: + """ + First column in batch/single-folder logs: always the arithmetic mean over **all** post-filter, + post-outlier candidates (``mean_all_pred_g_after_filters``), irrelevant of whether the final + ``pred`` uses top-K, mean-all, or max-fallback. Falls back to ``avg_predicted_weight_g`` if missing. + """ + v = summary.get("mean_all_pred_g_after_filters") + if v is not None and np.isfinite(float(v)): + return float(v) + ag = summary.get("avg_predicted_weight_g") + if ag is not None and np.isfinite(float(ag)): + return float(ag) + return float("nan") + + +def _batch_avg_leader(summary: Dict[str, Any], top_k: int, fallback_by_length: bool) -> str: + """First column prefix: always ``avg=`` (mean of all filtered candidates).""" + return "avg=" + + +def _batch_aggregate_log_suffix(summary: Dict[str, Any]) -> str: + """Batch log: PLY counts, top-K means, max_pred_g, pred.""" + parts: List[str] = [] + nu = summary.get("num_files_used_for_avg") + nt = summary.get("num_files_predicted") + if nu is not None and nt is not None: + try: + parts.append(f"n_ply={int(nu)}/{int(nt)}") + except (TypeError, ValueError): + pass + tk = int(summary.get("top_k") or 5) + topk_g = summary.get("avg_topk_mean_pred_g") + topk_sel = summary.get("avg_topk_mean_pred_selection") or "pred" + if topk_g is not None and np.isfinite(float(topk_g)): + by = "length" if topk_sel == "by_length" else "pred" + parts.append(f"top{tk}_avg={float(topk_g):.2f} g by {by}") + + wavg = summary.get("avg_topk_mean_pred_g_top_by_weight") + if wavg is not None and np.isfinite(float(wavg)): + parts.append(f"top{tk}_w_avg={float(wavg):.2f} g") + + mx = summary.get("max_predicted_weight_g_after_filter") + if mx is not None and np.isfinite(float(mx)): + parts.append(f"max_pred_g={float(mx):.2f} g") + + parts.append(f"pred={_final_pred_g_for_compare(summary):.2f} g") + return " | " + " | ".join(parts) if parts else "" + + +def _print_max_weight_after_filter(summary: Dict[str, Any], *, prefix: str = "") -> None: + """Print max predicted weight over candidates after geometry filters and optional outlier removal.""" + mx = summary.get("max_predicted_weight_g_after_filter") + if mx is not None and np.isfinite(float(mx)): + mk = float(mx) / 1000.0 + print(f"{prefix}Max predicted weight (after filters): {float(mx):.2f} g ({mk:.4f} kg)") + + def _format_length_input_display(value: Optional[float], xyz_scale: float) -> str: if value is None or not np.isfinite(float(value)): return "nan" @@ -53,12 +137,73 @@ def _format_length_input_display(value: Optional[float], xyz_scale: float) -> st def _summary_top_label(summary: Dict[str, Any], top_k: int, fallback_by_length: bool) -> str: + if summary.get("average_all_after_filter"): + return "mean of all after filters" eff = summary.get("effective_top_by_length") if eff is None: eff = fallback_by_length return f"top{top_k} by length" if eff else f"top{top_k} by pred" +def _topk_mean_prediction_g_for_display( + candidates_for_avg: List[Dict], + top_k: int, + top_by_length: bool, + length_switch_to_weight_mm: float, +) -> Tuple[float, str]: + """ + Mean predicted weight over the top-K cloud subset (length vs weight ranking + length switch), + always the arithmetic mean of up to K values. For comparison when aggregation is mean-all. + + Note: top-K **by length** is NOT top-K by weight — long thin clouds can have lower pred mass + than shorter heavier ones, so this mean can be **below** mean(all); see also + avg_topk_mean_pred_g_top_by_weight in summary. + """ + if not candidates_for_avg: + return float("nan"), "by_pred" + sorted_by_length = sorted( + candidates_for_avg, + key=lambda d: float(d.get("length_input", float("-inf"))), + reverse=True, + ) + selected_topk_by_length = sorted_by_length[: min(top_k, len(sorted_by_length))] + lengths_of_topk_by_length = [ + float(it["length_input"]) + for it in selected_topk_by_length + if np.isfinite(float(it.get("length_input", float("nan")))) + ] + avg_length_input_topk_by_length = ( + float(np.mean(lengths_of_topk_by_length)) if lengths_of_topk_by_length else float("nan") + ) + sorted_by_weight = sorted( + candidates_for_avg, + key=lambda d: float(d.get("predicted_weight_g", float("-inf"))), + reverse=True, + ) + selected_topk_by_weight = sorted_by_weight[: min(top_k, len(sorted_by_weight))] + + selection: str + if top_by_length: + switch = ( + np.isfinite(avg_length_input_topk_by_length) + and avg_length_input_topk_by_length > float(length_switch_to_weight_mm) + ) + if switch: + selected = selected_topk_by_weight + selection = "by_pred" + else: + selected = selected_topk_by_length + selection = "by_length" + else: + selected = selected_topk_by_weight + selection = "by_pred" + + preds = [float(it["predicted_weight_g"]) for it in selected] + if not preds: + return float("nan"), selection + return float(np.mean(preds)), selection + + def load_points_from_ply(ply_path: Path) -> np.ndarray: pcd = o3d.io.read_point_cloud(str(ply_path)) if len(pcd.points) > 0: @@ -83,18 +228,30 @@ def sample_points_deterministic(points: np.ndarray, num_points: int, seed: int) return points[idx] -def estimate_length_major_axis(points: np.ndarray) -> float: +def estimate_pca_extents_mm(points: np.ndarray) -> Tuple[float, float, float]: + """ + PCA axis extents (same units as input coordinates, e.g. mm). + length = 1st principal axis span, width = 2nd, height = 3rd. + """ if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3: - return float("nan") + return float("nan"), float("nan"), float("nan") pts = points[:, :3].astype(np.float32, copy=False) pts = pts - pts.mean(axis=0, keepdims=True) try: _u, _s, vt = np.linalg.svd(pts, full_matrices=False) - axis = vt[0] - proj = pts @ axis - return float(np.max(proj) - np.min(proj)) + extents: List[float] = [] + for i in range(3): + axis = vt[i] + proj = pts @ axis + extents.append(float(np.max(proj) - np.min(proj))) + return extents[0], extents[1], extents[2] except Exception: - return float("nan") + return float("nan"), float("nan"), float("nan") + + +def estimate_length_major_axis(points: np.ndarray) -> float: + L, _w, _h = estimate_pca_extents_mm(points) + return L def filter_outliers_iqr( @@ -177,6 +334,73 @@ def filter_outliers_zscore( return filtered, outliers, stats +def analyze_length_variance_quality( + lengths_mm: List[float], + cv_length_pct: float, + *, + cv_threshold_pct: float = 15.0, + max_span_mm: Optional[float] = 130.0, + min_samples: int = 3, +) -> Tuple[str, List[str], str, Dict[str, float]]: + """ + Classify length spread among frames (after max-length filter & optional outlier removal). + + Returns: + quality: "good" | "bad" | "unknown" + bad_reasons: human-readable reasons when bad + hint: one-line summary for logs / JSON + meta: length_min_mm, length_max_mm, length_span_mm, thresholds used + """ + meta: Dict[str, float] = { + "length_quality_cv_threshold_pct": float(cv_threshold_pct), + } + fin = [float(x) for x in lengths_mm if np.isfinite(float(x))] + n = len(fin) + if n < min_samples: + hint = ( + f"Length variance: unknown (only {n} sample(s); need ≥{min_samples} for a quality hint)." + ) + return "unknown", [], hint, meta + + mn = float(np.min(fin)) + mx = float(np.max(fin)) + span = mx - mn + mean_l = float(np.mean(fin)) + meta["length_min_mm"] = mn + meta["length_max_mm"] = mx + meta["length_span_mm"] = span + meta["length_mean_mm"] = mean_l + + bad_reasons: List[str] = [] + if max_span_mm is not None and float(max_span_mm) > 0: + meta["length_quality_max_span_mm_threshold"] = float(max_span_mm) + + if np.isfinite(cv_length_pct): + meta["cv_length_pct"] = float(cv_length_pct) + if cv_length_pct > cv_threshold_pct: + bad_reasons.append( + f"CV {cv_length_pct:.1f}% > {cv_threshold_pct:.1f}% (length estimates vary a lot across frames)" + ) + if max_span_mm is not None and float(max_span_mm) > 0: + if span > float(max_span_mm): + bad_reasons.append( + f"span {span:.1f} mm (max−min) > {float(max_span_mm):.1f} mm (poses/scales differ strongly)" + ) + + if bad_reasons: + hint = ( + "Length variance: BAD — estimated weight may be unreliable. " + " ".join(bad_reasons) + ) + return "bad", bad_reasons, hint, meta + + cv_s = f"{float(cv_length_pct):.1f}%" if np.isfinite(cv_length_pct) else "n/a" + hint = ( + f"Length variance: GOOD — spread is moderate (CV {cv_s}, span {span:.1f} mm, " + f"n={n})." + ) + return "good", [], hint, meta + + @torch.no_grad() def _predict_folder_impl( model: torch.nn.Module, @@ -194,14 +418,44 @@ def _predict_folder_impl( outlier_field: str = "length_input", iqr_factor: float = 1.5, zscore_threshold: float = 2.5, + max_length_mm: Optional[float] = None, + min_length_width_ratio: Optional[float] = 1.5, + length_quality_cv_threshold_pct: float = 15.0, + length_quality_max_span_mm: Optional[float] = 130.0, + average_all_after_filter: bool = False, + average_all_fallback_to_max_if_mean_over_g: Optional[float] = 400.0, + mean_pool_fallback_to_max_if_over_g: Optional[float] = 440.0, ) -> Tuple[List[Dict], Dict]: """ Predict weights for a folder of PLY files using DGCNN. Input format: (B, N, 3) — DGCNN transposes internally. + + If max_length_mm is set (>0), PLYs with finite length_input greater than this + (mm) are excluded from aggregation and marked filtered_by_max_length. + + If min_length_width_ratio is set (>0), PLYs with PCA length/width below this + threshold are excluded (filtered_by_length_width_ratio). Width = 2nd PCA axis extent. + + If average_all_after_filter is True, the final weight is normally the mean predicted weight + over all remaining PLYs after geometry filters and optional outlier removal (no top-K). + If that mean exceeds average_all_fallback_to_max_if_mean_over_g (default 400 g), the final + weight uses max_predicted_weight_g_after_filter instead (see pred_weight_rule). + + Independently, if the mean predicted weight over **all** post-filter candidates + (``avg_g_filtered``) exceeds mean_pool_fallback_to_max_if_mean_over_g (default 440 g), + the final prediction uses max after filter — applies to top-K and mean-all modes. Use 0 to disable. + + length_quality_*: heuristic good/bad hint from CV and (max−min) span on the + post-filter length pool (see analyze_length_variance_quality). Use + length_quality_max_span_mm=0 to disable the span rule. """ model.eval() per_file: List[Dict] = [] + span_cap_q: Optional[float] = None if ( + length_quality_max_span_mm is not None and float(length_quality_max_span_mm) <= 0 + ) else length_quality_max_span_mm + if not ply_files: summary_skipped = { "num_files": 0, @@ -210,15 +464,35 @@ def _predict_folder_impl( "skip_reason": "no .ply files", "avg_predicted_weight_kg": float("nan"), "avg_predicted_weight_g": float("nan"), + "mean_all_pred_g_after_filters": float("nan"), + "avg_topk_mean_pred_g": float("nan"), + "avg_topk_mean_pred_selection": None, + "avg_topk_mean_pred_g_top_by_weight": float("nan"), "avg_length_input_topk": float("nan"), + "length_variance_quality": "unknown", + "length_variance_bad_reasons": [], + "length_variance_hint": "Length variance: unknown (no .ply files).", + "length_quality_cv_threshold_pct": float(length_quality_cv_threshold_pct), + "near_max_length_band_mm": 30.0, + "max_length_mm_among_filtered": float("nan"), + "count_in_near_max_length_band": 0, + "num_filtered_for_length_band_stats": 0, + "fraction_in_near_max_length_band": float("nan"), } + if span_cap_q is not None: + summary_skipped["length_quality_max_span_mm_threshold"] = float(span_cap_q) return [], summary_skipped preds_kg: List[float] = [] for ply in ply_files: pts = load_points_from_ply(ply) - length_input = estimate_length_major_axis(pts) + length_input, width_input, height_input = estimate_pca_extents_mm(pts) + w_safe = float(width_input) + if np.isfinite(w_safe) and w_safe > 1e-9: + length_width_ratio = float(length_input) / w_safe + else: + length_width_ratio = float("nan") pts = pts * float(xyz_scale) pts = pts - pts.mean(axis=0, keepdims=True) @@ -239,25 +513,87 @@ def _predict_folder_impl( "predicted_weight_kg": pred_kg, "predicted_weight_g": pred_g, "length_input": float(length_input), + "width_input": float(width_input), + "height_input": float(height_input), + "length_width_ratio": float(length_width_ratio), "length_after_scale": float(length_input * float(xyz_scale)) if np.isfinite(length_input) else float("nan"), "is_outlier": False, + "filtered_by_max_length": False, + "filtered_by_length_width_ratio": False, }) + length_cap: Optional[float] = None + if max_length_mm is not None and float(max_length_mm) > 0: + length_cap = float(max_length_mm) + for it in per_file: + L = float(it.get("length_input", float("nan"))) + if np.isfinite(L) and L > length_cap: + it["filtered_by_max_length"] = True + + lw_min: Optional[float] = None + if min_length_width_ratio is not None and float(min_length_width_ratio) > 0: + lw_min = float(min_length_width_ratio) + for it in per_file: + if it.get("filtered_by_max_length"): + continue + r = float(it.get("length_width_ratio", float("nan"))) + if not np.isfinite(r) or r < lw_min: + it["filtered_by_length_width_ratio"] = True + + pool = [ + it + for it in per_file + if not it.get("filtered_by_max_length", False) and not it.get("filtered_by_length_width_ratio", False) + ] + num_filtered_by_max_length = sum(1 for it in per_file if it.get("filtered_by_max_length")) + num_filtered_by_length_width_ratio = sum(1 for it in per_file if it.get("filtered_by_length_width_ratio")) + + if not pool: + summary_skipped = { + "num_files": len(ply_files), + "num_files_predicted": len(per_file), + "skipped": True, + "skip_reason": "no PLYs after length / length-width filters", + "num_files_filtered_by_max_length": int(num_filtered_by_max_length), + "num_files_filtered_by_length_width_ratio": int(num_filtered_by_length_width_ratio), + "max_length_mm": length_cap, + "min_length_width_ratio": lw_min, + "avg_predicted_weight_kg": float("nan"), + "avg_predicted_weight_g": float("nan"), + "mean_all_pred_g_after_filters": float("nan"), + "avg_topk_mean_pred_g": float("nan"), + "avg_topk_mean_pred_selection": None, + "avg_topk_mean_pred_g_top_by_weight": float("nan"), + "avg_length_input_topk": float("nan"), + "length_variance_quality": "unknown", + "length_variance_bad_reasons": [], + "length_variance_hint": "Length variance: unknown (no PLYs left after length / length-width filters).", + "length_quality_cv_threshold_pct": float(length_quality_cv_threshold_pct), + "near_max_length_band_mm": 30.0, + "max_length_mm_among_filtered": float("nan"), + "count_in_near_max_length_band": 0, + "num_filtered_for_length_band_stats": 0, + "fraction_in_near_max_length_band": float("nan"), + } + if span_cap_q is not None: + summary_skipped["length_quality_max_span_mm_threshold"] = float(span_cap_q) + return per_file, summary_skipped + outlier_stats: Optional[Dict] = None - candidates_for_avg = per_file + candidates_for_avg = pool num_outliers_removed = 0 - if remove_outliers and per_file: + if remove_outliers and pool: if outlier_method == "iqr": filtered, outliers, outlier_stats = filter_outliers_iqr( - per_file, field=outlier_field, iqr_factor=iqr_factor + pool, field=outlier_field, iqr_factor=iqr_factor ) elif outlier_method == "zscore": filtered, outliers, outlier_stats = filter_outliers_zscore( - per_file, field=outlier_field, zscore_threshold=zscore_threshold + pool, field=outlier_field, zscore_threshold=zscore_threshold ) else: - filtered, outliers = per_file, [] + filtered, outliers = pool, [] outlier_stats = {"error": f"unknown method: {outlier_method}"} outlier_plys = {it["ply"] for it in outliers} for it in per_file: @@ -266,63 +602,106 @@ def _predict_folder_impl( candidates_for_avg = filtered num_outliers_removed = len(outliers) + avg_topk_mean_disp_g, avg_topk_mean_disp_sel = _topk_mean_prediction_g_for_display( + candidates_for_avg, + top_k, + top_by_length, + length_switch_to_weight_mm, + ) + + avg_topk_mean_top_by_weight_g = float("nan") + if candidates_for_avg: + sorted_w = sorted( + candidates_for_avg, + key=lambda d: float(d.get("predicted_weight_g", float("-inf"))), + reverse=True, + ) + sel_w = sorted_w[: min(top_k, len(sorted_w))] + if sel_w: + avg_topk_mean_top_by_weight_g = float( + np.mean([float(x["predicted_weight_g"]) for x in sel_w]) + ) + + # Mean length of the K longest clouds (by length), always; differs from mean(all) when N > K. + avg_length_mean_top_k_longest = _mean_length_top_k_longest(candidates_for_avg, top_k) + avg_kg_all = float(np.mean(preds_kg)) if preds_kg else float("nan") avg_g_all = avg_kg_all * 1000.0 if preds_kg else float("nan") filtered_preds_g = [float(it["predicted_weight_g"]) for it in candidates_for_avg] + max_g_after_filter = float(np.max(filtered_preds_g)) if filtered_preds_g else float("nan") + max_kg_after_filter = max_g_after_filter / 1000.0 if filtered_preds_g else float("nan") avg_kg_filtered = float(np.mean(filtered_preds_g)) / 1000.0 if filtered_preds_g else float("nan") avg_g_filtered = float(np.mean(filtered_preds_g)) if filtered_preds_g else float("nan") - sorted_by_length = sorted( - candidates_for_avg, - key=lambda d: float(d.get("length_input", float("-inf"))), - reverse=True, - ) - selected_topk_by_length = sorted_by_length[: min(top_k, len(sorted_by_length))] - lengths_of_topk_by_length = [ - float(it["length_input"]) - for it in selected_topk_by_length - if np.isfinite(float(it.get("length_input", float("nan")))) - ] - avg_length_input_topk_by_length = ( - float(np.mean(lengths_of_topk_by_length)) if lengths_of_topk_by_length else float("nan") - ) - - sorted_by_weight = sorted( - candidates_for_avg, - key=lambda d: float(d.get("predicted_weight_g", float("-inf"))), - reverse=True, - ) - selected_topk_by_weight = sorted_by_weight[: min(top_k, len(sorted_by_weight))] - - switched_to_weight_due_to_long_length = False - if top_by_length: - switch = ( - np.isfinite(avg_length_input_topk_by_length) - and avg_length_input_topk_by_length > float(length_switch_to_weight_mm) + if average_all_after_filter: + selected_topk = list(candidates_for_avg) + switched_to_weight_due_to_long_length = False + effective_top_by_length = False + lengths_for_switch = [ + float(it["length_input"]) + for it in candidates_for_avg + if np.isfinite(float(it.get("length_input", float("nan")))) + ] + avg_length_input_topk_by_length = ( + float(np.mean(lengths_for_switch)) if lengths_for_switch else float("nan") ) - if switch: + preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk] + use_max_instead_of_mean = False + avg_g_topk = float(np.mean(preds_g_topk)) if preds_g_topk else float(avg_g_all) + num_used_for_avg = len(selected_topk) + plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk] + else: + sorted_by_length = sorted( + candidates_for_avg, + key=lambda d: float(d.get("length_input", float("-inf"))), + reverse=True, + ) + selected_topk_by_length = sorted_by_length[: min(top_k, len(sorted_by_length))] + lengths_of_topk_by_length = [ + float(it["length_input"]) + for it in selected_topk_by_length + if np.isfinite(float(it.get("length_input", float("nan")))) + ] + avg_length_input_topk_by_length = ( + float(np.mean(lengths_of_topk_by_length)) if lengths_of_topk_by_length else float("nan") + ) + + sorted_by_weight = sorted( + candidates_for_avg, + key=lambda d: float(d.get("predicted_weight_g", float("-inf"))), + reverse=True, + ) + selected_topk_by_weight = sorted_by_weight[: min(top_k, len(sorted_by_weight))] + + switched_to_weight_due_to_long_length = False + if top_by_length: + switch = ( + np.isfinite(avg_length_input_topk_by_length) + and avg_length_input_topk_by_length > float(length_switch_to_weight_mm) + ) + if switch: + effective_top_by_length = False + selected_topk = selected_topk_by_weight + switched_to_weight_due_to_long_length = True + else: + effective_top_by_length = True + selected_topk = selected_topk_by_length + else: effective_top_by_length = False selected_topk = selected_topk_by_weight - switched_to_weight_due_to_long_length = True + preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk] + use_max_instead_of_mean = len(candidates_for_avg) < 5 + if preds_g_topk: + avg_g_topk = ( + float(np.max(preds_g_topk)) + if use_max_instead_of_mean + else float(np.mean(preds_g_topk)) + ) else: - effective_top_by_length = True - selected_topk = selected_topk_by_length - else: - effective_top_by_length = False - selected_topk = selected_topk_by_weight - preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk] - use_max_instead_of_mean = len(candidates_for_avg) < 5 - if preds_g_topk: - avg_g_topk = ( - float(np.max(preds_g_topk)) - if use_max_instead_of_mean - else float(np.mean(preds_g_topk)) - ) - else: - avg_g_topk = float(avg_g_all) - num_used_for_avg = len(selected_topk) - plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk] + avg_g_topk = float(avg_g_all) + num_used_for_avg = len(selected_topk) + plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk] lengths_topk = [ float(it["length_input"]) @@ -336,12 +715,78 @@ def _predict_folder_impl( for it in candidates_for_avg if np.isfinite(float(it.get("length_input", float("nan")))) ] + # Largest-length group: PLYs within 3 cm (30 mm) of max length among post-filter candidates. + near_max_length_band_mm = 30.0 + max_len_among_filtered = float(max(lengths)) if lengths else float("nan") + if lengths: + thr_len = max_len_among_filtered - near_max_length_band_mm + count_in_near_max_length_band = sum(1 for L in lengths if L >= thr_len) + fraction_in_near_max_length_band = float(count_in_near_max_length_band) / float(len(lengths)) + else: + count_in_near_max_length_band = 0 + fraction_in_near_max_length_band = float("nan") + avg_len_input = float(np.mean(lengths)) if lengths else float("nan") std_len_input = float(np.std(lengths)) if len(lengths) > 1 else 0.0 cv_length = float(std_len_input / avg_len_input * 100.0) if avg_len_input > 0 else float("nan") - avg_g = float(avg_g_topk) if np.isfinite(float(avg_g_topk)) else float(avg_g_all) - avg_kg = avg_g / 1000.0 + lq_quality, lq_bad_reasons, lq_hint, lq_meta = analyze_length_variance_quality( + lengths, + cv_length, + cv_threshold_pct=float(length_quality_cv_threshold_pct), + max_span_mm=span_cap_q, + ) + + base_g = float(avg_g_topk) if np.isfinite(float(avg_g_topk)) else float(avg_g_all) + mean_all_candidates_g = float("nan") + if average_all_after_filter: + mean_all_candidates_g = base_g + + fb_thr = average_all_fallback_to_max_if_mean_over_g + use_max_fallback = ( + average_all_after_filter + and fb_thr is not None + and float(fb_thr) > 0.0 + and np.isfinite(base_g) + and base_g > float(fb_thr) + and np.isfinite(max_g_after_filter) + ) + if use_max_fallback: + avg_g = float(max_g_after_filter) + avg_kg = avg_g / 1000.0 + prediction_aggregate_eff = "max_after_filter_high_mean_all" + elif average_all_after_filter: + avg_g = base_g + avg_kg = avg_g / 1000.0 + prediction_aggregate_eff = "mean_all_filtered" + else: + avg_g = base_g + avg_kg = avg_g / 1000.0 + prediction_aggregate_eff = "max" if use_max_instead_of_mean else "mean" + + mp_thr = mean_pool_fallback_to_max_if_over_g + use_mean_pool_max = ( + mp_thr is not None + and float(mp_thr) > 0.0 + and np.isfinite(avg_g_filtered) + and float(avg_g_filtered) > float(mp_thr) + and np.isfinite(max_g_after_filter) + ) + if use_mean_pool_max: + avg_g = float(max_g_after_filter) + avg_kg = avg_g / 1000.0 + prediction_aggregate_eff = "max_after_filter_high_mean_pool" + + pred_weight_g = float(avg_g) + pred_weight_kg = float(avg_kg) + if use_mean_pool_max: + pred_weight_rule = "max_after_filter_high_mean_pool_over_g" + elif use_max_fallback: + pred_weight_rule = "max_after_filter_high_mean_all" + elif average_all_after_filter: + pred_weight_rule = "mean_all_filtered" + else: + pred_weight_rule = "top_k_aggregate" kept_ply_names = set(plys_used_for_prediction) kept_ply_paths = {it["ply"] for it in per_file if Path(it["ply"]).name in kept_ply_names} @@ -363,27 +808,66 @@ def _predict_folder_impl( summary = { "num_files": len(ply_files), "num_files_predicted": len(per_file), + "num_files_filtered_by_max_length": int(num_filtered_by_max_length), + "num_files_filtered_by_length_width_ratio": int(num_filtered_by_length_width_ratio), + "max_length_mm": length_cap, + "min_length_width_ratio": lw_min, "num_outliers_removed": num_outliers_removed, "num_files_after_outlier_removal": len(candidates_for_avg), "num_files_used_for_avg": int(num_used_for_avg), - "prediction_aggregate": "max" if use_max_instead_of_mean else "mean", + "average_all_after_filter": bool(average_all_after_filter), + "average_all_fallback_to_max_if_mean_over_g": ( + float(fb_thr) if fb_thr is not None and float(fb_thr) > 0 else None + ), + "used_max_instead_of_mean_all_high_mean": bool(use_max_fallback), + "mean_pool_fallback_to_max_if_over_g": ( + float(mp_thr) if mp_thr is not None and float(mp_thr) > 0 else None + ), + "used_max_instead_of_high_mean_pool": bool(use_mean_pool_max), + "mean_all_candidates_g_before_max_fallback": mean_all_candidates_g, + "prediction_aggregate": prediction_aggregate_eff, + "pred_weight_g": pred_weight_g, + "pred_weight_kg": pred_weight_kg, + "pred_weight_rule": pred_weight_rule, "top_k": top_k, "top_by_length": top_by_length, "effective_top_by_length": effective_top_by_length, "switched_to_weight_due_to_long_length": switched_to_weight_due_to_long_length, "length_switch_threshold_mm": float(length_switch_to_weight_mm), "avg_length_input_topk_by_length": avg_length_input_topk_by_length - if top_by_length + if (top_by_length or average_all_after_filter) else float("nan"), "plys_used_for_prediction": plys_used_for_prediction, "avg_predicted_weight_kg": avg_kg, "avg_predicted_weight_g": avg_g, + "mean_all_pred_g_after_filters": float(avg_g_filtered), "avg_predicted_weight_kg_all": avg_kg_all, "avg_predicted_weight_g_all": avg_g_all, + "max_predicted_weight_g_after_filter": max_g_after_filter, + "max_predicted_weight_kg_after_filter": max_kg_after_filter, + "avg_topk_mean_pred_g": avg_topk_mean_disp_g, + "avg_topk_mean_pred_selection": avg_topk_mean_disp_sel, + "avg_topk_mean_pred_g_top_by_weight": avg_topk_mean_top_by_weight_g, "avg_length_input_topk": avg_length_input_topk, "avg_length_input": avg_len_input, + "avg_length_mean_top_k_longest": avg_length_mean_top_k_longest, + "length_averages_mm": { + "mean_over_all_candidates_after_filters": avg_len_input, + "mean_over_top_k_longest_by_length": avg_length_mean_top_k_longest, + "mean_over_top_k_used_for_weight": avg_length_input_topk, + "top_k": int(top_k), + }, "std_length_input": std_len_input, "cv_length_pct": cv_length, + "length_variance_quality": lq_quality, + "length_variance_bad_reasons": lq_bad_reasons, + "length_variance_hint": lq_hint, + "near_max_length_band_mm": float(near_max_length_band_mm), + "max_length_mm_among_filtered": float(max_len_among_filtered), + "count_in_near_max_length_band": int(count_in_near_max_length_band), + "num_filtered_for_length_band_stats": int(len(lengths)), + "fraction_in_near_max_length_band": float(fraction_in_near_max_length_band), + **{k: v for k, v in lq_meta.items()}, "outlier_removal": { "enabled": remove_outliers, "method": outlier_method if remove_outliers else None, @@ -401,6 +885,7 @@ def predict_folder( device: torch.device, num_points: int = 768, xyz_scale: float = 0.001, + max_length_mm: Optional[float] = None, ) -> Tuple[List[Dict], Dict]: return _predict_folder_impl( model=model, @@ -409,6 +894,7 @@ def predict_folder( num_points=num_points, xyz_scale=xyz_scale, topk_length=None, + max_length_mm=max_length_mm, ) @@ -461,6 +947,13 @@ def predict_cloud_folder( iqr_factor: float = 1.5, zscore_threshold: float = 2.5, ply_files: Optional[List[Path]] = None, + max_length_mm: Optional[float] = None, + min_length_width_ratio: Optional[float] = 1.5, + length_quality_cv_threshold_pct: float = 15.0, + length_quality_max_span_mm: Optional[float] = 130.0, + average_all_after_filter: bool = False, + average_all_fallback_to_max_if_mean_over_g: Optional[float] = 400.0, + mean_pool_fallback_to_max_if_over_g: Optional[float] = 440.0, ) -> Tuple[List[Dict], Dict]: if ply_files is not None: ply_files = sorted({Path(p).expanduser().resolve() for p in ply_files}) @@ -487,6 +980,13 @@ def predict_cloud_folder( outlier_field=outlier_field, iqr_factor=iqr_factor, zscore_threshold=zscore_threshold, + max_length_mm=max_length_mm, + min_length_width_ratio=min_length_width_ratio, + length_quality_cv_threshold_pct=length_quality_cv_threshold_pct, + length_quality_max_span_mm=length_quality_max_span_mm, + average_all_after_filter=average_all_after_filter, + average_all_fallback_to_max_if_mean_over_g=average_all_fallback_to_max_if_mean_over_g, + mean_pool_fallback_to_max_if_over_g=mean_pool_fallback_to_max_if_over_g, ) @@ -588,6 +1088,60 @@ def compare_pred_vs_actual(pred_g: float, actual_g: float) -> Dict[str, float]: } +def _batch_topk_avg_better_than_pred_suffix( + summary: Dict[str, Any], + pred_g: float, + actual_g: float, +) -> str: + """ + When top-K mean predicted weight is clearly closer to actual than final pred, append a short note + after diff% (e.g. mean-all + max fallback vs top-K aggregate). + """ + tk = int(summary.get("top_k") or 5) + topk_g = summary.get("avg_topk_mean_pred_g") + if topk_g is None or not np.isfinite(float(topk_g)): + return "" + topk_g = float(topk_g) + if not np.isfinite(pred_g) or not np.isfinite(actual_g) or abs(float(actual_g)) < 1e-9: + return "" + err_pred = abs(float(pred_g) - float(actual_g)) + err_topk = abs(topk_g - float(actual_g)) + if err_topk >= err_pred: + return "" + improvement = err_pred - err_topk + actual_abs = abs(float(actual_g)) + # Meaningful absolute gain (avoid noise when pred≈topK) + min_improve_g = max(5.0, 0.02 * actual_abs) + if improvement < min_improve_g: + return "" + # Also require some relative gain vs pred error (5% of pred's absolute error), not the old + # err_topk <= 0.75*err_pred rule — that missed cases like fish5: top5 closer but only ~15% better. + if improvement < 0.05 * err_pred: + return "" + return f" | note: top{tk}_avg closer to actual than pred" + + +def _batch_topk_max_stable_star(summary: Dict[str, Any]) -> str: + """ + Append `` *`` when we treat the run as confident: + - first-column avg (mean over all post-filter candidates) > 440 g, OR + - share of candidates whose length is within 30 mm of the max length among filtered PLYs + is at least 25 %. + """ + confident_avg_g = 440.0 + min_fraction_largest_length_group = 0.25 + + mean_g = _mean_column_g_for_log(summary) + if mean_g is not None and np.isfinite(float(mean_g)) and float(mean_g) > confident_avg_g: + return " *" + + frac = summary.get("fraction_in_near_max_length_band") + if frac is not None and np.isfinite(float(frac)) and float(frac) >= min_fraction_largest_length_group: + return " *" + + return "" + + def main() -> None: parser = argparse.ArgumentParser("DGCNN weight estimator (folder inference)") parser.add_argument("--checkpoint", type=str, required=True, help="Path to best.pt/last.pt checkpoint") @@ -685,8 +1239,77 @@ def main() -> None: default=None, help="Maximum CV%% for length. Skip folders with CV > this value (e.g. 15.0).", ) + parser.add_argument( + "--max-length-mm", + type=float, + default=400.0, + dest="max_length_mm", + help="Exclude PLYs with length_input > this (mm) from final aggregation; marked filtered_by_max_length. " + "Use 0 to disable (default: 400).", + ) + parser.add_argument( + "--min-length-width-ratio", + type=float, + default=1.5, + dest="min_length_width_ratio", + help="Exclude PLYs with PCA length/width < this (width = 2nd PCA axis). Catches non-fish blobs. " + "Use 0 to disable (default: 1.5).", + ) + parser.add_argument( + "--length-quality-cv-threshold-pct", + type=float, + default=15.0, + dest="length_quality_cv_threshold_pct", + help="Length variance hint: mark BAD if CV%% exceeds this (default: 15).", + ) + parser.add_argument( + "--length-quality-max-span-mm", + type=float, + default=130.0, + dest="length_quality_max_span_mm", + help="Length variance hint: mark BAD if max(length)−min(length) exceeds this (mm). Use 0 to disable (default: 130).", + ) + parser.add_argument( + "--average-all-after-filter", + action=argparse.BooleanOptionalAction, + default=False, + dest="average_all_after_filter", + help="Final weight = mean over all PLYs after geometry filters and optional outlier removal (no top-K).", + ) + parser.add_argument( + "--average-all-fallback-max-if-mean-over-g", + type=float, + default=400.0, + dest="average_all_fallback_max_if_mean_over_g", + help="With --average-all-after-filter: if mean predicted weight (g) exceeds this, use max after filter " + "instead. 0 disables (default: 400).", + ) + parser.add_argument( + "--mean-pool-fallback-max-if-over-g", + type=float, + default=440.0, + dest="mean_pool_fallback_max_if_over_g", + help="If mean predicted weight over all post-filter candidates (g) exceeds this, use max after filter " + "as final pred (top-K or mean-all). 0 disables (default: 440).", + ) args = parser.parse_args() + max_length_eff: Optional[float] = None if args.max_length_mm <= 0 else float(args.max_length_mm) + min_lw_eff: Optional[float] = None if args.min_length_width_ratio <= 0 else float(args.min_length_width_ratio) + length_span_eff: Optional[float] = ( + None if args.length_quality_max_span_mm <= 0 else float(args.length_quality_max_span_mm) + ) + avg_all_fb_eff: Optional[float] = ( + None + if args.average_all_fallback_max_if_mean_over_g <= 0 + else float(args.average_all_fallback_max_if_mean_over_g) + ) + mean_pool_fb_eff: Optional[float] = ( + None + if args.mean_pool_fallback_max_if_over_g <= 0 + else float(args.mean_pool_fallback_max_if_over_g) + ) + ckpt_path = Path(args.checkpoint).expanduser().resolve() if not ckpt_path.exists(): raise SystemExit(f"checkpoint not found: {ckpt_path}") @@ -744,6 +1367,13 @@ def main() -> None: iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, ply_files=ply_paths, + max_length_mm=max_length_eff, + min_length_width_ratio=min_lw_eff, + length_quality_cv_threshold_pct=float(args.length_quality_cv_threshold_pct), + length_quality_max_span_mm=length_span_eff, + average_all_after_filter=bool(args.average_all_after_filter), + average_all_fallback_to_max_if_mean_over_g=avg_all_fb_eff, + mean_pool_fallback_to_max_if_over_g=mean_pool_fb_eff, ) elif args.ply_folder: ply_folder = Path(args.ply_folder).expanduser().resolve() @@ -767,6 +1397,13 @@ def main() -> None: outlier_field=args.outlier_field, iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, + max_length_mm=max_length_eff, + min_length_width_ratio=min_lw_eff, + length_quality_cv_threshold_pct=float(args.length_quality_cv_threshold_pct), + length_quality_max_span_mm=length_span_eff, + average_all_after_filter=bool(args.average_all_after_filter), + average_all_fallback_to_max_if_mean_over_g=avg_all_fb_eff, + mean_pool_fallback_to_max_if_over_g=mean_pool_fb_eff, ) if args.ply_list_file or args.ply_folder: @@ -784,6 +1421,13 @@ def main() -> None: "device": str(device), "labels_json": str(labels_path) if labels_path else None, "length_switch_mm": float(args.length_switch_mm), + "max_length_mm": max_length_eff, + "min_length_width_ratio": min_lw_eff, + "length_quality_cv_threshold_pct": float(args.length_quality_cv_threshold_pct), + "length_quality_max_span_mm": length_span_eff, + "average_all_after_filter": bool(args.average_all_after_filter), + "average_all_fallback_max_if_mean_over_g": avg_all_fb_eff, + "mean_pool_fallback_max_if_over_g": mean_pool_fb_eff, }, "summary": summary, "comparison": None, @@ -808,7 +1452,13 @@ def main() -> None: used = it.get("used_for_prediction", True) rank = it.get("rank_by_selection", 0) n_total = summary.get("num_files_predicted", len(per_file)) - if rank > 0: + if it.get("filtered_by_max_length"): + cap_s = f"{max_length_eff:.0f}" if max_length_eff is not None else "?" + tag = f" [FILTERED length>{cap_s}mm]" + elif it.get("filtered_by_length_width_ratio"): + thr_s = f"{min_lw_eff:.2f}" if min_lw_eff is not None else "?" + tag = f" [FILTERED L/W<{thr_s}]" + elif rank > 0: tag = f" (kept, rank {rank}/{n_total})" if used else f" (filtered, rank {rank}/{n_total})" else: tag = " (kept)" if used else " (filtered)" @@ -819,28 +1469,42 @@ def main() -> None: print(f"{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}") print(f"Files: {summary.get('num_files_predicted', summary['num_files'])}") + n_cap = summary.get("num_files_filtered_by_max_length", 0) or 0 + if n_cap > 0 and max_length_eff is not None: + print( + f"Filtered by length cap (>{max_length_eff:.0f} mm): {n_cap} " + f"(excluded from aggregation)" + ) + n_lw = summary.get("num_files_filtered_by_length_width_ratio", 0) or 0 + if n_lw > 0 and min_lw_eff is not None: + print( + f"Filtered by length/width ratio (<{min_lw_eff:.2f}): {n_lw} " + f"(excluded from aggregation)" + ) if args.remove_outliers and summary.get("num_outliers_removed", 0) > 0: print(f"Outliers removed: {summary['num_outliers_removed']} (method={args.outlier_method}, field={args.outlier_field})") - if summary.get("cv_length_pct") is not None: - print(f"Length CV: {summary['cv_length_pct']:.1f}%") top_label = _summary_top_label(summary, args.top_k, args.top_by_length) + tk = int(summary.get("top_k") or args.top_k) + topk_g = summary.get("avg_topk_mean_pred_g") + topk_sel = summary.get("avg_topk_mean_pred_selection") or "pred" + topk_extra = "" + if topk_g is not None and np.isfinite(float(topk_g)): + by = "length" if topk_sel == "by_length" else "pred" + topk_extra = f" | top{tk}_avg={float(topk_g):.2f} g by {by}" + mcol = _mean_column_g_for_log(summary) print( - f"Average predicted weight ({top_label}): " - f"{summary['avg_predicted_weight_g']:.2f} g ({summary['avg_predicted_weight_kg']:.4f} kg) " - f"(used={summary.get('num_files_used_for_avg', 'n/a')}/{summary.get('num_files_predicted', 'n/a')})" + f"Average predicted weight {top_label}: " + f"{mcol:.2f} g, {mcol / 1000.0:.4f} kg{topk_extra}" ) - avg_len_topk = summary.get("avg_length_input_topk") - if avg_len_topk is not None and np.isfinite(avg_len_topk): - print( - f"Average length ({top_label}): " - f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}" - ) + _print_max_weight_after_filter(summary) fish_key = extract_fish_key_from_text(str(ply_folder)) comparison = None if fish_key and fish_key in labels: actual_g = float(labels[fish_key]) - comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g) + comparison = compare_pred_vs_actual( + pred_g=_final_pred_g_for_compare(summary), actual_g=actual_g + ) print( f"Actual weight ({fish_key}): {actual_g:.2f} g | " f"Diff: {comparison['diff_g']:.2f} g | " @@ -861,6 +1525,13 @@ def main() -> None: "device": str(device), "labels_json": str(labels_path) if labels_path else None, "length_switch_mm": float(args.length_switch_mm), + "max_length_mm": max_length_eff, + "min_length_width_ratio": min_lw_eff, + "length_quality_cv_threshold_pct": float(args.length_quality_cv_threshold_pct), + "length_quality_max_span_mm": length_span_eff, + "average_all_after_filter": bool(args.average_all_after_filter), + "average_all_fallback_max_if_mean_over_g": avg_all_fb_eff, + "mean_pool_fallback_max_if_over_g": mean_pool_fb_eff, }, "summary": summary, "comparison": comparison, @@ -904,6 +1575,13 @@ def main() -> None: outlier_field=args.outlier_field, iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, + max_length_mm=max_length_eff, + min_length_width_ratio=min_lw_eff, + length_quality_cv_threshold_pct=float(args.length_quality_cv_threshold_pct), + length_quality_max_span_mm=length_span_eff, + average_all_after_filter=bool(args.average_all_after_filter), + average_all_fallback_to_max_if_mean_over_g=avg_all_fb_eff, + mean_pool_fallback_to_max_if_over_g=mean_pool_fb_eff, ) if summary.get("skipped"): print(f"{fish_key}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})") @@ -917,36 +1595,29 @@ def main() -> None: }) continue - num_used = summary.get("num_files_used_for_avg", "n/a") - num_total = summary.get("num_files_predicted", "n/a") - used_info = f"used={num_used}/{num_total} (from {len(cloud_dirs)} folders)" - - top_label = _summary_top_label(summary, args.top_k, args.top_by_length) - avg_len_topk = summary.get("avg_length_input_topk") - len_part = "" - if avg_len_topk is not None and np.isfinite(avg_len_topk): - len_part = ( - f" | avg_len(top{args.top_k} selected)=" - f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}" - ) + avg_leader = _batch_avg_leader(summary, args.top_k, args.top_by_length) + weight_part = _batch_aggregate_log_suffix(summary) if fish_key in labels: actual_g = float(labels[fish_key]) - comparison = compare_pred_vs_actual( - pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g + pred_final = _final_pred_g_for_compare(summary) + comparison = compare_pred_vs_actual(pred_g=pred_final, actual_g=actual_g) + topk_note = _batch_topk_avg_better_than_pred_suffix( + summary, pred_final, actual_g ) + star = _batch_topk_max_stable_star(summary) print( - f"{fish_key}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part} | " - f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%" + f"{fish_key}: {avg_leader}{_mean_column_g_for_log(summary):.2f} g{weight_part} | " + f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%{topk_note}{star}" ) else: + star = _batch_topk_max_stable_star(summary) print( - f"{fish_key}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part}" + f"{fish_key}: {avg_leader}{_mean_column_g_for_log(summary):.2f} g{weight_part}{star}" ) - comparison = None if fish_key in labels: comparison = compare_pred_vs_actual( - pred_g=float(summary["avg_predicted_weight_g"]), + pred_g=_final_pred_g_for_compare(summary), actual_g=float(labels[fish_key]), ) results.append({ @@ -990,6 +1661,13 @@ def main() -> None: outlier_field=args.outlier_field, iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, + max_length_mm=max_length_eff, + min_length_width_ratio=min_lw_eff, + length_quality_cv_threshold_pct=float(args.length_quality_cv_threshold_pct), + length_quality_max_span_mm=length_span_eff, + average_all_after_filter=bool(args.average_all_after_filter), + average_all_fallback_to_max_if_mean_over_g=avg_all_fb_eff, + mean_pool_fallback_to_max_if_over_g=mean_pool_fb_eff, ) if summary.get("skipped"): print(f"{rel}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})") @@ -1022,7 +1700,14 @@ def main() -> None: kg = float(it["predicted_weight_kg"]) length_input = float(it.get("length_input", float("nan"))) used = it.get("used_for_prediction", True) - tag = " (kept)" if used else " (filtered)" + if it.get("filtered_by_max_length"): + cap_s = f"{max_length_eff:.0f}" if max_length_eff is not None else "?" + tag = f" [FILTERED length>{cap_s}mm]" + elif it.get("filtered_by_length_width_ratio"): + thr_s = f"{min_lw_eff:.2f}" if min_lw_eff is not None else "?" + tag = f" [FILTERED L/W<{thr_s}]" + else: + tag = " (kept)" if used else " (filtered)" if abs(float(args.xyz_scale) - 0.001) < 1e-12: length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan" else: @@ -1031,42 +1716,26 @@ def main() -> None: fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir)) comparison = None - num_after_outlier = summary.get( - "num_files_after_outlier_removal", summary.get("num_files_predicted", "n/a") - ) - num_used = summary.get("num_files_used_for_avg", "n/a") - num_total = summary.get("num_files_predicted", "n/a") - num_outliers = summary.get("num_outliers_removed", 0) - if args.remove_outliers and num_outliers > 0: - used_info = f"used={num_used}/{num_after_outlier}, outliers={num_outliers}" - else: - used_info = f"used={num_used}/{num_total}" - if cv_pct is not None and cv_pct > 10.0: - used_info += f", cv={cv_pct:.1f}%" - - top_label = _summary_top_label(summary, args.top_k, args.top_by_length) - avg_len_topk = summary.get("avg_length_input_topk") - len_part = "" - if avg_len_topk is not None and np.isfinite(avg_len_topk): - len_part = ( - f" | avg_len(top{args.top_k} selected)=" - f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}" - ) + avg_leader = _batch_avg_leader(summary, args.top_k, args.top_by_length) + weight_part = _batch_aggregate_log_suffix(summary) if fish_key and fish_key in labels: actual_g = float(labels[fish_key]) - comparison = compare_pred_vs_actual( - pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g + pred_final = _final_pred_g_for_compare(summary) + comparison = compare_pred_vs_actual(pred_g=pred_final, actual_g=actual_g) + topk_note = _batch_topk_avg_better_than_pred_suffix( + summary, pred_final, actual_g ) + star = _batch_topk_max_stable_star(summary) print( - f"{rel}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part} | " - f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%" + f"{rel}: {avg_leader}{_mean_column_g_for_log(summary):.2f} g{weight_part} | " + f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%{topk_note}{star}" ) else: + star = _batch_topk_max_stable_star(summary) print( - f"{rel}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part}" + f"{rel}: {avg_leader}{_mean_column_g_for_log(summary):.2f} g{weight_part}{star}" ) - results.append({ "id": rel, "cloud_dir": str(cloud_dir), @@ -1090,6 +1759,13 @@ def main() -> None: "top_k": args.top_k, "top_by_length": args.top_by_length, "length_switch_mm": float(args.length_switch_mm), + "max_length_mm": max_length_eff, + "min_length_width_ratio": min_lw_eff, + "length_quality_cv_threshold_pct": float(args.length_quality_cv_threshold_pct), + "length_quality_max_span_mm": length_span_eff, + "average_all_after_filter": bool(args.average_all_after_filter), + "average_all_fallback_max_if_mean_over_g": avg_all_fb_eff, + "mean_pool_fallback_max_if_over_g": mean_pool_fb_eff, "cloud_folders": meta_cloud_folders, "labels_json": str(labels_path) if labels_path else None, }, diff --git a/fish_api/app/services/measure.py b/fish_api/app/services/measure.py index b07005d..eef7bf0 100644 --- a/fish_api/app/services/measure.py +++ b/fish_api/app/services/measure.py @@ -63,6 +63,8 @@ def _predict_weigth_from_svo2_extra_args(settings: Settings) -> List[str]: str(settings.predict_minute_interval_sec), ] ) + if settings.predict_show_large_labels_at_top_right: + out.append("--show-large-labels-at-top-right") if not settings.measure_reuse_existing_clouds: out.append("--no-reuse-existing-clouds") return out diff --git a/fish_api/app/settings.py b/fish_api/app/settings.py index 22692ac..1d4bad1 100644 --- a/fish_api/app/settings.py +++ b/fish_api/app/settings.py @@ -95,6 +95,8 @@ class Settings(BaseSettings): #: 为 True 时 fish_video 内联 DGCNN + 预览叠加(更重;需 fish_video 已支持) predict_fish_video_weight_overlay: bool = False predict_minute_interval_sec: float = 60.0 + #: 为 True 时在视频右上角显示大型 weight/length 标签(10倍字体),便于查看真实/相机生成视频的标签数据 + predict_show_large_labels_at_top_right: bool = False action_checkpoint: Optional[str] = None action_clips_per_video: int = 8