121 lines
3.5 KiB
Python
121 lines
3.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
from app.logging_config import format_json_pretty
|
|
from app.settings import Settings
|
|
from app.state import HealthSnapshot
|
|
from app.subprocess_run import run_subprocess_with_log
|
|
from loguru import logger
|
|
|
|
BEHAVIOR_EN_TO_ZH = {
|
|
"feeding": "吃饵",
|
|
"normal": "正常游行",
|
|
"scared": "惊吓",
|
|
}
|
|
|
|
|
|
def _py_exe(settings: Settings) -> str:
|
|
return settings.python_fish_action or sys.executable
|
|
|
|
|
|
def behavior_to_health(behavior_en: str) -> str:
|
|
if behavior_en == "scared":
|
|
return "不健康"
|
|
return "健康"
|
|
|
|
|
|
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
|
|
script = settings.fish_action_root / "predict_video_x3d_3class.py"
|
|
if not script.is_file():
|
|
raise FileNotFoundError(f"Missing FishAction script: {script}")
|
|
|
|
mp4_path = mp4_path.resolve()
|
|
path_prefix = str(mp4_path.parent)
|
|
rel_name = mp4_path.name
|
|
|
|
with tempfile.NamedTemporaryFile(
|
|
mode="w", suffix=".csv", delete=False, encoding="utf-8"
|
|
) as tmp:
|
|
tmp.write(f"{rel_name} 0\n")
|
|
csv_path = tmp.name
|
|
|
|
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
|
|
os.close(out_fd)
|
|
|
|
try:
|
|
cmd = [
|
|
_py_exe(settings),
|
|
str(script),
|
|
"--checkpoint",
|
|
settings.action_checkpoint,
|
|
"--csv",
|
|
csv_path,
|
|
"--path_prefix",
|
|
path_prefix,
|
|
"--clips_per_video",
|
|
str(settings.action_clips_per_video),
|
|
"--batch_size",
|
|
str(settings.action_batch_size),
|
|
"--num_workers",
|
|
str(settings.action_num_workers),
|
|
"--output_json",
|
|
out_json,
|
|
"--log_interval",
|
|
"0",
|
|
]
|
|
proc = run_subprocess_with_log(
|
|
cmd,
|
|
cwd=str(settings.fish_action_root),
|
|
env=os.environ.copy(),
|
|
log_name="FishAction",
|
|
stream_to_logger=False,
|
|
)
|
|
if proc.returncode != 0:
|
|
err = proc.stdout or ""
|
|
raise RuntimeError(
|
|
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
|
|
)
|
|
if not Path(out_json).is_file():
|
|
raise RuntimeError(f"No output_json written: {out_json}")
|
|
with open(out_json, encoding="utf-8") as f:
|
|
rows = json.load(f)
|
|
if not rows:
|
|
raise RuntimeError("Empty prediction JSON")
|
|
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
|
|
logger.debug(
|
|
"[FishAction] prediction row:\n{}",
|
|
format_json_pretty(rows[0]),
|
|
)
|
|
if pred_en not in BEHAVIOR_EN_TO_ZH:
|
|
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
|
|
return pred_en
|
|
finally:
|
|
Path(csv_path).unlink(missing_ok=True)
|
|
Path(out_json).unlink(missing_ok=True)
|
|
|
|
|
|
def run_full_action(mp4_path: Path, settings: Settings) -> HealthSnapshot:
|
|
logger.info("[FishAction] start mp4={}", mp4_path.resolve())
|
|
pred_en = run_action_subprocess(mp4_path, settings)
|
|
zh = BEHAVIOR_EN_TO_ZH[pred_en]
|
|
health = behavior_to_health(pred_en)
|
|
logger.info(
|
|
"[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}",
|
|
mp4_path.name,
|
|
pred_en,
|
|
zh,
|
|
health,
|
|
)
|
|
return HealthSnapshot(
|
|
behavior_result=zh,
|
|
health_result=health,
|
|
updated_at=datetime.now(timezone.utc),
|
|
raw_class_en=pred_en,
|
|
)
|