Files
FishServer/FishMeasure/predict_weigth_from_svo2.py
zaiun xu 5e1b2117c1 feat(fish_api): SQLite 快照投递、日志与 watch 空闲告警
- 新增 SQLite:measure/health 快照、delivery_cursor 单消费者 pop;clear/start_fresh 可清空库
- biomass GET 仅返回约定 data 字段,X-Fish-Biomass-New 表示是否有新快照;poller 读响应头
- loguru 桥接 uvicorn,子进程 stdout 流式输出;format_json_pretty 与算法摘要日志
- measure/action watch 无新任务时限流 WARNING;watch_idle 共用逻辑
- 依赖 loguru;新增 db、logging_config、subprocess_run、watch_idle、启动脚本

FishMeasure: 更新 fish_video_weight_evaluation 与 predict_weigth_from_svo2;移除未用 refbox/segmentation 脚本
Made-with: Cursor
2026-04-09 11:54:30 +08:00

632 lines
24 KiB
Python
Executable File

#!/usr/bin/env python3
"""
SVO2 → point clouds via **fish_video_weight_evaluation.py** (subprocess), then weight via
**weight_estimator/test_dgcnn_weight_estimator.py** (subprocess).
Defaults for the fish step align with ``run_fish_evaluation_simple.sh`` (conf 0.5, filters,
classifier, flatness). Override with the same flags this script exposes.
Outputs:
``<save-output>/<svo_stem>/cloud/*.ply`` — from fish_video_weight_evaluation
``<save-output>/<svo_stem>/weight_prediction.json`` — merged SVO info + DGCNN JSON (single-SVO mode)
``<save-output>/weight_prediction.json`` — when ``--batch-svo-folder`` has **multiple** ``*.svo2``:
one DGCNN run over **all** PLYs from every ``<svo_stem>/cloud/`` (top-K applies to the union)
``<save-output>/<svo_stem>/dgcnn_test_output_top<K>.json`` — raw test_dgcnn output (per-SVO or combined under ``<save-output>/``)
With ``--reuse-existing-clouds`` (default), skips the fish subprocess when
``cloud/*.ply`` already exists; only runs test_dgcnn.
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
import math
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
def _fish_video_script() -> Path:
return REPO_ROOT / "fish_video_weight_evaluation.py"
def _dgcnn_test_script() -> Path:
return REPO_ROOT / "weight_estimator" / "test_dgcnn_weight_estimator.py"
def _torch_device_to_test_arg(device: torch.device) -> str:
return "cpu" if device.type == "cpu" else "cuda"
def _collect_existing_clouds(cloud_dir: Path) -> List[Path]:
return sorted(cloud_dir.glob("*.ply"))
def has_cached_clouds(output_base: Path, svo_path: Path) -> bool:
svo_path = svo_path.expanduser().resolve()
cloud_dir = output_base / svo_path.stem / "cloud"
return len(_collect_existing_clouds(cloud_dir)) > 0
def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_folder: Optional[Path]) -> None:
fvs = _fish_video_script()
if not fvs.is_file():
raise FileNotFoundError(f"Missing: {fvs}")
out_parent = Path(args.save_output).expanduser().resolve()
cmd: List[str] = [
sys.executable,
str(fvs),
"--save-output",
str(out_parent),
"--yolo-model",
args.yolo_model,
"--conf",
str(args.conf),
"--imgsz",
str(args.imgsz),
"--sam-device",
args.sam_device,
"--max-frames",
str(args.max_frames),
"--frame-stride",
str(args.frame_stride),
]
if batch_folder is not None:
cmd.extend(["--batch-svo-folder", str(batch_folder.resolve())])
else:
cmd.extend(["--svo", str(Path(args.svo).expanduser().resolve())])
if args.filter_pointcloud:
cmd.append("--filter-pointcloud")
if args.use_density_filter:
cmd.append("--use-density-filter")
if args.use_clustering_filter:
cmd.append("--use-clustering-filter")
if args.use_pointcloud_classifier:
cmd.append("--use-pointcloud-classifier")
if args.pointcloud_classifier:
cmd.extend(["--pointcloud-classifier", args.pointcloud_classifier])
cmd.extend(
["--pointcloud-classifier-threshold", str(args.pointcloud_classifier_threshold)]
)
if args.use_flatness_filter:
cmd.append("--use-flatness-filter")
cmd.extend(["--flatness-threshold", str(args.flatness_threshold)])
if getattr(args, "fish_video_weight_overlay", False):
wck = Path(args.weight_checkpoint).expanduser().resolve()
cmd.extend(
[
"--run-weight-estimation",
"--weight-estimator-checkpoint",
str(wck),
"--weight-overlay-video",
"--minute-interval-sec",
str(getattr(args, "minute_interval_sec", 60.0)),
]
)
print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
raise RuntimeError(f"fish_video_weight_evaluation.py exited with code {proc.returncode}")
def _run_test_dgcnn_weight_estimator_subprocess(
*,
cloud_dir: Optional[Path] = None,
ply_list_file: Optional[Path] = None,
checkpoint: Path,
device: torch.device,
out_dir: Path,
num_points: int,
xyz_scale: float,
top_k: int,
top_by_length: bool,
length_switch_mm: float,
topk_length: Optional[int],
remove_outliers: bool,
outlier_method: str,
labels_json: Optional[str],
) -> Tuple[Path, Dict[str, Any]]:
if (cloud_dir is None) == (ply_list_file is None):
raise ValueError("Exactly one of cloud_dir or ply_list_file must be set")
script = _dgcnn_test_script()
if not script.is_file():
raise FileNotFoundError(f"Missing: {script}")
json_base = out_dir / "dgcnn_test_output.json"
labels_path = labels_json if labels_json else str(out_dir / "__no_labels__.json")
cmd: List[str] = [
sys.executable,
str(script),
"--checkpoint",
str(checkpoint.expanduser().resolve()),
"--device",
_torch_device_to_test_arg(device),
"--num-points",
str(num_points),
"--xyz-scale",
str(xyz_scale),
"--top-k",
str(top_k),
"--output-json",
str(json_base),
"--labels-json",
labels_path,
]
if ply_list_file is not None:
cmd.extend(["--ply-list-file", str(ply_list_file.resolve())])
else:
cmd.extend(["--ply-folder", str(cloud_dir.resolve())])
if top_by_length:
cmd.append("--top-by-length")
else:
cmd.append("--no-top-by-length")
cmd.extend(["--length-switch-mm", str(length_switch_mm)])
if topk_length is not None:
cmd.extend(["--topk-length", str(topk_length)])
if remove_outliers:
cmd.append("--remove-outliers")
cmd.extend(["--outlier-method", outlier_method])
print(f"Invoking test_dgcnn_weight_estimator.py:\n {' '.join(cmd)}")
proc = subprocess.run(cmd, cwd=str(REPO_ROOT))
if proc.returncode != 0:
raise RuntimeError(f"test_dgcnn_weight_estimator.py exited with code {proc.returncode}")
suffix = f"_top{top_k}_by_length" if top_by_length else f"_top{top_k}"
written = json_base.parent / f"{json_base.stem}{suffix}{json_base.suffix}"
if not written.is_file():
raise FileNotFoundError(f"Expected DGCNN output JSON at {written}")
data = json.loads(written.read_text(encoding="utf-8"))
return written, data
def _merge_weight_prediction_json(
*,
svo_path: Path,
svo_name: str,
out_dir: Path,
cloud_dir: Path,
dgcnn_json_path: Path,
dgcnn_data: Dict[str, Any],
) -> Dict[str, Any]:
summary = dgcnn_data.get("summary") or {}
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
return {
"svo": str(svo_path),
"svo_name": svo_name,
"output_dir": str(out_dir),
"cloud_dir": str(cloud_dir),
"num_ply_on_disk": len(_collect_existing_clouds(cloud_dir)),
"dgcnn_output_json": str(dgcnn_json_path),
"skipped_weight": skipped,
"skip_reason": summary.get("skip_reason"),
"num_clouds_used": summary.get("num_files_used_for_avg"),
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
"per_cloud": dgcnn_data.get("per_file") or [],
"comparison": dgcnn_data.get("comparison"),
}
def _merge_weight_prediction_json_combined(
*,
svo_paths: List[Path],
output_base: Path,
ply_list_path: Path,
dgcnn_json_path: Path,
dgcnn_data: Dict[str, Any],
) -> Dict[str, Any]:
summary = dgcnn_data.get("summary") or {}
skipped = bool(summary.get("skipped"))
avg_kg = summary.get("avg_predicted_weight_kg")
avg_g = summary.get("avg_predicted_weight_g")
cloud_dirs = [output_base / p.stem / "cloud" for p in svo_paths]
n_ply = sum(len(_collect_existing_clouds(d)) for d in cloud_dirs)
return {
"combined": True,
"svo_paths": [str(p) for p in svo_paths],
"svo_names": [p.stem for p in svo_paths],
"output_dir": str(output_base),
"cloud_dirs": [str(d) for d in cloud_dirs],
"ply_list_file": str(ply_list_path),
"num_ply_on_disk": n_ply,
"dgcnn_output_json": str(dgcnn_json_path),
"skipped_weight": skipped,
"skip_reason": summary.get("skip_reason"),
"num_clouds_used": summary.get("num_files_used_for_avg"),
"num_ply_predicted": summary.get("num_files_predicted"),
"avg_predicted_weight_kg": None if skipped else avg_kg,
"avg_predicted_weight_g": None if skipped else avg_g,
"dgcnn_meta": dgcnn_data.get("meta"),
"dgcnn_summary": summary,
"weight_summary": summary,
"per_cloud": dgcnn_data.get("per_file") or [],
"comparison": dgcnn_data.get("comparison"),
}
def _sanitize_for_json(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: _sanitize_for_json(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_sanitize_for_json(v) for v in obj]
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return None
return obj
def run_weight_prediction_for_svo(
svo_path: Path,
output_base: Path,
weight_checkpoint: Path,
weight_device: torch.device,
num_points_model: int,
weight_top_k: int,
weight_top_by_length: bool,
weight_length_switch_mm: float,
weight_topk_length: Optional[int],
weight_remove_outliers: bool,
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
force_dgcnn_subprocess: bool = False,
) -> Dict[str, Any]:
svo_path = svo_path.expanduser().resolve()
if not svo_path.exists():
raise FileNotFoundError(f"svo not found: {svo_path}")
svo_name = svo_path.stem
out_dir = output_base / svo_name
cloud_dir = out_dir / "cloud"
plys = _collect_existing_clouds(cloud_dir)
if not plys:
raise RuntimeError(
f"No .ply files in {cloud_dir}. Run without --reuse-existing-clouds or run "
f"fish_video_weight_evaluation.py first."
)
fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json"
if not force_dgcnn_subprocess and fish_wj.is_file():
try:
dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8"))
if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"):
print(f"Using existing DGCNN results from fish_video: {fish_wj}")
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=fish_wj,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
except Exception as e:
print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=cloud_dir,
ply_list_file=None,
checkpoint=weight_checkpoint,
device=weight_device,
out_dir=out_dir,
num_points=num_points_model,
xyz_scale=weight_xyz_scale,
top_k=weight_top_k,
top_by_length=weight_top_by_length,
length_switch_mm=weight_length_switch_mm,
topk_length=weight_topk_length,
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
)
result = _merge_weight_prediction_json(
svo_path=svo_path,
svo_name=svo_name,
out_dir=out_dir,
cloud_dir=cloud_dir,
dgcnn_json_path=dgcnn_path,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
def run_weight_prediction_combined_svos(
svo_paths: List[Path],
output_base: Path,
weight_checkpoint: Path,
weight_device: torch.device,
num_points_model: int,
weight_top_k: int,
weight_top_by_length: bool,
weight_length_switch_mm: float,
weight_topk_length: Optional[int],
weight_remove_outliers: bool,
weight_outlier_method: str,
weight_xyz_scale: float,
weight_labels_json: Optional[str],
) -> Dict[str, Any]:
"""One DGCNN run over all ``<output_base>/<svo_stem>/cloud/*.ply`` (top-K / by-length applies to the union)."""
output_base = output_base.expanduser().resolve()
svo_paths = [p.expanduser().resolve() for p in svo_paths]
all_plys: List[Path] = []
for svo in svo_paths:
all_plys.extend(_collect_existing_clouds(output_base / svo.stem / "cloud"))
if not all_plys:
raise RuntimeError(
f"No .ply in any cloud folder under {output_base} for SVOs: {[p.stem for p in svo_paths]}"
)
out_dir = output_base
list_path = out_dir / "_combined_ply_list.txt"
list_path.write_text("\n".join(str(p.resolve()) for p in sorted(all_plys)) + "\n", encoding="utf-8")
dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess(
cloud_dir=None,
ply_list_file=list_path,
checkpoint=weight_checkpoint,
device=weight_device,
out_dir=out_dir,
num_points=num_points_model,
xyz_scale=weight_xyz_scale,
top_k=weight_top_k,
top_by_length=weight_top_by_length,
length_switch_mm=weight_length_switch_mm,
topk_length=weight_topk_length,
remove_outliers=weight_remove_outliers,
outlier_method=weight_outlier_method,
labels_json=weight_labels_json,
)
result = _merge_weight_prediction_json_combined(
svo_paths=svo_paths,
output_base=output_base,
ply_list_path=list_path,
dgcnn_json_path=dgcnn_path,
dgcnn_data=dgcnn_data,
)
(out_dir / "weight_prediction.json").write_text(
json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False),
encoding="utf-8",
)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="SVO2 → fish_video_weight_evaluation.py → test_dgcnn_weight_estimator.py"
)
src = parser.add_mutually_exclusive_group(required=True)
src.add_argument("--svo", type=str, default="", help="Path to a single .svo2 file")
src.add_argument("--batch-svo-folder", type=str, default="", help="Folder containing *.svo2 files")
parser.add_argument(
"--save-output",
type=str,
default="output_weight_estimator",
help="Base output directory (same as fish_video --save-output parent)",
)
# fish_video_weight_evaluation (aligned with run_fish_evaluation_simple.sh)
parser.add_argument(
"--yolo-model",
type=str,
default="/home/ubuntu/projects/FishMeasure/runs/train/fish_detection_20251127_104658/weights/best.pt",
)
parser.add_argument("--conf", type=float, default=0.5, help="YOLO confidence (simple.sh uses 0.5)")
parser.add_argument("--imgsz", type=int, default=640)
parser.add_argument(
"--sam-device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="SAM device for fish_video; also used as --device for test_dgcnn (cuda|cpu)",
)
parser.add_argument("--max-frames", type=int, default=0, help="0 = all frames")
parser.add_argument(
"--frame-stride",
type=int,
default=1,
metavar="N",
help="Passed to fish_video: only YOLO/SAM/PLY every N-th frame (default: 1)",
)
parser.add_argument("--filter-pointcloud", action="store_true")
parser.add_argument("--use-density-filter", action="store_true")
parser.add_argument("--use-clustering-filter", action="store_true")
parser.add_argument("--use-pointcloud-classifier", action="store_true")
parser.add_argument("--pointcloud-classifier", type=str, default=None)
parser.add_argument("--pointcloud-classifier-threshold", type=float, default=0.7)
parser.add_argument("--use-flatness-filter", action="store_true")
parser.add_argument("--flatness-threshold", type=float, default=70.0)
# test_dgcnn_weight_estimator
parser.add_argument(
"--weight-checkpoint",
type=str,
default=str(REPO_ROOT / "weight_estimator/runs/dgcnn_20260312_171043/best.pt"),
)
parser.add_argument("--num-points-model", type=int, default=768)
parser.add_argument("--weight-top-k", type=int, default=5)
parser.add_argument(
"--weight-top-by-length",
action=argparse.BooleanOptionalAction,
default=True,
help="Length-first top-K (default: on). Use --no-weight-top-by-length for weight-only order.",
)
parser.add_argument(
"--weight-length-switch-mm",
type=float,
default=319.0,
help="If mean top-K length (mm) exceeds this, use top-K by predicted weight (default: 320).",
)
parser.add_argument("--weight-topk-length", type=int, default=None)
parser.add_argument("--weight-remove-outliers", action="store_true")
parser.add_argument(
"--weight-outlier-method",
type=str,
default="iqr",
choices=["iqr", "zscore"],
)
parser.add_argument("--weight-xyz-scale", type=float, default=0.001)
parser.add_argument("--weight-labels-json", type=str, default=None)
parser.add_argument(
"--reuse-existing-clouds",
action="store_true",
default=True,
help="Skip fish_video if <save-output>/<svo_stem>/cloud/*.ply already exists",
)
parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds")
parser.add_argument(
"--fish-video-weight-overlay",
action="store_true",
help="Run fish_video with DGCNN + preview video overlay (fish weight g, top-5, per-window avg). "
"Avoids a duplicate test_dgcnn pass when weight_estimation_results.json is present.",
)
parser.add_argument(
"--minute-interval-sec",
type=float,
default=60.0,
help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).",
)
parser.add_argument(
"--force-dgcnn-subprocess",
action="store_true",
help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.",
)
args = parser.parse_args()
if args.frame_stride < 1:
raise SystemExit("--frame-stride must be >= 1")
output_base = Path(args.save_output).expanduser().resolve()
output_base.mkdir(parents=True, exist_ok=True)
weight_ckpt = Path(args.weight_checkpoint).expanduser().resolve()
if not weight_ckpt.is_file():
raise SystemExit(f"weight checkpoint not found: {weight_ckpt}")
if args.svo:
svo_paths = [Path(args.svo).expanduser().resolve()]
batch_folder: Optional[Path] = None
else:
batch_folder = Path(args.batch_svo_folder).expanduser().resolve()
svo_paths = sorted(batch_folder.glob("*.svo2"))
need_fish = True
if svo_paths and args.reuse_existing_clouds:
need_fish = not all(has_cached_clouds(output_base, p) for p in svo_paths)
if need_fish:
if args.svo:
_run_fish_video_evaluation_subprocess(args, batch_folder=None)
else:
if not batch_folder or not batch_folder.is_dir():
raise SystemExit(f"batch folder not found: {args.batch_svo_folder}")
if not svo_paths:
raise SystemExit(f"No .svo2 files in: {batch_folder}")
_run_fish_video_evaluation_subprocess(args, batch_folder=batch_folder)
else:
print("All targets already have cloud/*.ply — skipping fish_video_weight_evaluation.py.")
device = torch.device(args.sam_device)
results: List[Dict[str, Any]] = []
if len(svo_paths) > 1:
names = ", ".join(p.name for p in svo_paths)
print(f"\n=== Weight prediction: combined ({len(svo_paths)} SVOs) → top-{args.weight_top_k} over all PLYs ===")
print(f" {names}")
try:
results.append(
run_weight_prediction_combined_svos(
svo_paths=svo_paths,
output_base=output_base,
weight_checkpoint=weight_ckpt,
weight_device=device,
num_points_model=args.num_points_model,
weight_top_k=args.weight_top_k,
weight_top_by_length=args.weight_top_by_length,
weight_length_switch_mm=args.weight_length_switch_mm,
weight_topk_length=args.weight_topk_length,
weight_remove_outliers=args.weight_remove_outliers,
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
)
)
except Exception as e:
results.append(
{"combined": True, "svo_paths": [str(p) for p in svo_paths], "error": str(e)}
)
else:
for svo in svo_paths:
label = svo.name
print(f"\n=== Weight prediction: {label} ===")
try:
results.append(
run_weight_prediction_for_svo(
svo_path=svo,
output_base=output_base,
weight_checkpoint=weight_ckpt,
weight_device=device,
num_points_model=args.num_points_model,
weight_top_k=args.weight_top_k,
weight_top_by_length=args.weight_top_by_length,
weight_length_switch_mm=args.weight_length_switch_mm,
weight_topk_length=args.weight_topk_length,
weight_remove_outliers=args.weight_remove_outliers,
weight_outlier_method=args.weight_outlier_method,
weight_xyz_scale=args.weight_xyz_scale,
weight_labels_json=args.weight_labels_json,
force_dgcnn_subprocess=args.force_dgcnn_subprocess,
)
)
except Exception as e:
results.append({"svo": str(svo), "error": str(e)})
summary_path = output_base / "weight_predictions_summary.json"
summary_path.write_text(
json.dumps(_sanitize_for_json(results), indent=2, ensure_ascii=False),
encoding="utf-8",
)
print(f"\nSaved summary: {summary_path}")
if len(results) == 1 and "error" not in results[0]:
r0 = results[0]
avg_g = r0.get("avg_predicted_weight_g")
if avg_g is not None:
print(f"Final predicted weight (test_dgcnn): {avg_g:.2f} g")
else:
print("Final predicted weight: N/A (skipped or no valid point clouds)")
if __name__ == "__main__":
main()