#!/usr/bin/env python3 """ 持续监控某个目录,对新出现的 MP4 做与 fish_api 相同流程的 X3D 推理。 - 仅标准库:轮询目录 + 连续多次文件大小不变视为写入完成。 - 已处理路径写入 state 文件,避免重复推理(可指定路径或关闭)。 若使用 **fish_api** 与 Pydantic 管理配置,请在 `fish_api/.env` 中设置 `ACTION_WATCH_DIR` (及可选的 `ACTION_WATCH_POLL_INTERVAL` 等),由网关进程或 `python -m app.action_watch_cli` 统一监控;无需再跑本脚本。 示例(独立离线使用): python watch_folder_predict.py \\ --watch-dir /data/incoming_mp4 \\ --checkpoint checkpoints/ptv_x3d_m/checkpoint_best.pt """ from __future__ import annotations import argparse import json import os import subprocess import sys import tempfile import time import traceback from pathlib import Path from typing import Any, Dict, Set def _load_state(path: Path) -> Set[str]: if not path.is_file(): return set() try: with open(path, encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): return set(str(x) for x in data) if isinstance(data, dict) and "processed" in data: return set(str(x) for x in data["processed"]) except (json.JSONDecodeError, OSError): pass return set() def _save_state(path: Path, processed: Set[str]) -> None: path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(path.suffix + ".tmp") with open(tmp, "w", encoding="utf-8") as f: json.dump(sorted(processed), f, indent=0, ensure_ascii=False) tmp.replace(path) def _run_predict( repo_root: Path, script: Path, checkpoint: str, mp4_path: Path, clips_per_video: int, batch_size: int, num_workers: int, device: str, output_json: str, ) -> None: mp4_path = mp4_path.resolve() path_prefix = str(mp4_path.parent) rel_name = mp4_path.name with tempfile.NamedTemporaryFile( mode="w", suffix=".csv", delete=False, encoding="utf-8" ) as tmp: tmp.write(f"{rel_name} 0\n") csv_path = tmp.name try: cmd = [ sys.executable, str(script), "--checkpoint", checkpoint, "--csv", csv_path, "--path_prefix", path_prefix, "--clips_per_video", str(clips_per_video), "--batch_size", str(batch_size), "--num_workers", str(num_workers), "--output_json", output_json, "--log_interval", "0", ] if device.strip(): cmd.extend(["--device", device.strip()]) proc = subprocess.run( cmd, cwd=str(repo_root), env=os.environ.copy(), capture_output=True, text=True, ) if proc.returncode != 0: err = (proc.stderr or "") + (proc.stdout or "") raise RuntimeError( f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}" ) finally: Path(csv_path).unlink(missing_ok=True) def _iter_mp4(watch_dir: Path, recursive: bool) -> list[Path]: if recursive: return sorted( p for p in watch_dir.rglob("*") if p.is_file() and p.suffix.lower() == ".mp4" ) return sorted( p for p in watch_dir.iterdir() if p.is_file() and p.suffix.lower() == ".mp4" ) def main() -> None: parser = argparse.ArgumentParser( description="Monitor a folder and run FishAction X3D on each new MP4." ) parser.add_argument( "--watch-dir", type=str, required=True, help="Directory to watch (non-recursive unless --recursive).", ) parser.add_argument( "--recursive", action="store_true", help="Also scan subdirectories for .mp4.", ) parser.add_argument("--checkpoint", type=str, required=True) parser.add_argument( "--poll-interval", type=float, default=2.0, help="Seconds between directory scans (default: 2).", ) parser.add_argument( "--stable-polls", type=int, default=3, help="Require this many consecutive polls with unchanged size (default: 3).", ) parser.add_argument( "--state-file", type=str, default="", help="JSON file listing processed absolute paths. Default: /.fishaction_watch_processed.json", ) parser.add_argument( "--no-state", action="store_true", help="Do not read/write state file (may re-process files on restart).", ) parser.add_argument( "--output-json-dir", type=str, default="", help="If set, write each prediction to /_pred.json.", ) parser.add_argument( "--clips_per_video", type=int, default=10, ) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument( "--device", type=str, default="", help="cuda|cpu (empty = predict script default).", ) parser.add_argument( "--once", action="store_true", help="Process all currently pending stable files once then exit.", ) args = parser.parse_args() watch_dir = Path(args.watch_dir).expanduser().resolve() if not watch_dir.is_dir(): raise SystemExit(f"Not a directory: {watch_dir}") repo_root = Path(__file__).resolve().parent script = repo_root / "predict_video_x3d_3class.py" if not script.is_file(): raise SystemExit(f"Missing script: {script}") state_path = ( Path(args.state_file).expanduser().resolve() if args.state_file.strip() else watch_dir / ".fishaction_watch_processed.json" ) processed: Set[str] = set() if args.no_state else _load_state(state_path) out_json_dir = ( Path(args.output_json_dir).expanduser().resolve() if args.output_json_dir.strip() else None ) if out_json_dir is not None: out_json_dir.mkdir(parents=True, exist_ok=True) # path -> (last_size, consecutive_stable_count) stability: Dict[str, tuple[int, int]] = {} def process_one(mp4: Path) -> None: key = str(mp4.resolve()) if key in processed: return if out_json_dir is not None: out_json = str(out_json_dir / f"{mp4.stem}_pred.json") else: out_fd, out_json = tempfile.mkstemp(suffix="_pred.json") os.close(out_fd) try: print(f"[watch] inference: {mp4}", flush=True) _run_predict( repo_root, script, os.path.abspath(os.path.expanduser(args.checkpoint)), mp4, args.clips_per_video, args.batch_size, args.num_workers, args.device, out_json, ) with open(out_json, encoding="utf-8") as f: rows: Any = json.load(f) if rows: pred = rows[0].get("pred_3class", "") print(f"[watch] done: {mp4.name} -> {pred}", flush=True) processed.add(key) if not args.no_state: _save_state(state_path, processed) finally: if out_json_dir is None: Path(out_json).unlink(missing_ok=True) def tick() -> bool: """Returns True if any file was processed this tick.""" did = False seen_keys: Set[str] = set() for mp4 in _iter_mp4(watch_dir, args.recursive): key = str(mp4.resolve()) seen_keys.add(key) if key in processed: continue try: st = mp4.stat() except OSError: continue size = int(st.st_size) if size <= 0: stability.pop(key, None) continue last = stability.get(key) if last is None or last[0] != size: stability[key] = (size, 1) else: _, cnt = last stability[key] = (size, cnt + 1) _, cnt = stability[key] if cnt >= args.stable_polls: try: process_one(mp4) stability.pop(key, None) did = True except Exception as e: print(f"[watch] error on {mp4}: {e}", flush=True) traceback.print_exc() # 重置稳定计数,避免每个 poll 都重试打满日志 stability[key] = (size, 1) for k in list(stability.keys()): if k not in seen_keys: del stability[k] return did print( f"[watch] watching {watch_dir} (poll={args.poll_interval}s, " f"stable_polls={args.stable_polls}, state={state_path if not args.no_state else 'off'})", flush=True, ) if args.once: while tick(): pass return while True: tick() time.sleep(max(args.poll_interval, 0.1)) if __name__ == "__main__": main()