feat/ sync weight fitting logic and confidence level marker *
This commit is contained in:
773
FishMeasure/fish_video_weight_evaluation.py
Executable file → Normal file
773
FishMeasure/fish_video_weight_evaluation.py
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
188
FishMeasure/predict_weigth_from_svo2.py
Executable file → Normal file
188
FishMeasure/predict_weigth_from_svo2.py
Executable file → Normal file
@@ -102,26 +102,6 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol
|
||||
cmd.append("--use-flatness-filter")
|
||||
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
|
||||
|
||||
# 始终在 fish 内跑 DGCNN,生成 weight_estimation_results.json,预览视频才能叠加 weight/length;
|
||||
# predict 后续会合并该 JSON,避免重复跑 test_dgcnn。
|
||||
wck = Path(args.weight_checkpoint).expanduser().resolve()
|
||||
if wck.is_file():
|
||||
cmd.extend(
|
||||
[
|
||||
"--run-weight-estimation",
|
||||
"--weight-estimator-checkpoint",
|
||||
str(wck),
|
||||
]
|
||||
)
|
||||
if getattr(args, "fish_video_weight_overlay", False):
|
||||
cmd.extend(
|
||||
[
|
||||
"--weight-overlay-video",
|
||||
"--minute-interval-sec",
|
||||
str(getattr(args, "minute_interval_sec", 60.0)),
|
||||
]
|
||||
)
|
||||
|
||||
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
|
||||
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
||||
if proc.returncode != 0:
|
||||
@@ -144,6 +124,13 @@ def _run_test_dgcnn_weight_estimator_subprocess(
|
||||
remove_outliers: bool,
|
||||
outlier_method: str,
|
||||
labels_json: Optional[str],
|
||||
weight_max_length_mm: float = 400.0,
|
||||
weight_min_length_width_ratio: float = 1.5,
|
||||
weight_length_quality_cv_threshold_pct: float = 15.0,
|
||||
weight_length_quality_max_span_mm: float = 130.0,
|
||||
weight_average_all_after_filter: bool = False,
|
||||
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
||||
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
||||
) -> Tuple[Path, Dict[str, Any]]:
|
||||
if (cloud_dir is None) == (ply_list_file is None):
|
||||
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
|
||||
@@ -187,6 +174,22 @@ def _run_test_dgcnn_weight_estimator_subprocess(
|
||||
if remove_outliers:
|
||||
cmd.append("--remove-outliers")
|
||||
cmd.extend(["--outlier-method", outlier_method])
|
||||
cmd.extend(["--max-length-mm", str(weight_max_length_mm)])
|
||||
cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)])
|
||||
cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)])
|
||||
cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)])
|
||||
if weight_average_all_after_filter:
|
||||
cmd.append("--average-all-after-filter")
|
||||
fb_thr = weight_average_all_fallback_max_if_mean_over_g
|
||||
if fb_thr is not None and float(fb_thr) > 0:
|
||||
cmd.extend(
|
||||
["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))]
|
||||
)
|
||||
mp_fb = weight_mean_pool_fallback_max_if_over_g
|
||||
if mp_fb is not None and float(mp_fb) > 0:
|
||||
cmd.extend(
|
||||
["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))]
|
||||
)
|
||||
|
||||
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
|
||||
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
||||
@@ -215,6 +218,9 @@ def _merge_weight_prediction_json(
|
||||
skipped = bool(summary.get("skipped"))
|
||||
avg_kg = summary.get("avg_predicted_weight_kg")
|
||||
avg_g = summary.get("avg_predicted_weight_g")
|
||||
pred_kg = summary.get("pred_weight_kg")
|
||||
pred_g = summary.get("pred_weight_g")
|
||||
pred_rule = summary.get("pred_weight_rule")
|
||||
return {
|
||||
"svo": str(svo_path),
|
||||
"svo_name": svo_name,
|
||||
@@ -228,6 +234,9 @@ def _merge_weight_prediction_json(
|
||||
"num_ply_predicted": summary.get("num_files_predicted"),
|
||||
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
||||
"avg_predicted_weight_g": None if skipped else avg_g,
|
||||
"pred_weight_g": None if skipped else pred_g,
|
||||
"pred_weight_kg": None if skipped else pred_kg,
|
||||
"pred_weight_rule": None if skipped else pred_rule,
|
||||
"dgcnn_meta": dgcnn_data.get("meta"),
|
||||
"dgcnn_summary": summary,
|
||||
"weight_summary": summary,
|
||||
@@ -248,6 +257,9 @@ def _merge_weight_prediction_json_combined(
|
||||
skipped = bool(summary.get("skipped"))
|
||||
avg_kg = summary.get("avg_predicted_weight_kg")
|
||||
avg_g = summary.get("avg_predicted_weight_g")
|
||||
pred_kg = summary.get("pred_weight_kg")
|
||||
pred_g = summary.get("pred_weight_g")
|
||||
pred_rule = summary.get("pred_weight_rule")
|
||||
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
|
||||
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
|
||||
return {
|
||||
@@ -265,6 +277,9 @@ def _merge_weight_prediction_json_combined(
|
||||
"num_ply_predicted": summary.get("num_files_predicted"),
|
||||
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
||||
"avg_predicted_weight_g": None if skipped else avg_g,
|
||||
"pred_weight_g": None if skipped else pred_g,
|
||||
"pred_weight_kg": None if skipped else pred_kg,
|
||||
"pred_weight_rule": None if skipped else pred_rule,
|
||||
"dgcnn_meta": dgcnn_data.get("meta"),
|
||||
"dgcnn_summary": summary,
|
||||
"weight_summary": summary,
|
||||
@@ -297,7 +312,13 @@ def run_weight_prediction_for_svo(
|
||||
weight_outlier_method: str,
|
||||
weight_xyz_scale: float,
|
||||
weight_labels_json: Optional[str],
|
||||
force_dgcnn_subprocess: bool = False,
|
||||
weight_max_length_mm: float = 400.0,
|
||||
weight_min_length_width_ratio: float = 1.5,
|
||||
weight_length_quality_cv_threshold_pct: float = 15.0,
|
||||
weight_length_quality_max_span_mm: float = 130.0,
|
||||
weight_average_all_after_filter: bool = False,
|
||||
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
||||
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
||||
) -> Dict[str, Any]:
|
||||
svo_path = svo_path.expanduser().resolve()
|
||||
if not svo_path.exists():
|
||||
@@ -314,28 +335,6 @@ def run_weight_prediction_for_svo(
|
||||
f"fish_video_weight_evaluation.py first."
|
||||
)
|
||||
|
||||
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
|
||||
if not force_dgcnn_subprocess and fish_wj.is_file():
|
||||
try:
|
||||
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
|
||||
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
|
||||
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
|
||||
result = _merge_weight_prediction_json(
|
||||
svo_path=svo_path,
|
||||
svo_name=svo_name,
|
||||
out_dir=out_dir,
|
||||
cloud_dir=cloud_dir,
|
||||
dgcnn_json_path=fish_wj,
|
||||
dgcnn_data=dgcnn_data,
|
||||
)
|
||||
(out_dir / "weight_prediction.json").write_text(
|
||||
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
|
||||
|
||||
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
|
||||
cloud_dir=cloud_dir,
|
||||
ply_list_file=None,
|
||||
@@ -351,6 +350,13 @@ def run_weight_prediction_for_svo(
|
||||
remove_outliers=weight_remove_outliers,
|
||||
outlier_method=weight_outlier_method,
|
||||
labels_json=weight_labels_json,
|
||||
weight_max_length_mm=weight_max_length_mm,
|
||||
weight_min_length_width_ratio=weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
result = _merge_weight_prediction_json(
|
||||
svo_path=svo_path,
|
||||
@@ -381,6 +387,13 @@ def run_weight_prediction_combined_svos(
|
||||
weight_outlier_method: str,
|
||||
weight_xyz_scale: float,
|
||||
weight_labels_json: Optional[str],
|
||||
weight_max_length_mm: float = 400.0,
|
||||
weight_min_length_width_ratio: float = 1.5,
|
||||
weight_length_quality_cv_threshold_pct: float = 15.0,
|
||||
weight_length_quality_max_span_mm: float = 130.0,
|
||||
weight_average_all_after_filter: bool = False,
|
||||
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
||||
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
|
||||
output_base = output_base.expanduser().resolve()
|
||||
@@ -411,6 +424,13 @@ def run_weight_prediction_combined_svos(
|
||||
remove_outliers=weight_remove_outliers,
|
||||
outlier_method=weight_outlier_method,
|
||||
labels_json=weight_labels_json,
|
||||
weight_max_length_mm=weight_max_length_mm,
|
||||
weight_min_length_width_ratio=weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
result = _merge_weight_prediction_json_combined(
|
||||
svo_paths=svo_paths,
|
||||
@@ -502,6 +522,48 @@ def main() -> None:
|
||||
)
|
||||
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
|
||||
parser.add_argument("--weight-labels-json", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--weight-max-length-mm",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-min-length-width-ratio",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-average-all-after-filter",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-average-all-fallback-max-if-mean-over-g",
|
||||
type=float,
|
||||
default=400.0,
|
||||
help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-mean-pool-fallback-max-if-over-g",
|
||||
type=float,
|
||||
default=440.0,
|
||||
help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-length-quality-cv-threshold-pct",
|
||||
type=float,
|
||||
default=15.0,
|
||||
help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight-length-quality-max-span-mm",
|
||||
type=float,
|
||||
default=130.0,
|
||||
help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reuse-existing-clouds",
|
||||
@@ -511,24 +573,6 @@ def main() -> None:
|
||||
)
|
||||
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
|
||||
|
||||
parser.add_argument(
|
||||
"--fish-video-weight-overlay",
|
||||
action="store_true",
|
||||
help="Extra on-video header lines (top-5 / per-minute bucket). "
|
||||
"DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--minute-interval-sec",
|
||||
type=float,
|
||||
default=60.0,
|
||||
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-dgcnn-subprocess",
|
||||
action="store_true",
|
||||
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.frame_stride < 1:
|
||||
raise SystemExit("--frame-stride must be >= 1")
|
||||
@@ -586,6 +630,13 @@ def main() -> None:
|
||||
weight_outlier_method=args.weight_outlier_method,
|
||||
weight_xyz_scale=args.weight_xyz_scale,
|
||||
weight_labels_json=args.weight_labels_json,
|
||||
weight_max_length_mm=args.weight_max_length_mm,
|
||||
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=args.weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -612,7 +663,13 @@ def main() -> None:
|
||||
weight_outlier_method=args.weight_outlier_method,
|
||||
weight_xyz_scale=args.weight_xyz_scale,
|
||||
weight_labels_json=args.weight_labels_json,
|
||||
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
|
||||
weight_max_length_mm=args.weight_max_length_mm,
|
||||
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
|
||||
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
|
||||
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
|
||||
weight_average_all_after_filter=args.weight_average_all_after_filter,
|
||||
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
|
||||
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -628,8 +685,11 @@ def main() -> None:
|
||||
if len(results) == 1 and "error" not in results[0]:
|
||||
r0 = results[0]
|
||||
avg_g = r0.get("avg_predicted_weight_g")
|
||||
if avg_g is not None:
|
||||
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
|
||||
pred_g = r0.get("pred_weight_g")
|
||||
out_g = pred_g if pred_g is not None else avg_g
|
||||
if out_g is not None:
|
||||
label = "pred_weight" if pred_g is not None else "avg_predicted_weight"
|
||||
print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g")
|
||||
else:
|
||||
print("Final predicted weight: N/A (skipped or no valid point clouds)")
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
SESSION_ROOT="/home/ubuntu/data/fish/2016-1-22-last"
|
||||
FISH_NAME="fish17"
|
||||
FISH_NAME="fish9"
|
||||
fish_dir="${SESSION_ROOT}/${FISH_NAME}/"
|
||||
OUT_PARENT="output_weight_estimator"
|
||||
save_out="${OUT_PARENT}/${FISH_NAME}"
|
||||
@@ -37,7 +37,7 @@ python3 predict_weigth_from_svo2.py \
|
||||
--weight-checkpoint weight_estimator/runs/dgcnn_20260312_171043/best.pt \
|
||||
--save-output "$save_out" \
|
||||
--yolo-model "/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt" \
|
||||
--conf 0.5 \
|
||||
--conf 0.8 \
|
||||
--imgsz 640 \
|
||||
--sam-device cuda \
|
||||
--max-frames 0 \
|
||||
@@ -50,4 +50,5 @@ python3 predict_weigth_from_svo2.py \
|
||||
--flatness-threshold 55.0 \
|
||||
--frame-stride 1 \
|
||||
--weight-top-k 5 \
|
||||
--weight-top-by-length
|
||||
--weight-top-by-length
|
||||
# Optional: append --no-weight-top-by-length if you want top-K by predicted weight only.
|
||||
|
||||
5
FishMeasure/weight_estimator/test_dgcc.sh
Executable file → Normal file
5
FishMeasure/weight_estimator/test_dgcc.sh
Executable file → Normal file
@@ -1,5 +1,6 @@
|
||||
python test_dgcnn_weight_estimator.py --checkpoint runs/dgcnn_20260312_171043/best.pt\
|
||||
--batch-root '/home/ubuntu/projects/FishMeasure/output_weight_estimator' --top-k=5
|
||||
python test_dgcnn_weight_estimator.py --checkpoint runs/dgcnn_20260312_171043/best.pt \
|
||||
--batch-root '/home/ubuntu/projects/FishMeasure/output_weight_estimator' --top-k=5 --top-by-length #--average-all-after-filter
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
920
FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py
Executable file → Normal file
920
FishMeasure/weight_estimator/test_dgcnn_weight_estimator.py
Executable file → Normal file
File diff suppressed because it is too large
Load Diff
@@ -63,6 +63,8 @@ def _predict_weigth_from_svo2_extra_args(settings: Settings) -> List[str]:
|
||||
str(settings.predict_minute_interval_sec),
|
||||
]
|
||||
)
|
||||
if settings.predict_show_large_labels_at_top_right:
|
||||
out.append("--show-large-labels-at-top-right")
|
||||
if not settings.measure_reuse_existing_clouds:
|
||||
out.append("--no-reuse-existing-clouds")
|
||||
return out
|
||||
|
||||
@@ -95,6 +95,8 @@ class Settings(BaseSettings):
|
||||
#: 为 True 时 fish_video 内联 DGCNN + 预览叠加(更重;需 fish_video 已支持)
|
||||
predict_fish_video_weight_overlay: bool = False
|
||||
predict_minute_interval_sec: float = 60.0
|
||||
#: 为 True 时在视频右上角显示大型 weight/length 标签(10倍字体),便于查看真实/相机生成视频的标签数据
|
||||
predict_show_large_labels_at_top_right: bool = False
|
||||
|
||||
action_checkpoint: Optional[str] = None
|
||||
action_clips_per_video: int = 8
|
||||
|
||||
Reference in New Issue
Block a user