#!/usr/bin/env python3 """ Run PointNet++ weight regression on a folder of PLY point clouds. For each PLY: - load points - scale XYZ by 0.001 - center to origin (subtract centroid) - sample fixed N points (default 768) - predict weight (kg) Final prediction = average of per-PLY predictions. Example: python weight_estimator/test_pointnet_weight_estimator.py \ --checkpoint weight_estimator/runs/20260124_123456/best.pt \ --ply-folder /path/to/xxxx/cloud """ from __future__ import annotations import argparse import json import re import sys import zlib from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import open3d as o3d import torch # Ensure repo root is on sys.path so we can import the training model definition REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) def load_points_from_ply(ply_path: Path) -> np.ndarray: pcd = o3d.io.read_point_cloud(str(ply_path)) if len(pcd.points) > 0: return np.asarray(pcd.points, dtype=np.float32) mesh = o3d.io.read_triangle_mesh(str(ply_path)) if len(mesh.vertices) > 0: return np.asarray(mesh.vertices, dtype=np.float32) raise ValueError(f"No points/vertices found in: {ply_path}") def sample_points_deterministic(points: np.ndarray, num_points: int, seed: int) -> np.ndarray: n = points.shape[0] if n <= 0: raise ValueError("Empty point cloud") rng = np.random.default_rng(seed) if n >= num_points: idx = rng.choice(n, size=num_points, replace=False) else: idx = rng.choice(n, size=num_points, replace=True) return points[idx] def estimate_length_major_axis(points: np.ndarray) -> float: """ Estimate point cloud "length" as the extent along the 1st PCA axis. Unit is the same as the input points (e.g., mm if the PLY is in mm). """ if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3: return float("nan") pts = points[:, :3].astype(np.float32, copy=False) pts = pts - pts.mean(axis=0, keepdims=True) try: # SVD for PCA axis _u, _s, vt = np.linalg.svd(pts, full_matrices=False) axis = vt[0] # (3,) proj = pts @ axis return float(np.max(proj) - np.min(proj)) except Exception: return float("nan") def filter_outliers_iqr( per_file: List[Dict], field: str = "length_input", iqr_factor: float = 1.5, ) -> Tuple[List[Dict], List[Dict], Dict]: """ Filter outliers using the IQR (Interquartile Range) method. Args: per_file: List of per-file prediction dicts field: Field to use for outlier detection (default: "length_input") iqr_factor: IQR multiplier for outlier bounds (default: 1.5) Returns: (filtered_list, outliers_list, stats_dict) """ if not per_file: return [], [], {"error": "empty input"} # Extract valid values values = [] for it in per_file: v = float(it.get(field, float("nan"))) if np.isfinite(v): values.append(v) if len(values) < 4: # Not enough data for IQR, return all as valid return per_file, [], {"note": "too few samples for IQR", "num_samples": len(values)} values_arr = np.array(values, dtype=np.float64) q1 = float(np.percentile(values_arr, 25)) q3 = float(np.percentile(values_arr, 75)) iqr = q3 - q1 lower_bound = q1 - iqr_factor * iqr upper_bound = q3 + iqr_factor * iqr filtered = [] outliers = [] for it in per_file: v = float(it.get(field, float("nan"))) if not np.isfinite(v): # Keep items without valid field value (don't filter them) filtered.append(it) elif lower_bound <= v <= upper_bound: filtered.append(it) else: outliers.append(it) stats = { "field": field, "iqr_factor": float(iqr_factor), "q1": q1, "q3": q3, "iqr": iqr, "lower_bound": lower_bound, "upper_bound": upper_bound, "num_input": len(per_file), "num_filtered": len(filtered), "num_outliers": len(outliers), } return filtered, outliers, stats def filter_outliers_zscore( per_file: List[Dict], field: str = "length_input", zscore_threshold: float = 2.5, ) -> Tuple[List[Dict], List[Dict], Dict]: """ Filter outliers using Z-score method (for normally distributed data). Args: per_file: List of per-file prediction dicts field: Field to use for outlier detection (default: "length_input") zscore_threshold: Z-score threshold (default: 2.5, ~1% tails) Returns: (filtered_list, outliers_list, stats_dict) """ if not per_file: return [], [], {"error": "empty input"} # Extract valid values values = [] for it in per_file: v = float(it.get(field, float("nan"))) if np.isfinite(v): values.append(v) if len(values) < 3: return per_file, [], {"note": "too few samples for z-score", "num_samples": len(values)} values_arr = np.array(values, dtype=np.float64) mean_val = float(np.mean(values_arr)) std_val = float(np.std(values_arr)) if std_val < 1e-9: # No variation, keep all return per_file, [], {"note": "zero std", "mean": mean_val} filtered = [] outliers = [] for it in per_file: v = float(it.get(field, float("nan"))) if not np.isfinite(v): filtered.append(it) else: zscore = abs(v - mean_val) / std_val if zscore <= zscore_threshold: filtered.append(it) else: outliers.append(it) stats = { "field": field, "zscore_threshold": float(zscore_threshold), "mean": mean_val, "std": std_val, "num_input": len(per_file), "num_filtered": len(filtered), "num_outliers": len(outliers), } return filtered, outliers, stats @torch.no_grad() def predict_folder( model: torch.nn.Module, ply_files: List[Path], device: torch.device, num_points: int = 768, xyz_scale: float = 0.001, ) -> Tuple[List[Dict], Dict]: return _predict_folder_impl( model=model, ply_files=ply_files, device=device, num_points=num_points, xyz_scale=xyz_scale, topk_length=None, ) @torch.no_grad() def _predict_folder_impl( model: torch.nn.Module, ply_files: List[Path], device: torch.device, num_points: int = 768, xyz_scale: float = 0.001, topk_length: Optional[int] = None, topk_predictions: int = 10, top_k: int = 5, remove_outliers: bool = False, outlier_method: str = "iqr", outlier_field: str = "length_input", iqr_factor: float = 1.5, zscore_threshold: float = 2.5, ) -> Tuple[List[Dict], Dict]: """ Predict weights for a folder of PLY files. Args: model: PointNet++ regressor model ply_files: List of PLY file paths device: torch device num_points: Number of points to sample per PLY xyz_scale: XYZ scaling factor (0.001 = mm to m) topk_length: If set, use top-K longest PLYs for weighted average topk_predictions: When > this many PLYs, remove length outliers and average the rest (default: 10) remove_outliers: If True, remove outliers before averaging outlier_method: "iqr" or "zscore" outlier_field: Field to use for outlier detection ("length_input" or "predicted_weight_g") iqr_factor: IQR multiplier for outlier bounds (default: 1.5) zscore_threshold: Z-score threshold for outlier detection (default: 2.5) Returns: (per_file_list, summary_dict) """ model.eval() per_file: List[Dict] = [] # Skip folders with < 5 PLYs (inaccurate) min_plys_required = 5 if len(ply_files) < min_plys_required: summary_skipped = { "num_files": len(ply_files), "num_files_predicted": len(ply_files), "skipped": True, "skip_reason": f"fewer than {min_plys_required} PLYs (inaccurate)", "avg_predicted_weight_kg": float("nan"), "avg_predicted_weight_g": float("nan"), } return [], summary_skipped preds_kg: List[float] = [] for ply in ply_files: pts = load_points_from_ply(ply) length_input = estimate_length_major_axis(pts) # scale + center pts = pts * float(xyz_scale) pts = pts - pts.mean(axis=0, keepdims=True) # deterministic sampling per file path seed = int(zlib.adler32(str(ply).encode("utf-8")) & 0xFFFFFFFF) pts = sample_points_deterministic(pts, num_points=num_points, seed=seed) x = torch.from_numpy(pts).to(device=device, dtype=torch.float32) # (N,3) x = x.unsqueeze(0) # (1,N,3) x = x.transpose(1, 2).contiguous() # (1,3,N) pred_kg = float(model(x).item()) pred_g = pred_kg * 1000.0 preds_kg.append(pred_kg) per_file.append( { "ply": str(ply), "predicted_weight_kg": pred_kg, "predicted_weight_g": pred_g, "length_input": float(length_input), "length_after_scale": float(length_input * float(xyz_scale)) if np.isfinite(length_input) else float("nan"), "is_outlier": False, # Will be updated if outlier filtering is applied } ) # Apply outlier removal if enabled outlier_stats: Optional[Dict] = None candidates_for_avg = per_file num_outliers_removed = 0 if remove_outliers and per_file: if outlier_method == "iqr": filtered, outliers, outlier_stats = filter_outliers_iqr( per_file, field=outlier_field, iqr_factor=iqr_factor ) elif outlier_method == "zscore": filtered, outliers, outlier_stats = filter_outliers_zscore( per_file, field=outlier_field, zscore_threshold=zscore_threshold ) else: filtered, outliers = per_file, [] outlier_stats = {"error": f"unknown method: {outlier_method}"} # Mark outliers in the original per_file list outlier_plys = {it["ply"] for it in outliers} for it in per_file: if it["ply"] in outlier_plys: it["is_outlier"] = True candidates_for_avg = filtered num_outliers_removed = len(outliers) # mean over all predictions (before outlier removal) avg_kg_all = float(np.mean(preds_kg)) if preds_kg else float("nan") avg_g_all = avg_kg_all * 1000.0 if preds_kg else float("nan") # mean over filtered predictions (after outlier removal) filtered_preds_g = [float(it["predicted_weight_g"]) for it in candidates_for_avg] avg_kg_filtered = float(np.mean(filtered_preds_g)) / 1000.0 if filtered_preds_g else float("nan") avg_g_filtered = float(np.mean(filtered_preds_g)) if filtered_preds_g else float("nan") # Simple strategy: use top-K highest predictions (by predicted weight) and average sorted_by_pred = sorted(candidates_for_avg, key=lambda d: float(d.get("predicted_weight_g", float("-inf"))), reverse=True) selected_topk = sorted_by_pred[: min(top_k, len(sorted_by_pred))] preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk] use_max_instead_of_mean = len(candidates_for_avg) < 5 if preds_g_topk: avg_g_topk = ( float(np.max(preds_g_topk)) if use_max_instead_of_mean else float(np.mean(preds_g_topk)) ) else: avg_g_topk = float(avg_g_all) num_used_for_avg = len(selected_topk) plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk] lengths = [float(it["length_input"]) for it in candidates_for_avg if np.isfinite(float(it.get("length_input", float("nan"))))] avg_len_input = float(np.mean(lengths)) if lengths else float("nan") std_len_input = float(np.std(lengths)) if len(lengths) > 1 else 0.0 cv_length = float(std_len_input / avg_len_input * 100.0) if avg_len_input > 0 else float("nan") # final average: top 5 highest predictions avg_g = float(avg_g_topk) if np.isfinite(float(avg_g_topk)) else float(avg_g_all) avg_kg = avg_g / 1000.0 # Mark which per_file items were used for the final prediction (kept vs filtered) kept_ply_names = set(plys_used_for_prediction) kept_ply_paths = {it["ply"] for it in per_file if Path(it["ply"]).name in kept_ply_names} for it in per_file: it["used_for_prediction"] = it["ply"] in kept_ply_paths summary = { "num_files": len(ply_files), "num_files_predicted": len(per_file), "num_outliers_removed": num_outliers_removed, "num_files_after_outlier_removal": len(candidates_for_avg), "num_files_used_for_avg": int(num_used_for_avg), "prediction_aggregate": "max" if use_max_instead_of_mean else "mean", "top_k": top_k, "plys_used_for_prediction": plys_used_for_prediction, "avg_predicted_weight_kg": avg_kg, "avg_predicted_weight_g": avg_g, "avg_predicted_weight_kg_all": avg_kg_all, "avg_predicted_weight_g_all": avg_g_all, "avg_length_input": avg_len_input, "std_length_input": std_len_input, "cv_length_pct": cv_length, "outlier_removal": { "enabled": remove_outliers, "method": outlier_method if remove_outliers else None, "field": outlier_field if remove_outliers else None, "stats": outlier_stats, } if remove_outliers else None, } return per_file, summary def collect_cloud_folders(batch_root: Path) -> List[Path]: """ Recursively collect folders named 'cloud' under batch_root that contain at least one .ply. Typical layout: /fishxx/xxx/cloud/*.ply """ batch_root = batch_root.expanduser().resolve() if not batch_root.exists() or not batch_root.is_dir(): raise ValueError(f"batch_root not found/dir: {batch_root}") out: List[Path] = [] for p in batch_root.rglob("cloud"): if not p.is_dir(): continue if any(p.glob("*.ply")): out.append(p) return sorted(out) def group_cloud_folders_by_fish(batch_root: Path) -> Dict[str, List[Path]]: """ Group cloud folders by fish_id. Returns {fish_id: [cloud_dir1, cloud_dir2, ...]}. Typical layout: /fish1/HD1080_xxx/cloud, /fish1/HD1080_yyy/cloud """ cloud_folders = collect_cloud_folders(batch_root) by_fish: Dict[str, List[Path]] = {} for cloud_dir in cloud_folders: rel = str(cloud_dir.parent.relative_to(batch_root)) if cloud_dir.parent.is_relative_to(batch_root) else str(cloud_dir.parent) fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir)) if fish_key: by_fish.setdefault(fish_key, []).append(cloud_dir) else: # No fish id found, use folder path as key key = rel.replace("/", "_") or "unknown" by_fish.setdefault(key, []).append(cloud_dir) return by_fish def predict_cloud_folder( model: torch.nn.Module, cloud_dir: Path, device: torch.device, num_points: int, xyz_scale: float, recursive: bool = False, topk_length: Optional[int] = None, topk_predictions: int = 10, top_k: int = 5, remove_outliers: bool = False, outlier_method: str = "iqr", outlier_field: str = "length_input", iqr_factor: float = 1.5, zscore_threshold: float = 2.5, ) -> Tuple[List[Dict], Dict]: if recursive: ply_files = sorted(cloud_dir.rglob("*.ply")) else: ply_files = sorted(cloud_dir.glob("*.ply")) return _predict_folder_impl( model=model, ply_files=ply_files, device=device, num_points=num_points, xyz_scale=xyz_scale, topk_length=topk_length, topk_predictions=topk_predictions, top_k=top_k, remove_outliers=remove_outliers, outlier_method=outlier_method, outlier_field=outlier_field, iqr_factor=iqr_factor, zscore_threshold=zscore_threshold, ) def load_label_weights_json(labels_path: Path) -> Dict[str, float]: """ Load a simple mapping like: {"fish1": 380.25, "fish2": 487.45, ...} Values are expected in grams (float). """ labels_path = labels_path.expanduser().resolve() if not labels_path.exists(): return {} data = json.loads(labels_path.read_text(encoding="utf-8")) if not isinstance(data, dict): return {} out: Dict[str, float] = {} for k, v in data.items(): try: out[str(k).strip()] = float(v) except Exception: continue return out def extract_fish_key_from_text(text: str) -> Optional[str]: """ Extract fish id from a path/filename like: fish1, fish01, Fish_12, ... -> "fish12" """ m = re.search(r"(?:^|[^a-z0-9])fish\s*[_\- ]*\s*0*([0-9]{1,4})(?:[^0-9]|$)", text, flags=re.IGNORECASE) if not m: # fallback: handle ".../fish12/..." without separators (common) m = re.search(r"fish\s*0*([0-9]{1,4})", text, flags=re.IGNORECASE) if not m: return None try: n = int(m.group(1)) except Exception: return None if n <= 0: return None return f"fish{n}" def compare_pred_vs_actual(pred_g: float, actual_g: float) -> Dict[str, float]: if not np.isfinite(pred_g) or not np.isfinite(actual_g) or actual_g == 0.0: return { "predicted_weight_g": float(pred_g), "actual_weight_g": float(actual_g), "diff_g": float("nan"), "diff_pct": float("nan"), "abs_diff_pct": float("nan"), } diff_g = float(pred_g - actual_g) diff_pct = float(diff_g / actual_g * 100.0) return { "predicted_weight_g": float(pred_g), "actual_weight_g": float(actual_g), "diff_g": diff_g, "diff_pct": diff_pct, "abs_diff_pct": abs(diff_pct), } def main() -> None: parser = argparse.ArgumentParser("PointNet++ weight estimator (folder inference)") parser.add_argument("--checkpoint", type=str, required=True, help="Path to best.pt/last.pt checkpoint") src = parser.add_mutually_exclusive_group(required=True) src.add_argument("--ply-folder", type=str, default="", help="Folder containing .ply files") src.add_argument( "--batch-root", type=str, default="", help="Batch mode: recursively find */cloud folders under this root and predict each", ) parser.add_argument( "--combine-by-fish", action="store_true", default=True, help="Batch mode: combine all folders per fish, use top 10 for average (default: True)", ) parser.add_argument( "--no-combine-by-fish", action="store_false", dest="combine_by_fish", help="Batch mode: predict each folder separately (top 5 per folder)", ) parser.add_argument("--recursive", action="store_true", help="Recursively search for .ply files (single folder mode)") parser.add_argument("--num-points", type=int, default=768, help="Number of points per PLY (default: 768)") parser.add_argument("--xyz-scale", type=float, default=0.001, help="XYZ scaling factor (default: 0.001)") parser.add_argument("--device", type=str, default="auto", help="auto|cpu|cuda") parser.add_argument("--output-json", type=str, default=None, help="Optional path to save results JSON") parser.add_argument( "--labels-json", type=str, default="/home/ubuntu/label20_weights.json", help="Optional label mapping JSON (e.g. /home/ubuntu/label20_weights.json). Keys like 'fish12' and values are grams.", ) parser.add_argument( "--topk-length", "--topk_length", type=int, default=None, dest="topk_length", help="If set (e.g. 10), compute folder average as length-weighted average over the top-K longest PLYs", ) parser.add_argument( "--topk-predictions", type=int, default=10, help="When > this many PLYs, remove length outliers (IQR) and average the rest (default: 10)", ) parser.add_argument( "--top-k", "--top_k", type=int, default=10, dest="top_k", help="Number of top predictions to average (default: 10). Use e.g. --top-k 20 or --top_k=20.", ) parser.add_argument( "--print-per-ply", action="store_true", help="Print per-PLY predicted weight (single folder mode default prints per-PLY regardless; batch mode uses this flag)", ) # Outlier removal arguments parser.add_argument( "--remove-outliers", action="store_true", help="Remove outliers before computing average (improves robustness)", ) parser.add_argument( "--outlier-method", type=str, default="iqr", choices=["iqr", "zscore"], help="Outlier detection method: 'iqr' (default) or 'zscore'", ) parser.add_argument( "--outlier-field", type=str, default="length_input", choices=["length_input", "predicted_weight_g"], help="Field to use for outlier detection (default: length_input)", ) parser.add_argument( "--iqr-factor", type=float, default=1.5, help="IQR multiplier for outlier bounds (default: 1.5). Lower = more aggressive filtering.", ) parser.add_argument( "--zscore-threshold", type=float, default=2.5, help="Z-score threshold for outlier detection (default: 2.5). Lower = more aggressive filtering.", ) parser.add_argument( "--max-cv", type=float, default=None, help="Maximum coefficient of variation (CV%%) for length. Skip folders with CV > this value (e.g. 15.0). " "Lower CV means more consistent point clouds. If not set, no CV filtering is applied.", ) args = parser.parse_args() ckpt_path = Path(args.checkpoint).expanduser().resolve() if not ckpt_path.exists(): raise SystemExit(f"checkpoint not found: {ckpt_path}") if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) # Load model (SSG or MSG based on checkpoint config) from weight_estimator.train_pointnet_weigth_estimator import load_model_from_checkpoint # noqa: E402 model = load_model_from_checkpoint(ckpt_path, device=device) labels_path = Path(args.labels_json).expanduser().resolve() if args.labels_json else None labels = load_label_weights_json(labels_path) if labels_path is not None else {} if args.ply_folder: ply_folder = Path(args.ply_folder).expanduser().resolve() if not ply_folder.exists() or not ply_folder.is_dir(): raise SystemExit(f"ply-folder not found/dir: {ply_folder}") per_file, summary = predict_cloud_folder( model=model, cloud_dir=ply_folder, device=device, num_points=args.num_points, xyz_scale=args.xyz_scale, recursive=args.recursive, topk_length=args.topk_length, topk_predictions=args.topk_predictions, top_k=args.top_k, remove_outliers=args.remove_outliers, outlier_method=args.outlier_method, outlier_field=args.outlier_field, iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, ) if summary.get("skipped"): print(f"Skipped: {summary.get('skip_reason', 'fewer than 5 PLYs')} (files={summary.get('num_files', 0)})") out = { "meta": { "checkpoint": str(ckpt_path), "ply_folder": str(ply_folder), "recursive": bool(args.recursive), "num_points": int(args.num_points), "xyz_scale": float(args.xyz_scale), "device": str(device), "labels_json": str(labels_path) if labels_path is not None else None, }, "summary": summary, "comparison": None, "per_file": per_file, } if args.output_json: out_path = Path(args.output_json).expanduser().resolve() out_path = out_path.parent / f"{out_path.stem}_top{args.top_k}{out_path.suffix}" out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") print(f"Saved: {out_path}") return # Print per-file predictions (default behavior) for it in per_file: ply = Path(it["ply"]).name g = float(it["predicted_weight_g"]) kg = float(it["predicted_weight_kg"]) length_input = float(it.get("length_input", float("nan"))) used = it.get("used_for_prediction", True) tag = " (kept)" if used else " (filtered)" if abs(float(args.xyz_scale) - 0.001) < 1e-12: length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan" else: length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan" print(f"{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}") # Print summary print(f"Files: {summary.get('num_files_predicted', summary['num_files'])}") if args.remove_outliers and summary.get('num_outliers_removed', 0) > 0: print(f"Outliers removed: {summary['num_outliers_removed']} (method={args.outlier_method}, field={args.outlier_field})") if summary.get('cv_length_pct') is not None: print(f"Length CV: {summary['cv_length_pct']:.1f}%") print( f"Average predicted weight (top {args.top_k}): " f"{summary['avg_predicted_weight_g']:.2f} g ({summary['avg_predicted_weight_kg']:.4f} kg) " f"(used={summary.get('num_files_used_for_avg', 'n/a')}/{summary.get('num_files_predicted', 'n/a')})" ) # Compare with labels (if matchable) fish_key = extract_fish_key_from_text(str(ply_folder)) comparison = None if fish_key and fish_key in labels: actual_g = float(labels[fish_key]) comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g) print( f"Actual weight ({fish_key}): {actual_g:.2f} g | " f"Diff: {comparison['diff_g']:.2f} g | " f"Diff%: {comparison['diff_pct']:.2f}% (abs {comparison['abs_diff_pct']:.2f}%)" ) elif labels: print(f"[warn] No label match for folder: {ply_folder} (parsed key={fish_key})") out = { "meta": { "checkpoint": str(ckpt_path), "ply_folder": str(ply_folder), "recursive": bool(args.recursive), "num_points": int(args.num_points), "xyz_scale": float(args.xyz_scale), "device": str(device), "labels_json": str(labels_path) if labels_path is not None else None, }, "summary": summary, "comparison": comparison, "per_file": per_file, } if args.output_json: out_path = Path(args.output_json).expanduser().resolve() out_path = out_path.parent / f"{out_path.stem}_top{args.top_k}{out_path.suffix}" out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") print(f"Saved: {out_path}") else: batch_root = Path(args.batch_root).expanduser().resolve() if args.combine_by_fish: # Group by fish, combine all PLYs per fish, use top 10 by_fish = group_cloud_folders_by_fish(batch_root) if not by_fish: raise SystemExit(f"No cloud folders found under: {batch_root}") results: List[Dict] = [] for fish_key, cloud_dirs in sorted(by_fish.items()): ply_files: List[Path] = [] for cloud_dir in cloud_dirs: ply_files.extend(sorted(cloud_dir.glob("*.ply"))) ply_files = sorted(set(ply_files)) try: per_file, summary = _predict_folder_impl( model=model, ply_files=ply_files, device=device, num_points=args.num_points, xyz_scale=args.xyz_scale, top_k=args.top_k, remove_outliers=args.remove_outliers, outlier_method=args.outlier_method, outlier_field=args.outlier_field, iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, ) if summary.get("skipped"): print(f"{fish_key}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})") results.append({ "id": fish_key, "fish_key": fish_key, "skipped": True, "skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"), "num_files": summary.get("num_files", 0), "cloud_folders": len(cloud_dirs), }) continue num_used = summary.get("num_files_used_for_avg", "n/a") num_total = summary.get("num_files_predicted", "n/a") used_info = f"used={num_used}/{num_total} (from {len(cloud_dirs)} folders)" if fish_key in labels: actual_g = float(labels[fish_key]) comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g) print(f"{fish_key}: avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g ({used_info}) | actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%") else: print(f"{fish_key}: avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g ({used_info})") comparison = None if fish_key in labels: comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=float(labels[fish_key])) results.append({ "id": fish_key, "fish_key": fish_key, "cloud_folders": len(cloud_dirs), "summary": summary, "comparison": comparison, "per_file": per_file, }) except Exception as e: results.append({"id": fish_key, "fish_key": fish_key, "error": str(e)}) else: # Per-folder mode: predict each folder separately (top 5) cloud_folders = collect_cloud_folders(batch_root) if not cloud_folders: raise SystemExit(f"No cloud folders found under: {batch_root}") results = [] for cloud_dir in cloud_folders: rel = ( str(cloud_dir.parent.relative_to(batch_root)) if cloud_dir.parent.is_relative_to(batch_root) else str(cloud_dir.parent) ) try: per_file, summary = predict_cloud_folder( model=model, cloud_dir=cloud_dir, device=device, num_points=args.num_points, xyz_scale=args.xyz_scale, recursive=False, topk_length=args.topk_length, topk_predictions=args.topk_predictions, top_k=args.top_k, remove_outliers=args.remove_outliers, outlier_method=args.outlier_method, outlier_field=args.outlier_field, iqr_factor=args.iqr_factor, zscore_threshold=args.zscore_threshold, ) # Skip folders with < 5 PLYs if summary.get("skipped"): print(f"{rel}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})") results.append({ "id": rel, "cloud_dir": str(cloud_dir), "skipped": True, "skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"), "num_files": summary.get("num_files", 0), }) continue # Check CV filtering - skip this folder if CV exceeds threshold cv_pct = summary.get('cv_length_pct', None) if args.max_cv is not None and cv_pct is not None and np.isfinite(cv_pct): if cv_pct > args.max_cv: print(f"{rel}: SKIPPED (cv={cv_pct:.1f}% > max_cv={args.max_cv}%)") results.append({ "id": rel, "cloud_dir": str(cloud_dir), "skipped": True, "skip_reason": f"cv_length={cv_pct:.1f}% exceeds max_cv={args.max_cv}%", "cv_length_pct": cv_pct, }) continue if args.print_per_ply: for it in per_file: ply = Path(it["ply"]).name g = float(it["predicted_weight_g"]) kg = float(it["predicted_weight_kg"]) length_input = float(it.get("length_input", float("nan"))) used = it.get("used_for_prediction", True) tag = " (kept)" if used else " (filtered)" if abs(float(args.xyz_scale) - 0.001) < 1e-12: length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan" else: length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan" print(f"{rel}/{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}") # Compare with labels if possible fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir)) comparison = None # Build used_info string with outlier info num_after_outlier = summary.get('num_files_after_outlier_removal', summary.get('num_files_predicted', 'n/a')) num_used = summary.get('num_files_used_for_avg', 'n/a') num_total = summary.get('num_files_predicted', 'n/a') num_outliers = summary.get('num_outliers_removed', 0) if args.remove_outliers and num_outliers > 0: used_info = f"used={num_used}/{num_after_outlier}, outliers={num_outliers}" else: used_info = f"used={num_used}/{num_total}" # Add CV info if available and high cv_pct = summary.get('cv_length_pct', None) if cv_pct is not None and cv_pct > 10.0: used_info += f", cv={cv_pct:.1f}%" if fish_key and fish_key in labels: actual_g = float(labels[fish_key]) comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g) avg_info = f"avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g" print(f"{rel}: {avg_info} ({used_info}) | actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%") else: print(f"{rel}: avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g ({used_info})") results.append( { "id": rel, "cloud_dir": str(cloud_dir), "summary": summary, "fish_key": fish_key, "comparison": comparison, } ) except Exception as e: results.append({"id": rel, "cloud_dir": str(cloud_dir), "error": str(e)}) meta_cloud_folders = len(collect_cloud_folders(batch_root)) if batch_root else 0 out = { "meta": { "checkpoint": str(ckpt_path), "batch_root": str(batch_root), "num_points": int(args.num_points), "xyz_scale": float(args.xyz_scale), "device": str(device), "combine_by_fish": args.combine_by_fish, "top_k": args.top_k, "cloud_folders": meta_cloud_folders, "labels_json": str(labels_path) if labels_path is not None else None, }, "results": results, } out_path = Path(args.output_json).expanduser().resolve() if args.output_json else (batch_root / "batch_weight_predictions.json") out_path = out_path.parent / f"{out_path.stem}_top{args.top_k}{out_path.suffix}" out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8") print(f"Saved: {out_path}") if __name__ == "__main__": main()