feat/ sync weight fitting logic and confidence level marker *

This commit is contained in:
zaiun xu
2026-04-14 12:17:10 +08:00
parent c1aafc69bf
commit 8497d0eb1d
7 changed files with 1283 additions and 614 deletions

773
FishMeasure/fish_video_weight_evaluation.py Executable file → Normal file

File diff suppressed because it is too large Load Diff

188
FishMeasure/predict_weigth_from_svo2.py Executable file → Normal file
View File

@@ -102,26 +102,6 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol
cmd.append("--use-flatness-filter")
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
# 始终在 fish 内跑 DGCNN生成 weight_estimation_results.json预览视频才能叠加 weight/length
# predict 后续会合并该 JSON避免重复跑 test_dgcnn。
wck = Path(args.weight_checkpoint).expanduser().resolve()
if wck.is_file():
cmd.extend(
[
"--run-weight-estimation",
"--weight-estimator-checkpoint",
str(wck),
]
)
if getattr(args, "fish_video_weight_overlay", False):
cmd.extend(
[
"--weight-overlay-video",
"--minute-interval-sec",
str(getattr(args, "minute_interval_sec", 60.0)),
]
)
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
@@ -144,6 +124,13 @@ def _run_test_dgcnn_weight_estimator_subprocess(
remove_outliers: bool,
outlier_method: str,
labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Tuple[Path, Dict[str, Any]]:
if (cloud_dir is None) == (ply_list_file is None):
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
@@ -187,6 +174,22 @@ def _run_test_dgcnn_weight_estimator_subprocess(
if remove_outliers:
cmd.append("--remove-outliers")
cmd.extend(["--outlier-method", outlier_method])
cmd.extend(["--max-length-mm", str(weight_max_length_mm)])
cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)])
cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)])
cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)])
if weight_average_all_after_filter:
cmd.append("--average-all-after-filter")
fb_thr = weight_average_all_fallback_max_if_mean_over_g
if fb_thr is not None and float(fb_thr) > 0:
cmd.extend(
["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))]
)
mp_fb = weight_mean_pool_fallback_max_if_over_g
if mp_fb is not None and float(mp_fb) > 0:
cmd.extend(
["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))]
)
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
@@ -215,6 +218,9 @@ def _merge_weight_prediction_json(
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
pred_kg = summary.get("pred_weight_kg")
pred_g = summary.get("pred_weight_g")
pred_rule = summary.get("pred_weight_rule")
return {
"svo": str(svo_path),
"svo_name": svo_name,
@@ -228,6 +234,9 @@ def _merge_weight_prediction_json(
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"pred_weight_g": None if skipped else pred_g,
"pred_weight_kg": None if skipped else pred_kg,
"pred_weight_rule": None if skipped else pred_rule,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
@@ -248,6 +257,9 @@ def _merge_weight_prediction_json_combined(
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
pred_kg = summary.get("pred_weight_kg")
pred_g = summary.get("pred_weight_g")
pred_rule = summary.get("pred_weight_rule")
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
return {
@@ -265,6 +277,9 @@ def _merge_weight_prediction_json_combined(
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"pred_weight_g": None if skipped else pred_g,
"pred_weight_kg": None if skipped else pred_kg,
"pred_weight_rule": None if skipped else pred_rule,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
@@ -297,7 +312,13 @@ def run_weight_prediction_for_svo(
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
force_dgcnn_subprocess: bool = False,
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Dict[str, Any]:
svo_path = svo_path.expanduser().resolve()
if not svo_path.exists():
@@ -314,28 +335,6 @@ def run_weight_prediction_for_svo(
f"fish_video_weight_evaluation.py first."
)
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
if not force_dgcnn_subprocess and fish_wj.is_file():
try:
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=fish_wj,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
except Exception as e:
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=cloud_dir,
ply_list_file=None,
@@ -351,6 +350,13 @@ def run_weight_prediction_for_svo(
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
weight_max_length_mm=weight_max_length_mm,
weight_min_length_width_ratio=weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
weight_average_all_after_filter=weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
)
result = _merge_weight_prediction_json(
svo_path=svo_path,
@@ -381,6 +387,13 @@ def run_weight_prediction_combined_svos(
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Dict[str, Any]:
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
output_base = output_base.expanduser().resolve()
@@ -411,6 +424,13 @@ def run_weight_prediction_combined_svos(
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
weight_max_length_mm=weight_max_length_mm,
weight_min_length_width_ratio=weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
weight_average_all_after_filter=weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
)
result = _merge_weight_prediction_json_combined(
svo_paths=svo_paths,
@@ -502,6 +522,48 @@ def main() -> None:
)
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
parser.add_argument("--weight-labels-json", type=str, default=None)
parser.add_argument(
"--weight-max-length-mm",
type=float,
default=400.0,
help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).",
)
parser.add_argument(
"--weight-min-length-width-ratio",
type=float,
default=1.5,
help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.",
)
parser.add_argument(
"--weight-average-all-after-filter",
action=argparse.BooleanOptionalAction,
default=False,
help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).",
)
parser.add_argument(
"--weight-average-all-fallback-max-if-mean-over-g",
type=float,
default=400.0,
help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).",
)
parser.add_argument(
"--weight-mean-pool-fallback-max-if-over-g",
type=float,
default=440.0,
help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).",
)
parser.add_argument(
"--weight-length-quality-cv-threshold-pct",
type=float,
default=15.0,
help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).",
)
parser.add_argument(
"--weight-length-quality-max-span-mm",
type=float,
default=130.0,
help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.",
)
parser.add_argument(
"--reuse-existing-clouds",
@@ -511,24 +573,6 @@ def main() -> None:
)
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
parser.add_argument(
"--fish-video-weight-overlay",
action="store_true",
help="Extra on-video header lines (top-5 / per-minute bucket). "
"DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.",
)
parser.add_argument(
"--minute-interval-sec",
type=float,
default=60.0,
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
)
parser.add_argument(
"--force-dgcnn-subprocess",
action="store_true",
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
)
args = parser.parse_args()
if args.frame_stride < 1:
raise SystemExit("--frame-stride must be >= 1")
@@ -586,6 +630,13 @@ def main() -> None:
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
weight_max_length_mm=args.weight_max_length_mm,
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
weight_average_all_after_filter=args.weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
)
)
except Exception as e:
@@ -612,7 +663,13 @@ def main() -> None:
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
weight_max_length_mm=args.weight_max_length_mm,
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
weight_average_all_after_filter=args.weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
)
)
except Exception as e:
@@ -628,8 +685,11 @@ def main() -> None:
if len(results) == 1 and "error" not in results[0]:
r0 = results[0]
avg_g = r0.get("avg_predicted_weight_g")
if avg_g is not None:
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
pred_g = r0.get("pred_weight_g")
out_g = pred_g if pred_g is not None else avg_g
if out_g is not None:
label = "pred_weight" if pred_g is not None else "avg_predicted_weight"
print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g")
else:
print("Final predicted weight: N/A (skipped or no valid point clouds)")

View File

@@ -10,7 +10,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
SESSION_ROOT="/home/ubuntu/data/fish/2016-1-22-last"
FISH_NAME="fish17"
FISH_NAME="fish9"
fish_dir="${SESSION_ROOT}/${FISH_NAME}/"
OUT_PARENT="output_weight_estimator"
save_out="${OUT_PARENT}/${FISH_NAME}"
@@ -37,7 +37,7 @@ python3 predict_weigth_from_svo2.py \
--weight-checkpoint weight_estimator/runs/dgcnn_20260312_171043/best.pt \
--save-output "$save_out" \
--yolo-model "/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt" \
--conf 0.5 \
--conf 0.8 \
--imgsz 640 \
--sam-device cuda \
--max-frames 0 \
@@ -50,4 +50,5 @@ python3 predict_weigth_from_svo2.py \
--flatness-threshold 55.0 \
--frame-stride 1 \
--weight-top-k 5 \
--weight-top-by-length
--weight-top-by-length
# Optional: append --no-weight-top-by-length if you want top-K by predicted weight only.

5
FishMeasure/weight_estimator/test_dgcc.sh Executable file → Normal file
View File

@@ -1,5 +1,6 @@
python test_dgcnn_weight_estimator.py --checkpoint runs/dgcnn_20260312_171043/best.pt\
--batch-root '/home/ubuntu/projects/FishMeasure/output_weight_estimator' --top-k=5
python test_dgcnn_weight_estimator.py --checkpoint runs/dgcnn_20260312_171043/best.pt \
--batch-root '/home/ubuntu/projects/FishMeasure/output_weight_estimator' --top-k=5 --top-by-length #--average-all-after-filter

File diff suppressed because it is too large Load Diff

View File

@@ -63,6 +63,8 @@ def _predict_weigth_from_svo2_extra_args(settings: Settings) -> List[str]:
str(settings.predict_minute_interval_sec),
]
)
if settings.predict_show_large_labels_at_top_right:
out.append("--show-large-labels-at-top-right")
if not settings.measure_reuse_existing_clouds:
out.append("--no-reuse-existing-clouds")
return out

View File

@@ -95,6 +95,8 @@ class Settings(BaseSettings):
#: 为 True 时 fish_video 内联 DGCNN + 预览叠加(更重;需 fish_video 已支持)
predict_fish_video_weight_overlay: bool = False
predict_minute_interval_sec: float = 60.0
#: 为 True 时在视频右上角显示大型 weight/length 标签10倍字体便于查看真实/相机生成视频的标签数据
predict_show_large_labels_at_top_right: bool = False
action_checkpoint: Optional[str] = None
action_clips_per_video: int = 8