feat/ sync weight fitting logic and confidence level marker *
This commit is contained in:
188
FishMeasure/predict_weigth_from_svo2.py
Executable file → Normal file
188
FishMeasure/predict_weigth_from_svo2.py
Executable file → Normal file
@@ -102,26 +102,6 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol
|
||||
cmd.append("--use-flatness-filter")
|
||||
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
|
||||
|
||||
# 始终在 fish 内跑 DGCNN,生成 weight_estimation_results.json,预览视频才能叠加 weight/length;
|
||||
# predict 后续会合并该 JSON,避免重复跑 test_dgcnn。
|
||||
wck = Path(args.weight_checkpoint).expanduser().resolve()
|
||||
if wck.is_file():
|
||||
cmd.extend(
|
||||
[
|
||||
"--run-weight-estimation",
|
||||
"--weight-estimator-checkpoint",
|
||||
str(wck),
|
||||
]
|
||||
)
|
||||
if getattr(args, "fish_video_weight_overlay", False):
|
||||
cmd.extend(
|
||||
[
|
||||
"--weight-overlay-video",
|
||||
"--minute-interval-sec",
|
||||
str(getattr(args, "minute_interval_sec", 60.0)),
|
||||
]
|
||||
)
|
||||
|
||||
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
|
||||
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
||||
if proc.returncode != 0:
|
||||
@@ -144,6 +124,13 @@ def _run_test_dgcnn_weight_estimator_subprocess(
|
||||
remove_outliers: bool,
|
||||
outlier_method: str,
|
||||
labels_json: Optional[str],
|
||||
weight_max_length_mm: float = 400.0,
|
||||
weight_min_length_width_ratio: float = 1.5,
|
||||
weight_length_quality_cv_threshold_pct: float = 15.0,
|
||||
weight_length_quality_max_span_mm: float = 130.0,
|
||||
weight_average_all_after_filter: bool = False,
|
||||
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
||||
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
||||
) -> Tuple[Path, Dict[str, Any]]:
|
||||
if (cloud_dir is None) == (ply_list_file is None):
|
||||
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
|
||||
@@ -187,6 +174,22 @@ def _run_test_dgcnn_weight_estimator_subprocess(
|
||||
if remove_outliers:
|
||||
cmd.append("--remove-outliers")
|
||||
cmd.extend(["--outlier-method", outlier_method])
|
||||
cmd.extend(["--max-length-mm", str(weight_max_length_mm)])
|
||||
cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)])
|
||||
cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)])
|
||||
cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)])
|
||||
if weight_average_all_after_filter:
|
||||
cmd.append("--average-all-after-filter")
|
||||
fb_thr = weight_average_all_fallback_max_if_mean_over_g
|
||||
if fb_thr is not None and float(fb_thr) > 0:
|
||||
cmd.extend(
|
||||
["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))]
|
||||
)
|
||||
mp_fb = weight_mean_pool_fallback_max_if_over_g
|
||||
if mp_fb is not None and float(mp_fb) > 0:
|
||||
cmd.extend(
|
||||
["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))]
|
||||
)
|
||||
|
||||
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
|
||||
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
||||
@@ -215,6 +218,9 @@ def _merge_weight_prediction_json(
|
||||
skipped = bool(summary.get("skipped"))
|
||||
avg_kg = summary.get("avg_predicted_weight_kg")
|
||||
avg_g = summary.get("avg_predicted_weight_g")
|
||||
pred_kg = summary.get("pred_weight_kg")
|
||||
pred_g = summary.get("pred_weight_g")
|
||||
pred_rule = summary.get("pred_weight_rule")
|
||||
return {
|
||||
"svo": str(svo_path),
|
||||
"svo_name": svo_name,
|
||||
@@ -228,6 +234,9 @@ def _merge_weight_prediction_json(
|
||||
"num_ply_predicted": summary.get("num_files_predicted"),
|
||||
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
||||
"avg_predicted_weight_g": None if skipped else avg_g,
|
||||
"pred_weight_g": None if skipped else pred_g,
|
||||
"pred_weight_kg": None if skipped else pred_kg,
|
||||
"pred_weight_rule": None if skipped else pred_rule,
|
||||
"dgcnn_meta": dgcnn_data.get("meta"),
|
||||
"dgcnn_summary": summary,
|
||||
"weight_summary": summary,
|
||||
@@ -248,6 +257,9 @@ def _merge_weight_prediction_json_combined(
|
||||
skipped = bool(summary.get("skipped"))
|
||||
avg_kg = summary.get("avg_predicted_weight_kg")
|
||||
avg_g = summary.get("avg_predicted_weight_g")
|
||||
pred_kg = summary.get("pred_weight_kg")
|
||||
pred_g = summary.get("pred_weight_g")
|
||||
pred_rule = summary.get("pred_weight_rule")
|
||||
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
|
||||
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
|
||||
return {
|
||||
@@ -265,6 +277,9 @@ def _merge_weight_prediction_json_combined(
|
||||
"num_ply_predicted": summary.get("num_files_predicted"),
|
||||
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
||||
"avg_predicted_weight_g": None if skipped else avg_g,
|
||||
"pred_weight_g": None if skipped else pred_g,
|
||||
"pred_weight_kg": None if skipped else pred_kg,
|
||||
"pred_weight_rule": None if skipped else pred_rule,
|
||||
"dgcnn_meta": dgcnn_data.get("meta"),
|
||||
"dgcnn_summary": summary,
|
||||
"weight_summary": summary,
|
||||
@@ -297,7 +312,13 @@ def run_weight_prediction_for_svo(
|
||||
weight_outlier_method: str,
|
||||
weight_xyz_scale: float,
|
||||
weight_labels_json: Optional[str],
|
||||
force_dgcnn_subprocess: bool = False,
|
||||
weight_max_length_mm: float = 400.0,
|
||||
weight_min_length_width_ratio: float = 1.5,
|
||||
weight_length_quality_cv_threshold_pct: float = 15.0,
|
||||
weight_length_quality_max_span_mm: float = 130.0,
|
||||
weight_average_all_after_filter: bool = False,
|
||||
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
||||
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
||||
) -> Dict[str, Any]:
|
||||
svo_path = svo_path.expanduser().resolve()
|
||||
if not svo_path.exists():
|
||||
@@ -314,28 +335,6 @@ def run_weight_prediction_for_svo(
|
||||
f"fish_video_weight_evaluation.py first."
|
||||
)
|
||||
|
||||
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
|
||||
if not force_dgcnn_subprocess and fish_wj.is_file():
|
||||
try:
|
||||
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
|
||||
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
|
||||
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
|
||||
result = _merge_weight_prediction_json(
|
||||
svo_path=svo_path,
|
||||
svo_name=svo_name,
|
||||
out_dir=out_dir,
|
||||
cloud_dir=cloud_dir,
|
||||
dgcnn_json_path=fish_wj,
|
||||
dgcnn_data=dgcnn_data,
|
||||
)
|
||||
(out_dir / "weight_prediction.json").write_text(
|
||||
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
|
||||
|
||||
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
|
||||
cloud_dir=cloud_dir,
|
||||
ply_list_file=None,
|
||||
@@ -351,6 +350,13 @@ def run_weight_prediction_for_svo(
|
||||
remove_outliers=weight_remove_outliers,
|
||||
outlier_method=weight_outlier_method,
|
||||
labels_json=weight_labels_json,
|
||||
weight_max_length_mm=weight_max_length_mm,
|
||||
weight_min_length_width_ratio=weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
result = _merge_weight_prediction_json(
|
||||
svo_path=svo_path,
|
||||
@@ -381,6 +387,13 @@ def run_weight_prediction_combined_svos(
|
||||
weight_outlier_method: str,
|
||||
weight_xyz_scale: float,
|
||||
weight_labels_json: Optional[str],
|
||||
weight_max_length_mm: float = 400.0,
|
||||
weight_min_length_width_ratio: float = 1.5,
|
||||
weight_length_quality_cv_threshold_pct: float = 15.0,
|
||||
weight_length_quality_max_span_mm: float = 130.0,
|
||||
weight_average_all_after_filter: bool = False,
|
||||
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
||||
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
|
||||
output_base = output_base.expanduser().resolve()
|
||||
@@ -411,6 +424,13 @@ def run_weight_prediction_combined_svos(
|
||||
remove_outliers=weight_remove_outliers,
|
||||
outlier_method=weight_outlier_method,
|
||||
labels_json=weight_labels_json,
|
||||
weight_max_length_mm=weight_max_length_mm,
|
||||
weight_min_length_width_ratio=weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
result = _merge_weight_prediction_json_combined(
|
||||
svo_paths=svo_paths,
|
||||
@@ -502,6 +522,48 @@ def main() -> None:
|
||||
)
|
||||
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
|
||||
parser.add_argument("--weight-labels-json", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--weight-max-length-mm",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-min-length-width-ratio",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-average-all-after-filter",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-average-all-fallback-max-if-mean-over-g",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-mean-pool-fallback-max-if-over-g",
|
||||
type=float,
|
||||
default=440.0,
|
||||
help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-length-quality-cv-threshold-pct",
|
||||
type=float,
|
||||
default=15.0,
|
||||
help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-length-quality-max-span-mm",
|
||||
type=float,
|
||||
default=130.0,
|
||||
help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reuse-existing-clouds",
|
||||
@@ -511,24 +573,6 @@ def main() -> None:
|
||||
)
|
||||
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
|
||||
|
||||
parser.add_argument(
|
||||
"--fish-video-weight-overlay",
|
||||
action="store_true",
|
||||
help="Extra on-video header lines (top-5 / per-minute bucket). "
|
||||
"DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minute-interval-sec",
|
||||
type=float,
|
||||
default=60.0,
|
||||
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-dgcnn-subprocess",
|
||||
action="store_true",
|
||||
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.frame_stride < 1:
|
||||
raise SystemExit("--frame-stride must be >= 1")
|
||||
@@ -586,6 +630,13 @@ def main() -> None:
|
||||
weight_outlier_method=args.weight_outlier_method,
|
||||
weight_xyz_scale=args.weight_xyz_scale,
|
||||
weight_labels_json=args.weight_labels_json,
|
||||
weight_max_length_mm=args.weight_max_length_mm,
|
||||
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=args.weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -612,7 +663,13 @@ def main() -> None:
|
||||
weight_outlier_method=args.weight_outlier_method,
|
||||
weight_xyz_scale=args.weight_xyz_scale,
|
||||
weight_labels_json=args.weight_labels_json,
|
||||
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
|
||||
weight_max_length_mm=args.weight_max_length_mm,
|
||||
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=args.weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -628,8 +685,11 @@ def main() -> None:
|
||||
if len(results) == 1 and "error" not in results[0]:
|
||||
r0 = results[0]
|
||||
avg_g = r0.get("avg_predicted_weight_g")
|
||||
if avg_g is not None:
|
||||
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
|
||||
pred_g = r0.get("pred_weight_g")
|
||||
out_g = pred_g if pred_g is not None else avg_g
|
||||
if out_g is not None:
|
||||
label = "pred_weight" if pred_g is not None else "avg_predicted_weight"
|
||||
print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g")
|
||||
else:
|
||||
print("Final predicted weight: N/A (skipped or no valid point clouds)")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user