feat/ sync weight fitting logic and confidence level marker *

This commit is contained in:
zaiun xu
2026-04-14 12:17:10 +08:00
parent c1aafc69bf
commit 8497d0eb1d
7 changed files with 1283 additions and 614 deletions

188
FishMeasure/predict_weigth_from_svo2.py Executable file → Normal file
View File

@@ -102,26 +102,6 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol
cmd.append("--use-flatness-filter")
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
# 始终在 fish 内跑 DGCNN生成 weight_estimation_results.json预览视频才能叠加 weight/length
# predict 后续会合并该 JSON避免重复跑 test_dgcnn。
wck = Path(args.weight_checkpoint).expanduser().resolve()
if wck.is_file():
cmd.extend(
[
"--run-weight-estimation",
"--weight-estimator-checkpoint",
str(wck),
]
)
if getattr(args, "fish_video_weight_overlay", False):
cmd.extend(
[
"--weight-overlay-video",
"--minute-interval-sec",
str(getattr(args, "minute_interval_sec", 60.0)),
]
)
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
@@ -144,6 +124,13 @@ def _run_test_dgcnn_weight_estimator_subprocess(
remove_outliers: bool,
outlier_method: str,
labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Tuple[Path, Dict[str, Any]]:
if (cloud_dir is None) == (ply_list_file is None):
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
@@ -187,6 +174,22 @@ def _run_test_dgcnn_weight_estimator_subprocess(
if remove_outliers:
cmd.append("--remove-outliers")
cmd.extend(["--outlier-method", outlier_method])
cmd.extend(["--max-length-mm", str(weight_max_length_mm)])
cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)])
cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)])
cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)])
if weight_average_all_after_filter:
cmd.append("--average-all-after-filter")
fb_thr = weight_average_all_fallback_max_if_mean_over_g
if fb_thr is not None and float(fb_thr) > 0:
cmd.extend(
["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))]
)
mp_fb = weight_mean_pool_fallback_max_if_over_g
if mp_fb is not None and float(mp_fb) > 0:
cmd.extend(
["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))]
)
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
@@ -215,6 +218,9 @@ def _merge_weight_prediction_json(
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
pred_kg = summary.get("pred_weight_kg")
pred_g = summary.get("pred_weight_g")
pred_rule = summary.get("pred_weight_rule")
return {
"svo": str(svo_path),
"svo_name": svo_name,
@@ -228,6 +234,9 @@ def _merge_weight_prediction_json(
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"pred_weight_g": None if skipped else pred_g,
"pred_weight_kg": None if skipped else pred_kg,
"pred_weight_rule": None if skipped else pred_rule,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
@@ -248,6 +257,9 @@ def _merge_weight_prediction_json_combined(
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
pred_kg = summary.get("pred_weight_kg")
pred_g = summary.get("pred_weight_g")
pred_rule = summary.get("pred_weight_rule")
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
return {
@@ -265,6 +277,9 @@ def _merge_weight_prediction_json_combined(
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"pred_weight_g": None if skipped else pred_g,
"pred_weight_kg": None if skipped else pred_kg,
"pred_weight_rule": None if skipped else pred_rule,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
@@ -297,7 +312,13 @@ def run_weight_prediction_for_svo(
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
force_dgcnn_subprocess: bool = False,
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Dict[str, Any]:
svo_path = svo_path.expanduser().resolve()
if not svo_path.exists():
@@ -314,28 +335,6 @@ def run_weight_prediction_for_svo(
f"fish_video_weight_evaluation.py first."
)
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
if not force_dgcnn_subprocess and fish_wj.is_file():
try:
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=fish_wj,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
except Exception as e:
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=cloud_dir,
ply_list_file=None,
@@ -351,6 +350,13 @@ def run_weight_prediction_for_svo(
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
weight_max_length_mm=weight_max_length_mm,
weight_min_length_width_ratio=weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
weight_average_all_after_filter=weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
)
result = _merge_weight_prediction_json(
svo_path=svo_path,
@@ -381,6 +387,13 @@ def run_weight_prediction_combined_svos(
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Dict[str, Any]:
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
output_base = output_base.expanduser().resolve()
@@ -411,6 +424,13 @@ def run_weight_prediction_combined_svos(
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
weight_max_length_mm=weight_max_length_mm,
weight_min_length_width_ratio=weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
weight_average_all_after_filter=weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
)
result = _merge_weight_prediction_json_combined(
svo_paths=svo_paths,
@@ -502,6 +522,48 @@ def main() -> None:
)
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
parser.add_argument("--weight-labels-json", type=str, default=None)
parser.add_argument(
"--weight-max-length-mm",
type=float,
default=400.0,
help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).",
)
parser.add_argument(
"--weight-min-length-width-ratio",
type=float,
default=1.5,
help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.",
)
parser.add_argument(
"--weight-average-all-after-filter",
action=argparse.BooleanOptionalAction,
default=False,
help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).",
)
parser.add_argument(
"--weight-average-all-fallback-max-if-mean-over-g",
type=float,
default=400.0,
help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).",
)
parser.add_argument(
"--weight-mean-pool-fallback-max-if-over-g",
type=float,
default=440.0,
help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).",
)
parser.add_argument(
"--weight-length-quality-cv-threshold-pct",
type=float,
default=15.0,
help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).",
)
parser.add_argument(
"--weight-length-quality-max-span-mm",
type=float,
default=130.0,
help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.",
)
parser.add_argument(
"--reuse-existing-clouds",
@@ -511,24 +573,6 @@ def main() -> None:
)
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
parser.add_argument(
"--fish-video-weight-overlay",
action="store_true",
help="Extra on-video header lines (top-5 / per-minute bucket). "
"DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.",
)
parser.add_argument(
"--minute-interval-sec",
type=float,
default=60.0,
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
)
parser.add_argument(
"--force-dgcnn-subprocess",
action="store_true",
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
)
args = parser.parse_args()
if args.frame_stride < 1:
raise SystemExit("--frame-stride must be >= 1")
@@ -586,6 +630,13 @@ def main() -> None:
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
weight_max_length_mm=args.weight_max_length_mm,
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
weight_average_all_after_filter=args.weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
)
)
except Exception as e:
@@ -612,7 +663,13 @@ def main() -> None:
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
weight_max_length_mm=args.weight_max_length_mm,
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
weight_average_all_after_filter=args.weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
)
)
except Exception as e:
@@ -628,8 +685,11 @@ def main() -> None:
if len(results) == 1 and "error" not in results[0]:
r0 = results[0]
avg_g = r0.get("avg_predicted_weight_g")
if avg_g is not None:
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
pred_g = r0.get("pred_weight_g")
out_g = pred_g if pred_g is not None else avg_g
if out_g is not None:
label = "pred_weight" if pred_g is not None else "avg_predicted_weight"
print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g")
else:
print("Final predicted weight: N/A (skipped or no valid point clouds)")