Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
106
fish_api/app/services/action.py
Normal file
106
fish_api/app/services/action.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from app.settings import Settings
|
||||
from app.state import HealthSnapshot
|
||||
|
||||
BEHAVIOR_EN_TO_ZH = {
|
||||
"feeding": "吃饵",
|
||||
"normal": "正常游行",
|
||||
"scared": "惊吓",
|
||||
}
|
||||
|
||||
|
||||
def _py_exe(settings: Settings) -> str:
|
||||
return settings.python_fish_action or sys.executable
|
||||
|
||||
|
||||
def behavior_to_health(behavior_en: str) -> str:
|
||||
if behavior_en == "scared":
|
||||
return "不健康"
|
||||
return "健康"
|
||||
|
||||
|
||||
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
|
||||
script = settings.fish_action_root / "predict_video_x3d_3class.py"
|
||||
if not script.is_file():
|
||||
raise FileNotFoundError(f"Missing FishAction script: {script}")
|
||||
|
||||
mp4_path = mp4_path.resolve()
|
||||
path_prefix = str(mp4_path.parent)
|
||||
rel_name = mp4_path.name
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".csv", delete=False, encoding="utf-8"
|
||||
) as tmp:
|
||||
tmp.write(f"{rel_name} 0\n")
|
||||
csv_path = tmp.name
|
||||
|
||||
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
|
||||
os.close(out_fd)
|
||||
|
||||
try:
|
||||
cmd = [
|
||||
_py_exe(settings),
|
||||
str(script),
|
||||
"--checkpoint",
|
||||
settings.action_checkpoint,
|
||||
"--csv",
|
||||
csv_path,
|
||||
"--path_prefix",
|
||||
path_prefix,
|
||||
"--clips_per_video",
|
||||
str(settings.action_clips_per_video),
|
||||
"--batch_size",
|
||||
str(settings.action_batch_size),
|
||||
"--num_workers",
|
||||
str(settings.action_num_workers),
|
||||
"--output_json",
|
||||
out_json,
|
||||
"--log_interval",
|
||||
"0",
|
||||
]
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(settings.fish_action_root),
|
||||
env=os.environ.copy(),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
err = (proc.stderr or "") + (proc.stdout or "")
|
||||
raise RuntimeError(
|
||||
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
|
||||
)
|
||||
if not Path(out_json).is_file():
|
||||
raise RuntimeError(f"No output_json written: {out_json}")
|
||||
with open(out_json, encoding="utf-8") as f:
|
||||
rows = json.load(f)
|
||||
if not rows:
|
||||
raise RuntimeError("Empty prediction JSON")
|
||||
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
|
||||
if pred_en not in BEHAVIOR_EN_TO_ZH:
|
||||
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
|
||||
return pred_en
|
||||
finally:
|
||||
Path(csv_path).unlink(missing_ok=True)
|
||||
Path(out_json).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def run_full_action(mp4_path: Path, settings: Settings) -> HealthSnapshot:
|
||||
pred_en = run_action_subprocess(mp4_path, settings)
|
||||
zh = BEHAVIOR_EN_TO_ZH[pred_en]
|
||||
health = behavior_to_health(pred_en)
|
||||
return HealthSnapshot(
|
||||
behavior_result=zh,
|
||||
health_result=health,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
raw_class_en=pred_en,
|
||||
)
|
||||
Reference in New Issue
Block a user