1113 lines
45 KiB
Python
Executable File
1113 lines
45 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Run DGCNN weight regression on a folder of PLY point clouds.
|
|
|
|
For each PLY:
|
|
- load points
|
|
- scale XYZ by 0.001
|
|
- center to origin (subtract centroid)
|
|
- sample fixed N points (default 768)
|
|
- predict weight (kg)
|
|
|
|
Final prediction = aggregate (mean or max if <5 PLYs) over top-K selected by length by default;
|
|
if mean top-K length exceeds --length-switch-mm (default 319), top-K is taken by predicted weight instead.
|
|
With --no-top-by-length, selection is always by predicted weight.
|
|
|
|
Uses the same preprocessing and aggregation logic as test_pointnet_weight_estimator.py.
|
|
DGCNN accepts (B, N, 3) and transposes internally.
|
|
|
|
Example:
|
|
python weight_estimator/test_dgcnn_weight_estimator.py \\
|
|
--checkpoint weight_estimator/runs/dgcnn_20260312_143020/best.pt \\
|
|
--ply-folder /path/to/xxxx/cloud
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import re
|
|
import sys
|
|
import zlib
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import open3d as o3d
|
|
import torch
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
WEIGHT_EST_DIR = Path(__file__).resolve().parent
|
|
for _p in (str(REPO_ROOT), str(WEIGHT_EST_DIR)):
|
|
if _p not in sys.path:
|
|
sys.path.insert(0, _p)
|
|
|
|
|
|
def _format_length_input_display(value: Optional[float], xyz_scale: float) -> str:
|
|
if value is None or not np.isfinite(float(value)):
|
|
return "nan"
|
|
v = float(value)
|
|
if abs(float(xyz_scale) - 0.001) < 1e-12:
|
|
return f"{v:.1f} mm"
|
|
return f"{v:.4f} units"
|
|
|
|
|
|
def _summary_top_label(summary: Dict[str, Any], top_k: int, fallback_by_length: bool) -> str:
|
|
eff = summary.get("effective_top_by_length")
|
|
if eff is None:
|
|
eff = fallback_by_length
|
|
return f"top{top_k} by length" if eff else f"top{top_k} by pred"
|
|
|
|
|
|
def load_points_from_ply(ply_path: Path) -> np.ndarray:
|
|
pcd = o3d.io.read_point_cloud(str(ply_path))
|
|
if len(pcd.points) > 0:
|
|
return np.asarray(pcd.points, dtype=np.float32)
|
|
|
|
mesh = o3d.io.read_triangle_mesh(str(ply_path))
|
|
if len(mesh.vertices) > 0:
|
|
return np.asarray(mesh.vertices, dtype=np.float32)
|
|
|
|
raise ValueError(f"No points/vertices found in: {ply_path}")
|
|
|
|
|
|
def sample_points_deterministic(points: np.ndarray, num_points: int, seed: int) -> np.ndarray:
|
|
n = points.shape[0]
|
|
if n <= 0:
|
|
raise ValueError("Empty point cloud")
|
|
rng = np.random.default_rng(seed)
|
|
if n >= num_points:
|
|
idx = rng.choice(n, size=num_points, replace=False)
|
|
else:
|
|
idx = rng.choice(n, size=num_points, replace=True)
|
|
return points[idx]
|
|
|
|
|
|
def estimate_length_major_axis(points: np.ndarray) -> float:
|
|
if points is None or points.ndim != 2 or points.shape[0] < 3 or points.shape[1] < 3:
|
|
return float("nan")
|
|
pts = points[:, :3].astype(np.float32, copy=False)
|
|
pts = pts - pts.mean(axis=0, keepdims=True)
|
|
try:
|
|
_u, _s, vt = np.linalg.svd(pts, full_matrices=False)
|
|
axis = vt[0]
|
|
proj = pts @ axis
|
|
return float(np.max(proj) - np.min(proj))
|
|
except Exception:
|
|
return float("nan")
|
|
|
|
|
|
def filter_outliers_iqr(
|
|
per_file: List[Dict],
|
|
field: str = "length_input",
|
|
iqr_factor: float = 1.5,
|
|
) -> Tuple[List[Dict], List[Dict], Dict]:
|
|
if not per_file:
|
|
return [], [], {"error": "empty input"}
|
|
values = []
|
|
for it in per_file:
|
|
v = float(it.get(field, float("nan")))
|
|
if np.isfinite(v):
|
|
values.append(v)
|
|
if len(values) < 4:
|
|
return per_file, [], {"note": "too few samples for IQR", "num_samples": len(values)}
|
|
values_arr = np.array(values, dtype=np.float64)
|
|
q1 = float(np.percentile(values_arr, 25))
|
|
q3 = float(np.percentile(values_arr, 75))
|
|
iqr = q3 - q1
|
|
lower_bound = q1 - iqr_factor * iqr
|
|
upper_bound = q3 + iqr_factor * iqr
|
|
filtered = []
|
|
outliers = []
|
|
for it in per_file:
|
|
v = float(it.get(field, float("nan")))
|
|
if not np.isfinite(v):
|
|
filtered.append(it)
|
|
elif lower_bound <= v <= upper_bound:
|
|
filtered.append(it)
|
|
else:
|
|
outliers.append(it)
|
|
stats = {
|
|
"field": field,
|
|
"iqr_factor": float(iqr_factor),
|
|
"q1": q1, "q3": q3, "iqr": iqr,
|
|
"lower_bound": lower_bound, "upper_bound": upper_bound,
|
|
"num_input": len(per_file), "num_filtered": len(filtered), "num_outliers": len(outliers),
|
|
}
|
|
return filtered, outliers, stats
|
|
|
|
|
|
def filter_outliers_zscore(
|
|
per_file: List[Dict],
|
|
field: str = "length_input",
|
|
zscore_threshold: float = 2.5,
|
|
) -> Tuple[List[Dict], List[Dict], Dict]:
|
|
if not per_file:
|
|
return [], [], {"error": "empty input"}
|
|
values = []
|
|
for it in per_file:
|
|
v = float(it.get(field, float("nan")))
|
|
if np.isfinite(v):
|
|
values.append(v)
|
|
if len(values) < 3:
|
|
return per_file, [], {"note": "too few samples for z-score", "num_samples": len(values)}
|
|
values_arr = np.array(values, dtype=np.float64)
|
|
mean_val = float(np.mean(values_arr))
|
|
std_val = float(np.std(values_arr))
|
|
if std_val < 1e-9:
|
|
return per_file, [], {"note": "zero std", "mean": mean_val}
|
|
filtered = []
|
|
outliers = []
|
|
for it in per_file:
|
|
v = float(it.get(field, float("nan")))
|
|
if not np.isfinite(v):
|
|
filtered.append(it)
|
|
else:
|
|
zscore = abs(v - mean_val) / std_val
|
|
if zscore <= zscore_threshold:
|
|
filtered.append(it)
|
|
else:
|
|
outliers.append(it)
|
|
stats = {
|
|
"field": field,
|
|
"zscore_threshold": float(zscore_threshold),
|
|
"mean": mean_val, "std": std_val,
|
|
"num_input": len(per_file), "num_filtered": len(filtered), "num_outliers": len(outliers),
|
|
}
|
|
return filtered, outliers, stats
|
|
|
|
|
|
@torch.no_grad()
|
|
def _predict_folder_impl(
|
|
model: torch.nn.Module,
|
|
ply_files: List[Path],
|
|
device: torch.device,
|
|
num_points: int = 768,
|
|
xyz_scale: float = 0.001,
|
|
topk_length: Optional[int] = None,
|
|
topk_predictions: int = 10,
|
|
top_k: int = 5,
|
|
top_by_length: bool = True,
|
|
length_switch_to_weight_mm: float = 319.0,
|
|
remove_outliers: bool = False,
|
|
outlier_method: str = "iqr",
|
|
outlier_field: str = "length_input",
|
|
iqr_factor: float = 1.5,
|
|
zscore_threshold: float = 2.5,
|
|
) -> Tuple[List[Dict], Dict]:
|
|
"""
|
|
Predict weights for a folder of PLY files using DGCNN.
|
|
Input format: (B, N, 3) — DGCNN transposes internally.
|
|
"""
|
|
model.eval()
|
|
per_file: List[Dict] = []
|
|
|
|
if not ply_files:
|
|
summary_skipped = {
|
|
"num_files": 0,
|
|
"num_files_predicted": 0,
|
|
"skipped": True,
|
|
"skip_reason": "no .ply files",
|
|
"avg_predicted_weight_kg": float("nan"),
|
|
"avg_predicted_weight_g": float("nan"),
|
|
"avg_length_input_topk": float("nan"),
|
|
}
|
|
return [], summary_skipped
|
|
|
|
preds_kg: List[float] = []
|
|
|
|
for ply in ply_files:
|
|
pts = load_points_from_ply(ply)
|
|
length_input = estimate_length_major_axis(pts)
|
|
|
|
pts = pts * float(xyz_scale)
|
|
pts = pts - pts.mean(axis=0, keepdims=True)
|
|
|
|
seed = int(zlib.adler32(str(ply).encode("utf-8")) & 0xFFFFFFFF)
|
|
pts = sample_points_deterministic(pts, num_points=num_points, seed=seed)
|
|
|
|
# DGCNN expects (B, N, 3) — transposes internally
|
|
x = torch.from_numpy(pts).to(device=device, dtype=torch.float32) # (N, 3)
|
|
x = x.unsqueeze(0) # (1, N, 3)
|
|
|
|
pred_kg = float(model(x).item())
|
|
pred_g = pred_kg * 1000.0
|
|
|
|
preds_kg.append(pred_kg)
|
|
per_file.append({
|
|
"ply": str(ply),
|
|
"predicted_weight_kg": pred_kg,
|
|
"predicted_weight_g": pred_g,
|
|
"length_input": float(length_input),
|
|
"length_after_scale": float(length_input * float(xyz_scale)) if np.isfinite(length_input) else float("nan"),
|
|
"is_outlier": False,
|
|
})
|
|
|
|
outlier_stats: Optional[Dict] = None
|
|
candidates_for_avg = per_file
|
|
num_outliers_removed = 0
|
|
|
|
if remove_outliers and per_file:
|
|
if outlier_method == "iqr":
|
|
filtered, outliers, outlier_stats = filter_outliers_iqr(
|
|
per_file, field=outlier_field, iqr_factor=iqr_factor
|
|
)
|
|
elif outlier_method == "zscore":
|
|
filtered, outliers, outlier_stats = filter_outliers_zscore(
|
|
per_file, field=outlier_field, zscore_threshold=zscore_threshold
|
|
)
|
|
else:
|
|
filtered, outliers = per_file, []
|
|
outlier_stats = {"error": f"unknown method: {outlier_method}"}
|
|
outlier_plys = {it["ply"] for it in outliers}
|
|
for it in per_file:
|
|
if it["ply"] in outlier_plys:
|
|
it["is_outlier"] = True
|
|
candidates_for_avg = filtered
|
|
num_outliers_removed = len(outliers)
|
|
|
|
avg_kg_all = float(np.mean(preds_kg)) if preds_kg else float("nan")
|
|
avg_g_all = avg_kg_all * 1000.0 if preds_kg else float("nan")
|
|
|
|
filtered_preds_g = [float(it["predicted_weight_g"]) for it in candidates_for_avg]
|
|
avg_kg_filtered = float(np.mean(filtered_preds_g)) / 1000.0 if filtered_preds_g else float("nan")
|
|
avg_g_filtered = float(np.mean(filtered_preds_g)) if filtered_preds_g else float("nan")
|
|
|
|
sorted_by_length = sorted(
|
|
candidates_for_avg,
|
|
key=lambda d: float(d.get("length_input", float("-inf"))),
|
|
reverse=True,
|
|
)
|
|
selected_topk_by_length = sorted_by_length[: min(top_k, len(sorted_by_length))]
|
|
lengths_of_topk_by_length = [
|
|
float(it["length_input"])
|
|
for it in selected_topk_by_length
|
|
if np.isfinite(float(it.get("length_input", float("nan"))))
|
|
]
|
|
avg_length_input_topk_by_length = (
|
|
float(np.mean(lengths_of_topk_by_length)) if lengths_of_topk_by_length else float("nan")
|
|
)
|
|
|
|
sorted_by_weight = sorted(
|
|
candidates_for_avg,
|
|
key=lambda d: float(d.get("predicted_weight_g", float("-inf"))),
|
|
reverse=True,
|
|
)
|
|
selected_topk_by_weight = sorted_by_weight[: min(top_k, len(sorted_by_weight))]
|
|
|
|
switched_to_weight_due_to_long_length = False
|
|
if top_by_length:
|
|
switch = (
|
|
np.isfinite(avg_length_input_topk_by_length)
|
|
and avg_length_input_topk_by_length > float(length_switch_to_weight_mm)
|
|
)
|
|
if switch:
|
|
effective_top_by_length = False
|
|
selected_topk = selected_topk_by_weight
|
|
switched_to_weight_due_to_long_length = True
|
|
else:
|
|
effective_top_by_length = True
|
|
selected_topk = selected_topk_by_length
|
|
else:
|
|
effective_top_by_length = False
|
|
selected_topk = selected_topk_by_weight
|
|
preds_g_topk = [float(it["predicted_weight_g"]) for it in selected_topk]
|
|
use_max_instead_of_mean = len(candidates_for_avg) < 5
|
|
if preds_g_topk:
|
|
avg_g_topk = (
|
|
float(np.max(preds_g_topk))
|
|
if use_max_instead_of_mean
|
|
else float(np.mean(preds_g_topk))
|
|
)
|
|
else:
|
|
avg_g_topk = float(avg_g_all)
|
|
num_used_for_avg = len(selected_topk)
|
|
plys_used_for_prediction = [Path(it["ply"]).name for it in selected_topk]
|
|
|
|
lengths_topk = [
|
|
float(it["length_input"])
|
|
for it in selected_topk
|
|
if np.isfinite(float(it.get("length_input", float("nan"))))
|
|
]
|
|
avg_length_input_topk = float(np.mean(lengths_topk)) if lengths_topk else float("nan")
|
|
|
|
lengths = [
|
|
float(it["length_input"])
|
|
for it in candidates_for_avg
|
|
if np.isfinite(float(it.get("length_input", float("nan"))))
|
|
]
|
|
avg_len_input = float(np.mean(lengths)) if lengths else float("nan")
|
|
std_len_input = float(np.std(lengths)) if len(lengths) > 1 else 0.0
|
|
cv_length = float(std_len_input / avg_len_input * 100.0) if avg_len_input > 0 else float("nan")
|
|
|
|
avg_g = float(avg_g_topk) if np.isfinite(float(avg_g_topk)) else float(avg_g_all)
|
|
avg_kg = avg_g / 1000.0
|
|
|
|
kept_ply_names = set(plys_used_for_prediction)
|
|
kept_ply_paths = {it["ply"] for it in per_file if Path(it["ply"]).name in kept_ply_names}
|
|
# rank 1..N for all candidates
|
|
def _sort_key(d: Dict) -> float:
|
|
return (
|
|
float(d.get("length_input", float("-inf")))
|
|
if effective_top_by_length
|
|
else float(d.get("predicted_weight_g", float("-inf")))
|
|
)
|
|
|
|
full_sorted = sorted(candidates_for_avg, key=_sort_key, reverse=True)
|
|
rank_by_ply = {it["ply"]: (i + 1) for i, it in enumerate(full_sorted)}
|
|
for it in per_file:
|
|
it["used_for_prediction"] = it["ply"] in kept_ply_paths
|
|
it["rank_by_selection"] = rank_by_ply.get(it["ply"], 0)
|
|
it["top_by_length"] = effective_top_by_length
|
|
|
|
summary = {
|
|
"num_files": len(ply_files),
|
|
"num_files_predicted": len(per_file),
|
|
"num_outliers_removed": num_outliers_removed,
|
|
"num_files_after_outlier_removal": len(candidates_for_avg),
|
|
"num_files_used_for_avg": int(num_used_for_avg),
|
|
"prediction_aggregate": "max" if use_max_instead_of_mean else "mean",
|
|
"top_k": top_k,
|
|
"top_by_length": top_by_length,
|
|
"effective_top_by_length": effective_top_by_length,
|
|
"switched_to_weight_due_to_long_length": switched_to_weight_due_to_long_length,
|
|
"length_switch_threshold_mm": float(length_switch_to_weight_mm),
|
|
"avg_length_input_topk_by_length": avg_length_input_topk_by_length
|
|
if top_by_length
|
|
else float("nan"),
|
|
"plys_used_for_prediction": plys_used_for_prediction,
|
|
"avg_predicted_weight_kg": avg_kg,
|
|
"avg_predicted_weight_g": avg_g,
|
|
"avg_predicted_weight_kg_all": avg_kg_all,
|
|
"avg_predicted_weight_g_all": avg_g_all,
|
|
"avg_length_input_topk": avg_length_input_topk,
|
|
"avg_length_input": avg_len_input,
|
|
"std_length_input": std_len_input,
|
|
"cv_length_pct": cv_length,
|
|
"outlier_removal": {
|
|
"enabled": remove_outliers,
|
|
"method": outlier_method if remove_outliers else None,
|
|
"field": outlier_field if remove_outliers else None,
|
|
"stats": outlier_stats,
|
|
} if remove_outliers else None,
|
|
}
|
|
|
|
return per_file, summary
|
|
|
|
|
|
def predict_folder(
|
|
model: torch.nn.Module,
|
|
ply_files: List[Path],
|
|
device: torch.device,
|
|
num_points: int = 768,
|
|
xyz_scale: float = 0.001,
|
|
) -> Tuple[List[Dict], Dict]:
|
|
return _predict_folder_impl(
|
|
model=model,
|
|
ply_files=ply_files,
|
|
device=device,
|
|
num_points=num_points,
|
|
xyz_scale=xyz_scale,
|
|
topk_length=None,
|
|
)
|
|
|
|
|
|
def collect_cloud_folders(batch_root: Path) -> List[Path]:
|
|
batch_root = batch_root.expanduser().resolve()
|
|
if not batch_root.exists() or not batch_root.is_dir():
|
|
raise ValueError(f"batch_root not found/dir: {batch_root}")
|
|
out: List[Path] = []
|
|
for p in batch_root.rglob("cloud"):
|
|
if not p.is_dir():
|
|
continue
|
|
if any(p.glob("*.ply")):
|
|
out.append(p)
|
|
return sorted(out)
|
|
|
|
|
|
def group_cloud_folders_by_fish(batch_root: Path) -> Dict[str, List[Path]]:
|
|
cloud_folders = collect_cloud_folders(batch_root)
|
|
by_fish: Dict[str, List[Path]] = {}
|
|
for cloud_dir in cloud_folders:
|
|
rel = (
|
|
str(cloud_dir.parent.relative_to(batch_root))
|
|
if cloud_dir.parent.is_relative_to(batch_root)
|
|
else str(cloud_dir.parent)
|
|
)
|
|
fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir))
|
|
if fish_key:
|
|
by_fish.setdefault(fish_key, []).append(cloud_dir)
|
|
else:
|
|
key = rel.replace("/", "_") or "unknown"
|
|
by_fish.setdefault(key, []).append(cloud_dir)
|
|
return by_fish
|
|
|
|
|
|
def predict_cloud_folder(
|
|
model: torch.nn.Module,
|
|
cloud_dir: Path,
|
|
device: torch.device,
|
|
num_points: int,
|
|
xyz_scale: float,
|
|
recursive: bool = False,
|
|
topk_length: Optional[int] = None,
|
|
topk_predictions: int = 10,
|
|
top_k: int = 5,
|
|
top_by_length: bool = True,
|
|
length_switch_to_weight_mm: float = 319.0,
|
|
remove_outliers: bool = False,
|
|
outlier_method: str = "iqr",
|
|
outlier_field: str = "length_input",
|
|
iqr_factor: float = 1.5,
|
|
zscore_threshold: float = 2.5,
|
|
ply_files: Optional[List[Path]] = None,
|
|
) -> Tuple[List[Dict], Dict]:
|
|
if ply_files is not None:
|
|
ply_files = sorted({Path(p).expanduser().resolve() for p in ply_files})
|
|
for p in ply_files:
|
|
if not p.is_file():
|
|
raise FileNotFoundError(f"PLY not found: {p}")
|
|
elif recursive:
|
|
ply_files = sorted(cloud_dir.rglob("*.ply"))
|
|
else:
|
|
ply_files = sorted(cloud_dir.glob("*.ply"))
|
|
return _predict_folder_impl(
|
|
model=model,
|
|
ply_files=ply_files,
|
|
device=device,
|
|
num_points=num_points,
|
|
xyz_scale=xyz_scale,
|
|
topk_length=topk_length,
|
|
topk_predictions=topk_predictions,
|
|
top_k=top_k,
|
|
top_by_length=top_by_length,
|
|
length_switch_to_weight_mm=length_switch_to_weight_mm,
|
|
remove_outliers=remove_outliers,
|
|
outlier_method=outlier_method,
|
|
outlier_field=outlier_field,
|
|
iqr_factor=iqr_factor,
|
|
zscore_threshold=zscore_threshold,
|
|
)
|
|
|
|
|
|
def load_label_weights_json(labels_path: Path) -> Dict[str, float]:
|
|
labels_path = labels_path.expanduser().resolve()
|
|
if not labels_path.exists():
|
|
return {}
|
|
data = json.loads(labels_path.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
return {}
|
|
out: Dict[str, float] = {}
|
|
for k, v in data.items():
|
|
try:
|
|
out[str(k).strip()] = float(v)
|
|
except (TypeError, ValueError):
|
|
if isinstance(v, dict) and "weight" in v:
|
|
try:
|
|
out[str(k).strip()] = float(v["weight"])
|
|
except (TypeError, ValueError):
|
|
pass
|
|
continue
|
|
return out
|
|
|
|
|
|
def load_label_lengths_json(labels_path: Optional[Path] = None) -> Dict[str, float]:
|
|
"""Load fish_id -> length_mm from labels JSON when values are dict with 'length_mm'."""
|
|
if labels_path is None:
|
|
return {}
|
|
labels_path = labels_path.expanduser().resolve()
|
|
if not labels_path.exists():
|
|
return {}
|
|
data = json.loads(labels_path.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
return {}
|
|
out: Dict[str, float] = {}
|
|
for k, v in data.items():
|
|
if isinstance(v, dict) and "length_mm" in v:
|
|
try:
|
|
out[str(k).strip()] = float(v["length_mm"])
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return out
|
|
|
|
|
|
def load_label_lengths_from_file(path: Path) -> Dict[str, float]:
|
|
"""Load fish_id -> length_mm from JSON. Keys like 'fish12', values in mm (float or dict with length_mm)."""
|
|
path = path.expanduser().resolve()
|
|
if not path.exists():
|
|
return {}
|
|
data = json.loads(path.read_text(encoding="utf-8"))
|
|
if not isinstance(data, dict):
|
|
return {}
|
|
out: Dict[str, float] = {}
|
|
for k, v in data.items():
|
|
key = str(k).strip()
|
|
try:
|
|
out[key] = float(v)
|
|
except (TypeError, ValueError):
|
|
if isinstance(v, dict) and "length_mm" in v:
|
|
try:
|
|
out[key] = float(v["length_mm"])
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return out
|
|
|
|
|
|
def extract_fish_key_from_text(text: str) -> Optional[str]:
|
|
m = re.search(r"(?:^|[^a-z0-9])fish\s*[_\- ]*\s*0*([0-9]{1,4})(?:[^0-9]|$)", text, flags=re.IGNORECASE)
|
|
if not m:
|
|
m = re.search(r"fish\s*0*([0-9]{1,4})", text, flags=re.IGNORECASE)
|
|
if not m:
|
|
return None
|
|
try:
|
|
n = int(m.group(1))
|
|
except Exception:
|
|
return None
|
|
if n <= 0:
|
|
return None
|
|
return f"fish{n}"
|
|
|
|
|
|
def compare_pred_vs_actual(pred_g: float, actual_g: float) -> Dict[str, float]:
|
|
if not np.isfinite(pred_g) or not np.isfinite(actual_g) or actual_g == 0.0:
|
|
return {
|
|
"predicted_weight_g": float(pred_g),
|
|
"actual_weight_g": float(actual_g),
|
|
"diff_g": float("nan"),
|
|
"diff_pct": float("nan"),
|
|
"abs_diff_pct": float("nan"),
|
|
}
|
|
diff_g = float(pred_g - actual_g)
|
|
diff_pct = float(diff_g / actual_g * 100.0)
|
|
return {
|
|
"predicted_weight_g": float(pred_g),
|
|
"actual_weight_g": float(actual_g),
|
|
"diff_g": diff_g,
|
|
"diff_pct": diff_pct,
|
|
"abs_diff_pct": abs(diff_pct),
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser("DGCNN weight estimator (folder inference)")
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="Path to best.pt/last.pt checkpoint")
|
|
src = parser.add_mutually_exclusive_group(required=True)
|
|
src.add_argument("--ply-folder", type=str, default="", help="Folder containing .ply files")
|
|
src.add_argument(
|
|
"--ply-list-file",
|
|
type=str,
|
|
default="",
|
|
help="Text file: one absolute or relative .ply path per line (combine clouds from multiple runs)",
|
|
)
|
|
src.add_argument(
|
|
"--batch-root",
|
|
type=str,
|
|
default="",
|
|
help="Batch mode: recursively find */cloud folders under this root and predict each",
|
|
)
|
|
parser.add_argument(
|
|
"--combine-by-fish",
|
|
action="store_true",
|
|
default=True,
|
|
help="Batch mode: combine all folders per fish (default: True)",
|
|
)
|
|
parser.add_argument(
|
|
"--no-combine-by-fish",
|
|
action="store_false",
|
|
dest="combine_by_fish",
|
|
help="Batch mode: predict each folder separately",
|
|
)
|
|
parser.add_argument("--recursive", action="store_true", help="Recursively search for .ply files (single folder mode)")
|
|
parser.add_argument("--num-points", type=int, default=768, help="Number of points per PLY (default: 768)")
|
|
parser.add_argument("--xyz-scale", type=float, default=0.001, help="XYZ scaling factor (default: 0.001)")
|
|
parser.add_argument("--device", type=str, default="auto", help="auto|cpu|cuda")
|
|
parser.add_argument("--output-json", type=str, default=None, help="Optional path to save results JSON")
|
|
parser.add_argument(
|
|
"--labels-json",
|
|
type=str,
|
|
default="/home/ubuntu/label20_weights.json",
|
|
help="Optional label mapping JSON. Keys like 'fish12', values in grams.",
|
|
)
|
|
parser.add_argument(
|
|
"--labels-length-json",
|
|
type=str,
|
|
default=None,
|
|
help="Optional JSON with fish_id -> length_mm. Or use labels-json with dict values: {\"weight\": N, \"length_mm\": M}.",
|
|
)
|
|
parser.add_argument(
|
|
"--topk-length", "--topk_length",
|
|
type=int,
|
|
default=None,
|
|
dest="topk_length",
|
|
help="If set, compute folder average as length-weighted average over top-K longest PLYs",
|
|
)
|
|
parser.add_argument(
|
|
"--topk-predictions",
|
|
type=int,
|
|
default=10,
|
|
help="When > this many PLYs, remove length outliers and average the rest (default: 10)",
|
|
)
|
|
parser.add_argument(
|
|
"--top-k", "--top_k",
|
|
type=int,
|
|
default=10,
|
|
dest="top_k",
|
|
help="Number of top predictions to average (default: 10)",
|
|
)
|
|
parser.add_argument(
|
|
"--top-by-length",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
dest="top_by_length",
|
|
help="Select top-K by cloud length (default: on). Use --no-top-by-length for predicted-weight order.",
|
|
)
|
|
parser.add_argument(
|
|
"--length-switch-mm",
|
|
type=float,
|
|
default=319.0,
|
|
dest="length_switch_mm",
|
|
help="With --top-by-length: if mean top-K length (mm) exceeds this, use top-K by predicted weight instead (default: 319).",
|
|
)
|
|
parser.add_argument("--print-per-ply", action="store_true", help="Print per-PLY predicted weight (batch mode)")
|
|
parser.add_argument("--remove-outliers", action="store_true", help="Remove outliers before computing average")
|
|
parser.add_argument("--outlier-method", type=str, default="iqr", choices=["iqr", "zscore"])
|
|
parser.add_argument(
|
|
"--outlier-field",
|
|
type=str,
|
|
default="length_input",
|
|
choices=["length_input", "predicted_weight_g"],
|
|
)
|
|
parser.add_argument("--iqr-factor", type=float, default=1.5)
|
|
parser.add_argument("--zscore-threshold", type=float, default=2.5)
|
|
parser.add_argument(
|
|
"--max-cv",
|
|
type=float,
|
|
default=None,
|
|
help="Maximum CV%% for length. Skip folders with CV > this value (e.g. 15.0).",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
ckpt_path = Path(args.checkpoint).expanduser().resolve()
|
|
if not ckpt_path.exists():
|
|
raise SystemExit(f"checkpoint not found: {ckpt_path}")
|
|
|
|
if args.device == "auto":
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
else:
|
|
device = torch.device(args.device)
|
|
|
|
from weight_estimator.train_dgcnn_weight_estimator import load_model_from_checkpoint
|
|
|
|
model = load_model_from_checkpoint(ckpt_path, device=device)
|
|
|
|
labels_path = Path(args.labels_json).expanduser().resolve() if args.labels_json else None
|
|
labels = load_label_weights_json(labels_path) if labels_path is not None else {}
|
|
label_lengths: Dict[str, float] = load_label_lengths_json(labels_path) if labels_path else {}
|
|
if args.labels_length_json:
|
|
length_path = Path(args.labels_length_json).expanduser().resolve()
|
|
extra = load_label_lengths_from_file(length_path)
|
|
label_lengths = {**label_lengths, **extra}
|
|
|
|
ply_list_path_for_meta: Optional[Path] = None
|
|
if args.ply_list_file:
|
|
list_path = Path(args.ply_list_file).expanduser().resolve()
|
|
if not list_path.is_file():
|
|
raise SystemExit(f"ply-list-file not found: {list_path}")
|
|
lines = [
|
|
ln.strip()
|
|
for ln in list_path.read_text(encoding="utf-8").splitlines()
|
|
if ln.strip() and not ln.strip().startswith("#")
|
|
]
|
|
ply_paths = [Path(l).expanduser().resolve() for l in lines]
|
|
if not ply_paths:
|
|
raise SystemExit("ply-list-file is empty (no PLY paths)")
|
|
for p in ply_paths:
|
|
if not p.is_file():
|
|
raise SystemExit(f"PLY in list not found: {p}")
|
|
ply_list_path_for_meta = list_path
|
|
ply_folder = ply_paths[0].parent
|
|
per_file, summary = predict_cloud_folder(
|
|
model=model,
|
|
cloud_dir=ply_folder,
|
|
device=device,
|
|
num_points=args.num_points,
|
|
xyz_scale=args.xyz_scale,
|
|
recursive=False,
|
|
topk_length=args.topk_length,
|
|
topk_predictions=args.topk_predictions,
|
|
top_k=args.top_k,
|
|
top_by_length=args.top_by_length,
|
|
length_switch_to_weight_mm=args.length_switch_mm,
|
|
remove_outliers=args.remove_outliers,
|
|
outlier_method=args.outlier_method,
|
|
outlier_field=args.outlier_field,
|
|
iqr_factor=args.iqr_factor,
|
|
zscore_threshold=args.zscore_threshold,
|
|
ply_files=ply_paths,
|
|
)
|
|
elif args.ply_folder:
|
|
ply_folder = Path(args.ply_folder).expanduser().resolve()
|
|
if not ply_folder.exists() or not ply_folder.is_dir():
|
|
raise SystemExit(f"ply-folder not found/dir: {ply_folder}")
|
|
|
|
per_file, summary = predict_cloud_folder(
|
|
model=model,
|
|
cloud_dir=ply_folder,
|
|
device=device,
|
|
num_points=args.num_points,
|
|
xyz_scale=args.xyz_scale,
|
|
recursive=args.recursive,
|
|
topk_length=args.topk_length,
|
|
topk_predictions=args.topk_predictions,
|
|
top_k=args.top_k,
|
|
top_by_length=args.top_by_length,
|
|
length_switch_to_weight_mm=args.length_switch_mm,
|
|
remove_outliers=args.remove_outliers,
|
|
outlier_method=args.outlier_method,
|
|
outlier_field=args.outlier_field,
|
|
iqr_factor=args.iqr_factor,
|
|
zscore_threshold=args.zscore_threshold,
|
|
)
|
|
|
|
if args.ply_list_file or args.ply_folder:
|
|
if summary.get("skipped"):
|
|
print(f"Skipped: {summary.get('skip_reason', 'fewer than 5 PLYs')} (files={summary.get('num_files', 0)})")
|
|
out = {
|
|
"meta": {
|
|
"checkpoint": str(ckpt_path),
|
|
"model": "dgcnn",
|
|
"ply_folder": str(ply_folder),
|
|
"ply_list_file": str(ply_list_path_for_meta) if ply_list_path_for_meta else None,
|
|
"recursive": bool(args.recursive),
|
|
"num_points": int(args.num_points),
|
|
"xyz_scale": float(args.xyz_scale),
|
|
"device": str(device),
|
|
"labels_json": str(labels_path) if labels_path else None,
|
|
"length_switch_mm": float(args.length_switch_mm),
|
|
},
|
|
"summary": summary,
|
|
"comparison": None,
|
|
"per_file": per_file,
|
|
}
|
|
if args.output_json:
|
|
out_path = Path(args.output_json).expanduser().resolve()
|
|
suffix = f"_top{args.top_k}_by_length" if args.top_by_length else f"_top{args.top_k}"
|
|
out_path = out_path.parent / f"{out_path.stem}{suffix}{out_path.suffix}"
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
print(f"Saved: {out_path}")
|
|
return
|
|
|
|
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
|
|
print(f"[Kept = {top_label}]")
|
|
for it in per_file:
|
|
ply = Path(it["ply"]).name
|
|
g = float(it["predicted_weight_g"])
|
|
kg = float(it["predicted_weight_kg"])
|
|
length_input = float(it.get("length_input", float("nan")))
|
|
used = it.get("used_for_prediction", True)
|
|
rank = it.get("rank_by_selection", 0)
|
|
n_total = summary.get("num_files_predicted", len(per_file))
|
|
if rank > 0:
|
|
tag = f" (kept, rank {rank}/{n_total})" if used else f" (filtered, rank {rank}/{n_total})"
|
|
else:
|
|
tag = " (kept)" if used else " (filtered)"
|
|
if abs(float(args.xyz_scale) - 0.001) < 1e-12:
|
|
length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan"
|
|
else:
|
|
length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan"
|
|
print(f"{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}")
|
|
|
|
print(f"Files: {summary.get('num_files_predicted', summary['num_files'])}")
|
|
if args.remove_outliers and summary.get("num_outliers_removed", 0) > 0:
|
|
print(f"Outliers removed: {summary['num_outliers_removed']} (method={args.outlier_method}, field={args.outlier_field})")
|
|
if summary.get("cv_length_pct") is not None:
|
|
print(f"Length CV: {summary['cv_length_pct']:.1f}%")
|
|
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
|
|
print(
|
|
f"Average predicted weight ({top_label}): "
|
|
f"{summary['avg_predicted_weight_g']:.2f} g ({summary['avg_predicted_weight_kg']:.4f} kg) "
|
|
f"(used={summary.get('num_files_used_for_avg', 'n/a')}/{summary.get('num_files_predicted', 'n/a')})"
|
|
)
|
|
avg_len_topk = summary.get("avg_length_input_topk")
|
|
if avg_len_topk is not None and np.isfinite(avg_len_topk):
|
|
print(
|
|
f"Average length ({top_label}): "
|
|
f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}"
|
|
)
|
|
|
|
fish_key = extract_fish_key_from_text(str(ply_folder))
|
|
comparison = None
|
|
if fish_key and fish_key in labels:
|
|
actual_g = float(labels[fish_key])
|
|
comparison = compare_pred_vs_actual(pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g)
|
|
print(
|
|
f"Actual weight ({fish_key}): {actual_g:.2f} g | "
|
|
f"Diff: {comparison['diff_g']:.2f} g | "
|
|
f"Diff%: {comparison['diff_pct']:.2f}% (abs {comparison['abs_diff_pct']:.2f}%)"
|
|
)
|
|
elif labels:
|
|
print(f"[warn] No label match for folder: {ply_folder} (parsed key={fish_key})")
|
|
|
|
out = {
|
|
"meta": {
|
|
"checkpoint": str(ckpt_path),
|
|
"model": "dgcnn",
|
|
"ply_folder": str(ply_folder),
|
|
"ply_list_file": str(ply_list_path_for_meta) if ply_list_path_for_meta else None,
|
|
"recursive": bool(args.recursive),
|
|
"num_points": int(args.num_points),
|
|
"xyz_scale": float(args.xyz_scale),
|
|
"device": str(device),
|
|
"labels_json": str(labels_path) if labels_path else None,
|
|
"length_switch_mm": float(args.length_switch_mm),
|
|
},
|
|
"summary": summary,
|
|
"comparison": comparison,
|
|
"per_file": per_file,
|
|
}
|
|
|
|
if args.output_json:
|
|
out_path = Path(args.output_json).expanduser().resolve()
|
|
suffix = f"_top{args.top_k}_by_length" if args.top_by_length else f"_top{args.top_k}"
|
|
out_path = out_path.parent / f"{out_path.stem}{suffix}{out_path.suffix}"
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
print(f"Saved: {out_path}")
|
|
|
|
else:
|
|
batch_root = Path(args.batch_root).expanduser().resolve()
|
|
if args.combine_by_fish:
|
|
by_fish = group_cloud_folders_by_fish(batch_root)
|
|
if not by_fish:
|
|
raise SystemExit(f"No cloud folders found under: {batch_root}")
|
|
|
|
results: List[Dict] = []
|
|
for fish_key, cloud_dirs in sorted(by_fish.items()):
|
|
ply_files = []
|
|
for cloud_dir in cloud_dirs:
|
|
ply_files.extend(sorted(cloud_dir.glob("*.ply")))
|
|
ply_files = sorted(set(ply_files))
|
|
|
|
try:
|
|
per_file, summary = _predict_folder_impl(
|
|
model=model,
|
|
ply_files=ply_files,
|
|
device=device,
|
|
num_points=args.num_points,
|
|
xyz_scale=args.xyz_scale,
|
|
top_k=args.top_k,
|
|
top_by_length=args.top_by_length,
|
|
length_switch_to_weight_mm=args.length_switch_mm,
|
|
remove_outliers=args.remove_outliers,
|
|
outlier_method=args.outlier_method,
|
|
outlier_field=args.outlier_field,
|
|
iqr_factor=args.iqr_factor,
|
|
zscore_threshold=args.zscore_threshold,
|
|
)
|
|
if summary.get("skipped"):
|
|
print(f"{fish_key}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})")
|
|
results.append({
|
|
"id": fish_key,
|
|
"fish_key": fish_key,
|
|
"skipped": True,
|
|
"skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"),
|
|
"num_files": summary.get("num_files", 0),
|
|
"cloud_folders": len(cloud_dirs),
|
|
})
|
|
continue
|
|
|
|
num_used = summary.get("num_files_used_for_avg", "n/a")
|
|
num_total = summary.get("num_files_predicted", "n/a")
|
|
used_info = f"used={num_used}/{num_total} (from {len(cloud_dirs)} folders)"
|
|
|
|
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
|
|
avg_len_topk = summary.get("avg_length_input_topk")
|
|
len_part = ""
|
|
if avg_len_topk is not None and np.isfinite(avg_len_topk):
|
|
len_part = (
|
|
f" | avg_len(top{args.top_k} selected)="
|
|
f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}"
|
|
)
|
|
if fish_key in labels:
|
|
actual_g = float(labels[fish_key])
|
|
comparison = compare_pred_vs_actual(
|
|
pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g
|
|
)
|
|
print(
|
|
f"{fish_key}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part} | "
|
|
f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%"
|
|
)
|
|
else:
|
|
print(
|
|
f"{fish_key}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part}"
|
|
)
|
|
|
|
comparison = None
|
|
if fish_key in labels:
|
|
comparison = compare_pred_vs_actual(
|
|
pred_g=float(summary["avg_predicted_weight_g"]),
|
|
actual_g=float(labels[fish_key]),
|
|
)
|
|
results.append({
|
|
"id": fish_key,
|
|
"fish_key": fish_key,
|
|
"cloud_folders": len(cloud_dirs),
|
|
"summary": summary,
|
|
"comparison": comparison,
|
|
"per_file": per_file,
|
|
})
|
|
except Exception as e:
|
|
results.append({"id": fish_key, "fish_key": fish_key, "error": str(e)})
|
|
|
|
else:
|
|
cloud_folders = collect_cloud_folders(batch_root)
|
|
if not cloud_folders:
|
|
raise SystemExit(f"No cloud folders found under: {batch_root}")
|
|
|
|
results = []
|
|
for cloud_dir in cloud_folders:
|
|
rel = (
|
|
str(cloud_dir.parent.relative_to(batch_root))
|
|
if cloud_dir.parent.is_relative_to(batch_root)
|
|
else str(cloud_dir.parent)
|
|
)
|
|
try:
|
|
per_file, summary = predict_cloud_folder(
|
|
model=model,
|
|
cloud_dir=cloud_dir,
|
|
device=device,
|
|
num_points=args.num_points,
|
|
xyz_scale=args.xyz_scale,
|
|
recursive=False,
|
|
topk_length=args.topk_length,
|
|
topk_predictions=args.topk_predictions,
|
|
top_k=args.top_k,
|
|
top_by_length=args.top_by_length,
|
|
length_switch_to_weight_mm=args.length_switch_mm,
|
|
remove_outliers=args.remove_outliers,
|
|
outlier_method=args.outlier_method,
|
|
outlier_field=args.outlier_field,
|
|
iqr_factor=args.iqr_factor,
|
|
zscore_threshold=args.zscore_threshold,
|
|
)
|
|
if summary.get("skipped"):
|
|
print(f"{rel}: SKIPPED ({summary.get('skip_reason', '<5 PLYs')})")
|
|
results.append({
|
|
"id": rel,
|
|
"cloud_dir": str(cloud_dir),
|
|
"skipped": True,
|
|
"skip_reason": summary.get("skip_reason", "fewer than 5 PLYs"),
|
|
"num_files": summary.get("num_files", 0),
|
|
})
|
|
continue
|
|
|
|
cv_pct = summary.get("cv_length_pct", None)
|
|
if args.max_cv is not None and cv_pct is not None and np.isfinite(cv_pct):
|
|
if cv_pct > args.max_cv:
|
|
print(f"{rel}: SKIPPED (cv={cv_pct:.1f}% > max_cv={args.max_cv}%)")
|
|
results.append({
|
|
"id": rel,
|
|
"cloud_dir": str(cloud_dir),
|
|
"skipped": True,
|
|
"skip_reason": f"cv_length={cv_pct:.1f}% exceeds max_cv={args.max_cv}%",
|
|
"cv_length_pct": cv_pct,
|
|
})
|
|
continue
|
|
|
|
if args.print_per_ply:
|
|
for it in per_file:
|
|
ply = Path(it["ply"]).name
|
|
g = float(it["predicted_weight_g"])
|
|
kg = float(it["predicted_weight_kg"])
|
|
length_input = float(it.get("length_input", float("nan")))
|
|
used = it.get("used_for_prediction", True)
|
|
tag = " (kept)" if used else " (filtered)"
|
|
if abs(float(args.xyz_scale) - 0.001) < 1e-12:
|
|
length_str = f"{length_input:.1f} mm" if np.isfinite(length_input) else "nan"
|
|
else:
|
|
length_str = f"{length_input:.4f} units" if np.isfinite(length_input) else "nan"
|
|
print(f"{rel}/{ply}: len={length_str} | {g:.2f} g ({kg:.4f} kg){tag}")
|
|
|
|
fish_key = extract_fish_key_from_text(rel) or extract_fish_key_from_text(str(cloud_dir))
|
|
comparison = None
|
|
num_after_outlier = summary.get(
|
|
"num_files_after_outlier_removal", summary.get("num_files_predicted", "n/a")
|
|
)
|
|
num_used = summary.get("num_files_used_for_avg", "n/a")
|
|
num_total = summary.get("num_files_predicted", "n/a")
|
|
num_outliers = summary.get("num_outliers_removed", 0)
|
|
|
|
if args.remove_outliers and num_outliers > 0:
|
|
used_info = f"used={num_used}/{num_after_outlier}, outliers={num_outliers}"
|
|
else:
|
|
used_info = f"used={num_used}/{num_total}"
|
|
if cv_pct is not None and cv_pct > 10.0:
|
|
used_info += f", cv={cv_pct:.1f}%"
|
|
|
|
top_label = _summary_top_label(summary, args.top_k, args.top_by_length)
|
|
avg_len_topk = summary.get("avg_length_input_topk")
|
|
len_part = ""
|
|
if avg_len_topk is not None and np.isfinite(avg_len_topk):
|
|
len_part = (
|
|
f" | avg_len(top{args.top_k} selected)="
|
|
f"{_format_length_input_display(avg_len_topk, args.xyz_scale)}"
|
|
)
|
|
if fish_key and fish_key in labels:
|
|
actual_g = float(labels[fish_key])
|
|
comparison = compare_pred_vs_actual(
|
|
pred_g=float(summary["avg_predicted_weight_g"]), actual_g=actual_g
|
|
)
|
|
print(
|
|
f"{rel}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part} | "
|
|
f"actual={actual_g:.2f} g | diff%={comparison['diff_pct']:.2f}%"
|
|
)
|
|
else:
|
|
print(
|
|
f"{rel}: avg({top_label})={summary['avg_predicted_weight_g']:.2f} g ({used_info}){len_part}"
|
|
)
|
|
|
|
results.append({
|
|
"id": rel,
|
|
"cloud_dir": str(cloud_dir),
|
|
"summary": summary,
|
|
"fish_key": fish_key,
|
|
"comparison": comparison,
|
|
})
|
|
except Exception as e:
|
|
results.append({"id": rel, "cloud_dir": str(cloud_dir), "error": str(e)})
|
|
|
|
meta_cloud_folders = len(collect_cloud_folders(batch_root)) if batch_root else 0
|
|
out = {
|
|
"meta": {
|
|
"checkpoint": str(ckpt_path),
|
|
"model": "dgcnn",
|
|
"batch_root": str(batch_root),
|
|
"num_points": int(args.num_points),
|
|
"xyz_scale": float(args.xyz_scale),
|
|
"device": str(device),
|
|
"combine_by_fish": args.combine_by_fish,
|
|
"top_k": args.top_k,
|
|
"top_by_length": args.top_by_length,
|
|
"length_switch_mm": float(args.length_switch_mm),
|
|
"cloud_folders": meta_cloud_folders,
|
|
"labels_json": str(labels_path) if labels_path else None,
|
|
},
|
|
"results": results,
|
|
}
|
|
|
|
out_path = (
|
|
Path(args.output_json).expanduser().resolve()
|
|
if args.output_json
|
|
else (batch_root / "batch_weight_predictions.json")
|
|
)
|
|
suffix = f"_top{args.top_k}_by_length" if args.top_by_length else f"_top{args.top_k}"
|
|
out_path = out_path.parent / f"{out_path.stem}{suffix}{out_path.suffix}"
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
out_path.write_text(json.dumps(out, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
print(f"Saved: {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|