- 新增 SQLite:measure/health 快照、delivery_cursor 单消费者 pop;clear/start_fresh 可清空库 - biomass GET 仅返回约定 data 字段,X-Fish-Biomass-New 表示是否有新快照;poller 读响应头 - loguru 桥接 uvicorn,子进程 stdout 流式输出;format_json_pretty 与算法摘要日志 - measure/action watch 无新任务时限流 WARNING;watch_idle 共用逻辑 - 依赖 loguru;新增 db、logging_config、subprocess_run、watch_idle、启动脚本 FishMeasure: 更新 fish_video_weight_evaluation 与 predict_weigth_from_svo2;移除未用 refbox/segmentation 脚本 Made-with: Cursor
632 lines
24 KiB
Python
Executable File
632 lines
24 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
SVO2 → point clouds via **fish_video_weight_evaluation.py** (subprocess), then weight via
|
|
**weight_estimator/test_dgcnn_weight_estimator.py** (subprocess).
|
|
|
|
Defaults for the fish step align with ``run_fish_evaluation_simple.sh`` (conf 0.5, filters,
|
|
classifier, flatness). Override with the same flags this script exposes.
|
|
|
|
Outputs:
|
|
``<save-output>/<svo_stem>/cloud/*.ply`` — from fish_video_weight_evaluation
|
|
``<save-output>/<svo_stem>/weight_prediction.json`` — merged SVO info + DGCNN JSON (single-SVO mode)
|
|
``<save-output>/weight_prediction.json`` — when ``--batch-svo-folder`` has **multiple** ``*.svo2``:
|
|
one DGCNN run over **all** PLYs from every ``<svo_stem>/cloud/`` (top-K applies to the union)
|
|
``<save-output>/<svo_stem>/dgcnn_test_output_top<K>.json`` — raw test_dgcnn output (per-SVO or combined under ``<save-output>/``)
|
|
|
|
With ``--reuse-existing-clouds`` (default), skips the fish subprocess when
|
|
``cloud/*.ply`` already exists; only runs test_dgcnn.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
import math
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
|
|
def _fish_video_script() -> Path:
|
|
return REPO_ROOT / "fish_video_weight_evaluation.py"
|
|
|
|
|
|
def _dgcnn_test_script() -> Path:
|
|
return REPO_ROOT / "weight_estimator" / "test_dgcnn_weight_estimator.py"
|
|
|
|
|
|
def _torch_device_to_test_arg(device: torch.device) -> str:
|
|
return "cpu" if device.type == "cpu" else "cuda"
|
|
|
|
|
|
def _collect_existing_clouds(cloud_dir: Path) -> List[Path]:
|
|
return sorted(cloud_dir.glob("*.ply"))
|
|
|
|
|
|
def has_cached_clouds(output_base: Path, svo_path: Path) -> bool:
|
|
svo_path = svo_path.expanduser().resolve()
|
|
cloud_dir = output_base / svo_path.stem / "cloud"
|
|
return len(_collect_existing_clouds(cloud_dir)) > 0
|
|
|
|
|
|
def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_folder: Optional[Path]) -> None:
|
|
fvs = _fish_video_script()
|
|
if not fvs.is_file():
|
|
raise FileNotFoundError(f"Missing: {fvs}")
|
|
|
|
out_parent = Path(args.save_output).expanduser().resolve()
|
|
cmd: List[str] = [
|
|
sys.executable,
|
|
str(fvs),
|
|
"--save-output",
|
|
str(out_parent),
|
|
"--yolo-model",
|
|
args.yolo_model,
|
|
"--conf",
|
|
str(args.conf),
|
|
"--imgsz",
|
|
str(args.imgsz),
|
|
"--sam-device",
|
|
args.sam_device,
|
|
"--max-frames",
|
|
str(args.max_frames),
|
|
"--frame-stride",
|
|
str(args.frame_stride),
|
|
]
|
|
if batch_folder is not None:
|
|
cmd.extend(["--batch-svo-folder", str(batch_folder.resolve())])
|
|
else:
|
|
cmd.extend(["--svo", str(Path(args.svo).expanduser().resolve())])
|
|
|
|
if args.filter_pointcloud:
|
|
cmd.append("--filter-pointcloud")
|
|
if args.use_density_filter:
|
|
cmd.append("--use-density-filter")
|
|
if args.use_clustering_filter:
|
|
cmd.append("--use-clustering-filter")
|
|
if args.use_pointcloud_classifier:
|
|
cmd.append("--use-pointcloud-classifier")
|
|
if args.pointcloud_classifier:
|
|
cmd.extend(["--pointcloud-classifier", args.pointcloud_classifier])
|
|
cmd.extend(
|
|
["--pointcloud-classifier-threshold", str(args.pointcloud_classifier_threshold)]
|
|
)
|
|
if args.use_flatness_filter:
|
|
cmd.append("--use-flatness-filter")
|
|
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
|
|
|
|
if getattr(args, "fish_video_weight_overlay", False):
|
|
wck = Path(args.weight_checkpoint).expanduser().resolve()
|
|
cmd.extend(
|
|
[
|
|
"--run-weight-estimation",
|
|
"--weight-estimator-checkpoint",
|
|
str(wck),
|
|
"--weight-overlay-video",
|
|
"--minute-interval-sec",
|
|
str(getattr(args, "minute_interval_sec", 60.0)),
|
|
]
|
|
)
|
|
|
|
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
|
|
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"fish_video_weight_evaluation.py exited with code {proc.returncode}")
|
|
|
|
|
|
def _run_test_dgcnn_weight_estimator_subprocess(
|
|
*,
|
|
cloud_dir: Optional[Path] = None,
|
|
ply_list_file: Optional[Path] = None,
|
|
checkpoint: Path,
|
|
device: torch.device,
|
|
out_dir: Path,
|
|
num_points: int,
|
|
xyz_scale: float,
|
|
top_k: int,
|
|
top_by_length: bool,
|
|
length_switch_mm: float,
|
|
topk_length: Optional[int],
|
|
remove_outliers: bool,
|
|
outlier_method: str,
|
|
labels_json: Optional[str],
|
|
) -> Tuple[Path, Dict[str, Any]]:
|
|
if (cloud_dir is None) == (ply_list_file is None):
|
|
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
|
|
|
|
script = _dgcnn_test_script()
|
|
if not script.is_file():
|
|
raise FileNotFoundError(f"Missing: {script}")
|
|
|
|
json_base = out_dir / "dgcnn_test_output.json"
|
|
labels_path = labels_json if labels_json else str(out_dir / "__no_labels__.json")
|
|
|
|
cmd: List[str] = [
|
|
sys.executable,
|
|
str(script),
|
|
"--checkpoint",
|
|
str(checkpoint.expanduser().resolve()),
|
|
"--device",
|
|
_torch_device_to_test_arg(device),
|
|
"--num-points",
|
|
str(num_points),
|
|
"--xyz-scale",
|
|
str(xyz_scale),
|
|
"--top-k",
|
|
str(top_k),
|
|
"--output-json",
|
|
str(json_base),
|
|
"--labels-json",
|
|
labels_path,
|
|
]
|
|
if ply_list_file is not None:
|
|
cmd.extend(["--ply-list-file", str(ply_list_file.resolve())])
|
|
else:
|
|
cmd.extend(["--ply-folder", str(cloud_dir.resolve())])
|
|
if top_by_length:
|
|
cmd.append("--top-by-length")
|
|
else:
|
|
cmd.append("--no-top-by-length")
|
|
cmd.extend(["--length-switch-mm", str(length_switch_mm)])
|
|
if topk_length is not None:
|
|
cmd.extend(["--topk-length", str(topk_length)])
|
|
if remove_outliers:
|
|
cmd.append("--remove-outliers")
|
|
cmd.extend(["--outlier-method", outlier_method])
|
|
|
|
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
|
|
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"test_dgcnn_weight_estimator.py exited with code {proc.returncode}")
|
|
|
|
suffix = f"_top{top_k}_by_length" if top_by_length else f"_top{top_k}"
|
|
written = json_base.parent / f"{json_base.stem}{suffix}{json_base.suffix}"
|
|
if not written.is_file():
|
|
raise FileNotFoundError(f"Expected DGCNN output JSON at {written}")
|
|
|
|
data = json.loads(written.read_text(encoding="utf-8"))
|
|
return written, data
|
|
|
|
|
|
def _merge_weight_prediction_json(
|
|
*,
|
|
svo_path: Path,
|
|
svo_name: str,
|
|
out_dir: Path,
|
|
cloud_dir: Path,
|
|
dgcnn_json_path: Path,
|
|
dgcnn_data: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
summary = dgcnn_data.get("summary") or {}
|
|
skipped = bool(summary.get("skipped"))
|
|
avg_kg = summary.get("avg_predicted_weight_kg")
|
|
avg_g = summary.get("avg_predicted_weight_g")
|
|
return {
|
|
"svo": str(svo_path),
|
|
"svo_name": svo_name,
|
|
"output_dir": str(out_dir),
|
|
"cloud_dir": str(cloud_dir),
|
|
"num_ply_on_disk": len(_collect_existing_clouds(cloud_dir)),
|
|
"dgcnn_output_json": str(dgcnn_json_path),
|
|
"skipped_weight": skipped,
|
|
"skip_reason": summary.get("skip_reason"),
|
|
"num_clouds_used": summary.get("num_files_used_for_avg"),
|
|
"num_ply_predicted": summary.get("num_files_predicted"),
|
|
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
|
"avg_predicted_weight_g": None if skipped else avg_g,
|
|
"dgcnn_meta": dgcnn_data.get("meta"),
|
|
"dgcnn_summary": summary,
|
|
"weight_summary": summary,
|
|
"per_cloud": dgcnn_data.get("per_file") or [],
|
|
"comparison": dgcnn_data.get("comparison"),
|
|
}
|
|
|
|
|
|
def _merge_weight_prediction_json_combined(
|
|
*,
|
|
svo_paths: List[Path],
|
|
output_base: Path,
|
|
ply_list_path: Path,
|
|
dgcnn_json_path: Path,
|
|
dgcnn_data: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
summary = dgcnn_data.get("summary") or {}
|
|
skipped = bool(summary.get("skipped"))
|
|
avg_kg = summary.get("avg_predicted_weight_kg")
|
|
avg_g = summary.get("avg_predicted_weight_g")
|
|
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
|
|
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
|
|
return {
|
|
"combined": True,
|
|
"svo_paths": [str(p) for p in svo_paths],
|
|
"svo_names": [p.stem for p in svo_paths],
|
|
"output_dir": str(output_base),
|
|
"cloud_dirs": [str(d) for d in cloud_dirs],
|
|
"ply_list_file": str(ply_list_path),
|
|
"num_ply_on_disk": n_ply,
|
|
"dgcnn_output_json": str(dgcnn_json_path),
|
|
"skipped_weight": skipped,
|
|
"skip_reason": summary.get("skip_reason"),
|
|
"num_clouds_used": summary.get("num_files_used_for_avg"),
|
|
"num_ply_predicted": summary.get("num_files_predicted"),
|
|
"avg_predicted_weight_kg": None if skipped else avg_kg,
|
|
"avg_predicted_weight_g": None if skipped else avg_g,
|
|
"dgcnn_meta": dgcnn_data.get("meta"),
|
|
"dgcnn_summary": summary,
|
|
"weight_summary": summary,
|
|
"per_cloud": dgcnn_data.get("per_file") or [],
|
|
"comparison": dgcnn_data.get("comparison"),
|
|
}
|
|
|
|
|
|
def _sanitize_for_json(obj: Any) -> Any:
|
|
if isinstance(obj, dict):
|
|
return {k: _sanitize_for_json(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_sanitize_for_json(v) for v in obj]
|
|
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
|
return None
|
|
return obj
|
|
|
|
|
|
def run_weight_prediction_for_svo(
|
|
svo_path: Path,
|
|
output_base: Path,
|
|
weight_checkpoint: Path,
|
|
weight_device: torch.device,
|
|
num_points_model: int,
|
|
weight_top_k: int,
|
|
weight_top_by_length: bool,
|
|
weight_length_switch_mm: float,
|
|
weight_topk_length: Optional[int],
|
|
weight_remove_outliers: bool,
|
|
weight_outlier_method: str,
|
|
weight_xyz_scale: float,
|
|
weight_labels_json: Optional[str],
|
|
force_dgcnn_subprocess: bool = False,
|
|
) -> Dict[str, Any]:
|
|
svo_path = svo_path.expanduser().resolve()
|
|
if not svo_path.exists():
|
|
raise FileNotFoundError(f"svo not found: {svo_path}")
|
|
|
|
svo_name = svo_path.stem
|
|
out_dir = output_base / svo_name
|
|
cloud_dir = out_dir / "cloud"
|
|
|
|
plys = _collect_existing_clouds(cloud_dir)
|
|
if not plys:
|
|
raise RuntimeError(
|
|
f"No .ply files in {cloud_dir}. Run without --reuse-existing-clouds or run "
|
|
f"fish_video_weight_evaluation.py first."
|
|
)
|
|
|
|
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
|
|
if not force_dgcnn_subprocess and fish_wj.is_file():
|
|
try:
|
|
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
|
|
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
|
|
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
|
|
result = _merge_weight_prediction_json(
|
|
svo_path=svo_path,
|
|
svo_name=svo_name,
|
|
out_dir=out_dir,
|
|
cloud_dir=cloud_dir,
|
|
dgcnn_json_path=fish_wj,
|
|
dgcnn_data=dgcnn_data,
|
|
)
|
|
(out_dir / "weight_prediction.json").write_text(
|
|
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
return result
|
|
except Exception as e:
|
|
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
|
|
|
|
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
|
|
cloud_dir=cloud_dir,
|
|
ply_list_file=None,
|
|
checkpoint=weight_checkpoint,
|
|
device=weight_device,
|
|
out_dir=out_dir,
|
|
num_points=num_points_model,
|
|
xyz_scale=weight_xyz_scale,
|
|
top_k=weight_top_k,
|
|
top_by_length=weight_top_by_length,
|
|
length_switch_mm=weight_length_switch_mm,
|
|
topk_length=weight_topk_length,
|
|
remove_outliers=weight_remove_outliers,
|
|
outlier_method=weight_outlier_method,
|
|
labels_json=weight_labels_json,
|
|
)
|
|
result = _merge_weight_prediction_json(
|
|
svo_path=svo_path,
|
|
svo_name=svo_name,
|
|
out_dir=out_dir,
|
|
cloud_dir=cloud_dir,
|
|
dgcnn_json_path=dgcnn_path,
|
|
dgcnn_data=dgcnn_data,
|
|
)
|
|
(out_dir / "weight_prediction.json").write_text(
|
|
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
return result
|
|
|
|
|
|
def run_weight_prediction_combined_svos(
|
|
svo_paths: List[Path],
|
|
output_base: Path,
|
|
weight_checkpoint: Path,
|
|
weight_device: torch.device,
|
|
num_points_model: int,
|
|
weight_top_k: int,
|
|
weight_top_by_length: bool,
|
|
weight_length_switch_mm: float,
|
|
weight_topk_length: Optional[int],
|
|
weight_remove_outliers: bool,
|
|
weight_outlier_method: str,
|
|
weight_xyz_scale: float,
|
|
weight_labels_json: Optional[str],
|
|
) -> Dict[str, Any]:
|
|
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
|
|
output_base = output_base.expanduser().resolve()
|
|
svo_paths = [p.expanduser().resolve() for p in svo_paths]
|
|
all_plys: List[Path] = []
|
|
for svo in svo_paths:
|
|
all_plys.extend(_collect_existing_clouds(output_base / svo.stem / "cloud"))
|
|
if not all_plys:
|
|
raise RuntimeError(
|
|
f"No .ply in any cloud folder under {output_base} for SVOs: {[p.stem for p in svo_paths]}"
|
|
)
|
|
out_dir = output_base
|
|
list_path = out_dir / "_combined_ply_list.txt"
|
|
list_path.write_text("\n".join(str(p.resolve()) for p in sorted(all_plys)) + "\n", encoding="utf-8")
|
|
|
|
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
|
|
cloud_dir=None,
|
|
ply_list_file=list_path,
|
|
checkpoint=weight_checkpoint,
|
|
device=weight_device,
|
|
out_dir=out_dir,
|
|
num_points=num_points_model,
|
|
xyz_scale=weight_xyz_scale,
|
|
top_k=weight_top_k,
|
|
top_by_length=weight_top_by_length,
|
|
length_switch_mm=weight_length_switch_mm,
|
|
topk_length=weight_topk_length,
|
|
remove_outliers=weight_remove_outliers,
|
|
outlier_method=weight_outlier_method,
|
|
labels_json=weight_labels_json,
|
|
)
|
|
result = _merge_weight_prediction_json_combined(
|
|
svo_paths=svo_paths,
|
|
output_base=output_base,
|
|
ply_list_path=list_path,
|
|
dgcnn_json_path=dgcnn_path,
|
|
dgcnn_data=dgcnn_data,
|
|
)
|
|
(out_dir / "weight_prediction.json").write_text(
|
|
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
return result
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="SVO2 → fish_video_weight_evaluation.py → test_dgcnn_weight_estimator.py"
|
|
)
|
|
src = parser.add_mutually_exclusive_group(required=True)
|
|
src.add_argument("--svo", type=str, default="", help="Path to a single .svo2 file")
|
|
src.add_argument("--batch-svo-folder", type=str, default="", help="Folder containing *.svo2 files")
|
|
|
|
parser.add_argument(
|
|
"--save-output",
|
|
type=str,
|
|
default="output_weight_estimator",
|
|
help="Base output directory (same as fish_video --save-output parent)",
|
|
)
|
|
|
|
# fish_video_weight_evaluation (aligned with run_fish_evaluation_simple.sh)
|
|
parser.add_argument(
|
|
"--yolo-model",
|
|
type=str,
|
|
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
|
|
)
|
|
parser.add_argument("--conf", type=float, default=0.5, help="YOLO confidence (simple.sh uses 0.5)")
|
|
parser.add_argument("--imgsz", type=int, default=640)
|
|
parser.add_argument(
|
|
"--sam-device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="SAM device for fish_video; also used as --device for test_dgcnn (cuda|cpu)",
|
|
)
|
|
parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames")
|
|
parser.add_argument(
|
|
"--frame-stride",
|
|
type=int,
|
|
default=1,
|
|
metavar="N",
|
|
help="Passed to fish_video: only YOLO/SAM/PLY every N-th frame (default: 1)",
|
|
)
|
|
parser.add_argument("--filter-pointcloud", action="store_true")
|
|
parser.add_argument("--use-density-filter", action="store_true")
|
|
parser.add_argument("--use-clustering-filter", action="store_true")
|
|
parser.add_argument("--use-pointcloud-classifier", action="store_true")
|
|
parser.add_argument("--pointcloud-classifier", type=str, default=None)
|
|
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7)
|
|
parser.add_argument("--use-flatness-filter", action="store_true")
|
|
parser.add_argument("--flatness-threshold", type=float, default=70.0)
|
|
|
|
# test_dgcnn_weight_estimator
|
|
parser.add_argument(
|
|
"--weight-checkpoint",
|
|
type=str,
|
|
default=str(REPO_ROOT / "weight_estimator/runs/dgcnn_20260312_171043/best.pt"),
|
|
)
|
|
parser.add_argument("--num-points-model", type=int, default=768)
|
|
parser.add_argument("--weight-top-k", type=int, default=5)
|
|
parser.add_argument(
|
|
"--weight-top-by-length",
|
|
action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
help="Length-first top-K (default: on). Use --no-weight-top-by-length for weight-only order.",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-length-switch-mm",
|
|
type=float,
|
|
default=319.0,
|
|
help="If mean top-K length (mm) exceeds this, use top-K by predicted weight (default: 320).",
|
|
)
|
|
parser.add_argument("--weight-topk-length", type=int, default=None)
|
|
parser.add_argument("--weight-remove-outliers", action="store_true")
|
|
parser.add_argument(
|
|
"--weight-outlier-method",
|
|
type=str,
|
|
default="iqr",
|
|
choices=["iqr", "zscore"],
|
|
)
|
|
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
|
|
parser.add_argument("--weight-labels-json", type=str, default=None)
|
|
|
|
parser.add_argument(
|
|
"--reuse-existing-clouds",
|
|
action="store_true",
|
|
default=True,
|
|
help="Skip fish_video if <save-output>/<svo_stem>/cloud/*.ply already exists",
|
|
)
|
|
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
|
|
|
|
parser.add_argument(
|
|
"--fish-video-weight-overlay",
|
|
action="store_true",
|
|
help="Run fish_video with DGCNN + preview video overlay (fish weight g, top-5, per-window avg). "
|
|
"Avoids a duplicate test_dgcnn pass when weight_estimation_results.json is present.",
|
|
)
|
|
parser.add_argument(
|
|
"--minute-interval-sec",
|
|
type=float,
|
|
default=60.0,
|
|
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
|
|
)
|
|
parser.add_argument(
|
|
"--force-dgcnn-subprocess",
|
|
action="store_true",
|
|
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
if args.frame_stride < 1:
|
|
raise SystemExit("--frame-stride must be >= 1")
|
|
|
|
output_base = Path(args.save_output).expanduser().resolve()
|
|
output_base.mkdir(parents=True, exist_ok=True)
|
|
|
|
weight_ckpt = Path(args.weight_checkpoint).expanduser().resolve()
|
|
if not weight_ckpt.is_file():
|
|
raise SystemExit(f"weight checkpoint not found: {weight_ckpt}")
|
|
|
|
if args.svo:
|
|
svo_paths = [Path(args.svo).expanduser().resolve()]
|
|
batch_folder: Optional[Path] = None
|
|
else:
|
|
batch_folder = Path(args.batch_svo_folder).expanduser().resolve()
|
|
svo_paths = sorted(batch_folder.glob("*.svo2"))
|
|
|
|
need_fish = True
|
|
if svo_paths and args.reuse_existing_clouds:
|
|
need_fish = not all(has_cached_clouds(output_base, p) for p in svo_paths)
|
|
|
|
if need_fish:
|
|
if args.svo:
|
|
_run_fish_video_evaluation_subprocess(args, batch_folder=None)
|
|
else:
|
|
if not batch_folder or not batch_folder.is_dir():
|
|
raise SystemExit(f"batch folder not found: {args.batch_svo_folder}")
|
|
if not svo_paths:
|
|
raise SystemExit(f"No .svo2 files in: {batch_folder}")
|
|
_run_fish_video_evaluation_subprocess(args, batch_folder=batch_folder)
|
|
else:
|
|
print("All targets already have cloud/*.ply — skipping fish_video_weight_evaluation.py.")
|
|
|
|
device = torch.device(args.sam_device)
|
|
|
|
results: List[Dict[str, Any]] = []
|
|
if len(svo_paths) > 1:
|
|
names = ", ".join(p.name for p in svo_paths)
|
|
print(f"\n=== Weight prediction: combined ({len(svo_paths)} SVOs) → top-{args.weight_top_k} over all PLYs ===")
|
|
print(f" {names}")
|
|
try:
|
|
results.append(
|
|
run_weight_prediction_combined_svos(
|
|
svo_paths=svo_paths,
|
|
output_base=output_base,
|
|
weight_checkpoint=weight_ckpt,
|
|
weight_device=device,
|
|
num_points_model=args.num_points_model,
|
|
weight_top_k=args.weight_top_k,
|
|
weight_top_by_length=args.weight_top_by_length,
|
|
weight_length_switch_mm=args.weight_length_switch_mm,
|
|
weight_topk_length=args.weight_topk_length,
|
|
weight_remove_outliers=args.weight_remove_outliers,
|
|
weight_outlier_method=args.weight_outlier_method,
|
|
weight_xyz_scale=args.weight_xyz_scale,
|
|
weight_labels_json=args.weight_labels_json,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
results.append(
|
|
{"combined": True, "svo_paths": [str(p) for p in svo_paths], "error": str(e)}
|
|
)
|
|
else:
|
|
for svo in svo_paths:
|
|
label = svo.name
|
|
print(f"\n=== Weight prediction: {label} ===")
|
|
try:
|
|
results.append(
|
|
run_weight_prediction_for_svo(
|
|
svo_path=svo,
|
|
output_base=output_base,
|
|
weight_checkpoint=weight_ckpt,
|
|
weight_device=device,
|
|
num_points_model=args.num_points_model,
|
|
weight_top_k=args.weight_top_k,
|
|
weight_top_by_length=args.weight_top_by_length,
|
|
weight_length_switch_mm=args.weight_length_switch_mm,
|
|
weight_topk_length=args.weight_topk_length,
|
|
weight_remove_outliers=args.weight_remove_outliers,
|
|
weight_outlier_method=args.weight_outlier_method,
|
|
weight_xyz_scale=args.weight_xyz_scale,
|
|
weight_labels_json=args.weight_labels_json,
|
|
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
results.append({"svo": str(svo), "error": str(e)})
|
|
|
|
summary_path = output_base / "weight_predictions_summary.json"
|
|
summary_path.write_text(
|
|
json.dumps(_sanitize_for_json(results), indent=2, ensure_ascii=False),
|
|
encoding="utf-8",
|
|
)
|
|
print(f"\nSaved summary: {summary_path}")
|
|
|
|
if len(results) == 1 and "error" not in results[0]:
|
|
r0 = results[0]
|
|
avg_g = r0.get("avg_predicted_weight_g")
|
|
if avg_g is not None:
|
|
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
|
|
else:
|
|
print("Final predicted weight: N/A (skipped or no valid point clouds)")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|