Files
FishServer/fish_api/app/services/action.py

211 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import json
import os
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import List
from app.logging_config import format_json_pretty
from app.services.video_slice import get_video_duration, slice_video
from app.settings import Settings
from app.state import HealthSnapshot
from app.subprocess_run import run_subprocess_with_log
from loguru import logger
# 视频切片配置
DEFAULT_SLICE_DURATION = 10.0 # 每10秒切一个片段
DEFAULT_MIN_DURATION_FOR_SLICE = 15.0 # 视频超过15秒才切片
BEHAVIOR_EN_TO_ZH = {
"feeding": "吃饵",
"normal": "正常游行",
"scared": "惊吓",
}
def _py_exe(settings: Settings) -> str:
return settings.python_fish_action or sys.executable
def behavior_to_health(behavior_en: str) -> str:
if behavior_en == "scared":
return "不健康"
return "健康"
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
script = settings.fish_action_root / "predict_video_x3d_3class.py"
if not script.is_file():
raise FileNotFoundError(f"Missing FishAction script: {script}")
mp4_path = mp4_path.resolve()
path_prefix = str(mp4_path.parent)
rel_name = mp4_path.name
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, encoding="utf-8"
) as tmp:
tmp.write(f"{rel_name} 0\n")
csv_path = tmp.name
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
os.close(out_fd)
try:
#: 使用脚本默认参数clips_per_video=10, batch_size=8, num_workers=4
#: 如需调整,请修改 FishAction/predict_video_x3d_3class.py 中的默认值
cmd = [
_py_exe(settings),
str(script),
"--checkpoint",
settings.action_checkpoint,
"--csv",
csv_path,
"--path_prefix",
path_prefix,
"--output_json",
out_json,
"--log_interval",
"0",
]
proc = run_subprocess_with_log(
cmd,
cwd=str(settings.fish_action_root),
env=os.environ.copy(),
log_name="FishAction",
stream_to_logger=False,
)
if proc.returncode != 0:
err = proc.stdout or ""
raise RuntimeError(
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
)
if not Path(out_json).is_file():
raise RuntimeError(f"No output_json written: {out_json}")
with open(out_json, encoding="utf-8") as f:
rows = json.load(f)
if not rows:
raise RuntimeError("Empty prediction JSON")
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
logger.debug(
"[FishAction] prediction row:\n{}",
format_json_pretty(rows[0]),
)
if pred_en not in BEHAVIOR_EN_TO_ZH:
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
return pred_en
finally:
Path(csv_path).unlink(missing_ok=True)
Path(out_json).unlink(missing_ok=True)
def run_full_action(mp4_path: Path, settings: Settings) -> tuple[HealthSnapshot, list[HealthSnapshot]]:
"""运行 FishAction 健康检测。如果视频较长,会自动切片后分别检测。
每个切片被视为独立的视频,返回所有切片的结果列表。
Args:
mp4_path: 输入视频路径
settings: 应用配置
Returns:
tuple[HealthSnapshot, list[HealthSnapshot]]: (第一个切片/完整视频的快照, 所有切片快照列表)
- 如果视频被切片:返回 (第一个切片, 所有切片列表)
- 如果视频未被切片:返回 (完整视频快照, [完整视频快照])
"""
logger.info("[FishAction] start mp4={}", mp4_path.resolve())
# 检查视频时长
duration = get_video_duration(mp4_path)
should_slice = duration > DEFAULT_MIN_DURATION_FOR_SLICE
if should_slice:
# 视频较长,需要切片处理
logger.info(
"[FishAction] video duration {}s > {}s, slicing into {}s segments",
duration,
DEFAULT_MIN_DURATION_FOR_SLICE,
DEFAULT_SLICE_DURATION,
)
slice_files, slice_dir = slice_video(mp4_path, DEFAULT_SLICE_DURATION)
if len(slice_files) > 1:
logger.info(
"[FishAction] processing {} slices for {}",
len(slice_files),
mp4_path.name,
)
# 处理每个切片
all_snaps: list[HealthSnapshot] = []
for i, slice_file in enumerate(slice_files):
start_time = i * DEFAULT_SLICE_DURATION
end_time = min(start_time + DEFAULT_SLICE_DURATION, duration)
try:
pred_en = run_action_subprocess(slice_file, settings)
zh = BEHAVIOR_EN_TO_ZH[pred_en]
health = behavior_to_health(pred_en)
snap = HealthSnapshot(
behavior_result=zh,
health_result=health,
updated_at=datetime.now(timezone.utc),
raw_class_en=pred_en,
)
logger.info(
"[FishAction] slice {} ({}s-{}s): pred={} behavior={} health={}",
i, start_time, end_time, pred_en, zh, health,
)
all_snaps.append(snap)
except Exception as e:
logger.error("[FishAction] failed to process slice {}: {}", i, e)
# 创建一个表示失败的快照
error_snap = HealthSnapshot(
behavior_result="处理失败",
health_result="未知",
updated_at=datetime.now(timezone.utc),
raw_class_en="error",
error=str(e),
)
all_snaps.append(error_snap)
logger.info(
"[FishAction] done mp4={} total_slices={}",
mp4_path.name,
len(slice_files),
)
# 返回第一个切片的结果和所有切片列表
first_snap = all_snaps[0] if all_snaps else HealthSnapshot(
behavior_result="",
health_result="",
updated_at=datetime.now(timezone.utc),
)
return first_snap, all_snaps
# 视频较短,直接处理(原有逻辑)
pred_en = run_action_subprocess(mp4_path, settings)
zh = BEHAVIOR_EN_TO_ZH[pred_en]
health = behavior_to_health(pred_en)
logger.info(
"[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}",
mp4_path.name,
pred_en,
zh,
health,
)
snap = HealthSnapshot(
behavior_result=zh,
health_result=health,
updated_at=datetime.now(timezone.utc),
raw_class_en=pred_en,
)
return snap, [snap]