Files
FishServer/FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py
2026-04-08 19:32:23 +08:00

1113 lines
45 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Run DGCNN weight regression on a folder of PLY point clouds.
For each PLY:
- load points
- scale XYZ by 0.001
- center to origin (subtract centroid)
- sample fixed N points (default 768)
- predict weight (kg)
Final prediction = aggregate (mean or max if <5 PLYs) over top-K selected by length by default;
if mean top-K length exceeds --length-switch-mm (default 319), top-K is taken by predicted weight instead.
With --no-top-by-length, selection is always by predicted weight.
Uses the same preprocessing and aggregation logic as test_pointnet_weight_estimator.py.
DGCNN accepts (B, N, 3) and transposes internally.
Example:
python weight_estimator/test_dgcnn_weight_estimator.py \\
--checkpoint weight_estimator/runs/dgcnn_20260312_143020/best.pt \\
--ply-folder /path/to/xxxx/cloud
"""
from __future__ import annotations
import argparse
import json
import re
import sys
import zlib
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import open3d as o3d
import torch
REPO_ROOT = Path(__file__).resolve().parents[1]
WEIGHT_EST_DIR = Path(__file__).resolve().parent
for _p in (str(REPO_ROOT), str(WEIGHT_EST_DIR)):
if _p not in sys.path:
sys.path.insert(0, _p)
def _format_length_input_display(value: Optional[float], xyz_scale: float) -> str:
if value is None or not np.isfinite(float(value)):
return "nan"
v = float(value)
if abs(float(xyz_scale) - 0.001) < 1e-12:
return f"{v:.1f} mm"
return f"{v:.4f} units"
def _summary_top_label(summary: Dict[str, Any], top_k: int, fallback_by_length: bool) -> str:
eff = summary.get("effective_top_by_length")
if eff is None:
eff = fallback_by_length
return f"top{top_k} by length" if eff else f"top{top_k} by pred"
def load_points_from_ply(ply_path: Path) -> np.ndarray:
pcd = o3d.io.read_point_cloud(str(ply_path))
if len(pcd.points) > 0:
return np.asarray(pcd.points, dtype=np.float32)
mesh = o3d.io.read_triangle_mesh(str(ply_path))
if len(mesh.vertices) > 0:
return np.asarray(mesh.vertices, dtype=np.float32)
raise ValueError(f"No points/vertices found in: {ply_path}")
def sample_points_deterministic(points: np.ndarray, num_points: int, seed: int) -> np.ndarray:
n = points.shape[0]
if n <= 0:
raise ValueError("Empty point cloud")
rng = np.random.default_rng(seed)
if n >= num_points:
idx = rng.choice(n, size=num_points, replace=False)
else:
idx = rng.choice(n, size=num_points, replace=True)
return points[idx]
def estimate_length_major_axis(points: np.ndarray) -> float:
if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3:
return float("nan")
pts = points[:, :3].astype(np.float32, copy=False)
pts = pts - pts.mean(axis=0, keepdims=True)
try:
_u, _s, vt = np.linalg.svd(pts, full_matrices=False)
axis = vt[0]
proj = pts @ axis
return float(np.max(proj) - np.min(proj))
except Exception:
return float("nan")
def filter_outliers_iqr(
per_file: List[Dict],
field: str = "length_input",
iqr_factor: float = 1.5,
) -> Tuple[List[Dict], List[Dict], Dict]:
if not per_file:
return [], [], {"error": "empty input"}
values = []
for it in per_file:
v = float(it.get(field, float("nan")))
if np.isfinite(v):
values.append(v)
if len(values) < 4:
return per_file, [], {"note": "too few samples for IQR", "num_samples": len(values)}
values_arr = np.array(values, dtype=np.float64)
q1 = float(np.percentile(values_arr, 25))
q3 = float(np.percentile(values_arr, 75))
iqr = q3 - q1
lower_bound = q1 - iqr_factor * iqr
upper_bound = q3 + iqr_factor * iqr
filtered = []
outliers = []
for it in per_file:
v = float(it.get(field, float("nan")))
if not np.isfinite(v):
filtered.append(it)
elif lower_bound <= v <= upper_bound:
filtered.append(it)
else:
outliers.append(it)
stats = {
"field": field,
"iqr_factor": float(iqr_factor),
"q1": q1, "q3": q3, "iqr": iqr,
"lower_bound": lower_bound, "upper_bound": upper_bound,
"num_input": len(per_file), "num_filtered": len(filtered), "num_outliers": len(outliers),
}
return filtered, outliers, stats
def filter_outliers_zscore(
per_file: List[Dict],
field: str = "length_input",
zscore_threshold: float = 2.5,
) -> Tuple[List[Dict], List[Dict], Dict]:
if not per_file:
return [], [], {"error": "empty input"}
values = []
for it in per_file:
v = float(it.get(field, float("nan")))
if np.isfinite(v):
values.append(v)
if len(values) < 3:
return per_file, [], {"note": "too few samples for z-score", "num_samples": len(values)}
values_arr = np.array(values, dtype=np.float64)
mean_val = float(np.mean(values_arr))
std_val = float(np.std(values_arr))
if std_val < 1e-9:
return per_file, [], {"note": "zero std", "mean": mean_val}
filtered = []
outliers = []
for it in per_file:
v = float(it.get(field, float("nan")))
if not np.isfinite(v):
filtered.append(it)
else:
zscore = abs(v - mean_val) / std_val
if zscore <= zscore_threshold:
filtered.append(it)
else:
outliers.append(it)
stats = {
"field": field,
"zscore_threshold": float(zscore_threshold),
"mean": mean_val, "std": std_val,
"num_input": len(per_file), "num_filtered": len(filtered), "num_outliers": len(outliers),
}
return filtered, outliers, stats
@torch.no_grad()
def _predict_folder_impl(
model: torch.nn.Module,
ply_files: List[Path],
device: torch.device,
num_points: int = 768,
xyz_scale: float = 0.001,
topk_length: Optional[int] = None,
topk_predictions: int = 10,
top_k: int = 5,
top_by_length: bool = True,
length_switch_to_weight_mm: float = 319.0,
remove_outliers: bool = False,
outlier_method: str = "iqr",
outlier_field: str = "length_input",
iqr_factor: float = 1.5,
zscore_threshold: float = 2.5,
) -> Tuple[List[Dict], Dict]:
"""
Predict weights for a folder of PLY files using DGCNN.
Input format: (B, N, 3) — DGCNN transposes internally.
"""
model.eval()
per_file: List[Dict] = []
if not ply_files:
summary_skipped = {
"num_files": 0,
"num_files_predicted": 0,
"skipped": True,
"skip_reason": "no .ply files",
"avg_predicted_weight_kg": float("nan"),
"avg_predicted_weight_g": float("nan"),
"avg_length_input_topk": float("nan"),
}
return [], summary_skipped
preds_kg: List[float] = []
for ply in ply_files:
pts = load_points_from_ply(ply)
length_input = estimate_length_major_axis(pts)
pts = pts * float(xyz_scale)
pts = pts - pts.mean(axis=0, keepdims=True)
seed = int(zlib.adler32(str(ply).encode("utf-8")) & 0xFFFFFFFF)
pts = sample_points_deterministic(pts, num_points=num_points, seed=seed)
# DGCNN expects (B, N, 3) — transposes internally
x = torch.from_numpy(pts).to(device=device, dtype=torch.float32) # (N, 3)
x = x.unsqueeze(0) # (1, N, 3)
pred_kg = float(model(x).item())
pred_g = pred_kg * 1000.0
preds_kg.append(pred_kg)
per_file.append({
"ply": str(ply),
"predicted_weight_kg": pred_kg,
"predicted_weight_g": pred_g,
"length_input": float(length_input),
"length_after_scale": float(length_input * float(xyz_scale)) if np.isfinite(length_input) else float("nan"),
"is_outlier": False,
})
outlier_stats: Optional[Dict] = None
candidates_for_avg = per_file
num_outliers_removed = 0
if remove_outliers and per_file:
if outlier_method == "iqr":
filtered, outliers, outlier_stats = filter_outliers_iqr(
per_file, field=outlier_field, iqr_factor=iqr_factor
)
elif outlier_method == "zscore":
filtered, outliers, outlier_stats = filter_outliers_zscore(
per_file, field=outlier_field, zscore_threshold=zscore_threshold
)
else:
filtered, outliers = per_file, []
outlier_stats = {"error": f"unknown method: {outlier_method}"}
outlier_plys = {it["ply"] for it in outliers}
for it in per_file:
if it["ply"] in outlier_plys:
it["is_outlier"] = True
candidates_for_avg = filtered
num_outliers_removed = len(outliers)
avg_kg_all = float(np.mean(preds_kg)) if preds_kg else float("nan")
avg_g_all = avg_kg_all * 1000.0 if preds_kg else float("nan")
filtered_preds_g = [float(it["predicted_weight_g"]) for it in candidates_for_avg]
avg_kg_filtered = float(np.mean(filtered_preds_g)) / 1000.0 if filtered_preds_g else float("nan")
avg_g_filtered = float(np.mean(filtered_preds_g)) if filtered_preds_g else float("nan")
sorted_by_length = sorted(
candidates_for_avg,
key=lambda d: float(d.get("length_input", float("-inf"))),
reverse=True,
)
selected_topk_by_length = sorted_by_length[: min(top_k, len(sorted_by_length))]
lengths_of_topk_by_length = [
float(it["length_input"])
for it in selected_topk_by_length
if np.isfinite(float(it.get("length_input", float("nan"))))
]
avg_length_input_topk_by_length = (
float(np.mean(lengths_of_topk_by_length)) if lengths_of_topk_by_length else float("nan")
)
sorted_by_weight = sorted(
candidates_for_avg,
key=lambda d: float(d.get("predicted_weight_g", float("-inf"))),
reverse=True,
)
selected_topk_by_weight = sorted_by_weight[: min(top_k, len(sorted_by_weight))]
switched_to_weight_due_to_long_length = False
if top_by_length:
switch = (
np.isfinite(avg_length_input_topk_by_length)
and avg_length_input_topk_by_length > float(length_switch_to_weight_mm)
)
if switch:
effective_top_by_length = False
selected_topk = selected_topk_by_weight
switched_to_weight_due_to_long_length = True
else:
effective_top_by_length = True
selected_topk = selected_topk_by_length
else:
effective_top_by_length = False
selected_topk = selected_topk_by_weight
preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk]
use_max_instead_of_mean = len(candidates_for_avg) < 5
if preds_g_topk:
avg_g_topk = (
float(np.max(preds_g_topk))
if use_max_instead_of_mean
else float(np.mean(preds_g_topk))
)
else:
avg_g_topk = float(avg_g_all)
num_used_for_avg = len(selected_topk)
plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk]
lengths_topk = [
float(it["length_input"])
for it in selected_topk
if np.isfinite(float(it.get("length_input", float("nan"))))
]
avg_length_input_topk = float(np.mean(lengths_topk)) if lengths_topk else float("nan")
lengths = [
float(it["length_input"])
for it in candidates_for_avg
if np.isfinite(float(it.get("length_input", float("nan"))))
]
avg_len_input = float(np.mean(lengths)) if lengths else float("nan")
std_len_input = float(np.std(lengths)) if len(lengths) > 1 else 0.0
cv_length = float(std_len_input / avg_len_input * 100.0) if avg_len_input > 0 else float("nan")
avg_g = float(avg_g_topk) if np.isfinite(float(avg_g_topk)) else float(avg_g_all)
avg_kg = avg_g / 1000.0
kept_ply_names = set(plys_used_for_prediction)
kept_ply_paths = {it["ply"] for it in per_file if Path(it["ply"]).name in kept_ply_names}
# rank 1..N for all candidates
def _sort_key(d: Dict) -> float:
return (
float(d.get("length_input", float("-inf")))
if effective_top_by_length
else float(d.get("predicted_weight_g", float("-inf")))
)
full_sorted = sorted(candidates_for_avg, key=_sort_key, reverse=True)
rank_by_ply = {it["ply"]: (i + 1) for i, it in enumerate(full_sorted)}
for it in per_file:
it["used_for_prediction"] = it["ply"] in kept_ply_paths
it["rank_by_selection"] = rank_by_ply.get(it["ply"], 0)
it["top_by_length"] = effective_top_by_length
summary = {
"num_files": len(ply_files),
"num_files_predicted": len(per_file),
"num_outliers_removed": num_outliers_removed,
"num_files_after_outlier_removal": len(candidates_for_avg),
"num_files_used_for_avg": int(num_used_for_avg),
"prediction_aggregate": "max" if use_max_instead_of_mean else "mean",
"top_k": top_k,
"top_by_length": top_by_length,
"effective_top_by_length": effective_top_by_length,
"switched_to_weight_due_to_long_length": switched_to_weight_due_to_long_length,
"length_switch_threshold_mm": float(length_switch_to_weight_mm),
"avg_length_input_topk_by_length": avg_length_input_topk_by_length
if top_by_length
else float("nan"),
"plys_used_for_prediction": plys_used_for_prediction,
"avg_predicted_weight_kg": avg_kg,
"avg_predicted_weight_g": avg_g,
"avg_predicted_weight_kg_all": avg_kg_all,
"avg_predicted_weight_g_all": avg_g_all,
"avg_length_input_topk": avg_length_input_topk,
"avg_length_input": avg_len_input,
"std_length_input": std_len_input,
"cv_length_pct": cv_length,
"outlier_removal": {
"enabled": remove_outliers,
"method": outlier_method if remove_outliers else None,
"field": outlier_field if remove_outliers else None,
"stats": outlier_stats,
} if remove_outliers else None,
}
return per_file, summary
def predict_folder(
model: torch.nn.Module,
ply_files: List[Path],
device: torch.device,
num_points: int = 768,
xyz_scale: float = 0.001,
) -> Tuple[List[Dict], Dict]:
return _predict_folder_impl(
model=model,
ply_files=ply_files,
device=device,
num_points=num_points,
xyz_scale=xyz_scale,
topk_length=None,
)
def collect_cloud_folders(batch_root: Path) -> List[Path]:
batch_root = batch_root.expanduser().resolve()
if not batch_root.exists() or not batch_root.is_dir():
raise ValueError(f"batch_root not found/dir: {batch_root}")
out: List[Path] = []
for p in batch_root.rglob("cloud"):
if not p.is_dir():
continue
if any(p.glob("*.ply")):
out.append(p)
return sorted(out)
def group_cloud_folders_by_fish(batch_root: Path) -> Dict[str, List[Path]]:
cloud_folders = collect_cloud_folders(batch_root)
by_fish: Dict[str, List[Path]] = {}
for cloud_dir in cloud_folders:
rel = (
str(cloud_dir.parent.relative_to(batch_root))
if cloud_dir.parent.is_relative_to(batch_root)
else str(cloud_dir.parent)
)
fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir))
if fish_key:
by_fish.setdefault(fish_key, []).append(cloud_dir)
else:
key = rel.replace("/", "_") or "unknown"
by_fish.setdefault(key, []).append(cloud_dir)
return by_fish
def predict_cloud_folder(
model: torch.nn.Module,
cloud_dir: Path,
device: torch.device,
num_points: int,
xyz_scale: float,
recursive: bool = False,
topk_length: Optional[int] = None,
topk_predictions: int = 10,
top_k: int = 5,
top_by_length: bool = True,
length_switch_to_weight_mm: float = 319.0,
remove_outliers: bool = False,
outlier_method: str = "iqr",
outlier_field: str = "length_input",
iqr_factor: float = 1.5,
zscore_threshold: float = 2.5,
ply_files: Optional[List[Path]] = None,
) -> Tuple[List[Dict], Dict]:
if ply_files is not None:
ply_files = sorted({Path(p).expanduser().resolve() for p in ply_files})
for p in ply_files:
if not p.is_file():
raise FileNotFoundError(f"PLY not found: {p}")
elif recursive:
ply_files = sorted(cloud_dir.rglob("*.ply"))
else:
ply_files = sorted(cloud_dir.glob("*.ply"))
return _predict_folder_impl(
model=model,
ply_files=ply_files,
device=device,
num_points=num_points,
xyz_scale=xyz_scale,
topk_length=topk_length,
topk_predictions=topk_predictions,
top_k=top_k,
top_by_length=top_by_length,
length_switch_to_weight_mm=length_switch_to_weight_mm,
remove_outliers=remove_outliers,
outlier_method=outlier_method,
outlier_field=outlier_field,
iqr_factor=iqr_factor,
zscore_threshold=zscore_threshold,
)
def load_label_weights_json(labels_path: Path) -> Dict[str, float]:
labels_path = labels_path.expanduser().resolve()
if not labels_path.exists():
return {}
data = json.loads(labels_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
out: Dict[str, float] = {}
for k, v in data.items():
try:
out[str(k).strip()] = float(v)
except (TypeError, ValueError):
if isinstance(v, dict) and "weight" in v:
try:
out[str(k).strip()] = float(v["weight"])
except (TypeError, ValueError):
pass
continue
return out
def load_label_lengths_json(labels_path: Optional[Path] = None) -> Dict[str, float]:
"""Load fish_id -> length_mm from labels JSON when values are dict with 'length_mm'."""
if labels_path is None:
return {}
labels_path = labels_path.expanduser().resolve()
if not labels_path.exists():
return {}
data = json.loads(labels_path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
out: Dict[str, float] = {}
for k, v in data.items():
if isinstance(v, dict) and "length_mm" in v:
try:
out[str(k).strip()] = float(v["length_mm"])
except (TypeError, ValueError):
pass
return out
def load_label_lengths_from_file(path: Path) -> Dict[str, float]:
"""Load fish_id -> length_mm from JSON. Keys like 'fish12', values in mm (float or dict with length_mm)."""
path = path.expanduser().resolve()
if not path.exists():
return {}
data = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
return {}
out: Dict[str, float] = {}
for k, v in data.items():
key = str(k).strip()
try:
out[key] = float(v)
except (TypeError, ValueError):
if isinstance(v, dict) and "length_mm" in v:
try:
out[key] = float(v["length_mm"])
except (TypeError, ValueError):
pass
return out
def extract_fish_key_from_text(text: str) -> Optional[str]:
m = re.search(r"(?:^|[^a-z0-9])fish\s*[_\- ]*\s*0*([0-9]{1,4})(?:[^0-9]|$)", text, flags=re.IGNORECASE)
if not m:
m = re.search(r"fish\s*0*([0-9]{1,4})", text, flags=re.IGNORECASE)
if not m:
return None
try:
n = int(m.group(1))
except Exception:
return None
if n <= 0:
return None
return f"fish{n}"
def compare_pred_vs_actual(pred_g: float, actual_g: float) -> Dict[str, float]:
if not np.isfinite(pred_g) or not np.isfinite(actual_g) or actual_g == 0.0:
return {
"predicted_weight_g": float(pred_g),
"actual_weight_g": float(actual_g),
"diff_g": float("nan"),
"diff_pct": float("nan"),
"abs_diff_pct": float("nan"),
}
diff_g = float(pred_g - actual_g)
diff_pct = float(diff_g / actual_g * 100.0)
return {
"predicted_weight_g": float(pred_g),
"actual_weight_g": float(actual_g),
"diff_g": diff_g,
"diff_pct": diff_pct,
"abs_diff_pct": abs(diff_pct),
}
def main() -> None:
parser = argparse.ArgumentParser("DGCNN weight estimator (folder inference)")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to best.pt/last.pt checkpoint")
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument("--ply-folder", type=str, default="", help="Folder containing .ply files")
src.add_argument(
"--ply-list-file",
type=str,
default="",
help="Text file: one absolute or relative .ply path per line (combine clouds from multiple runs)",
)
src.add_argument(
"--batch-root",
type=str,
default="",
help="Batch mode: recursively find */cloud folders under this root and predict each",
)
parser.add_argument(
"--combine-by-fish",
action="store_true",
default=True,
help="Batch mode: combine all folders per fish (default: True)",
)
parser.add_argument(
"--no-combine-by-fish",
action="store_false",
dest="combine_by_fish",
help="Batch mode: predict each folder separately",
)
parser.add_argument("--recursive", action="store_true", help="Recursively search for .ply files (single folder mode)")
parser.add_argument("--num-points", type=int, default=768, help="Number of points per PLY (default: 768)")
parser.add_argument("--xyz-scale", type=float, default=0.001, help="XYZ scaling factor (default: 0.001)")
parser.add_argument("--device", type=str, default="auto", help="auto|cpu|cuda")
parser.add_argument("--output-json", type=str, default=None, help="Optional path to save results JSON")
parser.add_argument(
"--labels-json",
type=str,
default="/home/ubuntu/label20_weights.json",
help="Optional label mapping JSON. Keys like 'fish12', values in grams.",
)
parser.add_argument(
"--labels-length-json",
type=str,
default=None,
help="Optional JSON with fish_id -> length_mm. Or use labels-json with dict values: {\"weight\": N, \"length_mm\": M}.",
)
parser.add_argument(
"--topk-length", "--topk_length",
type=int,
default=None,
dest="topk_length",
help="If set, compute folder average as length-weighted average over top-K longest PLYs",
)
parser.add_argument(
"--topk-predictions",
type=int,
default=10,
help="When > this many PLYs, remove length outliers and average the rest (default: 10)",
)
parser.add_argument(
"--top-k", "--top_k",
type=int,
default=10,
dest="top_k",
help="Number of top predictions to average (default: 10)",
)
parser.add_argument(
"--top-by-length",
action=argparse.BooleanOptionalAction,
default=True,
dest="top_by_length",
help="Select top-K by cloud length (default: on). Use --no-top-by-length for predicted-weight order.",
)
parser.add_argument(
"--length-switch-mm",
type=float,
default=319.0,
dest="length_switch_mm",
help="With --top-by-length: if mean top-K length (mm) exceeds this, use top-K by predicted weight instead (default: 319).",
)
parser.add_argument("--print-per-ply", action="store_true", help="Print per-PLY predicted weight (batch mode)")
parser.add_argument("--remove-outliers", action="store_true", help="Remove outliers before computing average")
parser.add_argument("--outlier-method", type=str, default="iqr", choices=["iqr", "zscore"])
parser.add_argument(
"--outlier-field",
type=str,
default="length_input",
choices=["length_input", "predicted_weight_g"],
)
parser.add_argument("--iqr-factor", type=float, default=1.5)
parser.add_argument("--zscore-threshold", type=float, default=2.5)
parser.add_argument(
"--max-cv",
type=float,
default=None,
help="Maximum CV%% for length. Skip folders with CV > this value (e.g. 15.0).",
)
args = parser.parse_args()
ckpt_path = Path(args.checkpoint).expanduser().resolve()
if not ckpt_path.exists():
raise SystemExit(f"checkpoint not found: {ckpt_path}")
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
from weight_estimator.train_dgcnn_weight_estimator import load_model_from_checkpoint
model = load_model_from_checkpoint(ckpt_path, device=device)
labels_path = Path(args.labels_json).expanduser().resolve() if args.labels_json else None
labels = load_label_weights_json(labels_path) if labels_path is not None else {}
label_lengths: Dict[str, float] = load_label_lengths_json(labels_path) if labels_path else {}
if args.labels_length_json:
length_path = Path(args.labels_length_json).expanduser().resolve()
extra = load_label_lengths_from_file(length_path)
label_lengths = {**label_lengths, **extra}
ply_list_path_for_meta: Optional[Path] = None
if args.ply_list_file:
list_path = Path(args.ply_list_file).expanduser().resolve()
if not list_path.is_file():
raise SystemExit(f"ply-list-file not found: {list_path}")
lines = [
ln.strip()
for ln in list_path.read_text(encoding="utf-8").splitlines()
if ln.strip() and not ln.strip().startswith("#")
]
ply_paths = [Path(l).expanduser().resolve() for l in lines]
if not ply_paths:
raise SystemExit("ply-list-file is empty (no PLY paths)")
for p in ply_paths:
if not p.is_file():
raise SystemExit(f"PLY in list not found: {p}")
ply_list_path_for_meta = list_path
ply_folder = ply_paths[0].parent
per_file, summary = predict_cloud_folder(
model=model,
cloud_dir=ply_folder,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
recursive=False,
topk_length=args.topk_length,
topk_predictions=args.topk_predictions,
top_k=args.top_k,
top_by_length=args.top_by_length,
length_switch_to_weight_mm=args.length_switch_mm,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
ply_files=ply_paths,
)
elif args.ply_folder:
ply_folder = Path(args.ply_folder).expanduser().resolve()
if not ply_folder.exists() or not ply_folder.is_dir():
raise SystemExit(f"ply-folder not found/dir: {ply_folder}")
per_file, summary = predict_cloud_folder(
model=model,
cloud_dir=ply_folder,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
recursive=args.recursive,
topk_length=args.topk_length,
topk_predictions=args.topk_predictions,
top_k=args.top_k,
top_by_length=args.top_by_length,
length_switch_to_weight_mm=args.length_switch_mm,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
)
if args.ply_list_file or args.ply_folder:
if summary.get("skipped"):
print(f"Skipped: {summary.get('skip_reason', 'fewer than 5 PLYs')} (files={summary.get('num_files', 0)})")
out = {
"meta": {
"checkpoint": str(ckpt_path),
"model": "dgcnn",
"ply_folder": str(ply_folder),
"ply_list_file": str(ply_list_path_for_meta) if ply_list_path_for_meta else None,
"recursive": bool(args.recursive),
"num_points": int(args.num_points),
"xyz_scale": float(args.xyz_scale),
"device": str(device),
"labels_json": str(labels_path) if labels_path else None,
"length_switch_mm": float(args.length_switch_mm),
},
"summary": summary,
"comparison": None,
"per_file": per_file,
}
if args.output_json:
out_path = Path(args.output_json).expanduser().resolve()
suffix = f"_top{args.top_k}_by_length" if args.top_by_length else f"_top{args.top_k}"
out_path = out_path.parent / f"{out_path.stem}{suffix}{out_path.suffix}"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Saved: {out_path}")
return
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
print(f"[Kept = {top_label}]")
for it in per_file:
ply = Path(it["ply"]).name
g = float(it["predicted_weight_g"])
kg = float(it["predicted_weight_kg"])
length_input = float(it.get("length_input", float("nan")))
used = it.get("used_for_prediction", True)
rank = it.get("rank_by_selection", 0)
n_total = summary.get("num_files_predicted", len(per_file))
if rank > 0:
tag = f" (kept, rank {rank}/{n_total})" if used else f" (filtered, rank {rank}/{n_total})"
else:
tag = " (kept)" if used else " (filtered)"
if abs(float(args.xyz_scale) - 0.001) < 1e-12:
length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan"
else:
length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan"
print(f"{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}")
print(f"Files: {summary.get('num_files_predicted', summary['num_files'])}")
if args.remove_outliers and summary.get("num_outliers_removed", 0) > 0:
print(f"Outliers removed: {summary['num_outliers_removed']} (method={args.outlier_method}, field={args.outlier_field})")
if summary.get("cv_length_pct") is not None:
print(f"Length CV: {summary['cv_length_pct']:.1f}%")
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
print(
f"Average predicted weight ({top_label}): "
f"{summary['avg_predicted_weight_g']:.2f} g ({summary['avg_predicted_weight_kg']:.4f} kg) "
f"(used={summary.get('num_files_used_for_avg', 'n/a')}/{summary.get('num_files_predicted', 'n/a')})"
)
avg_len_topk = summary.get("avg_length_input_topk")
if avg_len_topk is not None and np.isfinite(avg_len_topk):
print(
f"Average length ({top_label}): "
f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}"
)
fish_key = extract_fish_key_from_text(str(ply_folder))
comparison = None
if fish_key and fish_key in labels:
actual_g = float(labels[fish_key])
comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g)
print(
f"Actual weight ({fish_key}): {actual_g:.2f} g | "
f"Diff: {comparison['diff_g']:.2f} g | "
f"Diff%: {comparison['diff_pct']:.2f}% (abs {comparison['abs_diff_pct']:.2f}%)"
)
elif labels:
print(f"[warn] No label match for folder: {ply_folder} (parsed key={fish_key})")
out = {
"meta": {
"checkpoint": str(ckpt_path),
"model": "dgcnn",
"ply_folder": str(ply_folder),
"ply_list_file": str(ply_list_path_for_meta) if ply_list_path_for_meta else None,
"recursive": bool(args.recursive),
"num_points": int(args.num_points),
"xyz_scale": float(args.xyz_scale),
"device": str(device),
"labels_json": str(labels_path) if labels_path else None,
"length_switch_mm": float(args.length_switch_mm),
},
"summary": summary,
"comparison": comparison,
"per_file": per_file,
}
if args.output_json:
out_path = Path(args.output_json).expanduser().resolve()
suffix = f"_top{args.top_k}_by_length" if args.top_by_length else f"_top{args.top_k}"
out_path = out_path.parent / f"{out_path.stem}{suffix}{out_path.suffix}"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Saved: {out_path}")
else:
batch_root = Path(args.batch_root).expanduser().resolve()
if args.combine_by_fish:
by_fish = group_cloud_folders_by_fish(batch_root)
if not by_fish:
raise SystemExit(f"No cloud folders found under: {batch_root}")
results: List[Dict] = []
for fish_key, cloud_dirs in sorted(by_fish.items()):
ply_files = []
for cloud_dir in cloud_dirs:
ply_files.extend(sorted(cloud_dir.glob("*.ply")))
ply_files = sorted(set(ply_files))
try:
per_file, summary = _predict_folder_impl(
model=model,
ply_files=ply_files,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
top_k=args.top_k,
top_by_length=args.top_by_length,
length_switch_to_weight_mm=args.length_switch_mm,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
)
if summary.get("skipped"):
print(f"{fish_key}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})")
results.append({
"id": fish_key,
"fish_key": fish_key,
"skipped": True,
"skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"),
"num_files": summary.get("num_files", 0),
"cloud_folders": len(cloud_dirs),
})
continue
num_used = summary.get("num_files_used_for_avg", "n/a")
num_total = summary.get("num_files_predicted", "n/a")
used_info = f"used={num_used}/{num_total} (from {len(cloud_dirs)} folders)"
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
avg_len_topk = summary.get("avg_length_input_topk")
len_part = ""
if avg_len_topk is not None and np.isfinite(avg_len_topk):
len_part = (
f" | avg_len(top{args.top_k} selected)="
f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}"
)
if fish_key in labels:
actual_g = float(labels[fish_key])
comparison = compare_pred_vs_actual(
pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g
)
print(
f"{fish_key}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part} | "
f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%"
)
else:
print(
f"{fish_key}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part}"
)
comparison = None
if fish_key in labels:
comparison = compare_pred_vs_actual(
pred_g=float(summary["avg_predicted_weight_g"]),
actual_g=float(labels[fish_key]),
)
results.append({
"id": fish_key,
"fish_key": fish_key,
"cloud_folders": len(cloud_dirs),
"summary": summary,
"comparison": comparison,
"per_file": per_file,
})
except Exception as e:
results.append({"id": fish_key, "fish_key": fish_key, "error": str(e)})
else:
cloud_folders = collect_cloud_folders(batch_root)
if not cloud_folders:
raise SystemExit(f"No cloud folders found under: {batch_root}")
results = []
for cloud_dir in cloud_folders:
rel = (
str(cloud_dir.parent.relative_to(batch_root))
if cloud_dir.parent.is_relative_to(batch_root)
else str(cloud_dir.parent)
)
try:
per_file, summary = predict_cloud_folder(
model=model,
cloud_dir=cloud_dir,
device=device,
num_points=args.num_points,
xyz_scale=args.xyz_scale,
recursive=False,
topk_length=args.topk_length,
topk_predictions=args.topk_predictions,
top_k=args.top_k,
top_by_length=args.top_by_length,
length_switch_to_weight_mm=args.length_switch_mm,
remove_outliers=args.remove_outliers,
outlier_method=args.outlier_method,
outlier_field=args.outlier_field,
iqr_factor=args.iqr_factor,
zscore_threshold=args.zscore_threshold,
)
if summary.get("skipped"):
print(f"{rel}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})")
results.append({
"id": rel,
"cloud_dir": str(cloud_dir),
"skipped": True,
"skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"),
"num_files": summary.get("num_files", 0),
})
continue
cv_pct = summary.get("cv_length_pct", None)
if args.max_cv is not None and cv_pct is not None and np.isfinite(cv_pct):
if cv_pct > args.max_cv:
print(f"{rel}: SKIPPED (cv={cv_pct:.1f}% > max_cv={args.max_cv}%)")
results.append({
"id": rel,
"cloud_dir": str(cloud_dir),
"skipped": True,
"skip_reason": f"cv_length={cv_pct:.1f}% exceeds max_cv={args.max_cv}%",
"cv_length_pct": cv_pct,
})
continue
if args.print_per_ply:
for it in per_file:
ply = Path(it["ply"]).name
g = float(it["predicted_weight_g"])
kg = float(it["predicted_weight_kg"])
length_input = float(it.get("length_input", float("nan")))
used = it.get("used_for_prediction", True)
tag = " (kept)" if used else " (filtered)"
if abs(float(args.xyz_scale) - 0.001) < 1e-12:
length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan"
else:
length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan"
print(f"{rel}/{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}")
fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir))
comparison = None
num_after_outlier = summary.get(
"num_files_after_outlier_removal", summary.get("num_files_predicted", "n/a")
)
num_used = summary.get("num_files_used_for_avg", "n/a")
num_total = summary.get("num_files_predicted", "n/a")
num_outliers = summary.get("num_outliers_removed", 0)
if args.remove_outliers and num_outliers > 0:
used_info = f"used={num_used}/{num_after_outlier}, outliers={num_outliers}"
else:
used_info = f"used={num_used}/{num_total}"
if cv_pct is not None and cv_pct > 10.0:
used_info += f", cv={cv_pct:.1f}%"
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
avg_len_topk = summary.get("avg_length_input_topk")
len_part = ""
if avg_len_topk is not None and np.isfinite(avg_len_topk):
len_part = (
f" | avg_len(top{args.top_k} selected)="
f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}"
)
if fish_key and fish_key in labels:
actual_g = float(labels[fish_key])
comparison = compare_pred_vs_actual(
pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g
)
print(
f"{rel}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part} | "
f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%"
)
else:
print(
f"{rel}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part}"
)
results.append({
"id": rel,
"cloud_dir": str(cloud_dir),
"summary": summary,
"fish_key": fish_key,
"comparison": comparison,
})
except Exception as e:
results.append({"id": rel, "cloud_dir": str(cloud_dir), "error": str(e)})
meta_cloud_folders = len(collect_cloud_folders(batch_root)) if batch_root else 0
out = {
"meta": {
"checkpoint": str(ckpt_path),
"model": "dgcnn",
"batch_root": str(batch_root),
"num_points": int(args.num_points),
"xyz_scale": float(args.xyz_scale),
"device": str(device),
"combine_by_fish": args.combine_by_fish,
"top_k": args.top_k,
"top_by_length": args.top_by_length,
"length_switch_mm": float(args.length_switch_mm),
"cloud_folders": meta_cloud_folders,
"labels_json": str(labels_path) if labels_path else None,
},
"results": results,
}
out_path = (
Path(args.output_json).expanduser().resolve()
if args.output_json
else (batch_root / "batch_weight_predictions.json")
)
suffix = f"_top{args.top_k}_by_length" if args.top_by_length else f"_top{args.top_k}"
out_path = out_path.parent / f"{out_path.stem}{suffix}{out_path.suffix}"
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Saved: {out_path}")
if __name__ == "__main__":
main()