#!/usr/bin/env python3 """ SVO2 → point clouds via **fish_video_weight_evaluation.py** (subprocess), then weight via **weight_estimator/test_dgcnn_weight_estimator.py** (subprocess). Default ``--conf`` is 0.8 (same as ``run_predict_from_svo2_fish9.sh``); ``run_fish_evaluation_simple.sh`` uses 0.5 — override with ``--conf``. FishServer passes ``--conf`` from ``MEASURE_YOLO_CONF`` (default 0.8). Outputs: ``//cloud/*.ply`` — from fish_video_weight_evaluation ``//weight_prediction.json`` — merged SVO info + DGCNN JSON (single-SVO mode) ``/weight_prediction.json`` — when ``--batch-svo-folder`` has **multiple** ``*.svo2``: one DGCNN run over **all** PLYs from every ``/cloud/`` (top-K applies to the union) ``//dgcnn_test_output_top.json`` — raw test_dgcnn output (per-SVO or combined under ``/``) With ``--reuse-existing-clouds`` (default), skips the fish subprocess when ``cloud/*.ply`` already exists; only runs test_dgcnn. """ from __future__ import annotations import argparse import json import subprocess import sys import math from pathlib import Path if not hasattr(argparse, "BooleanOptionalAction"): class _BooleanOptionalAction(argparse.Action): def __init__(self, option_strings, dest, default=None, type=None, choices=None, required=False, help=None, metavar=None): _option_strings = [] for opt in option_strings: _option_strings.append(opt) if opt.startswith("--"): _option_strings.append("--no-" + opt[2:]) super().__init__(option_strings=_option_strings, dest=dest, nargs=0, default=default, type=type, choices=choices, required=required, help=help, metavar=metavar) def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, not option_string.startswith("--no-")) argparse.BooleanOptionalAction = _BooleanOptionalAction from typing import Any, Dict, List, Optional, Tuple import torch REPO_ROOT = Path(__file__).resolve().parent if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) def _fish_video_script() -> Path: return REPO_ROOT / "fish_video_weight_evaluation.py" def _dgcnn_test_script() -> Path: return REPO_ROOT / "weight_estimator" / "test_dgcnn_weight_estimator.py" def _torch_device_to_test_arg(device: torch.device) -> str: return "cpu" if device.type == "cpu" else "cuda" def _collect_existing_clouds(cloud_dir: Path) -> List[Path]: return sorted(cloud_dir.glob("*.ply")) def has_cached_clouds(output_base: Path, svo_path: Path) -> bool: svo_path = svo_path.expanduser().resolve() cloud_dir = output_base / svo_path.stem / "cloud" return len(_collect_existing_clouds(cloud_dir)) > 0 def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_folder: Optional[Path]) -> None: fvs = _fish_video_script() if not fvs.is_file(): raise FileNotFoundError(f"Missing: {fvs}") out_parent = Path(args.save_output).expanduser().resolve() cmd: List[str] = [ sys.executable, str(fvs), "--save-output", str(out_parent), "--yolo-model", args.yolo_model, "--conf", str(args.conf), "--imgsz", str(args.imgsz), "--sam-device", args.sam_device, "--max-frames", str(args.max_frames), "--frame-stride", str(args.frame_stride), ] if batch_folder is not None: cmd.extend(["--batch-svo-folder", str(batch_folder.resolve())]) else: cmd.extend(["--svo", str(Path(args.svo).expanduser().resolve())]) if args.filter_pointcloud: cmd.append("--filter-pointcloud") if args.use_density_filter: cmd.append("--use-density-filter") if args.use_clustering_filter: cmd.append("--use-clustering-filter") if args.use_pointcloud_classifier: cmd.append("--use-pointcloud-classifier") if args.pointcloud_classifier: cmd.extend(["--pointcloud-classifier", args.pointcloud_classifier]) cmd.extend( ["--pointcloud-classifier-threshold", str(args.pointcloud_classifier_threshold)] ) if args.use_flatness_filter: cmd.append("--use-flatness-filter") cmd.extend(["--flatness-threshold", str(args.flatness_threshold)]) if getattr(args, "show_large_labels_at_top_right", False): cmd.append("--show-large-labels-at-top-right") print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}") proc = subprocess.run(cmd, cwd=str(REPO_ROOT)) if proc.returncode != 0: raise RuntimeError(f"fish_video_weight_evaluation.py exited with code {proc.returncode}") def _run_generate_video_with_labels_subprocess( *, args: argparse.Namespace, svo_path: Path, output_dir: Path, weight_json: Path, ) -> None: """Run generate_video_with_labels.py to create preview video with weight/length labels. Called after DGCNN completes to (re)generate the preview video with actual weight/length overlaid on detection boxes. """ script = REPO_ROOT / "generate_video_with_labels.py" if not script.is_file(): print(f"Warning: generate_video_with_labels.py not found at {script}, skipping video generation") return cmd = [ sys.executable, str(script), "--svo", str(svo_path.expanduser().resolve()), "--save-output", str(output_dir.expanduser().resolve()), "--yolo-model", args.yolo_model, "--conf", str(args.conf), "--imgsz", str(args.imgsz), "--frame-stride", str(args.frame_stride), "--sam-device", str(args.sam_device), "--weight-json", str(weight_json.expanduser().resolve()), ] if getattr(args, "show_large_labels_at_top_right", False): cmd.append("--show-large-labels-at-top-right") if getattr(args, "summary_star", False): cmd.append("--summary-star") print(f"Invoking generate_video_with_labels.py:\n {' '.join(cmd)}") proc = subprocess.run(cmd, cwd=str(REPO_ROOT)) if proc.returncode != 0: print(f"Warning: generate_video_with_labels.py exited with code {proc.returncode}") def _run_test_dgcnn_weight_estimator_subprocess( *, cloud_dir: Optional[Path] = None, ply_list_file: Optional[Path] = None, checkpoint: Path, device: torch.device, out_dir: Path, num_points: int, xyz_scale: float, top_k: int, top_by_length: bool, length_switch_mm: float, topk_length: Optional[int], remove_outliers: bool, outlier_method: str, labels_json: Optional[str], weight_max_length_mm: float = 400.0, weight_min_length_width_ratio: float = 1.5, weight_length_quality_cv_threshold_pct: float = 15.0, weight_length_quality_max_span_mm: float = 130.0, weight_average_all_after_filter: bool = False, weight_average_all_fallback_max_if_mean_over_g: float = 400.0, weight_mean_pool_fallback_max_if_over_g: float = 440.0, ) -> Tuple[Path, Dict[str, Any]]: if (cloud_dir is None) == (ply_list_file is None): raise ValueError("Exactly one of cloud_dir or ply_list_file must be set") script = _dgcnn_test_script() if not script.is_file(): raise FileNotFoundError(f"Missing: {script}") json_base = out_dir / "dgcnn_test_output.json" labels_path = labels_json if labels_json else str(out_dir / "__no_labels__.json") cmd: List[str] = [ sys.executable, str(script), "--checkpoint", str(checkpoint.expanduser().resolve()), "--device", _torch_device_to_test_arg(device), "--num-points", str(num_points), "--xyz-scale", str(xyz_scale), "--top-k", str(top_k), "--output-json", str(json_base), "--labels-json", labels_path, ] if ply_list_file is not None: cmd.extend(["--ply-list-file", str(ply_list_file.resolve())]) else: cmd.extend(["--ply-folder", str(cloud_dir.resolve())]) if top_by_length: cmd.append("--top-by-length") else: cmd.append("--no-top-by-length") cmd.extend(["--length-switch-mm", str(length_switch_mm)]) if topk_length is not None: cmd.extend(["--topk-length", str(topk_length)]) if remove_outliers: cmd.append("--remove-outliers") cmd.extend(["--outlier-method", outlier_method]) cmd.extend(["--max-length-mm", str(weight_max_length_mm)]) cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)]) cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)]) cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)]) if weight_average_all_after_filter: cmd.append("--average-all-after-filter") fb_thr = weight_average_all_fallback_max_if_mean_over_g if fb_thr is not None and float(fb_thr) > 0: cmd.extend( ["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))] ) mp_fb = weight_mean_pool_fallback_max_if_over_g if mp_fb is not None and float(mp_fb) > 0: cmd.extend( ["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))] ) print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}") proc = subprocess.run(cmd, cwd=str(REPO_ROOT)) if proc.returncode != 0: raise RuntimeError(f"test_dgcnn_weight_estimator.py exited with code {proc.returncode}") suffix = f"_top{top_k}_by_length" if top_by_length else f"_top{top_k}" written = json_base.parent / f"{json_base.stem}{suffix}{json_base.suffix}" if not written.is_file(): raise FileNotFoundError(f"Expected DGCNN output JSON at {written}") data = json.loads(written.read_text(encoding="utf-8")) return written, data def _merge_weight_prediction_json( *, svo_path: Path, svo_name: str, out_dir: Path, cloud_dir: Path, dgcnn_json_path: Path, dgcnn_data: Dict[str, Any], ) -> Dict[str, Any]: summary = dgcnn_data.get("summary") or {} skipped = bool(summary.get("skipped")) avg_kg = summary.get("avg_predicted_weight_kg") avg_g = summary.get("avg_predicted_weight_g") pred_kg = summary.get("pred_weight_kg") pred_g = summary.get("pred_weight_g") pred_rule = summary.get("pred_weight_rule") return { "svo": str(svo_path), "svo_name": svo_name, "output_dir": str(out_dir), "cloud_dir": str(cloud_dir), "num_ply_on_disk": len(_collect_existing_clouds(cloud_dir)), "dgcnn_output_json": str(dgcnn_json_path), "skipped_weight": skipped, "skip_reason": summary.get("skip_reason"), "num_clouds_used": summary.get("num_files_used_for_avg"), "num_ply_predicted": summary.get("num_files_predicted"), "avg_predicted_weight_kg": None if skipped else avg_kg, "avg_predicted_weight_g": None if skipped else avg_g, "pred_weight_g": None if skipped else pred_g, "pred_weight_kg": None if skipped else pred_kg, "pred_weight_rule": None if skipped else pred_rule, "dgcnn_meta": dgcnn_data.get("meta"), "dgcnn_summary": summary, "weight_summary": summary, "per_cloud": dgcnn_data.get("per_file") or [], "comparison": dgcnn_data.get("comparison"), } def _merge_weight_prediction_json_combined( *, svo_paths: List[Path], output_base: Path, ply_list_path: Path, dgcnn_json_path: Path, dgcnn_data: Dict[str, Any], ) -> Dict[str, Any]: summary = dgcnn_data.get("summary") or {} skipped = bool(summary.get("skipped")) avg_kg = summary.get("avg_predicted_weight_kg") avg_g = summary.get("avg_predicted_weight_g") pred_kg = summary.get("pred_weight_kg") pred_g = summary.get("pred_weight_g") pred_rule = summary.get("pred_weight_rule") cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths] n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs) return { "combined": True, "svo_paths": [str(p) for p in svo_paths], "svo_names": [p.stem for p in svo_paths], "output_dir": str(output_base), "cloud_dirs": [str(d) for d in cloud_dirs], "ply_list_file": str(ply_list_path), "num_ply_on_disk": n_ply, "dgcnn_output_json": str(dgcnn_json_path), "skipped_weight": skipped, "skip_reason": summary.get("skip_reason"), "num_clouds_used": summary.get("num_files_used_for_avg"), "num_ply_predicted": summary.get("num_files_predicted"), "avg_predicted_weight_kg": None if skipped else avg_kg, "avg_predicted_weight_g": None if skipped else avg_g, "pred_weight_g": None if skipped else pred_g, "pred_weight_kg": None if skipped else pred_kg, "pred_weight_rule": None if skipped else pred_rule, "dgcnn_meta": dgcnn_data.get("meta"), "dgcnn_summary": summary, "weight_summary": summary, "per_cloud": dgcnn_data.get("per_file") or [], "comparison": dgcnn_data.get("comparison"), } def _sanitize_for_json(obj: Any) -> Any: if isinstance(obj, dict): return {k: _sanitize_for_json(v) for k, v in obj.items()} if isinstance(obj, list): return [_sanitize_for_json(v) for v in obj] if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): return None return obj def run_weight_prediction_for_svo( svo_path: Path, output_base: Path, weight_checkpoint: Path, weight_device: torch.device, num_points_model: int, weight_top_k: int, weight_top_by_length: bool, weight_length_switch_mm: float, weight_topk_length: Optional[int], weight_remove_outliers: bool, weight_outlier_method: str, weight_xyz_scale: float, weight_labels_json: Optional[str], weight_max_length_mm: float = 400.0, weight_min_length_width_ratio: float = 1.5, weight_length_quality_cv_threshold_pct: float = 15.0, weight_length_quality_max_span_mm: float = 130.0, weight_average_all_after_filter: bool = False, weight_average_all_fallback_max_if_mean_over_g: float = 400.0, weight_mean_pool_fallback_max_if_over_g: float = 440.0, ) -> Dict[str, Any]: svo_path = svo_path.expanduser().resolve() if not svo_path.exists(): raise FileNotFoundError(f"svo not found: {svo_path}") svo_name = svo_path.stem out_dir = output_base / svo_name cloud_dir = out_dir / "cloud" plys = _collect_existing_clouds(cloud_dir) if not plys: raise RuntimeError( f"No .ply files in {cloud_dir}. Run without --reuse-existing-clouds or run " f"fish_video_weight_evaluation.py first." ) dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess( cloud_dir=cloud_dir, ply_list_file=None, checkpoint=weight_checkpoint, device=weight_device, out_dir=out_dir, num_points=num_points_model, xyz_scale=weight_xyz_scale, top_k=weight_top_k, top_by_length=weight_top_by_length, length_switch_mm=weight_length_switch_mm, topk_length=weight_topk_length, remove_outliers=weight_remove_outliers, outlier_method=weight_outlier_method, labels_json=weight_labels_json, weight_max_length_mm=weight_max_length_mm, weight_min_length_width_ratio=weight_min_length_width_ratio, weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, weight_length_quality_max_span_mm=weight_length_quality_max_span_mm, weight_average_all_after_filter=weight_average_all_after_filter, weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, ) result = _merge_weight_prediction_json( svo_path=svo_path, svo_name=svo_name, out_dir=out_dir, cloud_dir=cloud_dir, dgcnn_json_path=dgcnn_path, dgcnn_data=dgcnn_data, ) (out_dir / "weight_prediction.json").write_text( json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False), encoding="utf-8", ) return result def run_weight_prediction_combined_svos( svo_paths: List[Path], output_base: Path, weight_checkpoint: Path, weight_device: torch.device, num_points_model: int, weight_top_k: int, weight_top_by_length: bool, weight_length_switch_mm: float, weight_topk_length: Optional[int], weight_remove_outliers: bool, weight_outlier_method: str, weight_xyz_scale: float, weight_labels_json: Optional[str], weight_max_length_mm: float = 400.0, weight_min_length_width_ratio: float = 1.5, weight_length_quality_cv_threshold_pct: float = 15.0, weight_length_quality_max_span_mm: float = 130.0, weight_average_all_after_filter: bool = False, weight_average_all_fallback_max_if_mean_over_g: float = 400.0, weight_mean_pool_fallback_max_if_over_g: float = 440.0, ) -> Dict[str, Any]: """One DGCNN run over all ``//cloud/*.ply`` (top-K / by-length applies to the union).""" output_base = output_base.expanduser().resolve() svo_paths = [p.expanduser().resolve() for p in svo_paths] all_plys: List[Path] = [] for svo in svo_paths: all_plys.extend(_collect_existing_clouds(output_base / svo.stem / "cloud")) if not all_plys: raise RuntimeError( f"No .ply in any cloud folder under {output_base} for SVOs: {[p.stem for p in svo_paths]}" ) out_dir = output_base list_path = out_dir / "_combined_ply_list.txt" list_path.write_text("\n".join(str(p.resolve()) for p in sorted(all_plys)) + "\n", encoding="utf-8") dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess( cloud_dir=None, ply_list_file=list_path, checkpoint=weight_checkpoint, device=weight_device, out_dir=out_dir, num_points=num_points_model, xyz_scale=weight_xyz_scale, top_k=weight_top_k, top_by_length=weight_top_by_length, length_switch_mm=weight_length_switch_mm, topk_length=weight_topk_length, remove_outliers=weight_remove_outliers, outlier_method=weight_outlier_method, labels_json=weight_labels_json, weight_max_length_mm=weight_max_length_mm, weight_min_length_width_ratio=weight_min_length_width_ratio, weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct, weight_length_quality_max_span_mm=weight_length_quality_max_span_mm, weight_average_all_after_filter=weight_average_all_after_filter, weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g, weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g, ) result = _merge_weight_prediction_json_combined( svo_paths=svo_paths, output_base=output_base, ply_list_path=list_path, dgcnn_json_path=dgcnn_path, dgcnn_data=dgcnn_data, ) (out_dir / "weight_prediction.json").write_text( json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False), encoding="utf-8", ) return result def main() -> None: parser = argparse.ArgumentParser( description="SVO2 → fish_video_weight_evaluation.py → test_dgcnn_weight_estimator.py" ) src = parser.add_mutually_exclusive_group(required=True) src.add_argument("--svo", type=str, default="", help="Path to a single .svo2 file") src.add_argument("--batch-svo-folder", type=str, default="", help="Folder containing *.svo2 files") parser.add_argument( "--save-output", type=str, default="output_weight_estimator", help="Base output directory (same as fish_video --save-output parent)", ) # fish_video_weight_evaluation (aligned with run_fish_evaluation_simple.sh) parser.add_argument( "--yolo-model", type=str, default=str(REPO_ROOT / "runs/train/fish_detection_20251127_104658/weights/best.pt"), ) parser.add_argument( "--conf", type=float, default=0.8, help="YOLO confidence (fish9 script uses 0.8; simple.sh uses 0.5)", ) parser.add_argument("--imgsz", type=int, default=640) parser.add_argument( "--sam-device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="SAM device for fish_video; also used as --device for test_dgcnn (cuda|cpu)", ) parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames") parser.add_argument( "--frame-stride", type=int, default=1, metavar="N", help="Passed to fish_video: only YOLO/SAM/PLY every N-th frame (default: 1)", ) parser.add_argument("--filter-pointcloud", action="store_true") parser.add_argument("--use-density-filter", action="store_true") parser.add_argument("--use-clustering-filter", action="store_true") parser.add_argument("--use-pointcloud-classifier", action="store_true") parser.add_argument("--pointcloud-classifier", type=str, default=None) parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7) parser.add_argument("--use-flatness-filter", action="store_true") parser.add_argument("--flatness-threshold", type=float, default=70.0) # test_dgcnn_weight_estimator parser.add_argument( "--weight-checkpoint", type=str, default=str(REPO_ROOT / "weight_estimator/runs/dgcnn_20260312_171043/best.pt"), ) parser.add_argument("--num-points-model", type=int, default=768) parser.add_argument("--weight-top-k", type=int, default=5) parser.add_argument( "--weight-top-by-length", action=argparse.BooleanOptionalAction, default=True, help="Length-first top-K (default: on). Use --no-weight-top-by-length for weight-only order.", ) parser.add_argument( "--weight-length-switch-mm", type=float, default=319.0, help="If mean top-K length (mm) exceeds this, use top-K by predicted weight (default: 320).", ) parser.add_argument("--weight-topk-length", type=int, default=None) parser.add_argument("--weight-remove-outliers", action="store_true") parser.add_argument( "--weight-outlier-method", type=str, default="iqr", choices=["iqr", "zscore"], ) parser.add_argument("--weight-xyz-scale", type=float, default=0.001) parser.add_argument("--weight-labels-json", type=str, default=None) parser.add_argument( "--weight-max-length-mm", type=float, default=400.0, help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).", ) parser.add_argument( "--weight-min-length-width-ratio", type=float, default=1.5, help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.", ) parser.add_argument( "--weight-average-all-after-filter", action=argparse.BooleanOptionalAction, default=False, help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).", ) parser.add_argument( "--weight-average-all-fallback-max-if-mean-over-g", type=float, default=400.0, help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).", ) parser.add_argument( "--weight-mean-pool-fallback-max-if-over-g", type=float, default=440.0, help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).", ) parser.add_argument( "--weight-length-quality-cv-threshold-pct", type=float, default=15.0, help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).", ) parser.add_argument( "--weight-length-quality-max-span-mm", type=float, default=130.0, help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.", ) parser.add_argument( "--reuse-existing-clouds", action="store_true", default=True, help="Skip fish_video if //cloud/*.ply already exists", ) parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds") parser.add_argument( "--show-large-labels-at-top-right", action="store_true", help="Show large weight/length labels (10x font) at top right corner for real/camera generated videos.", ) parser.add_argument( "--summary-star", action="store_true", default=False, help="Pass to generate_video_with_labels: whether the summary line should draw *.", ) args = parser.parse_args() if args.frame_stride < 1: raise SystemExit("--frame-stride must be >= 1") output_base = Path(args.save_output).expanduser().resolve() output_base.mkdir(parents=True, exist_ok=True) weight_ckpt = Path(args.weight_checkpoint).expanduser().resolve() if not weight_ckpt.is_file(): raise SystemExit(f"weight checkpoint not found: {weight_ckpt}") if args.svo: svo_paths = [Path(args.svo).expanduser().resolve()] batch_folder: Optional[Path] = None else: batch_folder = Path(args.batch_svo_folder).expanduser().resolve() svo_paths = sorted(batch_folder.glob("*.svo2")) need_fish = True if svo_paths and args.reuse_existing_clouds: need_fish = not all(has_cached_clouds(output_base, p) for p in svo_paths) if need_fish: if args.svo: _run_fish_video_evaluation_subprocess(args, batch_folder=None) else: if not batch_folder or not batch_folder.is_dir(): raise SystemExit(f"batch folder not found: {args.batch_svo_folder}") if not svo_paths: raise SystemExit(f"No .svo2 files in: {batch_folder}") _run_fish_video_evaluation_subprocess(args, batch_folder=batch_folder) else: print("All targets already have cloud/*.ply — skipping fish_video_weight_evaluation.py.") device = torch.device(args.sam_device) results: List[Dict[str, Any]] = [] if len(svo_paths) > 1: names = ", ".join(p.name for p in svo_paths) print(f"\n=== Weight prediction: combined ({len(svo_paths)} SVOs) → top-{args.weight_top_k} over all PLYs ===") print(f" {names}") try: results.append( run_weight_prediction_combined_svos( svo_paths=svo_paths, output_base=output_base, weight_checkpoint=weight_ckpt, weight_device=device, num_points_model=args.num_points_model, weight_top_k=args.weight_top_k, weight_top_by_length=args.weight_top_by_length, weight_length_switch_mm=args.weight_length_switch_mm, weight_topk_length=args.weight_topk_length, weight_remove_outliers=args.weight_remove_outliers, weight_outlier_method=args.weight_outlier_method, weight_xyz_scale=args.weight_xyz_scale, weight_labels_json=args.weight_labels_json, weight_max_length_mm=args.weight_max_length_mm, weight_min_length_width_ratio=args.weight_min_length_width_ratio, weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm, weight_average_all_after_filter=args.weight_average_all_after_filter, weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, ) ) except Exception as e: results.append( {"combined": True, "svo_paths": [str(p) for p in svo_paths], "error": str(e)} ) else: for svo in svo_paths: label = svo.name print(f"\n=== Weight prediction: {label} ===") try: results.append( run_weight_prediction_for_svo( svo_path=svo, output_base=output_base, weight_checkpoint=weight_ckpt, weight_device=device, num_points_model=args.num_points_model, weight_top_k=args.weight_top_k, weight_top_by_length=args.weight_top_by_length, weight_length_switch_mm=args.weight_length_switch_mm, weight_topk_length=args.weight_topk_length, weight_remove_outliers=args.weight_remove_outliers, weight_outlier_method=args.weight_outlier_method, weight_xyz_scale=args.weight_xyz_scale, weight_labels_json=args.weight_labels_json, weight_max_length_mm=args.weight_max_length_mm, weight_min_length_width_ratio=args.weight_min_length_width_ratio, weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct, weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm, weight_average_all_after_filter=args.weight_average_all_after_filter, weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g, weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g, ) ) except Exception as e: results.append({"svo": str(svo), "error": str(e)}) # Generate preview video with weight/length labels (after DGCNN, for all runs) print("\n=== Generating preview video with weight/length labels ===") combined_json = output_base / "weight_prediction.json" for svo in svo_paths: svo_name = svo.stem svo_output_dir = output_base / svo_name weight_json = svo_output_dir / "weight_prediction.json" if not weight_json.exists(): weight_json = svo_output_dir / "weight_estimation" / "weight_estimation_results.json" if not weight_json.exists() and combined_json.exists(): weight_json = combined_json if weight_json.exists(): try: _run_generate_video_with_labels_subprocess( args=args, svo_path=svo, output_dir=svo_output_dir, weight_json=weight_json, ) except Exception as e: print(f" Warning: Failed to generate video for {svo.name}: {e}") else: print(f" Warning: No weight JSON found for {svo.name}, skipping video generation") summary_path = output_base / "weight_predictions_summary.json" summary_path.write_text( json.dumps(_sanitize_for_json(results), indent=2, ensure_ascii=False), encoding="utf-8", ) print(f"\nSaved summary: {summary_path}") if len(results) == 1 and "error" not in results[0]: r0 = results[0] avg_g = r0.get("avg_predicted_weight_g") pred_g = r0.get("pred_weight_g") out_g = pred_g if pred_g is not None else avg_g if out_g is not None: label = "pred_weight" if pred_g is not None else "avg_predicted_weight" print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g") else: print("Final predicted weight: N/A (skipped or no valid point clouds)") if __name__ == "__main__": main()