Files
FishServer/FishMeasure/weight_estimator/test_pointnet_weight_estimator.py

975 lines
38 KiB
Python

#!/usr/bin/env python3
"""
Run PointNet++ weight regression on a folder of PLY point clouds.
For each PLY:
- load points
- scale XYZ by 0.001
- center to origin (subtract centroid)
- sample fixed N points (default 768)
- predict weight (kg)
Final prediction = average of per-PLY predictions.
Example:
python weight_estimator/test_pointnet_weight_estimator.py \
--checkpoint weight_estimator/runs/20260124_123456/best.pt \
--ply-folder /path/to/xxxx/cloud
"""
from __future__ import annotations
import argparse
import json
import re
import sys
import zlib
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import open3d as o3d
import torch
# Ensure repo root is on sys.path so we can import the training model definition
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def load_points_from_ply(ply_path: Path) -> np.ndarray:
pcd = o3d.io.read_point_cloud(str(ply_path))
if len(pcd.points) > 0:
return np.asarray(pcd.points, dtype=np.float32)
mesh = o3d.io.read_triangle_mesh(str(ply_path))
if len(mesh.vertices) > 0:
return np.asarray(mesh.vertices, dtype=np.float32)
raise ValueError(f"No points/vertices found in: {ply_path}")
def sample_points_deterministic(points: np.ndarray, num_points: int, seed: int) -> np.ndarray:
n = points.shape[0]
if n <= 0:
raise ValueError("Empty point cloud")
rng = np.random.default_rng(seed)
if n >= num_points:
idx = rng.choice(n, size=num_points, replace=False)
else:
idx = rng.choice(n, size=num_points, replace=True)
return points[idx]
def estimate_length_major_axis(points: np.ndarray) -> float:
"""
Estimate point cloud "length" as the extent along the 1st PCA axis.
Unit is the same as the input points (e.g., mm if the PLY is in mm).
"""
if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3:
return float("nan")
pts = points[:, :3].astype(np.float32, copy=False)
pts = pts - pts.mean(axis=0, keepdims=True)
try:
# SVD for PCA axis
_u, _s, vt = np.linalg.svd(pts, full_matrices=False)
axis = vt[0] # (3,)
proj = pts @ axis
return float(np.max(proj) - np.min(proj))
except Exception:
return float("nan")
def filter_outliers_iqr(
per_file: List[Dict],
field: str = "length_input",
iqr_factor: float = 1.5,
) -> Tuple[List[Dict], List[Dict], Dict]:
"""
Filter outliers using the IQR (Interquartile Range) method.
Args:
per_file: List of per-file prediction dicts
field: Field to use for outlier detection (default: "length_input")
iqr_factor: IQR multiplier for outlier bounds (default: 1.5)
Returns:
(filtered_list, outliers_list, stats_dict)
"""
if not per_file:
return [], [], {"error": "empty input"}
# Extract valid values
values = []
for it in per_file:
v = float(it.get(field, float("nan")))
if np.isfinite(v):
values.append(v)
if len(values) < 4:
# Not enough data for IQR, return all as valid
return per_file, [], {"note": "too few samples for IQR", "num_samples": len(values)}
values_arr = np.array(values, dtype=np.float64)
q1 = float(np.percentile(values_arr, 25))
q3 = float(np.percentile(values_arr, 75))
iqr = q3 - q1
lower_bound = q1 - iqr_factor * iqr
upper_bound = q3 + iqr_factor * iqr
filtered = []
outliers = []
for it in per_file:
v = float(it.get(field, float("nan")))
if not np.isfinite(v):
# Keep items without valid field value (don't filter them)
filtered.append(it)
elif lower_bound <= v <= upper_bound:
filtered.append(it)
else:
outliers.append(it)
stats = {
"field": field,
"iqr_factor": float(iqr_factor),
"q1": q1,
"q3": q3,
"iqr": iqr,
"lower_bound": lower_bound,
"upper_bound": upper_bound,
"num_input": len(per_file),
"num_filtered": len(filtered),
"num_outliers": len(outliers),
}
return filtered, outliers, stats
def filter_outliers_zscore(
per_file: List[Dict],
field: str = "length_input",
zscore_threshold: float = 2.5,
) -> Tuple[List[Dict], List[Dict], Dict]:
"""
Filter outliers using Z-score method (for normally distributed data).
Args:
per_file: List of per-file prediction dicts
field: Field to use for outlier detection (default: "length_input")
zscore_threshold: Z-score threshold (default: 2.5, ~1% tails)
Returns:
(filtered_list, outliers_list, stats_dict)
"""
if not per_file:
return [], [], {"error": "empty input"}
# Extract valid values
values = []
for it in per_file:
v = float(it.get(field, float("nan")))
if np.isfinite(v):
values.append(v)
if len(values) < 3:
return per_file, [], {"note": "too few samples for z-score", "num_samples": len(values)}
values_arr = np.array(values, dtype=np.float64)
mean_val = float(np.mean(values_arr))
std_val = float(np.std(values_arr))
if std_val < 1e-9:
# No variation, keep all
return per_file, [], {"note": "zero std", "mean": mean_val}
filtered = []
outliers = []
for it in per_file:
v = float(it.get(field, float("nan")))
if not np.isfinite(v):
filtered.append(it)
else:
zscore = abs(v - mean_val) / std_val
if zscore <= zscore_threshold:
filtered.append(it)
else:
outliers.append(it)
stats = {
"field": field,
"zscore_threshold": float(zscore_threshold),
"mean": mean_val,
"std": std_val,
"num_input": len(per_file),
"num_filtered": len(filtered),
"num_outliers": len(outliers),
}
return filtered, outliers, stats
@torch.no_grad()
def predict_folder(
model: torch.nn.Module,
ply_files: List[Path],
device: torch.device,
num_points: int = 768,
xyz_scale: float = 0.001,
) -> Tuple[List[Dict], Dict]:
return _predict_folder_impl(
model=model,
ply_files=ply_files,
device=device,
num_points=num_points,
xyz_scale=xyz_scale,
topk_length=None,
)
@torch.no_grad()
def _predict_folder_impl(
model: torch.nn.Module,
ply_files: List[Path],
device: torch.device,
num_points: int = 768,
xyz_scale: float = 0.001,
topk_length: Optional[int] = None,
topk_predictions: int = 10,
top_k: int = 5,
remove_outliers: bool = False,
outlier_method: str = "iqr",
outlier_field: str = "length_input",
iqr_factor: float = 1.5,
zscore_threshold: float = 2.5,
) -> Tuple[List[Dict], Dict]:
"""
Predict weights for a folder of PLY files.
Args:
model: PointNet++ regressor model
ply_files: List of PLY file paths
device: torch device
num_points: Number of points to sample per PLY
xyz_scale: XYZ scaling factor (0.001 = mm to m)
topk_length: If set, use top-K longest PLYs for weighted average
topk_predictions: When > this many PLYs, remove length outliers and average the rest (default: 10)
remove_outliers: If True, remove outliers before averaging
outlier_method: "iqr" or "zscore"
outlier_field: Field to use for outlier detection ("length_input" or "predicted_weight_g")
iqr_factor: IQR multiplier for outlier bounds (default: 1.5)
zscore_threshold: Z-score threshold for outlier detection (default: 2.5)
Returns:
(per_file_list, summary_dict)
"""
model.eval()
per_file: List[Dict] = []
# Skip folders with < 5 PLYs (inaccurate)
min_plys_required = 5
if len(ply_files) < min_plys_required:
summary_skipped = {
"num_files": len(ply_files),
"num_files_predicted": len(ply_files),
"skipped": True,
"skip_reason": f"fewer than {min_plys_required} PLYs (inaccurate)",
"avg_predicted_weight_kg": float("nan"),
"avg_predicted_weight_g": float("nan"),
}
return [], summary_skipped
preds_kg: List[float] = []
for ply in ply_files:
pts = load_points_from_ply(ply)
length_input = estimate_length_major_axis(pts)
# scale + center
pts = pts * float(xyz_scale)
pts = pts - pts.mean(axis=0, keepdims=True)
# deterministic sampling per file path
seed = int(zlib.adler32(str(ply).encode("utf-8")) & 0xFFFFFFFF)
pts = sample_points_deterministic(pts, num_points=num_points, seed=seed)
x = torch.from_numpy(pts).to(device=device, dtype=torch.float32) # (N,3)
x = x.unsqueeze(0) # (1,N,3)
x = x.transpose(1, 2).contiguous() # (1,3,N)
pred_kg = float(model(x).item())
pred_g = pred_kg * 1000.0
preds_kg.append(pred_kg)
per_file.append(
{
"ply": str(ply),
"predicted_weight_kg": pred_kg,
"predicted_weight_g": pred_g,
"length_input": float(length_input),
"length_after_scale": float(length_input * float(xyz_scale)) if np.isfinite(length_input) else float("nan"),
"is_outlier": False, # Will be updated if outlier filtering is applied
}
)
# Apply outlier removal if enabled
outlier_stats: Optional[Dict] = None
candidates_for_avg = per_file
num_outliers_removed = 0
if remove_outliers and per_file:
if outlier_method == "iqr":
filtered, outliers, outlier_stats = filter_outliers_iqr(
per_file, field=outlier_field, iqr_factor=iqr_factor
)
elif outlier_method == "zscore":
filtered, outliers, outlier_stats = filter_outliers_zscore(
per_file, field=outlier_field, zscore_threshold=zscore_threshold
)
else:
filtered, outliers = per_file, []
outlier_stats = {"error": f"unknown method: {outlier_method}"}
# Mark outliers in the original per_file list
outlier_plys = {it["ply"] for it in outliers}
for it in per_file:
if it["ply"] in outlier_plys:
it["is_outlier"] = True
candidates_for_avg = filtered
num_outliers_removed = len(outliers)
# mean over all predictions (before outlier removal)
avg_kg_all = float(np.mean(preds_kg)) if preds_kg else float("nan")
avg_g_all = avg_kg_all * 1000.0 if preds_kg else float("nan")
# mean over filtered predictions (after outlier removal)
filtered_preds_g = [float(it["predicted_weight_g"]) for it in candidates_for_avg]
avg_kg_filtered = float(np.mean(filtered_preds_g)) / 1000.0 if filtered_preds_g else float("nan")
avg_g_filtered = float(np.mean(filtered_preds_g)) if filtered_preds_g else float("nan")
# Simple strategy: use top-K highest predictions (by predicted weight) and average
sorted_by_pred = sorted(candidates_for_avg, key=lambda d: float(d.get("predicted_weight_g", float("-inf"))), reverse=True)
selected_topk = sorted_by_pred[: min(top_k, len(sorted_by_pred))]
preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk]
use_max_instead_of_mean = len(candidates_for_avg) < 5
if preds_g_topk:
avg_g_topk = (
float(np.max(preds_g_topk))
if use_max_instead_of_mean
else float(np.mean(preds_g_topk))
)
else:
avg_g_topk = float(avg_g_all)
num_used_for_avg = len(selected_topk)
plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk]
lengths = [float(it["length_input"]) for it in candidates_for_avg if np.isfinite(float(it.get("length_input", float("nan"))))]
avg_len_input = float(np.mean(lengths)) if lengths else float("nan")
std_len_input = float(np.std(lengths)) if len(lengths) > 1 else 0.0
cv_length = float(std_len_input / avg_len_input * 100.0) if avg_len_input > 0 else float("nan")
# final average: top 5 highest predictions
avg_g = float(avg_g_topk) if np.isfinite(float(avg_g_topk)) else float(avg_g_all)
avg_kg = avg_g / 1000.0
# Mark which per_file items were used for the final prediction (kept vs filtered)
kept_ply_names = set(plys_used_for_prediction)
kept_ply_paths = {it["ply"] for it in per_file if Path(it["ply"]).name in kept_ply_names}
for it in per_file:
it["used_for_prediction"] = it["ply"] in kept_ply_paths
summary = {
"num_files": len(ply_files),
"num_files_predicted": len(per_file),
"num_outliers_removed": num_outliers_removed,
"num_files_after_outlier_removal": len(candidates_for_avg),
"num_files_used_for_avg": int(num_used_for_avg),
"prediction_aggregate": "max" if use_max_instead_of_mean else "mean",
"top_k": top_k,
"plys_used_for_prediction": plys_used_for_prediction,
"avg_predicted_weight_kg": avg_kg,
"avg_predicted_weight_g": avg_g,
"avg_predicted_weight_kg_all": avg_kg_all,
"avg_predicted_weight_g_all": avg_g_all,
"avg_length_input": avg_len_input,
"std_length_input": std_len_input,
"cv_length_pct": cv_length,
"outlier_removal": {
"enabled": remove_outliers,
"method": outlier_method if remove_outliers else None,
"field": outlier_field if remove_outliers else None,
"stats": outlier_stats,
} if remove_outliers else None,
}
return per_file, summary
def collect_cloud_folders(batch_root: Path) -> List[Path]:
"""
Recursively collect folders named 'cloud' under batch_root that contain at least one .ply.
Typical layout:
<batch_root>/fishxx/xxx/cloud/*.ply
"""
batch_root = batch_root.expanduser().resolve()
if not batch_root.exists() or not batch_root.is_dir():
raise ValueError(f"batch_root not found/dir: {batch_root}")
out: List[Path] = []
for p in batch_root.rglob("cloud"):
if not p.is_dir():
continue
if any(p.glob("*.ply")):
out.append(p)
return sorted(out)
def group_cloud_folders_by_fish(batch_root: Path) -> Dict[str, List[Path]]:
"""
Group cloud folders by fish_id. Returns {fish_id: [cloud_dir1, cloud_dir2, ...]}.
Typical layout: <batch_root>/fish1/HD1080_xxx/cloud, <batch_root>/fish1/HD1080_yyy/cloud
"""
cloud_folders = collect_cloud_folders(batch_root)
by_fish: Dict[str, List[Path]] = {}
for cloud_dir in cloud_folders:
rel = str(cloud_dir.parent.relative_to(batch_root)) if cloud_dir.parent.is_relative_to(batch_root) else str(cloud_dir.parent)
fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir))
if fish_key:
by_fish.setdefault(fish_key, []).append(cloud_dir)
else:
# No fish id found, use folder path as key
key = rel.replace("/", "_") or "unknown"
by_fish.setdefault(key, []).append(cloud_dir)
return by_fish
def predict_cloud_folder(
model: torch.nn.Module,
cloud_dir: Path,
device: torch.device,
num_points: int,
xyz_scale: float,
recursive: bool = False,
topk_length: Optional[int] = None,
topk_predictions: int = 10,
top_k: int = 5,
remove_outliers: bool = False,
outlier_method: str = "iqr",
outlier_field: str = "length_input",
iqr_factor: float = 1.5,
zscore_threshold: float = 2.5,
) -> Tuple[List[Dict], Dict]:
if recursive:
ply_files = sorted(cloud_dir.rglob("*.ply"))
else:
ply_files = sorted(cloud_dir.glob("*.ply"))
return _predict_folder_impl(
model=model,
ply_files=ply_files,
device=device,
num_points=num_points,
xyz_scale=xyz_scale,
topk_length=topk_length,
topk_predictions=topk_predictions,
top_k=top_k,
remove_outliers=remove_outliers,
outlier_method=outlier_method,
outlier_field=outlier_field,
iqr_factor=iqr_factor,
zscore_threshold=zscore_threshold,
)
def load_label_weights_json(labels_path: Path) -> Dict[str, float]:
"""
Load a simple mapping like:
{"fish1": 380.25, "fish2": 487.45, ...}
Values are expected in grams (float).
"""
labels_path = labels_path.expanduser().resolve()
if not labels_path.exists():
return {}
data = json.loads(labels_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
out: Dict[str, float] = {}
for k, v in data.items():
try:
out[str(k).strip()] = float(v)
except Exception:
continue
return out
def extract_fish_key_from_text(text: str) -> Optional[str]:
"""
Extract fish id from a path/filename like:
fish1, fish01, Fish_12, ... -> "fish12"
"""
m = re.search(r"(?:^|[^a-z0-9])fish\s*[_\- ]*\s*0*([0-9]{1,4})(?:[^0-9]|$)", text, flags=re.IGNORECASE)
if not m:
# fallback: handle ".../fish12/..." without separators (common)
m = re.search(r"fish\s*0*([0-9]{1,4})", text, flags=re.IGNORECASE)
if not m:
return None
try:
n = int(m.group(1))
except Exception:
return None
if n <= 0:
return None
return f"fish{n}"
def compare_pred_vs_actual(pred_g: float, actual_g: float) -> Dict[str, float]:
if not np.isfinite(pred_g) or not np.isfinite(actual_g) or actual_g == 0.0:
return {
"predicted_weight_g": float(pred_g),
"actual_weight_g": float(actual_g),
"diff_g": float("nan"),
"diff_pct": float("nan"),
"abs_diff_pct": float("nan"),
}
diff_g = float(pred_g - actual_g)
diff_pct = float(diff_g / actual_g * 100.0)
return {
"predicted_weight_g": float(pred_g),
"actual_weight_g": float(actual_g),
"diff_g": diff_g,
"diff_pct": diff_pct,
"abs_diff_pct": abs(diff_pct),
}
def main() -> None:
parser = argparse.ArgumentParser("PointNet++ weight estimator (folder inference)")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to best.pt/last.pt checkpoint")
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument("--ply-folder", type=str, default="", help="Folder containing .ply files")
src.add_argument(
"--batch-root",
type=str,
default="",
help="Batch mode: recursively find */cloud folders under this root and predict each",
)
parser.add_argument(
"--combine-by-fish",
action="store_true",
default=True,
help="Batch mode: combine all folders per fish, use top 10 for average (default: True)",
)
parser.add_argument(
"--no-combine-by-fish",
action="store_false",
dest="combine_by_fish",
help="Batch mode: predict each folder separately (top 5 per folder)",
)
parser.add_argument("--recursive", action="store_true", help="Recursively search for .ply files (single folder mode)")
parser.add_argument("--num-points", type=int, default=768, help="Number of points per PLY (default: 768)")
parser.add_argument("--xyz-scale", type=float, default=0.001, help="XYZ scaling factor (default: 0.001)")
parser.add_argument("--device", type=str, default="auto", help="auto|cpu|cuda")
parser.add_argument("--output-json", type=str, default=None, help="Optional path to save results JSON")
parser.add_argument(
"--labels-json",
type=str,
default="/home/ubuntu/label20_weights.json",
help="Optional label mapping JSON (e.g. /home/ubuntu/label20_weights.json). Keys like 'fish12' and values are grams.",
)
parser.add_argument(
"--topk-length", "--topk_length",
type=int,
default=None,
dest="topk_length",
help="If set (e.g. 10), compute folder average as length-weighted average over the top-K longest PLYs",
)
parser.add_argument(
"--topk-predictions",
type=int,
default=10,
help="When > this many PLYs, remove length outliers (IQR) and average the rest (default: 10)",
)
parser.add_argument(
"--top-k", "--top_k",
type=int,
default=10,
dest="top_k",
help="Number of top predictions to average (default: 10). Use e.g. --top-k 20 or --top_k=20.",
)
parser.add_argument(
"--print-per-ply",
action="store_true",
help="Print per-PLY predicted weight (single folder mode default prints per-PLY regardless; batch mode uses this flag)",
)
# Outlier removal arguments
parser.add_argument(
"--remove-outliers",
action="store_true",
help="Remove outliers before computing average (improves robustness)",
)
parser.add_argument(
"--outlier-method",
type=str,
default="iqr",
choices=["iqr", "zscore"],
help="Outlier detection method: 'iqr' (default) or 'zscore'",
)
parser.add_argument(
"--outlier-field",
type=str,
default="length_input",
choices=["length_input", "predicted_weight_g"],
help="Field to use for outlier detection (default: length_input)",
)
parser.add_argument(
"--iqr-factor",
type=float,
default=1.5,
help="IQR multiplier for outlier bounds (default: 1.5). Lower = more aggressive filtering.",
)
parser.add_argument(
"--zscore-threshold",
type=float,
default=2.5,
help="Z-score threshold for outlier detection (default: 2.5). Lower = more aggressive filtering.",
)
parser.add_argument(
"--max-cv",
type=float,
default=None,
help="Maximum coefficient of variation (CV%%) for length. Skip folders with CV > this value (e.g. 15.0). "
"Lower CV means more consistent point clouds. If not set, no CV filtering is applied.",
)
args = parser.parse_args()
ckpt_path = Path(args.checkpoint).expanduser().resolve()
if not ckpt_path.exists():
raise SystemExit(f"checkpoint not found: {ckpt_path}")
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
# Load model (SSG or MSG based on checkpoint config)
from weight_estimator.train_pointnet_weigth_estimator import load_model_from_checkpoint # noqa: E402
model = load_model_from_checkpoint(ckpt_path, device=device)
labels_path = Path(args.labels_json).expanduser().resolve() if args.labels_json else None
labels = load_label_weights_json(labels_path) if labels_path is not None else {}
if args.ply_folder:
ply_folder = Path(args.ply_folder).expanduser().resolve()
if not ply_folder.exists() or not ply_folder.is_dir():
raise SystemExit(f"ply-folder not found/dir: {ply_folder}")
per_file, summary = predict_cloud_folder(
model=model,
cloud_dir=ply_folder,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
recursive=args.recursive,
topk_length=args.topk_length,
topk_predictions=args.topk_predictions,
top_k=args.top_k,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
)
if summary.get("skipped"):
print(f"Skipped: {summary.get('skip_reason', 'fewer than 5 PLYs')} (files={summary.get('num_files', 0)})")
out = {
"meta": {
"checkpoint": str(ckpt_path),
"ply_folder": str(ply_folder),
"recursive": bool(args.recursive),
"num_points": int(args.num_points),
"xyz_scale": float(args.xyz_scale),
"device": str(device),
"labels_json": str(labels_path) if labels_path is not None else None,
},
"summary": summary,
"comparison": None,
"per_file": per_file,
}
if args.output_json:
out_path = Path(args.output_json).expanduser().resolve()
out_path = out_path.parent / f"{out_path.stem}_top{args.top_k}{out_path.suffix}"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Saved: {out_path}")
return
# Print per-file predictions (default behavior)
for it in per_file:
ply = Path(it["ply"]).name
g = float(it["predicted_weight_g"])
kg = float(it["predicted_weight_kg"])
length_input = float(it.get("length_input", float("nan")))
used = it.get("used_for_prediction", True)
tag = " (kept)" if used else " (filtered)"
if abs(float(args.xyz_scale) - 0.001) < 1e-12:
length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan"
else:
length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan"
print(f"{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}")
# Print summary
print(f"Files: {summary.get('num_files_predicted', summary['num_files'])}")
if args.remove_outliers and summary.get('num_outliers_removed', 0) > 0:
print(f"Outliers removed: {summary['num_outliers_removed']} (method={args.outlier_method}, field={args.outlier_field})")
if summary.get('cv_length_pct') is not None:
print(f"Length CV: {summary['cv_length_pct']:.1f}%")
print(
f"Average predicted weight (top {args.top_k}): "
f"{summary['avg_predicted_weight_g']:.2f} g ({summary['avg_predicted_weight_kg']:.4f} kg) "
f"(used={summary.get('num_files_used_for_avg', 'n/a')}/{summary.get('num_files_predicted', 'n/a')})"
)
# Compare with labels (if matchable)
fish_key = extract_fish_key_from_text(str(ply_folder))
comparison = None
if fish_key and fish_key in labels:
actual_g = float(labels[fish_key])
comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g)
print(
f"Actual weight ({fish_key}): {actual_g:.2f} g | "
f"Diff: {comparison['diff_g']:.2f} g | "
f"Diff%: {comparison['diff_pct']:.2f}% (abs {comparison['abs_diff_pct']:.2f}%)"
)
elif labels:
print(f"[warn] No label match for folder: {ply_folder} (parsed key={fish_key})")
out = {
"meta": {
"checkpoint": str(ckpt_path),
"ply_folder": str(ply_folder),
"recursive": bool(args.recursive),
"num_points": int(args.num_points),
"xyz_scale": float(args.xyz_scale),
"device": str(device),
"labels_json": str(labels_path) if labels_path is not None else None,
},
"summary": summary,
"comparison": comparison,
"per_file": per_file,
}
if args.output_json:
out_path = Path(args.output_json).expanduser().resolve()
out_path = out_path.parent / f"{out_path.stem}_top{args.top_k}{out_path.suffix}"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Saved: {out_path}")
else:
batch_root = Path(args.batch_root).expanduser().resolve()
if args.combine_by_fish:
# Group by fish, combine all PLYs per fish, use top 10
by_fish = group_cloud_folders_by_fish(batch_root)
if not by_fish:
raise SystemExit(f"No cloud folders found under: {batch_root}")
results: List[Dict] = []
for fish_key, cloud_dirs in sorted(by_fish.items()):
ply_files: List[Path] = []
for cloud_dir in cloud_dirs:
ply_files.extend(sorted(cloud_dir.glob("*.ply")))
ply_files = sorted(set(ply_files))
try:
per_file, summary = _predict_folder_impl(
model=model,
ply_files=ply_files,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
top_k=args.top_k,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
)
if summary.get("skipped"):
print(f"{fish_key}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})")
results.append({
"id": fish_key,
"fish_key": fish_key,
"skipped": True,
"skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"),
"num_files": summary.get("num_files", 0),
"cloud_folders": len(cloud_dirs),
})
continue
num_used = summary.get("num_files_used_for_avg", "n/a")
num_total = summary.get("num_files_predicted", "n/a")
used_info = f"used={num_used}/{num_total} (from {len(cloud_dirs)} folders)"
if fish_key in labels:
actual_g = float(labels[fish_key])
comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g)
print(f"{fish_key}: avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g ({used_info}) | actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%")
else:
print(f"{fish_key}: avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g ({used_info})")
comparison = None
if fish_key in labels:
comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=float(labels[fish_key]))
results.append({
"id": fish_key,
"fish_key": fish_key,
"cloud_folders": len(cloud_dirs),
"summary": summary,
"comparison": comparison,
"per_file": per_file,
})
except Exception as e:
results.append({"id": fish_key, "fish_key": fish_key, "error": str(e)})
else:
# Per-folder mode: predict each folder separately (top 5)
cloud_folders = collect_cloud_folders(batch_root)
if not cloud_folders:
raise SystemExit(f"No cloud folders found under: {batch_root}")
results = []
for cloud_dir in cloud_folders:
rel = (
str(cloud_dir.parent.relative_to(batch_root))
if cloud_dir.parent.is_relative_to(batch_root)
else str(cloud_dir.parent)
)
try:
per_file, summary = predict_cloud_folder(
model=model,
cloud_dir=cloud_dir,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
recursive=False,
topk_length=args.topk_length,
topk_predictions=args.topk_predictions,
top_k=args.top_k,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
)
# Skip folders with < 5 PLYs
if summary.get("skipped"):
print(f"{rel}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})")
results.append({
"id": rel,
"cloud_dir": str(cloud_dir),
"skipped": True,
"skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"),
"num_files": summary.get("num_files", 0),
})
continue
# Check CV filtering - skip this folder if CV exceeds threshold
cv_pct = summary.get('cv_length_pct', None)
if args.max_cv is not None and cv_pct is not None and np.isfinite(cv_pct):
if cv_pct > args.max_cv:
print(f"{rel}: SKIPPED (cv={cv_pct:.1f}% > max_cv={args.max_cv}%)")
results.append({
"id": rel,
"cloud_dir": str(cloud_dir),
"skipped": True,
"skip_reason": f"cv_length={cv_pct:.1f}% exceeds max_cv={args.max_cv}%",
"cv_length_pct": cv_pct,
})
continue
if args.print_per_ply:
for it in per_file:
ply = Path(it["ply"]).name
g = float(it["predicted_weight_g"])
kg = float(it["predicted_weight_kg"])
length_input = float(it.get("length_input", float("nan")))
used = it.get("used_for_prediction", True)
tag = " (kept)" if used else " (filtered)"
if abs(float(args.xyz_scale) - 0.001) < 1e-12:
length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan"
else:
length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan"
print(f"{rel}/{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}")
# Compare with labels if possible
fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir))
comparison = None
# Build used_info string with outlier info
num_after_outlier = summary.get('num_files_after_outlier_removal', summary.get('num_files_predicted', 'n/a'))
num_used = summary.get('num_files_used_for_avg', 'n/a')
num_total = summary.get('num_files_predicted', 'n/a')
num_outliers = summary.get('num_outliers_removed', 0)
if args.remove_outliers and num_outliers > 0:
used_info = f"used={num_used}/{num_after_outlier}, outliers={num_outliers}"
else:
used_info = f"used={num_used}/{num_total}"
# Add CV info if available and high
cv_pct = summary.get('cv_length_pct', None)
if cv_pct is not None and cv_pct > 10.0:
used_info += f", cv={cv_pct:.1f}%"
if fish_key and fish_key in labels:
actual_g = float(labels[fish_key])
comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g)
avg_info = f"avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g"
print(f"{rel}: {avg_info} ({used_info}) | actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%")
else:
print(f"{rel}: avg(top{args.top_k})={summary['avg_predicted_weight_g']:.2f} g ({used_info})")
results.append(
{
"id": rel,
"cloud_dir": str(cloud_dir),
"summary": summary,
"fish_key": fish_key,
"comparison": comparison,
}
)
except Exception as e:
results.append({"id": rel, "cloud_dir": str(cloud_dir), "error": str(e)})
meta_cloud_folders = len(collect_cloud_folders(batch_root)) if batch_root else 0
out = {
"meta": {
"checkpoint": str(ckpt_path),
"batch_root": str(batch_root),
"num_points": int(args.num_points),
"xyz_scale": float(args.xyz_scale),
"device": str(device),
"combine_by_fish": args.combine_by_fish,
"top_k": args.top_k,
"cloud_folders": meta_cloud_folders,
"labels_json": str(labels_path) if labels_path is not None else None,
},
"results": results,
}
out_path = Path(args.output_json).expanduser().resolve() if args.output_json else (batch_root / "batch_weight_predictions.json")
out_path = out_path.parent / f"{out_path.stem}_top{args.top_k}{out_path.suffix}"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Saved: {out_path}")
if __name__ == "__main__":
main()