180 lines
5.7 KiB
Python
180 lines
5.7 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
import tempfile
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import List, Tuple
|
||
|
||
from app.logging_config import format_json_pretty
|
||
from app.services.video_slice import get_video_duration, slice_video
|
||
from app.settings import Settings
|
||
from app.state import HealthSnapshot
|
||
from app.subprocess_run import run_subprocess_with_log
|
||
from loguru import logger
|
||
|
||
# 视频切片配置
|
||
DEFAULT_SLICE_DURATION = 10.0 # 每10秒切一个片段
|
||
DEFAULT_MIN_DURATION_FOR_SLICE = 15.0 # 视频超过15秒才切片
|
||
|
||
BEHAVIOR_EN_TO_ZH = {
|
||
"feeding": "吃饵",
|
||
"normal": "正常游行",
|
||
"scared": "惊吓",
|
||
}
|
||
|
||
|
||
def _py_exe(settings: Settings) -> str:
|
||
return settings.python_fish_action or sys.executable
|
||
|
||
|
||
def behavior_to_health(behavior_en: str) -> str:
|
||
if behavior_en == "scared":
|
||
return "不健康"
|
||
return "健康"
|
||
|
||
|
||
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
|
||
script = settings.fish_action_root / "predict_video_x3d_3class.py"
|
||
if not script.is_file():
|
||
raise FileNotFoundError(f"Missing FishAction script: {script}")
|
||
|
||
mp4_path = mp4_path.resolve()
|
||
path_prefix = str(mp4_path.parent)
|
||
rel_name = mp4_path.name
|
||
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".csv", delete=False, encoding="utf-8"
|
||
) as tmp:
|
||
tmp.write(f"{rel_name} 0\n")
|
||
csv_path = tmp.name
|
||
|
||
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
|
||
os.close(out_fd)
|
||
|
||
try:
|
||
#: 使用脚本默认参数(clips_per_video=10, batch_size=8, num_workers=4)
|
||
#: 如需调整,请修改 FishAction/predict_video_x3d_3class.py 中的默认值
|
||
cmd = [
|
||
_py_exe(settings),
|
||
str(script),
|
||
"--checkpoint",
|
||
settings.action_checkpoint,
|
||
"--csv",
|
||
csv_path,
|
||
"--path_prefix",
|
||
path_prefix,
|
||
"--output_json",
|
||
out_json,
|
||
"--log_interval",
|
||
"0",
|
||
]
|
||
proc = run_subprocess_with_log(
|
||
cmd,
|
||
cwd=str(settings.fish_action_root),
|
||
env=os.environ.copy(),
|
||
log_name="FishAction",
|
||
stream_to_logger=False,
|
||
)
|
||
if proc.returncode != 0:
|
||
err = proc.stdout or ""
|
||
raise RuntimeError(
|
||
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
|
||
)
|
||
if not Path(out_json).is_file():
|
||
raise RuntimeError(f"No output_json written: {out_json}")
|
||
with open(out_json, encoding="utf-8") as f:
|
||
rows = json.load(f)
|
||
if not rows:
|
||
raise RuntimeError("Empty prediction JSON")
|
||
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
|
||
logger.debug(
|
||
"[FishAction] prediction row:\n{}",
|
||
format_json_pretty(rows[0]),
|
||
)
|
||
if pred_en not in BEHAVIOR_EN_TO_ZH:
|
||
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
|
||
return pred_en
|
||
finally:
|
||
Path(csv_path).unlink(missing_ok=True)
|
||
Path(out_json).unlink(missing_ok=True)
|
||
|
||
|
||
def prepare_action_slices(
|
||
mp4_path: Path, settings: Settings
|
||
) -> Tuple[List[Path], float]:
|
||
"""Check video duration and create slices if needed.
|
||
|
||
Returns ``(slice_files, duration)``. If the video is short enough,
|
||
returns ``([mp4_path], duration)`` without slicing.
|
||
"""
|
||
logger.info("[FishAction] start mp4={}", mp4_path.resolve())
|
||
duration = get_video_duration(mp4_path)
|
||
|
||
if duration > DEFAULT_MIN_DURATION_FOR_SLICE:
|
||
logger.info(
|
||
"[FishAction] video duration {}s > {}s, slicing into {}s segments",
|
||
duration,
|
||
DEFAULT_MIN_DURATION_FOR_SLICE,
|
||
DEFAULT_SLICE_DURATION,
|
||
)
|
||
slice_files, _slice_dir = slice_video(mp4_path, DEFAULT_SLICE_DURATION)
|
||
if len(slice_files) > 1:
|
||
logger.info(
|
||
"[FishAction] processing {} slices for {}",
|
||
len(slice_files),
|
||
mp4_path.name,
|
||
)
|
||
return slice_files, duration
|
||
|
||
return [mp4_path], duration
|
||
|
||
|
||
def run_single_slice_inference(
|
||
slice_file: Path,
|
||
slice_index: int,
|
||
total_slices: int,
|
||
duration: float,
|
||
mp4_name: str,
|
||
settings: Settings,
|
||
) -> HealthSnapshot:
|
||
"""Run FishAction inference on a single video file / slice.
|
||
|
||
Returns a ``HealthSnapshot`` (may contain error info if inference fails).
|
||
"""
|
||
start_time = slice_index * DEFAULT_SLICE_DURATION
|
||
end_time = min(start_time + DEFAULT_SLICE_DURATION, duration)
|
||
|
||
try:
|
||
pred_en = run_action_subprocess(slice_file, settings)
|
||
zh = BEHAVIOR_EN_TO_ZH[pred_en]
|
||
health = behavior_to_health(pred_en)
|
||
snap = HealthSnapshot(
|
||
behavior_result=zh,
|
||
health_result=health,
|
||
updated_at=datetime.now(timezone.utc),
|
||
raw_class_en=pred_en,
|
||
)
|
||
if total_slices > 1:
|
||
logger.info(
|
||
"[FishAction] slice {} ({}s-{}s): pred={} behavior={} health={}",
|
||
slice_index, start_time, end_time, pred_en, zh, health,
|
||
)
|
||
else:
|
||
logger.info(
|
||
"[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}",
|
||
mp4_name, pred_en, zh, health,
|
||
)
|
||
return snap
|
||
except Exception as e:
|
||
logger.error("[FishAction] failed to process slice {}: {}", slice_index, e)
|
||
return HealthSnapshot(
|
||
behavior_result="处理失败",
|
||
health_result="未知",
|
||
updated_at=datetime.now(timezone.utc),
|
||
raw_class_en="error",
|
||
error=str(e),
|
||
)
|