Files
FishServer/fish_api/app/services/action.py
zaiun xu c1aafc69bf 验收1
2026-04-13 17:13:02 +08:00

215 lines
7.1 KiB
Python

from __future__ import annotations
import json
import os
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import List
from app.logging_config import format_json_pretty
from app.services.video_slice import get_video_duration, slice_video
from app.settings import Settings
from app.state import HealthSnapshot
from app.subprocess_run import run_subprocess_with_log
from loguru import logger
# 视频切片配置
DEFAULT_SLICE_DURATION = 10.0 # 每10秒切一个片段
DEFAULT_MIN_DURATION_FOR_SLICE = 15.0 # 视频超过15秒才切片
BEHAVIOR_EN_TO_ZH = {
"feeding": "吃饵",
"normal": "正常游行",
"scared": "惊吓",
}
def _py_exe(settings: Settings) -> str:
return settings.python_fish_action or sys.executable
def behavior_to_health(behavior_en: str) -> str:
if behavior_en == "scared":
return "不健康"
return "健康"
def run_action_subprocess(mp4_path: Path, settings: Settings) -> str:
script = settings.fish_action_root / "predict_video_x3d_3class.py"
if not script.is_file():
raise FileNotFoundError(f"Missing FishAction script: {script}")
mp4_path = mp4_path.resolve()
path_prefix = str(mp4_path.parent)
rel_name = mp4_path.name
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, encoding="utf-8"
) as tmp:
tmp.write(f"{rel_name} 0\n")
csv_path = tmp.name
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
os.close(out_fd)
try:
cmd = [
_py_exe(settings),
str(script),
"--checkpoint",
settings.action_checkpoint,
"--csv",
csv_path,
"--path_prefix",
path_prefix,
"--clips_per_video",
str(settings.action_clips_per_video),
"--batch_size",
str(settings.action_batch_size),
"--num_workers",
str(settings.action_num_workers),
"--output_json",
out_json,
"--log_interval",
"0",
]
proc = run_subprocess_with_log(
cmd,
cwd=str(settings.fish_action_root),
env=os.environ.copy(),
log_name="FishAction",
stream_to_logger=False,
)
if proc.returncode != 0:
err = proc.stdout or ""
raise RuntimeError(
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
)
if not Path(out_json).is_file():
raise RuntimeError(f"No output_json written: {out_json}")
with open(out_json, encoding="utf-8") as f:
rows = json.load(f)
if not rows:
raise RuntimeError("Empty prediction JSON")
pred_en = str(rows[0].get("pred_3class", "")).strip().lower()
logger.debug(
"[FishAction] prediction row:\n{}",
format_json_pretty(rows[0]),
)
if pred_en not in BEHAVIOR_EN_TO_ZH:
raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}")
return pred_en
finally:
Path(csv_path).unlink(missing_ok=True)
Path(out_json).unlink(missing_ok=True)
def run_full_action(mp4_path: Path, settings: Settings) -> tuple[HealthSnapshot, list[HealthSnapshot]]:
"""运行 FishAction 健康检测。如果视频较长,会自动切片后分别检测。
每个切片被视为独立的视频,返回所有切片的结果列表。
Args:
mp4_path: 输入视频路径
settings: 应用配置
Returns:
tuple[HealthSnapshot, list[HealthSnapshot]]: (第一个切片/完整视频的快照, 所有切片快照列表)
- 如果视频被切片:返回 (第一个切片, 所有切片列表)
- 如果视频未被切片:返回 (完整视频快照, [完整视频快照])
"""
logger.info("[FishAction] start mp4={}", mp4_path.resolve())
# 检查视频时长
duration = get_video_duration(mp4_path)
should_slice = duration > DEFAULT_MIN_DURATION_FOR_SLICE
if should_slice:
# 视频较长,需要切片处理
logger.info(
"[FishAction] video duration {}s > {}s, slicing into {}s segments",
duration,
DEFAULT_MIN_DURATION_FOR_SLICE,
DEFAULT_SLICE_DURATION,
)
slice_files, slice_dir = slice_video(mp4_path, DEFAULT_SLICE_DURATION)
if len(slice_files) > 1:
logger.info(
"[FishAction] processing {} slices for {}",
len(slice_files),
mp4_path.name,
)
# 处理每个切片
all_snaps: list[HealthSnapshot] = []
for i, slice_file in enumerate(slice_files):
start_time = i * DEFAULT_SLICE_DURATION
end_time = min(start_time + DEFAULT_SLICE_DURATION, duration)
try:
pred_en = run_action_subprocess(slice_file, settings)
zh = BEHAVIOR_EN_TO_ZH[pred_en]
health = behavior_to_health(pred_en)
snap = HealthSnapshot(
behavior_result=zh,
health_result=health,
updated_at=datetime.now(timezone.utc),
raw_class_en=pred_en,
)
logger.info(
"[FishAction] slice {} ({}s-{}s): pred={} behavior={} health={}",
i, start_time, end_time, pred_en, zh, health,
)
all_snaps.append(snap)
except Exception as e:
logger.error("[FishAction] failed to process slice {}: {}", i, e)
# 创建一个表示失败的快照
error_snap = HealthSnapshot(
behavior_result="处理失败",
health_result="未知",
updated_at=datetime.now(timezone.utc),
raw_class_en="error",
error=str(e),
)
all_snaps.append(error_snap)
logger.info(
"[FishAction] done mp4={} total_slices={}",
mp4_path.name,
len(slice_files),
)
# 返回第一个切片的结果和所有切片列表
first_snap = all_snaps[0] if all_snaps else HealthSnapshot(
behavior_result="",
health_result="",
updated_at=datetime.now(timezone.utc),
)
return first_snap, all_snaps
# 视频较短,直接处理(原有逻辑)
pred_en = run_action_subprocess(mp4_path, settings)
zh = BEHAVIOR_EN_TO_ZH[pred_en]
health = behavior_to_health(pred_en)
logger.info(
"[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}",
mp4_path.name,
pred_en,
zh,
health,
)
snap = HealthSnapshot(
behavior_result=zh,
health_result=health,
updated_at=datetime.now(timezone.utc),
raw_class_en=pred_en,
)
return snap, [snap]