Files
FishServer/FishAction/watch_folder_predict.py
zaiun xu d0c53068dd feat(fish_api): add FishAction folder watch via Pydantic settings
Add ACTION_WATCH_* settings, background watch loop in FastAPI lifespan,
and shared action_watch service that updates app_state. Add fish-action-watch
CLI and optional FishAction watch_folder_predict.py for standalone use.

Made-with: Cursor
2026-04-08 19:54:18 +08:00

311 lines
9.3 KiB
Python

#!/usr/bin/env python3
"""
持续监控某个目录,对新出现的 MP4 做与 fish_api 相同流程的 X3D 推理。
- 仅标准库:轮询目录 + 连续多次文件大小不变视为写入完成。
- 已处理路径写入 state 文件,避免重复推理(可指定路径或关闭)。
若使用 **fish_api** 与 Pydantic 管理配置,请在 `fish_api/.env` 中设置 `ACTION_WATCH_DIR`
(及可选的 `ACTION_WATCH_POLL_INTERVAL` 等),由网关进程或 `python -m app.action_watch_cli`
统一监控;无需再跑本脚本。
示例(独立离线使用):
python watch_folder_predict.py \\
--watch-dir /data/incoming_mp4 \\
--checkpoint checkpoints/ptv_x3d_m/checkpoint_best.pt
"""
from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import tempfile
import time
import traceback
from pathlib import Path
from typing import Any, Dict, Set
def _load_state(path: Path) -> Set[str]:
if not path.is_file():
return set()
try:
with open(path, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
return set(str(x) for x in data)
if isinstance(data, dict) and "processed" in data:
return set(str(x) for x in data["processed"])
except (json.JSONDecodeError, OSError):
pass
return set()
def _save_state(path: Path, processed: Set[str]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(path.suffix + ".tmp")
with open(tmp, "w", encoding="utf-8") as f:
json.dump(sorted(processed), f, indent=0, ensure_ascii=False)
tmp.replace(path)
def _run_predict(
repo_root: Path,
script: Path,
checkpoint: str,
mp4_path: Path,
clips_per_video: int,
batch_size: int,
num_workers: int,
device: str,
output_json: str,
) -> None:
mp4_path = mp4_path.resolve()
path_prefix = str(mp4_path.parent)
rel_name = mp4_path.name
with tempfile.NamedTemporaryFile(
mode="w", suffix=".csv", delete=False, encoding="utf-8"
) as tmp:
tmp.write(f"{rel_name} 0\n")
csv_path = tmp.name
try:
cmd = [
sys.executable,
str(script),
"--checkpoint",
checkpoint,
"--csv",
csv_path,
"--path_prefix",
path_prefix,
"--clips_per_video",
str(clips_per_video),
"--batch_size",
str(batch_size),
"--num_workers",
str(num_workers),
"--output_json",
output_json,
"--log_interval",
"0",
]
if device.strip():
cmd.extend(["--device", device.strip()])
proc = subprocess.run(
cmd,
cwd=str(repo_root),
env=os.environ.copy(),
capture_output=True,
text=True,
)
if proc.returncode != 0:
err = (proc.stderr or "") + (proc.stdout or "")
raise RuntimeError(
f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}"
)
finally:
Path(csv_path).unlink(missing_ok=True)
def _iter_mp4(watch_dir: Path, recursive: bool) -> list[Path]:
if recursive:
return sorted(
p
for p in watch_dir.rglob("*")
if p.is_file() and p.suffix.lower() == ".mp4"
)
return sorted(
p
for p in watch_dir.iterdir()
if p.is_file() and p.suffix.lower() == ".mp4"
)
def main() -> None:
parser = argparse.ArgumentParser(
description="Monitor a folder and run FishAction X3D on each new MP4."
)
parser.add_argument(
"--watch-dir",
type=str,
required=True,
help="Directory to watch (non-recursive unless --recursive).",
)
parser.add_argument(
"--recursive",
action="store_true",
help="Also scan subdirectories for .mp4.",
)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument(
"--poll-interval",
type=float,
default=2.0,
help="Seconds between directory scans (default: 2).",
)
parser.add_argument(
"--stable-polls",
type=int,
default=3,
help="Require this many consecutive polls with unchanged size (default: 3).",
)
parser.add_argument(
"--state-file",
type=str,
default="",
help="JSON file listing processed absolute paths. Default: <watch-dir>/.fishaction_watch_processed.json",
)
parser.add_argument(
"--no-state",
action="store_true",
help="Do not read/write state file (may re-process files on restart).",
)
parser.add_argument(
"--output-json-dir",
type=str,
default="",
help="If set, write each prediction to <dir>/<video_stem>_pred.json.",
)
parser.add_argument(
"--clips_per_video",
type=int,
default=10,
)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument(
"--device",
type=str,
default="",
help="cuda|cpu (empty = predict script default).",
)
parser.add_argument(
"--once",
action="store_true",
help="Process all currently pending stable files once then exit.",
)
args = parser.parse_args()
watch_dir = Path(args.watch_dir).expanduser().resolve()
if not watch_dir.is_dir():
raise SystemExit(f"Not a directory: {watch_dir}")
repo_root = Path(__file__).resolve().parent
script = repo_root / "predict_video_x3d_3class.py"
if not script.is_file():
raise SystemExit(f"Missing script: {script}")
state_path = (
Path(args.state_file).expanduser().resolve()
if args.state_file.strip()
else watch_dir / ".fishaction_watch_processed.json"
)
processed: Set[str] = set() if args.no_state else _load_state(state_path)
out_json_dir = (
Path(args.output_json_dir).expanduser().resolve()
if args.output_json_dir.strip()
else None
)
if out_json_dir is not None:
out_json_dir.mkdir(parents=True, exist_ok=True)
# path -> (last_size, consecutive_stable_count)
stability: Dict[str, tuple[int, int]] = {}
def process_one(mp4: Path) -> None:
key = str(mp4.resolve())
if key in processed:
return
if out_json_dir is not None:
out_json = str(out_json_dir / f"{mp4.stem}_pred.json")
else:
out_fd, out_json = tempfile.mkstemp(suffix="_pred.json")
os.close(out_fd)
try:
print(f"[watch] inference: {mp4}", flush=True)
_run_predict(
repo_root,
script,
os.path.abspath(os.path.expanduser(args.checkpoint)),
mp4,
args.clips_per_video,
args.batch_size,
args.num_workers,
args.device,
out_json,
)
with open(out_json, encoding="utf-8") as f:
rows: Any = json.load(f)
if rows:
pred = rows[0].get("pred_3class", "")
print(f"[watch] done: {mp4.name} -> {pred}", flush=True)
processed.add(key)
if not args.no_state:
_save_state(state_path, processed)
finally:
if out_json_dir is None:
Path(out_json).unlink(missing_ok=True)
def tick() -> bool:
"""Returns True if any file was processed this tick."""
did = False
seen_keys: Set[str] = set()
for mp4 in _iter_mp4(watch_dir, args.recursive):
key = str(mp4.resolve())
seen_keys.add(key)
if key in processed:
continue
try:
st = mp4.stat()
except OSError:
continue
size = int(st.st_size)
if size <= 0:
stability.pop(key, None)
continue
last = stability.get(key)
if last is None or last[0] != size:
stability[key] = (size, 1)
else:
_, cnt = last
stability[key] = (size, cnt + 1)
_, cnt = stability[key]
if cnt >= args.stable_polls:
try:
process_one(mp4)
stability.pop(key, None)
did = True
except Exception as e:
print(f"[watch] error on {mp4}: {e}", flush=True)
traceback.print_exc()
# 重置稳定计数,避免每个 poll 都重试打满日志
stability[key] = (size, 1)
for k in list(stability.keys()):
if k not in seen_keys:
del stability[k]
return did
print(
f"[watch] watching {watch_dir} (poll={args.poll_interval}s, "
f"stable_polls={args.stable_polls}, state={state_path if not args.no_state else 'off'})",
flush=True,
)
if args.once:
while tick():
pass
return
while True:
tick()
time.sleep(max(args.poll_interval, 0.1))
if __name__ == "__main__":
main()