Files
FishServer/fish_api/app/services/action.py
2026-04-16 14:53:01 +08:00

180 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Tuple
from app.logging_config import format_json_pretty
from app.services.video_slice import get_video_duration, slice_video
from app.settings import Settings
from app.state import HealthSnapshot
from app.subprocess_run import run_subprocess_with_log
from loguru import logger
# 视频切片配置
DEFAULT_SLICE_DURATION = 10.0 # 每10秒切一个片段
DEFAULT_MIN_DURATION_FOR_SLICE = 15.0 # 视频超过15秒才切片
BEHAVIOR_EN_TO_ZH = {
"feeding": "吃饵",
"normal": "正常游行",
"scared": "惊吓",
}
def _py_exe(settings: Settings) -> str:
return settings.python_fish_action or sys.executable
def behavior_to_health(behavior_en: str) -> str:
if behavior_en == "scared":
return "不健康"
return "健康"
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
script = settings.fish_action_root / "predict_video_x3d_3class.py"
if not script.is_file():
raise FileNotFoundError(f"Missing FishAction script: {script}")
mp4_path = mp4_path.resolve()
path_prefix = str(mp4_path.parent)
rel_name = mp4_path.name
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, encoding="utf-8"
) as tmp:
tmp.write(f"{rel_name} 0\n")
csv_path = tmp.name
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
os.close(out_fd)
try:
#: 使用脚本默认参数clips_per_video=10, batch_size=8, num_workers=4
#: 如需调整,请修改 FishAction/predict_video_x3d_3class.py 中的默认值
cmd = [
_py_exe(settings),
str(script),
"--checkpoint",
settings.action_checkpoint,
"--csv",
csv_path,
"--path_prefix",
path_prefix,
"--output_json",
out_json,
"--log_interval",
"0",
]
proc = run_subprocess_with_log(
cmd,
cwd=str(settings.fish_action_root),
env=os.environ.copy(),
log_name="FishAction",
stream_to_logger=False,
)
if proc.returncode != 0:
err = proc.stdout or ""
raise RuntimeError(
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
)
if not Path(out_json).is_file():
raise RuntimeError(f"No output_json written: {out_json}")
with open(out_json, encoding="utf-8") as f:
rows = json.load(f)
if not rows:
raise RuntimeError("Empty prediction JSON")
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
logger.debug(
"[FishAction] prediction row:\n{}",
format_json_pretty(rows[0]),
)
if pred_en not in BEHAVIOR_EN_TO_ZH:
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
return pred_en
finally:
Path(csv_path).unlink(missing_ok=True)
Path(out_json).unlink(missing_ok=True)
def prepare_action_slices(
mp4_path: Path, settings: Settings
) -> Tuple[List[Path], float]:
"""Check video duration and create slices if needed.
Returns ``(slice_files, duration)``. If the video is short enough,
returns ``([mp4_path], duration)`` without slicing.
"""
logger.info("[FishAction] start mp4={}", mp4_path.resolve())
duration = get_video_duration(mp4_path)
if duration > DEFAULT_MIN_DURATION_FOR_SLICE:
logger.info(
"[FishAction] video duration {}s > {}s, slicing into {}s segments",
duration,
DEFAULT_MIN_DURATION_FOR_SLICE,
DEFAULT_SLICE_DURATION,
)
slice_files, _slice_dir = slice_video(mp4_path, DEFAULT_SLICE_DURATION)
if len(slice_files) > 1:
logger.info(
"[FishAction] processing {} slices for {}",
len(slice_files),
mp4_path.name,
)
return slice_files, duration
return [mp4_path], duration
def run_single_slice_inference(
slice_file: Path,
slice_index: int,
total_slices: int,
duration: float,
mp4_name: str,
settings: Settings,
) -> HealthSnapshot:
"""Run FishAction inference on a single video file / slice.
Returns a ``HealthSnapshot`` (may contain error info if inference fails).
"""
start_time = slice_index * DEFAULT_SLICE_DURATION
end_time = min(start_time + DEFAULT_SLICE_DURATION, duration)
try:
pred_en = run_action_subprocess(slice_file, settings)
zh = BEHAVIOR_EN_TO_ZH[pred_en]
health = behavior_to_health(pred_en)
snap = HealthSnapshot(
behavior_result=zh,
health_result=health,
updated_at=datetime.now(timezone.utc),
raw_class_en=pred_en,
)
if total_slices > 1:
logger.info(
"[FishAction] slice {} ({}s-{}s): pred={} behavior={} health={}",
slice_index, start_time, end_time, pred_en, zh, health,
)
else:
logger.info(
"[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}",
mp4_name, pred_en, zh, health,
)
return snap
except Exception as e:
logger.error("[FishAction] failed to process slice {}: {}", slice_index, e)
return HealthSnapshot(
behavior_result="处理失败",
health_result="未知",
updated_at=datetime.now(timezone.utc),
raw_class_en="error",
error=str(e),
)