802 lines
32 KiB
Python
802 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
SVO2 → point clouds via **fish_video_weight_evaluation.py** (subprocess), then weight via
|
|
**weight_estimator/test_dgcnn_weight_estimator.py** (subprocess).
|
|
|
|
Default ``--conf`` is 0.8 (same as ``run_predict_from_svo2_fish9.sh``); ``run_fish_evaluation_simple.sh``
|
|
uses 0.5 — override with ``--conf``. FishServer passes ``--conf`` from ``MEASURE_YOLO_CONF`` (default 0.8).
|
|
|
|
Outputs:
|
|
``<save-output>/<svo_stem>/cloud/*.ply`` — from fish_video_weight_evaluation
|
|
``<save-output>/<svo_stem>/weight_prediction.json`` — merged SVO info + DGCNN JSON (single-SVO mode)
|
|
``<save-output>/weight_prediction.json`` — when ``--batch-svo-folder`` has **multiple** ``*.svo2``:
|
|
one DGCNN run over **all** PLYs from every ``<svo_stem>/cloud/`` (top-K applies to the union)
|
|
``<save-output>/<svo_stem>/dgcnn_test_output_top<K>.json`` — raw test_dgcnn output (per-SVO or combined under ``<save-output>/``)
|
|
|
|
With ``--reuse-existing-clouds`` (default), skips the fish subprocess when
|
|
``cloud/*.ply`` already exists; only runs test_dgcnn.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
import math
|
|
from pathlib import Path
|
|
|
|
if not hasattr(argparse, "BooleanOptionalAction"):
|
|
class _BooleanOptionalAction(argparse.Action):
|
|
def __init__(self, option_strings, dest, default=None, type=None,
|
|
choices=None, required=False, help=None, metavar=None):
|
|
_option_strings = []
|
|
for opt in option_strings:
|
|
_option_strings.append(opt)
|
|
if opt.startswith("--"):
|
|
_option_strings.append("--no-" + opt[2:])
|
|
super().__init__(option_strings=_option_strings, dest=dest, nargs=0,
|
|
default=default, type=type, choices=choices,
|
|
required=required, help=help, metavar=metavar)
|
|
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
setattr(namespace, self.dest, not option_string.startswith("--no-"))
|
|
|
|
argparse.BooleanOptionalAction = _BooleanOptionalAction
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
|
|
def _fish_video_script() -> Path:
|
|
return REPO_ROOT / "fish_video_weight_evaluation.py"
|
|
|
|
|
|
def _dgcnn_test_script() -> Path:
|
|
return REPO_ROOT / "weight_estimator" / "test_dgcnn_weight_estimator.py"
|
|
|
|
|
|
def _torch_device_to_test_arg(device: torch.device) -> str:
|
|
return "cpu" if device.type == "cpu" else "cuda"
|
|
|
|
|
|
def _collect_existing_clouds(cloud_dir: Path) -> List[Path]:
|
|
return sorted(cloud_dir.glob("*.ply"))
|
|
|
|
|
|
def has_cached_clouds(output_base: Path, svo_path: Path) -> bool:
|
|
svo_path = svo_path.expanduser().resolve()
|
|
cloud_dir = output_base / svo_path.stem / "cloud"
|
|
return len(_collect_existing_clouds(cloud_dir)) > 0
|
|
|
|
|
|
def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_folder: Optional[Path]) -> None:
|
|
fvs = _fish_video_script()
|
|
if not fvs.is_file():
|
|
raise FileNotFoundError(f"Missing: {fvs}")
|
|
|
|
out_parent = Path(args.save_output).expanduser().resolve()
|
|
cmd: List[str] = [
|
|
sys.executable,
|
|
str(fvs),
|
|
"--save-output",
|
|
str(out_parent),
|
|
"--yolo-model",
|
|
args.yolo_model,
|
|
"--conf",
|
|
str(args.conf),
|
|
"--imgsz",
|
|
str(args.imgsz),
|
|
"--sam-device",
|
|
args.sam_device,
|
|
"--max-frames",
|
|
str(args.max_frames),
|
|
"--frame-stride",
|
|
str(args.frame_stride),
|
|
]
|
|
if batch_folder is not None:
|
|
cmd.extend(["--batch-svo-folder", str(batch_folder.resolve())])
|
|
else:
|
|
cmd.extend(["--svo", str(Path(args.svo).expanduser().resolve())])
|
|
|
|
if args.filter_pointcloud:
|
|
cmd.append("--filter-pointcloud")
|
|
if args.use_density_filter:
|
|
cmd.append("--use-density-filter")
|
|
if args.use_clustering_filter:
|
|
cmd.append("--use-clustering-filter")
|
|
if args.use_pointcloud_classifier:
|
|
cmd.append("--use-pointcloud-classifier")
|
|
if args.pointcloud_classifier:
|
|
cmd.extend(["--pointcloud-classifier", args.pointcloud_classifier])
|
|
cmd.extend(
|
|
["--pointcloud-classifier-threshold", str(args.pointcloud_classifier_threshold)]
|
|
)
|
|
if args.use_flatness_filter:
|
|
cmd.append("--use-flatness-filter")
|
|
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
|
|
if getattr(args, "show_large_labels_at_top_right", False):
|
|
cmd.append("--show-large-labels-at-top-right")
|
|
|
|
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
|
|
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"fish_video_weight_evaluation.py exited with code {proc.returncode}")
|
|
|
|
|
|
def _run_generate_video_with_labels_subprocess(
|
|
*,
|
|
args: argparse.Namespace,
|
|
svo_path: Path,
|
|
output_dir: Path,
|
|
weight_json: Path,
|
|
) -> None:
|
|
"""Run generate_video_with_labels.py to create preview video with weight/length labels.
|
|
|
|
Called after DGCNN completes to (re)generate the preview video with
|
|
actual weight/length overlaid on detection boxes.
|
|
"""
|
|
script = REPO_ROOT / "generate_video_with_labels.py"
|
|
if not script.is_file():
|
|
print(f"Warning: generate_video_with_labels.py not found at {script}, skipping video generation")
|
|
return
|
|
|
|
cmd = [
|
|
sys.executable,
|
|
str(script),
|
|
"--svo", str(svo_path.expanduser().resolve()),
|
|
"--save-output", str(output_dir.expanduser().resolve()),
|
|
"--yolo-model", args.yolo_model,
|
|
"--conf", str(args.conf),
|
|
"--imgsz", str(args.imgsz),
|
|
"--frame-stride", str(args.frame_stride),
|
|
"--sam-device", str(args.sam_device),
|
|
"--weight-json", str(weight_json.expanduser().resolve()),
|
|
]
|
|
|
|
if getattr(args, "show_large_labels_at_top_right", False):
|
|
cmd.append("--show-large-labels-at-top-right")
|
|
if getattr(args, "summary_star", False):
|
|
cmd.append("--summary-star")
|
|
|
|
print(f"Invoking generate_video_with_labels.py:\n {' '.join(cmd)}")
|
|
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
|
if proc.returncode != 0:
|
|
print(f"Warning: generate_video_with_labels.py exited with code {proc.returncode}")
|
|
|
|
|
|
def _run_test_dgcnn_weight_estimator_subprocess(
|
|
*,
|
|
cloud_dir: Optional[Path] = None,
|
|
ply_list_file: Optional[Path] = None,
|
|
checkpoint: Path,
|
|
device: torch.device,
|
|
out_dir: Path,
|
|
num_points: int,
|
|
xyz_scale: float,
|
|
top_k: int,
|
|
top_by_length: bool,
|
|
length_switch_mm: float,
|
|
topk_length: Optional[int],
|
|
remove_outliers: bool,
|
|
outlier_method: str,
|
|
labels_json: Optional[str],
|
|
weight_max_length_mm: float = 400.0,
|
|
weight_min_length_width_ratio: float = 1.5,
|
|
weight_length_quality_cv_threshold_pct: float = 15.0,
|
|
weight_length_quality_max_span_mm: float = 130.0,
|
|
weight_average_all_after_filter: bool = False,
|
|
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
|
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
|
) -> Tuple[Path, Dict[str, Any]]:
|
|
if (cloud_dir is None) == (ply_list_file is None):
|
|
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
|
|
|
|
script = _dgcnn_test_script()
|
|
if not script.is_file():
|
|
raise FileNotFoundError(f"Missing: {script}")
|
|
|
|
json_base = out_dir / "dgcnn_test_output.json"
|
|
labels_path = labels_json if labels_json else str(out_dir / "__no_labels__.json")
|
|
|
|
cmd: List[str] = [
|
|
sys.executable,
|
|
str(script),
|
|
"--checkpoint",
|
|
str(checkpoint.expanduser().resolve()),
|
|
"--device",
|
|
_torch_device_to_test_arg(device),
|
|
"--num-points",
|
|
str(num_points),
|
|
"--xyz-scale",
|
|
str(xyz_scale),
|
|
"--top-k",
|
|
str(top_k),
|
|
"--output-json",
|
|
str(json_base),
|
|
"--labels-json",
|
|
labels_path,
|
|
]
|
|
if ply_list_file is not None:
|
|
cmd.extend(["--ply-list-file", str(ply_list_file.resolve())])
|
|
else:
|
|
cmd.extend(["--ply-folder", str(cloud_dir.resolve())])
|
|
if top_by_length:
|
|
cmd.append("--top-by-length")
|
|
else:
|
|
cmd.append("--no-top-by-length")
|
|
cmd.extend(["--length-switch-mm", str(length_switch_mm)])
|
|
if topk_length is not None:
|
|
cmd.extend(["--topk-length", str(topk_length)])
|
|
if remove_outliers:
|
|
cmd.append("--remove-outliers")
|
|
cmd.extend(["--outlier-method", outlier_method])
|
|
cmd.extend(["--max-length-mm", str(weight_max_length_mm)])
|
|
cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)])
|
|
cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)])
|
|
cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)])
|
|
if weight_average_all_after_filter:
|
|
cmd.append("--average-all-after-filter")
|
|
fb_thr = weight_average_all_fallback_max_if_mean_over_g
|
|
if fb_thr is not None and float(fb_thr) > 0:
|
|
cmd.extend(
|
|
["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))]
|
|
)
|
|
mp_fb = weight_mean_pool_fallback_max_if_over_g
|
|
if mp_fb is not None and float(mp_fb) > 0:
|
|
cmd.extend(
|
|
["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))]
|
|
)
|
|
|
|
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
|
|
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"test_dgcnn_weight_estimator.py exited with code {proc.returncode}")
|
|
|
|
suffix = f"_top{top_k}_by_length" if top_by_length else f"_top{top_k}"
|
|
written = json_base.parent / f"{json_base.stem}{suffix}{json_base.suffix}"
|
|
if not written.is_file():
|
|
raise FileNotFoundError(f"Expected DGCNN output JSON at {written}")
|
|
|
|
data = json.loads(written.read_text(encoding="utf-8"))
|
|
return written, data
|
|
|
|
|
|
def _merge_weight_prediction_json(
|
|
*,
|
|
svo_path: Path,
|
|
svo_name: str,
|
|
out_dir: Path,
|
|
cloud_dir: Path,
|
|
dgcnn_json_path: Path,
|
|
dgcnn_data: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
summary = dgcnn_data.get("summary") or {}
|
|
skipped = bool(summary.get("skipped"))
|
|
avg_kg = summary.get("avg_predicted_weight_kg")
|
|
avg_g = summary.get("avg_predicted_weight_g")
|
|
pred_kg = summary.get("pred_weight_kg")
|
|
pred_g = summary.get("pred_weight_g")
|
|
pred_rule = summary.get("pred_weight_rule")
|
|
return {
|
|
"svo": str(svo_path),
|
|
"svo_name": svo_name,
|
|
"output_dir": str(out_dir),
|
|
"cloud_dir": str(cloud_dir),
|
|
"num_ply_on_disk": len(_collect_existing_clouds(cloud_dir)),
|
|
"dgcnn_output_json": str(dgcnn_json_path),
|
|
"skipped_weight": skipped,
|
|
"skip_reason": summary.get("skip_reason"),
|
|
"num_clouds_used": summary.get("num_files_used_for_avg"),
|
|
"num_ply_predicted": summary.get("num_files_predicted"),
|
|
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
|
"avg_predicted_weight_g": None if skipped else avg_g,
|
|
"pred_weight_g": None if skipped else pred_g,
|
|
"pred_weight_kg": None if skipped else pred_kg,
|
|
"pred_weight_rule": None if skipped else pred_rule,
|
|
"dgcnn_meta": dgcnn_data.get("meta"),
|
|
"dgcnn_summary": summary,
|
|
"weight_summary": summary,
|
|
"per_cloud": dgcnn_data.get("per_file") or [],
|
|
"comparison": dgcnn_data.get("comparison"),
|
|
}
|
|
|
|
|
|
def _merge_weight_prediction_json_combined(
|
|
*,
|
|
svo_paths: List[Path],
|
|
output_base: Path,
|
|
ply_list_path: Path,
|
|
dgcnn_json_path: Path,
|
|
dgcnn_data: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
summary = dgcnn_data.get("summary") or {}
|
|
skipped = bool(summary.get("skipped"))
|
|
avg_kg = summary.get("avg_predicted_weight_kg")
|
|
avg_g = summary.get("avg_predicted_weight_g")
|
|
pred_kg = summary.get("pred_weight_kg")
|
|
pred_g = summary.get("pred_weight_g")
|
|
pred_rule = summary.get("pred_weight_rule")
|
|
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
|
|
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
|
|
return {
|
|
"combined": True,
|
|
"svo_paths": [str(p) for p in svo_paths],
|
|
"svo_names": [p.stem for p in svo_paths],
|
|
"output_dir": str(output_base),
|
|
"cloud_dirs": [str(d) for d in cloud_dirs],
|
|
"ply_list_file": str(ply_list_path),
|
|
"num_ply_on_disk": n_ply,
|
|
"dgcnn_output_json": str(dgcnn_json_path),
|
|
"skipped_weight": skipped,
|
|
"skip_reason": summary.get("skip_reason"),
|
|
"num_clouds_used": summary.get("num_files_used_for_avg"),
|
|
"num_ply_predicted": summary.get("num_files_predicted"),
|
|
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
|
"avg_predicted_weight_g": None if skipped else avg_g,
|
|
"pred_weight_g": None if skipped else pred_g,
|
|
"pred_weight_kg": None if skipped else pred_kg,
|
|
"pred_weight_rule": None if skipped else pred_rule,
|
|
"dgcnn_meta": dgcnn_data.get("meta"),
|
|
"dgcnn_summary": summary,
|
|
"weight_summary": summary,
|
|
"per_cloud": dgcnn_data.get("per_file") or [],
|
|
"comparison": dgcnn_data.get("comparison"),
|
|
}
|
|
|
|
|
|
def _sanitize_for_json(obj: Any) -> Any:
|
|
if isinstance(obj, dict):
|
|
return {k: _sanitize_for_json(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_sanitize_for_json(v) for v in obj]
|
|
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
|
return None
|
|
return obj
|
|
|
|
|
|
def run_weight_prediction_for_svo(
|
|
svo_path: Path,
|
|
output_base: Path,
|
|
weight_checkpoint: Path,
|
|
weight_device: torch.device,
|
|
num_points_model: int,
|
|
weight_top_k: int,
|
|
weight_top_by_length: bool,
|
|
weight_length_switch_mm: float,
|
|
weight_topk_length: Optional[int],
|
|
weight_remove_outliers: bool,
|
|
weight_outlier_method: str,
|
|
weight_xyz_scale: float,
|
|
weight_labels_json: Optional[str],
|
|
weight_max_length_mm: float = 400.0,
|
|
weight_min_length_width_ratio: float = 1.5,
|
|
weight_length_quality_cv_threshold_pct: float = 15.0,
|
|
weight_length_quality_max_span_mm: float = 130.0,
|
|
weight_average_all_after_filter: bool = False,
|
|
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
|
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
|
) -> Dict[str, Any]:
|
|
svo_path = svo_path.expanduser().resolve()
|
|
if not svo_path.exists():
|
|
raise FileNotFoundError(f"svo not found: {svo_path}")
|
|
|
|
svo_name = svo_path.stem
|
|
out_dir = output_base / svo_name
|
|
cloud_dir = out_dir / "cloud"
|
|
|
|
plys = _collect_existing_clouds(cloud_dir)
|
|
if not plys:
|
|
raise RuntimeError(
|
|
f"No .ply files in {cloud_dir}. Run without --reuse-existing-clouds or run "
|
|
f"fish_video_weight_evaluation.py first."
|
|
)
|
|
|
|
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
|
|
cloud_dir=cloud_dir,
|
|
ply_list_file=None,
|
|
checkpoint=weight_checkpoint,
|
|
device=weight_device,
|
|
out_dir=out_dir,
|
|
num_points=num_points_model,
|
|
xyz_scale=weight_xyz_scale,
|
|
top_k=weight_top_k,
|
|
top_by_length=weight_top_by_length,
|
|
length_switch_mm=weight_length_switch_mm,
|
|
topk_length=weight_topk_length,
|
|
remove_outliers=weight_remove_outliers,
|
|
outlier_method=weight_outlier_method,
|
|
labels_json=weight_labels_json,
|
|
weight_max_length_mm=weight_max_length_mm,
|
|
weight_min_length_width_ratio=weight_min_length_width_ratio,
|
|
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
|
|
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
|
|
weight_average_all_after_filter=weight_average_all_after_filter,
|
|
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
|
|
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
|
|
)
|
|
result = _merge_weight_prediction_json(
|
|
svo_path=svo_path,
|
|
svo_name=svo_name,
|
|
out_dir=out_dir,
|
|
cloud_dir=cloud_dir,
|
|
dgcnn_json_path=dgcnn_path,
|
|
dgcnn_data=dgcnn_data,
|
|
)
|
|
(out_dir / "weight_prediction.json").write_text(
|
|
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
return result
|
|
|
|
|
|
def run_weight_prediction_combined_svos(
|
|
svo_paths: List[Path],
|
|
output_base: Path,
|
|
weight_checkpoint: Path,
|
|
weight_device: torch.device,
|
|
num_points_model: int,
|
|
weight_top_k: int,
|
|
weight_top_by_length: bool,
|
|
weight_length_switch_mm: float,
|
|
weight_topk_length: Optional[int],
|
|
weight_remove_outliers: bool,
|
|
weight_outlier_method: str,
|
|
weight_xyz_scale: float,
|
|
weight_labels_json: Optional[str],
|
|
weight_max_length_mm: float = 400.0,
|
|
weight_min_length_width_ratio: float = 1.5,
|
|
weight_length_quality_cv_threshold_pct: float = 15.0,
|
|
weight_length_quality_max_span_mm: float = 130.0,
|
|
weight_average_all_after_filter: bool = False,
|
|
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
|
|
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
|
|
) -> Dict[str, Any]:
|
|
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
|
|
output_base = output_base.expanduser().resolve()
|
|
svo_paths = [p.expanduser().resolve() for p in svo_paths]
|
|
all_plys: List[Path] = []
|
|
for svo in svo_paths:
|
|
all_plys.extend(_collect_existing_clouds(output_base / svo.stem / "cloud"))
|
|
if not all_plys:
|
|
raise RuntimeError(
|
|
f"No .ply in any cloud folder under {output_base} for SVOs: {[p.stem for p in svo_paths]}"
|
|
)
|
|
out_dir = output_base
|
|
list_path = out_dir / "_combined_ply_list.txt"
|
|
list_path.write_text("\n".join(str(p.resolve()) for p in sorted(all_plys)) + "\n", encoding="utf-8")
|
|
|
|
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
|
|
cloud_dir=None,
|
|
ply_list_file=list_path,
|
|
checkpoint=weight_checkpoint,
|
|
device=weight_device,
|
|
out_dir=out_dir,
|
|
num_points=num_points_model,
|
|
xyz_scale=weight_xyz_scale,
|
|
top_k=weight_top_k,
|
|
top_by_length=weight_top_by_length,
|
|
length_switch_mm=weight_length_switch_mm,
|
|
topk_length=weight_topk_length,
|
|
remove_outliers=weight_remove_outliers,
|
|
outlier_method=weight_outlier_method,
|
|
labels_json=weight_labels_json,
|
|
weight_max_length_mm=weight_max_length_mm,
|
|
weight_min_length_width_ratio=weight_min_length_width_ratio,
|
|
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
|
|
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
|
|
weight_average_all_after_filter=weight_average_all_after_filter,
|
|
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
|
|
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
|
|
)
|
|
result = _merge_weight_prediction_json_combined(
|
|
svo_paths=svo_paths,
|
|
output_base=output_base,
|
|
ply_list_path=list_path,
|
|
dgcnn_json_path=dgcnn_path,
|
|
dgcnn_data=dgcnn_data,
|
|
)
|
|
(out_dir / "weight_prediction.json").write_text(
|
|
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
return result
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="SVO2 → fish_video_weight_evaluation.py → test_dgcnn_weight_estimator.py"
|
|
)
|
|
src = parser.add_mutually_exclusive_group(required=True)
|
|
src.add_argument("--svo", type=str, default="", help="Path to a single .svo2 file")
|
|
src.add_argument("--batch-svo-folder", type=str, default="", help="Folder containing *.svo2 files")
|
|
|
|
parser.add_argument(
|
|
"--save-output",
|
|
type=str,
|
|
default="output_weight_estimator",
|
|
help="Base output directory (same as fish_video --save-output parent)",
|
|
)
|
|
|
|
# fish_video_weight_evaluation (aligned with run_fish_evaluation_simple.sh)
|
|
parser.add_argument(
|
|
"--yolo-model",
|
|
type=str,
|
|
default=str(REPO_ROOT / "runs/train/fish_detection_20251127_104658/weights/best.pt"),
|
|
)
|
|
parser.add_argument(
|
|
"--conf",
|
|
type=float,
|
|
default=0.8,
|
|
help="YOLO confidence (fish9 script uses 0.8; simple.sh uses 0.5)",
|
|
)
|
|
parser.add_argument("--imgsz", type=int, default=640)
|
|
parser.add_argument(
|
|
"--sam-device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="SAM device for fish_video; also used as --device for test_dgcnn (cuda|cpu)",
|
|
)
|
|
parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames")
|
|
parser.add_argument(
|
|
"--frame-stride",
|
|
type=int,
|
|
default=1,
|
|
metavar="N",
|
|
help="Passed to fish_video: only YOLO/SAM/PLY every N-th frame (default: 1)",
|
|
)
|
|
parser.add_argument("--filter-pointcloud", action="store_true")
|
|
parser.add_argument("--use-density-filter", action="store_true")
|
|
parser.add_argument("--use-clustering-filter", action="store_true")
|
|
parser.add_argument("--use-pointcloud-classifier", action="store_true")
|
|
parser.add_argument("--pointcloud-classifier", type=str, default=None)
|
|
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7)
|
|
parser.add_argument("--use-flatness-filter", action="store_true")
|
|
parser.add_argument("--flatness-threshold", type=float, default=70.0)
|
|
|
|
# test_dgcnn_weight_estimator
|
|
parser.add_argument(
|
|
"--weight-checkpoint",
|
|
type=str,
|
|
default=str(REPO_ROOT / "weight_estimator/runs/dgcnn_20260312_171043/best.pt"),
|
|
)
|
|
parser.add_argument("--num-points-model", type=int, default=768)
|
|
parser.add_argument("--weight-top-k", type=int, default=5)
|
|
parser.add_argument(
|
|
"--weight-top-by-length",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
help="Length-first top-K (default: on). Use --no-weight-top-by-length for weight-only order.",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-length-switch-mm",
|
|
type=float,
|
|
default=319.0,
|
|
help="If mean top-K length (mm) exceeds this, use top-K by predicted weight (default: 320).",
|
|
)
|
|
parser.add_argument("--weight-topk-length", type=int, default=None)
|
|
parser.add_argument("--weight-remove-outliers", action="store_true")
|
|
parser.add_argument(
|
|
"--weight-outlier-method",
|
|
type=str,
|
|
default="iqr",
|
|
choices=["iqr", "zscore"],
|
|
)
|
|
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
|
|
parser.add_argument("--weight-labels-json", type=str, default=None)
|
|
parser.add_argument(
|
|
"--weight-max-length-mm",
|
|
type=float,
|
|
default=400.0,
|
|
help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-min-length-width-ratio",
|
|
type=float,
|
|
default=1.5,
|
|
help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-average-all-after-filter",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=False,
|
|
help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-average-all-fallback-max-if-mean-over-g",
|
|
type=float,
|
|
default=400.0,
|
|
help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-mean-pool-fallback-max-if-over-g",
|
|
type=float,
|
|
default=440.0,
|
|
help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-length-quality-cv-threshold-pct",
|
|
type=float,
|
|
default=15.0,
|
|
help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-length-quality-max-span-mm",
|
|
type=float,
|
|
default=130.0,
|
|
help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--reuse-existing-clouds",
|
|
action="store_true",
|
|
default=True,
|
|
help="Skip fish_video if <save-output>/<svo_stem>/cloud/*.ply already exists",
|
|
)
|
|
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
|
|
parser.add_argument(
|
|
"--show-large-labels-at-top-right",
|
|
action="store_true",
|
|
help="Show large weight/length labels (10x font) at top right corner for real/camera generated videos.",
|
|
)
|
|
parser.add_argument(
|
|
"--summary-star",
|
|
action="store_true",
|
|
default=False,
|
|
help="Pass to generate_video_with_labels: whether the summary line should draw *.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
if args.frame_stride < 1:
|
|
raise SystemExit("--frame-stride must be >= 1")
|
|
|
|
output_base = Path(args.save_output).expanduser().resolve()
|
|
output_base.mkdir(parents=True, exist_ok=True)
|
|
|
|
weight_ckpt = Path(args.weight_checkpoint).expanduser().resolve()
|
|
if not weight_ckpt.is_file():
|
|
raise SystemExit(f"weight checkpoint not found: {weight_ckpt}")
|
|
|
|
if args.svo:
|
|
svo_paths = [Path(args.svo).expanduser().resolve()]
|
|
batch_folder: Optional[Path] = None
|
|
else:
|
|
batch_folder = Path(args.batch_svo_folder).expanduser().resolve()
|
|
svo_paths = sorted(batch_folder.glob("*.svo2"))
|
|
|
|
need_fish = True
|
|
if svo_paths and args.reuse_existing_clouds:
|
|
need_fish = not all(has_cached_clouds(output_base, p) for p in svo_paths)
|
|
|
|
if need_fish:
|
|
if args.svo:
|
|
_run_fish_video_evaluation_subprocess(args, batch_folder=None)
|
|
else:
|
|
if not batch_folder or not batch_folder.is_dir():
|
|
raise SystemExit(f"batch folder not found: {args.batch_svo_folder}")
|
|
if not svo_paths:
|
|
raise SystemExit(f"No .svo2 files in: {batch_folder}")
|
|
_run_fish_video_evaluation_subprocess(args, batch_folder=batch_folder)
|
|
else:
|
|
print("All targets already have cloud/*.ply — skipping fish_video_weight_evaluation.py.")
|
|
|
|
device = torch.device(args.sam_device)
|
|
|
|
results: List[Dict[str, Any]] = []
|
|
if len(svo_paths) > 1:
|
|
names = ", ".join(p.name for p in svo_paths)
|
|
print(f"\n=== Weight prediction: combined ({len(svo_paths)} SVOs) → top-{args.weight_top_k} over all PLYs ===")
|
|
print(f" {names}")
|
|
try:
|
|
results.append(
|
|
run_weight_prediction_combined_svos(
|
|
svo_paths=svo_paths,
|
|
output_base=output_base,
|
|
weight_checkpoint=weight_ckpt,
|
|
weight_device=device,
|
|
num_points_model=args.num_points_model,
|
|
weight_top_k=args.weight_top_k,
|
|
weight_top_by_length=args.weight_top_by_length,
|
|
weight_length_switch_mm=args.weight_length_switch_mm,
|
|
weight_topk_length=args.weight_topk_length,
|
|
weight_remove_outliers=args.weight_remove_outliers,
|
|
weight_outlier_method=args.weight_outlier_method,
|
|
weight_xyz_scale=args.weight_xyz_scale,
|
|
weight_labels_json=args.weight_labels_json,
|
|
weight_max_length_mm=args.weight_max_length_mm,
|
|
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
|
|
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
|
|
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
|
|
weight_average_all_after_filter=args.weight_average_all_after_filter,
|
|
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
|
|
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
results.append(
|
|
{"combined": True, "svo_paths": [str(p) for p in svo_paths], "error": str(e)}
|
|
)
|
|
else:
|
|
for svo in svo_paths:
|
|
label = svo.name
|
|
print(f"\n=== Weight prediction: {label} ===")
|
|
try:
|
|
results.append(
|
|
run_weight_prediction_for_svo(
|
|
svo_path=svo,
|
|
output_base=output_base,
|
|
weight_checkpoint=weight_ckpt,
|
|
weight_device=device,
|
|
num_points_model=args.num_points_model,
|
|
weight_top_k=args.weight_top_k,
|
|
weight_top_by_length=args.weight_top_by_length,
|
|
weight_length_switch_mm=args.weight_length_switch_mm,
|
|
weight_topk_length=args.weight_topk_length,
|
|
weight_remove_outliers=args.weight_remove_outliers,
|
|
weight_outlier_method=args.weight_outlier_method,
|
|
weight_xyz_scale=args.weight_xyz_scale,
|
|
weight_labels_json=args.weight_labels_json,
|
|
weight_max_length_mm=args.weight_max_length_mm,
|
|
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
|
|
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
|
|
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
|
|
weight_average_all_after_filter=args.weight_average_all_after_filter,
|
|
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
|
|
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
results.append({"svo": str(svo), "error": str(e)})
|
|
|
|
# Generate preview video with weight/length labels (after DGCNN, for all runs)
|
|
print("\n=== Generating preview video with weight/length labels ===")
|
|
combined_json = output_base / "weight_prediction.json"
|
|
for svo in svo_paths:
|
|
svo_name = svo.stem
|
|
svo_output_dir = output_base / svo_name
|
|
|
|
weight_json = svo_output_dir / "weight_prediction.json"
|
|
if not weight_json.exists():
|
|
weight_json = svo_output_dir / "weight_estimation" / "weight_estimation_results.json"
|
|
if not weight_json.exists() and combined_json.exists():
|
|
weight_json = combined_json
|
|
|
|
if weight_json.exists():
|
|
try:
|
|
_run_generate_video_with_labels_subprocess(
|
|
args=args,
|
|
svo_path=svo,
|
|
output_dir=svo_output_dir,
|
|
weight_json=weight_json,
|
|
)
|
|
except Exception as e:
|
|
print(f" Warning: Failed to generate video for {svo.name}: {e}")
|
|
else:
|
|
print(f" Warning: No weight JSON found for {svo.name}, skipping video generation")
|
|
|
|
summary_path = output_base / "weight_predictions_summary.json"
|
|
summary_path.write_text(
|
|
json.dumps(_sanitize_for_json(results), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
print(f"\nSaved summary: {summary_path}")
|
|
|
|
if len(results) == 1 and "error" not in results[0]:
|
|
r0 = results[0]
|
|
avg_g = r0.get("avg_predicted_weight_g")
|
|
pred_g = r0.get("pred_weight_g")
|
|
out_g = pred_g if pred_g is not None else avg_g
|
|
if out_g is not None:
|
|
label = "pred_weight" if pred_g is not None else "avg_predicted_weight"
|
|
print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g")
|
|
else:
|
|
print("Final predicted weight: N/A (skipped or no valid point clouds)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|