Files
FishServer/fish_api/app/services/measure.py
2026-04-10 10:30:01 +08:00

370 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import math
import os
import re
import shutil
import subprocess
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from app.logging_config import format_json_pretty
from app.settings import Settings
from app.state import MeasureSnapshot
from app.subprocess_run import run_subprocess_with_log
from loguru import logger
def _py_exe(settings: Settings) -> str:
return settings.python_fish_measure or sys.executable
def _predict_weigth_from_svo2_extra_args(settings: Settings) -> List[str]:
"""Flags aligned with FishMeasure/predict_weigth_from_svo2.py CLI."""
out: List[str] = []
if settings.predict_filter_pointcloud:
out.append("--filter-pointcloud")
if settings.predict_use_density_filter:
out.append("--use-density-filter")
if settings.predict_use_clustering_filter:
out.append("--use-clustering-filter")
if (
settings.predict_use_pointcloud_classifier
and settings.predict_pointcloud_classifier
and Path(settings.predict_pointcloud_classifier).is_file()
):
out.extend(
[
"--pointcloud-classifier",
settings.predict_pointcloud_classifier,
"--use-pointcloud-classifier",
"--pointcloud-classifier-threshold",
str(settings.predict_pointcloud_classifier_threshold),
]
)
if settings.predict_use_flatness_filter:
out.append("--use-flatness-filter")
out.extend(["--flatness-threshold", str(settings.predict_flatness_threshold)])
out.extend(["--weight-top-k", str(settings.measure_weight_top_k)])
if settings.measure_weight_top_by_length:
out.append("--weight-top-by-length")
else:
out.append("--no-weight-top-by-length")
if settings.predict_fish_video_weight_overlay:
out.extend(
[
"--fish-video-weight-overlay",
"--minute-interval-sec",
str(settings.predict_minute_interval_sec),
]
)
if not settings.measure_reuse_existing_clouds:
out.append("--no-reuse-existing-clouds")
return out
def run_measure_subprocess(svo_path: Path, settings: Settings) -> None:
script = settings.fish_measure_root / "predict_weigth_from_svo2.py"
if not script.is_file():
raise FileNotFoundError(f"Missing FishMeasure script: {script}")
settings.measure_output_root.mkdir(parents=True, exist_ok=True)
cmd = [
_py_exe(settings),
str(script),
"--svo",
str(svo_path.resolve()),
"--save-output",
str(settings.measure_output_root.resolve()),
"--yolo-model",
settings.yolo_model,
"--weight-checkpoint",
settings.weight_checkpoint,
"--conf",
str(settings.predict_conf),
"--imgsz",
str(settings.predict_imgsz),
"--sam-device",
settings.sam_device,
"--max-frames",
str(settings.predict_max_frames),
"--frame-stride",
str(settings.predict_frame_stride),
]
cmd.extend(_predict_weigth_from_svo2_extra_args(settings))
proc = run_subprocess_with_log(
cmd,
cwd=str(settings.fish_measure_root),
env=os.environ.copy(),
log_name="FishMeasure",
)
if proc.returncode != 0:
err = proc.stdout or ""
raise RuntimeError(
f"predict_weigth_from_svo2.py failed ({proc.returncode}): {err[-4000:]}"
)
def _summary_entry_matches_svo(item: Dict[str, Any], svo_path: Path) -> bool:
stem = svo_path.stem
resolved = str(svo_path.resolve())
svo_key = item.get("svo")
if svo_key:
try:
if Path(str(svo_key)).resolve() == svo_path.resolve():
return True
except OSError:
pass
if str(svo_key) == resolved:
return True
if item.get("svo_name") == stem:
return True
return False
def _load_weight_json(svo_path: Path, settings: Settings) -> Dict[str, Any]:
"""读取 FishMeasure 合并结果。优先 per-SVO 的 weight_prediction.json否则从 weight_predictions_summary.json 取匹配项predict 脚本在权重步失败时仍 exit 0 只写 summary"""
stem = svo_path.stem
root = settings.measure_output_root
candidate = root / stem / "weight_prediction.json"
if candidate.is_file():
with open(candidate, encoding="utf-8") as f:
return json.load(f)
summary_path = root / "weight_predictions_summary.json"
if summary_path.is_file():
with open(summary_path, encoding="utf-8") as f:
summary_list: Any = json.load(f)
if isinstance(summary_list, list):
for item in reversed(summary_list):
if not isinstance(item, dict):
continue
if not _summary_entry_matches_svo(item, svo_path):
continue
err = item.get("error")
if err:
raise RuntimeError(
f"FishMeasure 权重步骤失败({svo_path.name}: {err}"
)
if item.get("per_cloud") or item.get("per_file") or item.get(
"dgcnn_summary"
):
return item
break
combined_path = root / "weight_prediction.json"
if combined_path.is_file():
with open(combined_path, encoding="utf-8") as f:
combined: Any = json.load(f)
if isinstance(combined, dict) and combined.get("combined"):
names = combined.get("svo_names") or []
if stem in names:
return combined
raise FileNotFoundError(
f"未找到测量结果 JSON{candidate}(且 summary 中无本条 SVO 的成功记录)"
)
_TID_RE = re.compile(r"_tid(\d+)")
def _parse_tid_from_ply_name(name: str) -> Optional[int]:
"""与 FishMeasure/fish_video_weight_evaluation._parse_tid_from_ply_name 一致。"""
m = _TID_RE.search(name)
return int(m.group(1)) if m else None
def _safe_media_prefix(stem: str) -> str:
s = re.sub(r"[^\w.\-]+", "_", stem, flags=re.UNICODE).strip("._") or "svo"
return s[:120]
def _result_from_weight_prediction(data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""按 track_id 聚合:体重取 max(predicted_weight_g),体长取达到 max 的那条 PLY 的 length_input (mm)。"""
items = data.get("per_cloud") or data.get("per_file") or []
if not isinstance(items, list):
return []
# tid -> (max_weight_g, length_mm at max weight)
best: Dict[int, Tuple[float, float]] = {}
for it in items:
if not isinstance(it, dict):
continue
ply = it.get("ply")
if not ply:
continue
tid = _parse_tid_from_ply_name(Path(str(ply)).name)
if tid is None:
continue
try:
wg = float(it.get("predicted_weight_g", float("nan")))
except (TypeError, ValueError):
continue
if not math.isfinite(wg):
continue
try:
ln = float(it.get("length_input", float("nan")))
except (TypeError, ValueError):
ln = float("nan")
if tid not in best or wg > best[tid][0]:
best[tid] = (wg, ln)
out: List[Dict[str, Any]] = []
for tid in sorted(best.keys()):
wg, ln = best[tid]
if not math.isfinite(ln):
continue
out.append({"id": tid, "weight": wg, "length": ln})
return out
def _find_preview_videos(output_dir: Path) -> Tuple[Optional[Path], Optional[Path]]:
previews = sorted(output_dir.rglob("*preview*.mp4"))
if len(previews) >= 2:
return previews[0], previews[1]
all_mp4 = sorted(output_dir.rglob("*.mp4"))
if len(all_mp4) >= 2:
return all_mp4[0], all_mp4[1]
if len(all_mp4) == 1:
return all_mp4[0], all_mp4[0]
if len(previews) == 1:
return previews[0], previews[0]
return None, None
def _split_sbs_video(src: Path, left_dst: Path, right_dst: Path) -> bool:
"""Split a side-by-side stereo video (W x H where W == 2*H_single) into left/right halves.
Returns True if split succeeded, False otherwise (caller should fall back to copy).
"""
probe = subprocess.run(
[
"ffprobe", "-v", "quiet", "-print_format", "json",
"-show_streams", str(src),
],
capture_output=True, text=True,
)
if probe.returncode != 0:
return False
import json as _json
try:
streams = _json.loads(probe.stdout).get("streams", [])
vstream = next((s for s in streams if s.get("codec_type") == "video"), None)
if vstream is None:
return False
w, h = int(vstream["width"]), int(vstream["height"])
except Exception:
return False
half_w = w // 2
if half_w < 1 or w < h:
return False
for crop, dst in [
(f"crop={half_w}:{h}:{half_w}:0", left_dst),
(f"crop={half_w}:{h}:0:0", right_dst),
]:
r = subprocess.run(
["ffmpeg", "-y", "-i", str(src), "-vf", crop, "-an", "-q:v", "5", str(dst)],
capture_output=True, text=True,
)
if r.returncode != 0:
return False
return True
def _publish_media(
left: Optional[Path],
right: Optional[Path],
settings: Settings,
file_prefix: str,
) -> Tuple[str, str]:
settings.media_root.mkdir(parents=True, exist_ok=True)
safe_p = _safe_media_prefix(file_prefix)
left_dst = settings.media_root / f"{safe_p}_left.mp4"
right_dst = settings.media_root / f"{safe_p}_right.mp4"
base = settings.public_base_url.rstrip("/")
if left is not None and left == right and left.is_file():
if _split_sbs_video(left, left_dst, right_dst):
return (
f"{base}/media/{left_dst.name}",
f"{base}/media/{right_dst.name}",
)
def publish(src: Optional[Path], dst: Path) -> str:
if src is None or not src.is_file():
return ""
shutil.copy2(src, dst)
return f"{base}/media/{dst.name}"
vl = publish(left, left_dst)
vr = publish(right, right_dst)
return vl, vr
def build_measure_snapshot(svo_path: Path, settings: Settings) -> MeasureSnapshot:
data = _load_weight_json(svo_path, settings)
summary = data.get("dgcnn_summary") or data.get("weight_summary") or {}
weight_g = summary.get("avg_predicted_weight_g")
length_mm = summary.get("avg_length_input_topk")
if weight_g is None:
weight_g = data.get("avg_predicted_weight_g")
if length_mm is None:
length_mm = summary.get("avg_length_input") or data.get("avg_length_input")
result: List[Dict[str, Any]] = []
if weight_g is not None and length_mm is not None:
try:
w = float(weight_g)
l = float(length_mm)
if math.isfinite(w) and math.isfinite(l):
result = [{"id": 1, "weight": w, "length": l}]
except (TypeError, ValueError):
pass
logger.info(
"[FishMeasure] parsed {}\navg_weight_g(top5)={} avg_length_mm(top5)={}\nresult:\n{}\ndgcnn_summary:\n{}",
svo_path.name,
weight_g,
length_mm,
format_json_pretty(result),
format_json_pretty(summary if summary else {}),
)
out_dir = Path(data.get("output_dir", settings.measure_output_root / svo_path.stem))
lv, rv = _find_preview_videos(out_dir)
prefix = (
f"{datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%S')}_{svo_path.stem}"
)
v_left, v_right = _publish_media(lv, rv, settings, prefix)
logger.info(
"[FishMeasure] media preview_paths={} {} | published_left={} published_right={}",
lv,
rv,
v_left or "(none)",
v_right or "(none)",
)
return MeasureSnapshot(
result=result,
video_left=v_left,
video_right=v_right,
updated_at=datetime.now(timezone.utc),
raw_prediction_path=str(
settings.measure_output_root / svo_path.stem / "weight_prediction.json"
),
)
def run_full_measure(svo_path: Path, settings: Settings) -> MeasureSnapshot:
logger.info("[FishMeasure] start svo={}", svo_path.resolve())
run_measure_subprocess(svo_path, settings)
snap = build_measure_snapshot(svo_path, settings)
logger.info("[FishMeasure] done svo={} result_len={}", svo_path.name, len(snap.result))
return snap