Files
FishServer/fish_api/app/services/action.py
zaiun xu 5a0d7ba11b ?
2026-04-13 13:49:55 +08:00

121 lines
3.5 KiB
Python

from __future__ import annotations
import json
import os
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from app.logging_config import format_json_pretty
from app.settings import Settings
from app.state import HealthSnapshot
from app.subprocess_run import run_subprocess_with_log
from loguru import logger
BEHAVIOR_EN_TO_ZH = {
"feeding": "吃饵",
"normal": "正常游行",
"scared": "惊吓",
}
def _py_exe(settings: Settings) -> str:
return settings.python_fish_action or sys.executable
def behavior_to_health(behavior_en: str) -> str:
if behavior_en == "scared":
return "不健康"
return "健康"
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
script = settings.fish_action_root / "predict_video_x3d_3class.py"
if not script.is_file():
raise FileNotFoundError(f"Missing FishAction script: {script}")
mp4_path = mp4_path.resolve()
path_prefix = str(mp4_path.parent)
rel_name = mp4_path.name
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, encoding="utf-8"
) as tmp:
tmp.write(f"{rel_name} 0\n")
csv_path = tmp.name
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
os.close(out_fd)
try:
cmd = [
_py_exe(settings),
str(script),
"--checkpoint",
settings.action_checkpoint,
"--csv",
csv_path,
"--path_prefix",
path_prefix,
"--clips_per_video",
str(settings.action_clips_per_video),
"--batch_size",
str(settings.action_batch_size),
"--num_workers",
str(settings.action_num_workers),
"--output_json",
out_json,
"--log_interval",
"0",
]
proc = run_subprocess_with_log(
cmd,
cwd=str(settings.fish_action_root),
env=os.environ.copy(),
log_name="FishAction",
stream_to_logger=False,
)
if proc.returncode != 0:
err = proc.stdout or ""
raise RuntimeError(
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
)
if not Path(out_json).is_file():
raise RuntimeError(f"No output_json written: {out_json}")
with open(out_json, encoding="utf-8") as f:
rows = json.load(f)
if not rows:
raise RuntimeError("Empty prediction JSON")
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
logger.debug(
"[FishAction] prediction row:\n{}",
format_json_pretty(rows[0]),
)
if pred_en not in BEHAVIOR_EN_TO_ZH:
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
return pred_en
finally:
Path(csv_path).unlink(missing_ok=True)
Path(out_json).unlink(missing_ok=True)
def run_full_action(mp4_path: Path, settings: Settings) -> HealthSnapshot:
logger.info("[FishAction] start mp4={}", mp4_path.resolve())
pred_en = run_action_subprocess(mp4_path, settings)
zh = BEHAVIOR_EN_TO_ZH[pred_en]
health = behavior_to_health(pred_en)
logger.info(
"[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}",
mp4_path.name,
pred_en,
zh,
health,
)
return HealthSnapshot(
behavior_result=zh,
health_result=health,
updated_at=datetime.now(timezone.utc),
raw_class_en=pred_en,
)