Files
FishServer/FishMeasure/predict_weigth_from_svo2.py
2026-04-10 10:30:01 +08:00

639 lines
24 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
SVO2 → point clouds via **fish_video_weight_evaluation.py** (subprocess), then weight via
**weight_estimator/test_dgcnn_weight_estimator.py** (subprocess).
Defaults for the fish step align with ``run_fish_evaluation_simple.sh`` (conf 0.5, filters,
classifier, flatness). Override with the same flags this script exposes.
Outputs:
``<save-output>/<svo_stem>/cloud/*.ply`` — from fish_video_weight_evaluation
``<save-output>/<svo_stem>/weight_prediction.json`` — merged SVO info + DGCNN JSON (single-SVO mode)
``<save-output>/weight_prediction.json`` — when ``--batch-svo-folder`` has **multiple** ``*.svo2``:
one DGCNN run over **all** PLYs from every ``<svo_stem>/cloud/`` (top-K applies to the union)
``<save-output>/<svo_stem>/dgcnn_test_output_top<K>.json`` — raw test_dgcnn output (per-SVO or combined under ``<save-output>/``)
With ``--reuse-existing-clouds`` (default), skips the fish subprocess when
``cloud/*.ply`` already exists; only runs test_dgcnn.
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def _fish_video_script() -> Path:
return REPO_ROOT / "fish_video_weight_evaluation.py"
def _dgcnn_test_script() -> Path:
return REPO_ROOT / "weight_estimator" / "test_dgcnn_weight_estimator.py"
def _torch_device_to_test_arg(device: torch.device) -> str:
return "cpu" if device.type == "cpu" else "cuda"
def _collect_existing_clouds(cloud_dir: Path) -> List[Path]:
return sorted(cloud_dir.glob("*.ply"))
def has_cached_clouds(output_base: Path, svo_path: Path) -> bool:
svo_path = svo_path.expanduser().resolve()
cloud_dir = output_base / svo_path.stem / "cloud"
return len(_collect_existing_clouds(cloud_dir)) > 0
def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_folder: Optional[Path]) -> None:
fvs = _fish_video_script()
if not fvs.is_file():
raise FileNotFoundError(f"Missing: {fvs}")
out_parent = Path(args.save_output).expanduser().resolve()
cmd: List[str] = [
sys.executable,
str(fvs),
"--save-output",
str(out_parent),
"--yolo-model",
args.yolo_model,
"--conf",
str(args.conf),
"--imgsz",
str(args.imgsz),
"--sam-device",
args.sam_device,
"--max-frames",
str(args.max_frames),
"--frame-stride",
str(args.frame_stride),
]
if batch_folder is not None:
cmd.extend(["--batch-svo-folder", str(batch_folder.resolve())])
else:
cmd.extend(["--svo", str(Path(args.svo).expanduser().resolve())])
if args.filter_pointcloud:
cmd.append("--filter-pointcloud")
if args.use_density_filter:
cmd.append("--use-density-filter")
if args.use_clustering_filter:
cmd.append("--use-clustering-filter")
if args.use_pointcloud_classifier:
cmd.append("--use-pointcloud-classifier")
if args.pointcloud_classifier:
cmd.extend(["--pointcloud-classifier", args.pointcloud_classifier])
cmd.extend(
["--pointcloud-classifier-threshold", str(args.pointcloud_classifier_threshold)]
)
if args.use_flatness_filter:
cmd.append("--use-flatness-filter")
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
# 始终在 fish 内跑 DGCNN生成 weight_estimation_results.json预览视频才能叠加 weight/length
# predict 后续会合并该 JSON避免重复跑 test_dgcnn。
wck = Path(args.weight_checkpoint).expanduser().resolve()
if wck.is_file():
cmd.extend(
[
"--run-weight-estimation",
"--weight-estimator-checkpoint",
str(wck),
]
)
if getattr(args, "fish_video_weight_overlay", False):
cmd.extend(
[
"--weight-overlay-video",
"--minute-interval-sec",
str(getattr(args, "minute_interval_sec", 60.0)),
]
)
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
raise RuntimeError(f"fish_video_weight_evaluation.py exited with code {proc.returncode}")
def _run_test_dgcnn_weight_estimator_subprocess(
*,
cloud_dir: Optional[Path] = None,
ply_list_file: Optional[Path] = None,
checkpoint: Path,
device: torch.device,
out_dir: Path,
num_points: int,
xyz_scale: float,
top_k: int,
top_by_length: bool,
length_switch_mm: float,
topk_length: Optional[int],
remove_outliers: bool,
outlier_method: str,
labels_json: Optional[str],
) -> Tuple[Path, Dict[str, Any]]:
if (cloud_dir is None) == (ply_list_file is None):
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
script = _dgcnn_test_script()
if not script.is_file():
raise FileNotFoundError(f"Missing: {script}")
json_base = out_dir / "dgcnn_test_output.json"
labels_path = labels_json if labels_json else str(out_dir / "__no_labels__.json")
cmd: List[str] = [
sys.executable,
str(script),
"--checkpoint",
str(checkpoint.expanduser().resolve()),
"--device",
_torch_device_to_test_arg(device),
"--num-points",
str(num_points),
"--xyz-scale",
str(xyz_scale),
"--top-k",
str(top_k),
"--output-json",
str(json_base),
"--labels-json",
labels_path,
]
if ply_list_file is not None:
cmd.extend(["--ply-list-file", str(ply_list_file.resolve())])
else:
cmd.extend(["--ply-folder", str(cloud_dir.resolve())])
if top_by_length:
cmd.append("--top-by-length")
else:
cmd.append("--no-top-by-length")
cmd.extend(["--length-switch-mm", str(length_switch_mm)])
if topk_length is not None:
cmd.extend(["--topk-length", str(topk_length)])
if remove_outliers:
cmd.append("--remove-outliers")
cmd.extend(["--outlier-method", outlier_method])
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
raise RuntimeError(f"test_dgcnn_weight_estimator.py exited with code {proc.returncode}")
suffix = f"_top{top_k}_by_length" if top_by_length else f"_top{top_k}"
written = json_base.parent / f"{json_base.stem}{suffix}{json_base.suffix}"
if not written.is_file():
raise FileNotFoundError(f"Expected DGCNN output JSON at {written}")
data = json.loads(written.read_text(encoding="utf-8"))
return written, data
def _merge_weight_prediction_json(
*,
svo_path: Path,
svo_name: str,
out_dir: Path,
cloud_dir: Path,
dgcnn_json_path: Path,
dgcnn_data: Dict[str, Any],
) -> Dict[str, Any]:
summary = dgcnn_data.get("summary") or {}
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
return {
"svo": str(svo_path),
"svo_name": svo_name,
"output_dir": str(out_dir),
"cloud_dir": str(cloud_dir),
"num_ply_on_disk": len(_collect_existing_clouds(cloud_dir)),
"dgcnn_output_json": str(dgcnn_json_path),
"skipped_weight": skipped,
"skip_reason": summary.get("skip_reason"),
"num_clouds_used": summary.get("num_files_used_for_avg"),
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
"per_cloud": dgcnn_data.get("per_file") or [],
"comparison": dgcnn_data.get("comparison"),
}
def _merge_weight_prediction_json_combined(
*,
svo_paths: List[Path],
output_base: Path,
ply_list_path: Path,
dgcnn_json_path: Path,
dgcnn_data: Dict[str, Any],
) -> Dict[str, Any]:
summary = dgcnn_data.get("summary") or {}
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
return {
"combined": True,
"svo_paths": [str(p) for p in svo_paths],
"svo_names": [p.stem for p in svo_paths],
"output_dir": str(output_base),
"cloud_dirs": [str(d) for d in cloud_dirs],
"ply_list_file": str(ply_list_path),
"num_ply_on_disk": n_ply,
"dgcnn_output_json": str(dgcnn_json_path),
"skipped_weight": skipped,
"skip_reason": summary.get("skip_reason"),
"num_clouds_used": summary.get("num_files_used_for_avg"),
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
"per_cloud": dgcnn_data.get("per_file") or [],
"comparison": dgcnn_data.get("comparison"),
}
def _sanitize_for_json(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_sanitize_for_json(v) for v in obj]
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return None
return obj
def run_weight_prediction_for_svo(
svo_path: Path,
output_base: Path,
weight_checkpoint: Path,
weight_device: torch.device,
num_points_model: int,
weight_top_k: int,
weight_top_by_length: bool,
weight_length_switch_mm: float,
weight_topk_length: Optional[int],
weight_remove_outliers: bool,
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
force_dgcnn_subprocess: bool = False,
) -> Dict[str, Any]:
svo_path = svo_path.expanduser().resolve()
if not svo_path.exists():
raise FileNotFoundError(f"svo not found: {svo_path}")
svo_name = svo_path.stem
out_dir = output_base / svo_name
cloud_dir = out_dir / "cloud"
plys = _collect_existing_clouds(cloud_dir)
if not plys:
raise RuntimeError(
f"No .ply files in {cloud_dir}. Run without --reuse-existing-clouds or run "
f"fish_video_weight_evaluation.py first."
)
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
if not force_dgcnn_subprocess and fish_wj.is_file():
try:
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=fish_wj,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
except Exception as e:
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=cloud_dir,
ply_list_file=None,
checkpoint=weight_checkpoint,
device=weight_device,
out_dir=out_dir,
num_points=num_points_model,
xyz_scale=weight_xyz_scale,
top_k=weight_top_k,
top_by_length=weight_top_by_length,
length_switch_mm=weight_length_switch_mm,
topk_length=weight_topk_length,
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
)
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=dgcnn_path,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
def run_weight_prediction_combined_svos(
svo_paths: List[Path],
output_base: Path,
weight_checkpoint: Path,
weight_device: torch.device,
num_points_model: int,
weight_top_k: int,
weight_top_by_length: bool,
weight_length_switch_mm: float,
weight_topk_length: Optional[int],
weight_remove_outliers: bool,
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
) -> Dict[str, Any]:
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
output_base = output_base.expanduser().resolve()
svo_paths = [p.expanduser().resolve() for p in svo_paths]
all_plys: List[Path] = []
for svo in svo_paths:
all_plys.extend(_collect_existing_clouds(output_base / svo.stem / "cloud"))
if not all_plys:
raise RuntimeError(
f"No .ply in any cloud folder under {output_base} for SVOs: {[p.stem for p in svo_paths]}"
)
out_dir = output_base
list_path = out_dir / "_combined_ply_list.txt"
list_path.write_text("\n".join(str(p.resolve()) for p in sorted(all_plys)) + "\n", encoding="utf-8")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=None,
ply_list_file=list_path,
checkpoint=weight_checkpoint,
device=weight_device,
out_dir=out_dir,
num_points=num_points_model,
xyz_scale=weight_xyz_scale,
top_k=weight_top_k,
top_by_length=weight_top_by_length,
length_switch_mm=weight_length_switch_mm,
topk_length=weight_topk_length,
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
)
result = _merge_weight_prediction_json_combined(
svo_paths=svo_paths,
output_base=output_base,
ply_list_path=list_path,
dgcnn_json_path=dgcnn_path,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="SVO2 → fish_video_weight_evaluation.py → test_dgcnn_weight_estimator.py"
)
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument("--svo", type=str, default="", help="Path to a single .svo2 file")
src.add_argument("--batch-svo-folder", type=str, default="", help="Folder containing *.svo2 files")
parser.add_argument(
"--save-output",
type=str,
default="output_weight_estimator",
help="Base output directory (same as fish_video --save-output parent)",
)
# fish_video_weight_evaluation (aligned with run_fish_evaluation_simple.sh)
parser.add_argument(
"--yolo-model",
type=str,
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
)
parser.add_argument("--conf", type=float, default=0.5, help="YOLO confidence (simple.sh uses 0.5)")
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument(
"--sam-device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="SAM device for fish_video; also used as --device for test_dgcnn (cuda|cpu)",
)
parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames")
parser.add_argument(
"--frame-stride",
type=int,
default=1,
metavar="N",
help="Passed to fish_video: only YOLO/SAM/PLY every N-th frame (default: 1)",
)
parser.add_argument("--filter-pointcloud", action="store_true")
parser.add_argument("--use-density-filter", action="store_true")
parser.add_argument("--use-clustering-filter", action="store_true")
parser.add_argument("--use-pointcloud-classifier", action="store_true")
parser.add_argument("--pointcloud-classifier", type=str, default=None)
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7)
parser.add_argument("--use-flatness-filter", action="store_true")
parser.add_argument("--flatness-threshold", type=float, default=70.0)
# test_dgcnn_weight_estimator
parser.add_argument(
"--weight-checkpoint",
type=str,
default=str(REPO_ROOT / "weight_estimator/runs/dgcnn_20260312_171043/best.pt"),
)
parser.add_argument("--num-points-model", type=int, default=768)
parser.add_argument("--weight-top-k", type=int, default=5)
parser.add_argument(
"--weight-top-by-length",
action=argparse.BooleanOptionalAction,
default=True,
help="Length-first top-K (default: on). Use --no-weight-top-by-length for weight-only order.",
)
parser.add_argument(
"--weight-length-switch-mm",
type=float,
default=319.0,
help="If mean top-K length (mm) exceeds this, use top-K by predicted weight (default: 320).",
)
parser.add_argument("--weight-topk-length", type=int, default=None)
parser.add_argument("--weight-remove-outliers", action="store_true")
parser.add_argument(
"--weight-outlier-method",
type=str,
default="iqr",
choices=["iqr", "zscore"],
)
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
parser.add_argument("--weight-labels-json", type=str, default=None)
parser.add_argument(
"--reuse-existing-clouds",
action="store_true",
default=True,
help="Skip fish_video if <save-output>/<svo_stem>/cloud/*.ply already exists",
)
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
parser.add_argument(
"--fish-video-weight-overlay",
action="store_true",
help="Extra on-video header lines (top-5 / per-minute bucket). "
"DGCNN in fish + preview weight/length labels are already enabled when weight checkpoint exists.",
)
parser.add_argument(
"--minute-interval-sec",
type=float,
default=60.0,
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
)
parser.add_argument(
"--force-dgcnn-subprocess",
action="store_true",
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
)
args = parser.parse_args()
if args.frame_stride < 1:
raise SystemExit("--frame-stride must be >= 1")
output_base = Path(args.save_output).expanduser().resolve()
output_base.mkdir(parents=True, exist_ok=True)
weight_ckpt = Path(args.weight_checkpoint).expanduser().resolve()
if not weight_ckpt.is_file():
raise SystemExit(f"weight checkpoint not found: {weight_ckpt}")
if args.svo:
svo_paths = [Path(args.svo).expanduser().resolve()]
batch_folder: Optional[Path] = None
else:
batch_folder = Path(args.batch_svo_folder).expanduser().resolve()
svo_paths = sorted(batch_folder.glob("*.svo2"))
need_fish = True
if svo_paths and args.reuse_existing_clouds:
need_fish = not all(has_cached_clouds(output_base, p) for p in svo_paths)
if need_fish:
if args.svo:
_run_fish_video_evaluation_subprocess(args, batch_folder=None)
else:
if not batch_folder or not batch_folder.is_dir():
raise SystemExit(f"batch folder not found: {args.batch_svo_folder}")
if not svo_paths:
raise SystemExit(f"No .svo2 files in: {batch_folder}")
_run_fish_video_evaluation_subprocess(args, batch_folder=batch_folder)
else:
print("All targets already have cloud/*.ply — skipping fish_video_weight_evaluation.py.")
device = torch.device(args.sam_device)
results: List[Dict[str, Any]] = []
if len(svo_paths) > 1:
names = ", ".join(p.name for p in svo_paths)
print(f"\n=== Weight prediction: combined ({len(svo_paths)} SVOs) → top-{args.weight_top_k} over all PLYs ===")
print(f" {names}")
try:
results.append(
run_weight_prediction_combined_svos(
svo_paths=svo_paths,
output_base=output_base,
weight_checkpoint=weight_ckpt,
weight_device=device,
num_points_model=args.num_points_model,
weight_top_k=args.weight_top_k,
weight_top_by_length=args.weight_top_by_length,
weight_length_switch_mm=args.weight_length_switch_mm,
weight_topk_length=args.weight_topk_length,
weight_remove_outliers=args.weight_remove_outliers,
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
)
)
except Exception as e:
results.append(
{"combined": True, "svo_paths": [str(p) for p in svo_paths], "error": str(e)}
)
else:
for svo in svo_paths:
label = svo.name
print(f"\n=== Weight prediction: {label} ===")
try:
results.append(
run_weight_prediction_for_svo(
svo_path=svo,
output_base=output_base,
weight_checkpoint=weight_ckpt,
weight_device=device,
num_points_model=args.num_points_model,
weight_top_k=args.weight_top_k,
weight_top_by_length=args.weight_top_by_length,
weight_length_switch_mm=args.weight_length_switch_mm,
weight_topk_length=args.weight_topk_length,
weight_remove_outliers=args.weight_remove_outliers,
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
)
)
except Exception as e:
results.append({"svo": str(svo), "error": str(e)})
summary_path = output_base / "weight_predictions_summary.json"
summary_path.write_text(
json.dumps(_sanitize_for_json(results), indent=2, ensure_ascii=False),
encoding="utf-8",
)
print(f"\nSaved summary: {summary_path}")
if len(results) == 1 and "error" not in results[0]:
r0 = results[0]
avg_g = r0.get("avg_predicted_weight_g")
if avg_g is not None:
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
else:
print("Final predicted weight: N/A (skipped or no valid point clouds)")
if __name__ == "__main__":
main()