Files
FishServer/FishMeasure/predict_weigth_from_svo2.py

802 lines
32 KiB
Python

#!/usr/bin/env python3
"""
SVO2 → point clouds via **fish_video_weight_evaluation.py** (subprocess), then weight via
**weight_estimator/test_dgcnn_weight_estimator.py** (subprocess).
Default ``--conf`` is 0.8 (same as ``run_predict_from_svo2_fish9.sh``); ``run_fish_evaluation_simple.sh``
uses 0.5 — override with ``--conf``. FishServer passes ``--conf`` from ``MEASURE_YOLO_CONF`` (default 0.8).
Outputs:
``<save-output>/<svo_stem>/cloud/*.ply`` — from fish_video_weight_evaluation
``<save-output>/<svo_stem>/weight_prediction.json`` — merged SVO info + DGCNN JSON (single-SVO mode)
``<save-output>/weight_prediction.json`` — when ``--batch-svo-folder`` has **multiple** ``*.svo2``:
one DGCNN run over **all** PLYs from every ``<svo_stem>/cloud/`` (top-K applies to the union)
``<save-output>/<svo_stem>/dgcnn_test_output_top<K>.json`` — raw test_dgcnn output (per-SVO or combined under ``<save-output>/``)
With ``--reuse-existing-clouds`` (default), skips the fish subprocess when
``cloud/*.ply`` already exists; only runs test_dgcnn.
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import math
from pathlib import Path
if not hasattr(argparse, "BooleanOptionalAction"):
class _BooleanOptionalAction(argparse.Action):
def __init__(self, option_strings, dest, default=None, type=None,
choices=None, required=False, help=None, metavar=None):
_option_strings = []
for opt in option_strings:
_option_strings.append(opt)
if opt.startswith("--"):
_option_strings.append("--no-" + opt[2:])
super().__init__(option_strings=_option_strings, dest=dest, nargs=0,
default=default, type=type, choices=choices,
required=required, help=help, metavar=metavar)
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, not option_string.startswith("--no-"))
argparse.BooleanOptionalAction = _BooleanOptionalAction
from typing import Any, Dict, List, Optional, Tuple
import torch
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def _fish_video_script() -> Path:
return REPO_ROOT / "fish_video_weight_evaluation.py"
def _dgcnn_test_script() -> Path:
return REPO_ROOT / "weight_estimator" / "test_dgcnn_weight_estimator.py"
def _torch_device_to_test_arg(device: torch.device) -> str:
return "cpu" if device.type == "cpu" else "cuda"
def _collect_existing_clouds(cloud_dir: Path) -> List[Path]:
return sorted(cloud_dir.glob("*.ply"))
def has_cached_clouds(output_base: Path, svo_path: Path) -> bool:
svo_path = svo_path.expanduser().resolve()
cloud_dir = output_base / svo_path.stem / "cloud"
return len(_collect_existing_clouds(cloud_dir)) > 0
def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_folder: Optional[Path]) -> None:
fvs = _fish_video_script()
if not fvs.is_file():
raise FileNotFoundError(f"Missing: {fvs}")
out_parent = Path(args.save_output).expanduser().resolve()
cmd: List[str] = [
sys.executable,
str(fvs),
"--save-output",
str(out_parent),
"--yolo-model",
args.yolo_model,
"--conf",
str(args.conf),
"--imgsz",
str(args.imgsz),
"--sam-device",
args.sam_device,
"--max-frames",
str(args.max_frames),
"--frame-stride",
str(args.frame_stride),
]
if batch_folder is not None:
cmd.extend(["--batch-svo-folder", str(batch_folder.resolve())])
else:
cmd.extend(["--svo", str(Path(args.svo).expanduser().resolve())])
if args.filter_pointcloud:
cmd.append("--filter-pointcloud")
if args.use_density_filter:
cmd.append("--use-density-filter")
if args.use_clustering_filter:
cmd.append("--use-clustering-filter")
if args.use_pointcloud_classifier:
cmd.append("--use-pointcloud-classifier")
if args.pointcloud_classifier:
cmd.extend(["--pointcloud-classifier", args.pointcloud_classifier])
cmd.extend(
["--pointcloud-classifier-threshold", str(args.pointcloud_classifier_threshold)]
)
if args.use_flatness_filter:
cmd.append("--use-flatness-filter")
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
if getattr(args, "show_large_labels_at_top_right", False):
cmd.append("--show-large-labels-at-top-right")
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
raise RuntimeError(f"fish_video_weight_evaluation.py exited with code {proc.returncode}")
def _run_generate_video_with_labels_subprocess(
*,
args: argparse.Namespace,
svo_path: Path,
output_dir: Path,
weight_json: Path,
) -> None:
"""Run generate_video_with_labels.py to create preview video with weight/length labels.
Called after DGCNN completes to (re)generate the preview video with
actual weight/length overlaid on detection boxes.
"""
script = REPO_ROOT / "generate_video_with_labels.py"
if not script.is_file():
print(f"Warning: generate_video_with_labels.py not found at {script}, skipping video generation")
return
cmd = [
sys.executable,
str(script),
"--svo", str(svo_path.expanduser().resolve()),
"--save-output", str(output_dir.expanduser().resolve()),
"--yolo-model", args.yolo_model,
"--conf", str(args.conf),
"--imgsz", str(args.imgsz),
"--frame-stride", str(args.frame_stride),
"--sam-device", str(args.sam_device),
"--weight-json", str(weight_json.expanduser().resolve()),
]
if getattr(args, "show_large_labels_at_top_right", False):
cmd.append("--show-large-labels-at-top-right")
if getattr(args, "summary_star", False):
cmd.append("--summary-star")
print(f"Invoking generate_video_with_labels.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
print(f"Warning: generate_video_with_labels.py exited with code {proc.returncode}")
def _run_test_dgcnn_weight_estimator_subprocess(
*,
cloud_dir: Optional[Path] = None,
ply_list_file: Optional[Path] = None,
checkpoint: Path,
device: torch.device,
out_dir: Path,
num_points: int,
xyz_scale: float,
top_k: int,
top_by_length: bool,
length_switch_mm: float,
topk_length: Optional[int],
remove_outliers: bool,
outlier_method: str,
labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Tuple[Path, Dict[str, Any]]:
if (cloud_dir is None) == (ply_list_file is None):
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
script = _dgcnn_test_script()
if not script.is_file():
raise FileNotFoundError(f"Missing: {script}")
json_base = out_dir / "dgcnn_test_output.json"
labels_path = labels_json if labels_json else str(out_dir / "__no_labels__.json")
cmd: List[str] = [
sys.executable,
str(script),
"--checkpoint",
str(checkpoint.expanduser().resolve()),
"--device",
_torch_device_to_test_arg(device),
"--num-points",
str(num_points),
"--xyz-scale",
str(xyz_scale),
"--top-k",
str(top_k),
"--output-json",
str(json_base),
"--labels-json",
labels_path,
]
if ply_list_file is not None:
cmd.extend(["--ply-list-file", str(ply_list_file.resolve())])
else:
cmd.extend(["--ply-folder", str(cloud_dir.resolve())])
if top_by_length:
cmd.append("--top-by-length")
else:
cmd.append("--no-top-by-length")
cmd.extend(["--length-switch-mm", str(length_switch_mm)])
if topk_length is not None:
cmd.extend(["--topk-length", str(topk_length)])
if remove_outliers:
cmd.append("--remove-outliers")
cmd.extend(["--outlier-method", outlier_method])
cmd.extend(["--max-length-mm", str(weight_max_length_mm)])
cmd.extend(["--min-length-width-ratio", str(weight_min_length_width_ratio)])
cmd.extend(["--length-quality-cv-threshold-pct", str(weight_length_quality_cv_threshold_pct)])
cmd.extend(["--length-quality-max-span-mm", str(weight_length_quality_max_span_mm)])
if weight_average_all_after_filter:
cmd.append("--average-all-after-filter")
fb_thr = weight_average_all_fallback_max_if_mean_over_g
if fb_thr is not None and float(fb_thr) > 0:
cmd.extend(
["--average-all-fallback-max-if-mean-over-g", str(float(fb_thr))]
)
mp_fb = weight_mean_pool_fallback_max_if_over_g
if mp_fb is not None and float(mp_fb) > 0:
cmd.extend(
["--mean-pool-fallback-max-if-over-g", str(float(mp_fb))]
)
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
raise RuntimeError(f"test_dgcnn_weight_estimator.py exited with code {proc.returncode}")
suffix = f"_top{top_k}_by_length" if top_by_length else f"_top{top_k}"
written = json_base.parent / f"{json_base.stem}{suffix}{json_base.suffix}"
if not written.is_file():
raise FileNotFoundError(f"Expected DGCNN output JSON at {written}")
data = json.loads(written.read_text(encoding="utf-8"))
return written, data
def _merge_weight_prediction_json(
*,
svo_path: Path,
svo_name: str,
out_dir: Path,
cloud_dir: Path,
dgcnn_json_path: Path,
dgcnn_data: Dict[str, Any],
) -> Dict[str, Any]:
summary = dgcnn_data.get("summary") or {}
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
pred_kg = summary.get("pred_weight_kg")
pred_g = summary.get("pred_weight_g")
pred_rule = summary.get("pred_weight_rule")
return {
"svo": str(svo_path),
"svo_name": svo_name,
"output_dir": str(out_dir),
"cloud_dir": str(cloud_dir),
"num_ply_on_disk": len(_collect_existing_clouds(cloud_dir)),
"dgcnn_output_json": str(dgcnn_json_path),
"skipped_weight": skipped,
"skip_reason": summary.get("skip_reason"),
"num_clouds_used": summary.get("num_files_used_for_avg"),
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"pred_weight_g": None if skipped else pred_g,
"pred_weight_kg": None if skipped else pred_kg,
"pred_weight_rule": None if skipped else pred_rule,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
"per_cloud": dgcnn_data.get("per_file") or [],
"comparison": dgcnn_data.get("comparison"),
}
def _merge_weight_prediction_json_combined(
*,
svo_paths: List[Path],
output_base: Path,
ply_list_path: Path,
dgcnn_json_path: Path,
dgcnn_data: Dict[str, Any],
) -> Dict[str, Any]:
summary = dgcnn_data.get("summary") or {}
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
pred_kg = summary.get("pred_weight_kg")
pred_g = summary.get("pred_weight_g")
pred_rule = summary.get("pred_weight_rule")
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
return {
"combined": True,
"svo_paths": [str(p) for p in svo_paths],
"svo_names": [p.stem for p in svo_paths],
"output_dir": str(output_base),
"cloud_dirs": [str(d) for d in cloud_dirs],
"ply_list_file": str(ply_list_path),
"num_ply_on_disk": n_ply,
"dgcnn_output_json": str(dgcnn_json_path),
"skipped_weight": skipped,
"skip_reason": summary.get("skip_reason"),
"num_clouds_used": summary.get("num_files_used_for_avg"),
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"pred_weight_g": None if skipped else pred_g,
"pred_weight_kg": None if skipped else pred_kg,
"pred_weight_rule": None if skipped else pred_rule,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
"per_cloud": dgcnn_data.get("per_file") or [],
"comparison": dgcnn_data.get("comparison"),
}
def _sanitize_for_json(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_sanitize_for_json(v) for v in obj]
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return None
return obj
def run_weight_prediction_for_svo(
svo_path: Path,
output_base: Path,
weight_checkpoint: Path,
weight_device: torch.device,
num_points_model: int,
weight_top_k: int,
weight_top_by_length: bool,
weight_length_switch_mm: float,
weight_topk_length: Optional[int],
weight_remove_outliers: bool,
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Dict[str, Any]:
svo_path = svo_path.expanduser().resolve()
if not svo_path.exists():
raise FileNotFoundError(f"svo not found: {svo_path}")
svo_name = svo_path.stem
out_dir = output_base / svo_name
cloud_dir = out_dir / "cloud"
plys = _collect_existing_clouds(cloud_dir)
if not plys:
raise RuntimeError(
f"No .ply files in {cloud_dir}. Run without --reuse-existing-clouds or run "
f"fish_video_weight_evaluation.py first."
)
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=cloud_dir,
ply_list_file=None,
checkpoint=weight_checkpoint,
device=weight_device,
out_dir=out_dir,
num_points=num_points_model,
xyz_scale=weight_xyz_scale,
top_k=weight_top_k,
top_by_length=weight_top_by_length,
length_switch_mm=weight_length_switch_mm,
topk_length=weight_topk_length,
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
weight_max_length_mm=weight_max_length_mm,
weight_min_length_width_ratio=weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
weight_average_all_after_filter=weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
)
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=dgcnn_path,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
def run_weight_prediction_combined_svos(
svo_paths: List[Path],
output_base: Path,
weight_checkpoint: Path,
weight_device: torch.device,
num_points_model: int,
weight_top_k: int,
weight_top_by_length: bool,
weight_length_switch_mm: float,
weight_topk_length: Optional[int],
weight_remove_outliers: bool,
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
weight_max_length_mm: float = 400.0,
weight_min_length_width_ratio: float = 1.5,
weight_length_quality_cv_threshold_pct: float = 15.0,
weight_length_quality_max_span_mm: float = 130.0,
weight_average_all_after_filter: bool = False,
weight_average_all_fallback_max_if_mean_over_g: float = 400.0,
weight_mean_pool_fallback_max_if_over_g: float = 440.0,
) -> Dict[str, Any]:
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
output_base = output_base.expanduser().resolve()
svo_paths = [p.expanduser().resolve() for p in svo_paths]
all_plys: List[Path] = []
for svo in svo_paths:
all_plys.extend(_collect_existing_clouds(output_base / svo.stem / "cloud"))
if not all_plys:
raise RuntimeError(
f"No .ply in any cloud folder under {output_base} for SVOs: {[p.stem for p in svo_paths]}"
)
out_dir = output_base
list_path = out_dir / "_combined_ply_list.txt"
list_path.write_text("\n".join(str(p.resolve()) for p in sorted(all_plys)) + "\n", encoding="utf-8")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=None,
ply_list_file=list_path,
checkpoint=weight_checkpoint,
device=weight_device,
out_dir=out_dir,
num_points=num_points_model,
xyz_scale=weight_xyz_scale,
top_k=weight_top_k,
top_by_length=weight_top_by_length,
length_switch_mm=weight_length_switch_mm,
topk_length=weight_topk_length,
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
weight_max_length_mm=weight_max_length_mm,
weight_min_length_width_ratio=weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=weight_length_quality_max_span_mm,
weight_average_all_after_filter=weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=weight_mean_pool_fallback_max_if_over_g,
)
result = _merge_weight_prediction_json_combined(
svo_paths=svo_paths,
output_base=output_base,
ply_list_path=list_path,
dgcnn_json_path=dgcnn_path,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="SVO2 → fish_video_weight_evaluation.py → test_dgcnn_weight_estimator.py"
)
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument("--svo", type=str, default="", help="Path to a single .svo2 file")
src.add_argument("--batch-svo-folder", type=str, default="", help="Folder containing *.svo2 files")
parser.add_argument(
"--save-output",
type=str,
default="output_weight_estimator",
help="Base output directory (same as fish_video --save-output parent)",
)
# fish_video_weight_evaluation (aligned with run_fish_evaluation_simple.sh)
parser.add_argument(
"--yolo-model",
type=str,
default=str(REPO_ROOT / "runs/train/fish_detection_20251127_104658/weights/best.pt"),
)
parser.add_argument(
"--conf",
type=float,
default=0.8,
help="YOLO confidence (fish9 script uses 0.8; simple.sh uses 0.5)",
)
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument(
"--sam-device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="SAM device for fish_video; also used as --device for test_dgcnn (cuda|cpu)",
)
parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames")
parser.add_argument(
"--frame-stride",
type=int,
default=1,
metavar="N",
help="Passed to fish_video: only YOLO/SAM/PLY every N-th frame (default: 1)",
)
parser.add_argument("--filter-pointcloud", action="store_true")
parser.add_argument("--use-density-filter", action="store_true")
parser.add_argument("--use-clustering-filter", action="store_true")
parser.add_argument("--use-pointcloud-classifier", action="store_true")
parser.add_argument("--pointcloud-classifier", type=str, default=None)
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7)
parser.add_argument("--use-flatness-filter", action="store_true")
parser.add_argument("--flatness-threshold", type=float, default=70.0)
# test_dgcnn_weight_estimator
parser.add_argument(
"--weight-checkpoint",
type=str,
default=str(REPO_ROOT / "weight_estimator/runs/dgcnn_20260312_171043/best.pt"),
)
parser.add_argument("--num-points-model", type=int, default=768)
parser.add_argument("--weight-top-k", type=int, default=5)
parser.add_argument(
"--weight-top-by-length",
action=argparse.BooleanOptionalAction,
default=True,
help="Length-first top-K (default: on). Use --no-weight-top-by-length for weight-only order.",
)
parser.add_argument(
"--weight-length-switch-mm",
type=float,
default=319.0,
help="If mean top-K length (mm) exceeds this, use top-K by predicted weight (default: 320).",
)
parser.add_argument("--weight-topk-length", type=int, default=None)
parser.add_argument("--weight-remove-outliers", action="store_true")
parser.add_argument(
"--weight-outlier-method",
type=str,
default="iqr",
choices=["iqr", "zscore"],
)
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
parser.add_argument("--weight-labels-json", type=str, default=None)
parser.add_argument(
"--weight-max-length-mm",
type=float,
default=400.0,
help="Passed to test_dgcnn --max-length-mm: exclude length > this from aggregation (0 = off).",
)
parser.add_argument(
"--weight-min-length-width-ratio",
type=float,
default=1.5,
help="Passed to test_dgcnn --min-length-width-ratio; 0 disables.",
)
parser.add_argument(
"--weight-average-all-after-filter",
action=argparse.BooleanOptionalAction,
default=False,
help="Passed to test_dgcnn --average-all-after-filter: mean all PLYs after filters (no top-K).",
)
parser.add_argument(
"--weight-average-all-fallback-max-if-mean-over-g",
type=float,
default=400.0,
help="Passed to test_dgcnn --average-all-fallback-max-if-mean-over-g; 0 disables (default: 400).",
)
parser.add_argument(
"--weight-mean-pool-fallback-max-if-over-g",
type=float,
default=440.0,
help="Passed to test_dgcnn --mean-pool-fallback-max-if-over-g; 0 disables (default: 440).",
)
parser.add_argument(
"--weight-length-quality-cv-threshold-pct",
type=float,
default=15.0,
help="Passed to test_dgcnn --length-quality-cv-threshold-pct (length variance hint).",
)
parser.add_argument(
"--weight-length-quality-max-span-mm",
type=float,
default=130.0,
help="Passed to test_dgcnn --length-quality-max-span-mm; 0 disables span rule.",
)
parser.add_argument(
"--reuse-existing-clouds",
action="store_true",
default=True,
help="Skip fish_video if <save-output>/<svo_stem>/cloud/*.ply already exists",
)
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
parser.add_argument(
"--show-large-labels-at-top-right",
action="store_true",
help="Show large weight/length labels (10x font) at top right corner for real/camera generated videos.",
)
parser.add_argument(
"--summary-star",
action=argparse.BooleanOptionalAction,
default=False,
help="Pass to generate_video_with_labels: whether the summary line should draw *.",
)
args = parser.parse_args()
if args.frame_stride < 1:
raise SystemExit("--frame-stride must be >= 1")
output_base = Path(args.save_output).expanduser().resolve()
output_base.mkdir(parents=True, exist_ok=True)
weight_ckpt = Path(args.weight_checkpoint).expanduser().resolve()
if not weight_ckpt.is_file():
raise SystemExit(f"weight checkpoint not found: {weight_ckpt}")
if args.svo:
svo_paths = [Path(args.svo).expanduser().resolve()]
batch_folder: Optional[Path] = None
else:
batch_folder = Path(args.batch_svo_folder).expanduser().resolve()
svo_paths = sorted(batch_folder.glob("*.svo2"))
need_fish = True
if svo_paths and args.reuse_existing_clouds:
need_fish = not all(has_cached_clouds(output_base, p) for p in svo_paths)
if need_fish:
if args.svo:
_run_fish_video_evaluation_subprocess(args, batch_folder=None)
else:
if not batch_folder or not batch_folder.is_dir():
raise SystemExit(f"batch folder not found: {args.batch_svo_folder}")
if not svo_paths:
raise SystemExit(f"No .svo2 files in: {batch_folder}")
_run_fish_video_evaluation_subprocess(args, batch_folder=batch_folder)
else:
print("All targets already have cloud/*.ply — skipping fish_video_weight_evaluation.py.")
device = torch.device(args.sam_device)
results: List[Dict[str, Any]] = []
if len(svo_paths) > 1:
names = ", ".join(p.name for p in svo_paths)
print(f"\n=== Weight prediction: combined ({len(svo_paths)} SVOs) → top-{args.weight_top_k} over all PLYs ===")
print(f" {names}")
try:
results.append(
run_weight_prediction_combined_svos(
svo_paths=svo_paths,
output_base=output_base,
weight_checkpoint=weight_ckpt,
weight_device=device,
num_points_model=args.num_points_model,
weight_top_k=args.weight_top_k,
weight_top_by_length=args.weight_top_by_length,
weight_length_switch_mm=args.weight_length_switch_mm,
weight_topk_length=args.weight_topk_length,
weight_remove_outliers=args.weight_remove_outliers,
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
weight_max_length_mm=args.weight_max_length_mm,
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
weight_average_all_after_filter=args.weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
)
)
except Exception as e:
results.append(
{"combined": True, "svo_paths": [str(p) for p in svo_paths], "error": str(e)}
)
else:
for svo in svo_paths:
label = svo.name
print(f"\n=== Weight prediction: {label} ===")
try:
results.append(
run_weight_prediction_for_svo(
svo_path=svo,
output_base=output_base,
weight_checkpoint=weight_ckpt,
weight_device=device,
num_points_model=args.num_points_model,
weight_top_k=args.weight_top_k,
weight_top_by_length=args.weight_top_by_length,
weight_length_switch_mm=args.weight_length_switch_mm,
weight_topk_length=args.weight_topk_length,
weight_remove_outliers=args.weight_remove_outliers,
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
weight_max_length_mm=args.weight_max_length_mm,
weight_min_length_width_ratio=args.weight_min_length_width_ratio,
weight_length_quality_cv_threshold_pct=args.weight_length_quality_cv_threshold_pct,
weight_length_quality_max_span_mm=args.weight_length_quality_max_span_mm,
weight_average_all_after_filter=args.weight_average_all_after_filter,
weight_average_all_fallback_max_if_mean_over_g=args.weight_average_all_fallback_max_if_mean_over_g,
weight_mean_pool_fallback_max_if_over_g=args.weight_mean_pool_fallback_max_if_over_g,
)
)
except Exception as e:
results.append({"svo": str(svo), "error": str(e)})
# Generate preview video with weight/length labels (after DGCNN, for all runs)
print("\n=== Generating preview video with weight/length labels ===")
combined_json = output_base / "weight_prediction.json"
for svo in svo_paths:
svo_name = svo.stem
svo_output_dir = output_base / svo_name
weight_json = svo_output_dir / "weight_prediction.json"
if not weight_json.exists():
weight_json = svo_output_dir / "weight_estimation" / "weight_estimation_results.json"
if not weight_json.exists() and combined_json.exists():
weight_json = combined_json
if weight_json.exists():
try:
_run_generate_video_with_labels_subprocess(
args=args,
svo_path=svo,
output_dir=svo_output_dir,
weight_json=weight_json,
)
except Exception as e:
print(f" Warning: Failed to generate video for {svo.name}: {e}")
else:
print(f" Warning: No weight JSON found for {svo.name}, skipping video generation")
summary_path = output_base / "weight_predictions_summary.json"
summary_path.write_text(
json.dumps(_sanitize_for_json(results), indent=2, ensure_ascii=False),
encoding="utf-8",
)
print(f"\nSaved summary: {summary_path}")
if len(results) == 1 and "error" not in results[0]:
r0 = results[0]
avg_g = r0.get("avg_predicted_weight_g")
pred_g = r0.get("pred_weight_g")
out_g = pred_g if pred_g is not None else avg_g
if out_g is not None:
label = "pred_weight" if pred_g is not None else "avg_predicted_weight"
print(f"Final predicted weight (test_dgcnn, {label}): {float(out_g):.2f} g")
else:
print("Final predicted weight: N/A (skipped or no valid point clouds)")
if __name__ == "__main__":
main()