From d0c53068ddb7f4eb8e5a11615e20bf11cbc17bca Mon Sep 17 00:00:00 2001 From: zaiun xu Date: Wed, 8 Apr 2026 19:54:18 +0800 Subject: [PATCH] feat(fish_api): add FishAction folder watch via Pydantic settings Add ACTION_WATCH_* settings, background watch loop in FastAPI lifespan, and shared action_watch service that updates app_state. Add fish-action-watch CLI and optional FishAction watch_folder_predict.py for standalone use. Made-with: Cursor --- FishAction/watch_folder_predict.py | 310 ++++++++++++++++++++++++++ fish_api/app/action_watch_cli.py | 24 ++ fish_api/app/main.py | 11 + fish_api/app/services/action_watch.py | 159 +++++++++++++ fish_api/app/settings.py | 24 +- fish_api/pyproject.toml | 3 + 6 files changed, 530 insertions(+), 1 deletion(-) create mode 100644 FishAction/watch_folder_predict.py create mode 100644 fish_api/app/action_watch_cli.py create mode 100644 fish_api/app/services/action_watch.py diff --git a/FishAction/watch_folder_predict.py b/FishAction/watch_folder_predict.py new file mode 100644 index 0000000..bf2f07e --- /dev/null +++ b/FishAction/watch_folder_predict.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +""" +持续监控某个目录,对新出现的 MP4 做与 fish_api 相同流程的 X3D 推理。 + +- 仅标准库:轮询目录 + 连续多次文件大小不变视为写入完成。 +- 已处理路径写入 state 文件,避免重复推理(可指定路径或关闭)。 + +若使用 **fish_api** 与 Pydantic 管理配置,请在 `fish_api/.env` 中设置 `ACTION_WATCH_DIR` +(及可选的 `ACTION_WATCH_POLL_INTERVAL` 等),由网关进程或 `python -m app.action_watch_cli` +统一监控;无需再跑本脚本。 + +示例(独立离线使用): + + python watch_folder_predict.py \\ + --watch-dir /data/incoming_mp4 \\ + --checkpoint checkpoints/ptv_x3d_m/checkpoint_best.pt +""" + +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +import tempfile +import time +import traceback +from pathlib import Path +from typing import Any, Dict, Set + + +def _load_state(path: Path) -> Set[str]: + if not path.is_file(): + return set() + try: + with open(path, encoding="utf-8") as f: + data = json.load(f) + if isinstance(data, list): + return set(str(x) for x in data) + if isinstance(data, dict) and "processed" in data: + return set(str(x) for x in data["processed"]) + except (json.JSONDecodeError, OSError): + pass + return set() + + +def _save_state(path: Path, processed: Set[str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(sorted(processed), f, indent=0, ensure_ascii=False) + tmp.replace(path) + + +def _run_predict( + repo_root: Path, + script: Path, + checkpoint: str, + mp4_path: Path, + clips_per_video: int, + batch_size: int, + num_workers: int, + device: str, + output_json: str, +) -> None: + mp4_path = mp4_path.resolve() + path_prefix = str(mp4_path.parent) + rel_name = mp4_path.name + with tempfile.NamedTemporaryFile( + mode="w", suffix=".csv", delete=False, encoding="utf-8" + ) as tmp: + tmp.write(f"{rel_name} 0\n") + csv_path = tmp.name + try: + cmd = [ + sys.executable, + str(script), + "--checkpoint", + checkpoint, + "--csv", + csv_path, + "--path_prefix", + path_prefix, + "--clips_per_video", + str(clips_per_video), + "--batch_size", + str(batch_size), + "--num_workers", + str(num_workers), + "--output_json", + output_json, + "--log_interval", + "0", + ] + if device.strip(): + cmd.extend(["--device", device.strip()]) + proc = subprocess.run( + cmd, + cwd=str(repo_root), + env=os.environ.copy(), + capture_output=True, + text=True, + ) + if proc.returncode != 0: + err = (proc.stderr or "") + (proc.stdout or "") + raise RuntimeError( + f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}" + ) + finally: + Path(csv_path).unlink(missing_ok=True) + + +def _iter_mp4(watch_dir: Path, recursive: bool) -> list[Path]: + if recursive: + return sorted( + p + for p in watch_dir.rglob("*") + if p.is_file() and p.suffix.lower() == ".mp4" + ) + return sorted( + p + for p in watch_dir.iterdir() + if p.is_file() and p.suffix.lower() == ".mp4" + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Monitor a folder and run FishAction X3D on each new MP4." + ) + parser.add_argument( + "--watch-dir", + type=str, + required=True, + help="Directory to watch (non-recursive unless --recursive).", + ) + parser.add_argument( + "--recursive", + action="store_true", + help="Also scan subdirectories for .mp4.", + ) + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument( + "--poll-interval", + type=float, + default=2.0, + help="Seconds between directory scans (default: 2).", + ) + parser.add_argument( + "--stable-polls", + type=int, + default=3, + help="Require this many consecutive polls with unchanged size (default: 3).", + ) + parser.add_argument( + "--state-file", + type=str, + default="", + help="JSON file listing processed absolute paths. Default: /.fishaction_watch_processed.json", + ) + parser.add_argument( + "--no-state", + action="store_true", + help="Do not read/write state file (may re-process files on restart).", + ) + parser.add_argument( + "--output-json-dir", + type=str, + default="", + help="If set, write each prediction to /_pred.json.", + ) + parser.add_argument( + "--clips_per_video", + type=int, + default=10, + ) + parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument( + "--device", + type=str, + default="", + help="cuda|cpu (empty = predict script default).", + ) + parser.add_argument( + "--once", + action="store_true", + help="Process all currently pending stable files once then exit.", + ) + args = parser.parse_args() + + watch_dir = Path(args.watch_dir).expanduser().resolve() + if not watch_dir.is_dir(): + raise SystemExit(f"Not a directory: {watch_dir}") + + repo_root = Path(__file__).resolve().parent + script = repo_root / "predict_video_x3d_3class.py" + if not script.is_file(): + raise SystemExit(f"Missing script: {script}") + + state_path = ( + Path(args.state_file).expanduser().resolve() + if args.state_file.strip() + else watch_dir / ".fishaction_watch_processed.json" + ) + processed: Set[str] = set() if args.no_state else _load_state(state_path) + + out_json_dir = ( + Path(args.output_json_dir).expanduser().resolve() + if args.output_json_dir.strip() + else None + ) + if out_json_dir is not None: + out_json_dir.mkdir(parents=True, exist_ok=True) + + # path -> (last_size, consecutive_stable_count) + stability: Dict[str, tuple[int, int]] = {} + + def process_one(mp4: Path) -> None: + key = str(mp4.resolve()) + if key in processed: + return + if out_json_dir is not None: + out_json = str(out_json_dir / f"{mp4.stem}_pred.json") + else: + out_fd, out_json = tempfile.mkstemp(suffix="_pred.json") + os.close(out_fd) + try: + print(f"[watch] inference: {mp4}", flush=True) + _run_predict( + repo_root, + script, + os.path.abspath(os.path.expanduser(args.checkpoint)), + mp4, + args.clips_per_video, + args.batch_size, + args.num_workers, + args.device, + out_json, + ) + with open(out_json, encoding="utf-8") as f: + rows: Any = json.load(f) + if rows: + pred = rows[0].get("pred_3class", "") + print(f"[watch] done: {mp4.name} -> {pred}", flush=True) + processed.add(key) + if not args.no_state: + _save_state(state_path, processed) + finally: + if out_json_dir is None: + Path(out_json).unlink(missing_ok=True) + + def tick() -> bool: + """Returns True if any file was processed this tick.""" + did = False + seen_keys: Set[str] = set() + for mp4 in _iter_mp4(watch_dir, args.recursive): + key = str(mp4.resolve()) + seen_keys.add(key) + if key in processed: + continue + try: + st = mp4.stat() + except OSError: + continue + size = int(st.st_size) + if size <= 0: + stability.pop(key, None) + continue + last = stability.get(key) + if last is None or last[0] != size: + stability[key] = (size, 1) + else: + _, cnt = last + stability[key] = (size, cnt + 1) + _, cnt = stability[key] + if cnt >= args.stable_polls: + try: + process_one(mp4) + stability.pop(key, None) + did = True + except Exception as e: + print(f"[watch] error on {mp4}: {e}", flush=True) + traceback.print_exc() + # 重置稳定计数,避免每个 poll 都重试打满日志 + stability[key] = (size, 1) + for k in list(stability.keys()): + if k not in seen_keys: + del stability[k] + return did + + print( + f"[watch] watching {watch_dir} (poll={args.poll_interval}s, " + f"stable_polls={args.stable_polls}, state={state_path if not args.no_state else 'off'})", + flush=True, + ) + + if args.once: + while tick(): + pass + return + + while True: + tick() + time.sleep(max(args.poll_interval, 0.1)) + + +if __name__ == "__main__": + main() diff --git a/fish_api/app/action_watch_cli.py b/fish_api/app/action_watch_cli.py new file mode 100644 index 0000000..a3a4dc8 --- /dev/null +++ b/fish_api/app/action_watch_cli.py @@ -0,0 +1,24 @@ +"""独立进程:仅根据 .env / 环境变量中的 action_watch_dir 跑目录监控(不启动 FastAPI)。""" + +from __future__ import annotations + +import asyncio +import sys + +from app.services.action_watch import run_action_watch_loop +from app.settings import get_settings + + +def main() -> None: + s = get_settings() + if s.action_watch_dir is None: + print( + "未配置监控目录:请在 .env 中设置 ACTION_WATCH_DIR=/你的/mp4/目录", + file=sys.stderr, + ) + raise SystemExit(1) + asyncio.run(run_action_watch_loop(s)) + + +if __name__ == "__main__": + main() diff --git a/fish_api/app/main.py b/fish_api/app/main.py index f5043d0..bc98cc1 100644 --- a/fish_api/app/main.py +++ b/fish_api/app/main.py @@ -1,11 +1,13 @@ from __future__ import annotations +import asyncio from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from app.routers import biomass, ingest +from app.services.action_watch import run_action_watch_loop from app.settings import get_settings @@ -14,7 +16,16 @@ async def lifespan(app: FastAPI): s = get_settings() s.media_root.mkdir(parents=True, exist_ok=True) s.stream_tmp_dir.mkdir(parents=True, exist_ok=True) + watch_task: asyncio.Task[None] | None = None + if s.action_watch_dir is not None: + watch_task = asyncio.create_task(run_action_watch_loop(s)) yield + if watch_task is not None: + watch_task.cancel() + try: + await watch_task + except asyncio.CancelledError: + pass app = FastAPI(title="Fish API", lifespan=lifespan) diff --git a/fish_api/app/services/action_watch.py b/fish_api/app/services/action_watch.py new file mode 100644 index 0000000..272ec76 --- /dev/null +++ b/fish_api/app/services/action_watch.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import asyncio +import json +import traceback +from pathlib import Path +from typing import Any, Dict, Set + +from app.services import action as action_svc +from app.settings import Settings +from app.state import HealthSnapshot, app_state + + +def _state_path(settings: Settings) -> Path: + if settings.action_watch_state_file is not None: + return settings.action_watch_state_file + assert settings.action_watch_dir is not None + return settings.action_watch_dir / ".fishaction_watch_processed.json" + + +def load_processed(path: Path) -> Set[str]: + if not path.is_file(): + return set() + try: + with open(path, encoding="utf-8") as f: + data: Any = json.load(f) + if isinstance(data, list): + return set(str(x) for x in data) + if isinstance(data, dict) and "processed" in data: + return set(str(x) for x in data["processed"]) + except (json.JSONDecodeError, OSError): + pass + return set() + + +def save_processed(path: Path, processed: Set[str]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(sorted(processed), f, indent=0, ensure_ascii=False) + tmp.replace(path) + + +def iter_mp4(watch_dir: Path, recursive: bool) -> list[Path]: + if recursive: + return sorted( + p + for p in watch_dir.rglob("*") + if p.is_file() and p.suffix.lower() == ".mp4" + ) + return sorted( + p + for p in watch_dir.iterdir() + if p.is_file() and p.suffix.lower() == ".mp4" + ) + + +async def _run_inference_and_state( + mp4: Path, + settings: Settings, + processed: Set[str], + state_file: Path, +) -> None: + key = str(mp4.resolve()) + if key in processed: + return + print(f"[action-watch] inference: {mp4}", flush=True) + async with app_state.action_lock: + app_state.action_status = "running" + try: + snap = await asyncio.to_thread(action_svc.run_full_action, mp4, settings) + app_state.last_health = snap + app_state.action_status = "idle" + processed.add(key) + if settings.action_watch_use_state_file: + save_processed(state_file, processed) + pred = (snap.raw_class_en or "").strip() + print(f"[action-watch] done: {mp4.name} -> {pred}", flush=True) + except Exception as e: + app_state.last_health = HealthSnapshot( + behavior_result="", + health_result="", + error=str(e), + ) + app_state.action_status = "error" + print(f"[action-watch] error on {mp4}: {e}", flush=True) + traceback.print_exc() + raise + + +async def watch_tick( + settings: Settings, + processed: Set[str], + stability: Dict[str, tuple[int, int]], + state_file: Path, +) -> bool: + """处理一轮目录扫描;若处理了至少一个文件返回 True。""" + assert settings.action_watch_dir is not None + watch_dir = settings.action_watch_dir + did = False + seen_keys: Set[str] = set() + for mp4 in iter_mp4(watch_dir, settings.action_watch_recursive): + key = str(mp4.resolve()) + seen_keys.add(key) + if key in processed: + continue + try: + st = mp4.stat() + except OSError: + continue + size = int(st.st_size) + if size <= 0: + stability.pop(key, None) + continue + last = stability.get(key) + if last is None or last[0] != size: + stability[key] = (size, 1) + else: + _, cnt = last + stability[key] = (size, cnt + 1) + _, cnt = stability[key] + if cnt >= settings.action_watch_stable_polls: + try: + await _run_inference_and_state(mp4, settings, processed, state_file) + stability.pop(key, None) + did = True + except Exception: + stability[key] = (size, 1) + for k in list(stability.keys()): + if k not in seen_keys: + del stability[k] + return did + + +async def run_action_watch_loop(settings: Settings) -> None: + assert settings.action_watch_dir is not None + wd = settings.action_watch_dir + if not wd.is_dir(): + print(f"[action-watch] skip: not a directory: {wd}", flush=True) + return + + state_file = _state_path(settings) + processed: Set[str] = ( + load_processed(state_file) if settings.action_watch_use_state_file else set() + ) + stability: Dict[str, tuple[int, int]] = {} + + print( + f"[action-watch] watching {wd} " + f"(poll={settings.action_watch_poll_interval}s, " + f"stable_polls={settings.action_watch_stable_polls}, " + f"state={'on' if settings.action_watch_use_state_file else 'off'} " + f"{state_file if settings.action_watch_use_state_file else ''})", + flush=True, + ) + + while True: + await watch_tick(settings, processed, stability, state_file) + await asyncio.sleep(max(settings.action_watch_poll_interval, 0.1)) diff --git a/fish_api/app/settings.py b/fish_api/app/settings.py index f0f34aa..5eab214 100644 --- a/fish_api/app/settings.py +++ b/fish_api/app/settings.py @@ -4,7 +4,7 @@ from functools import lru_cache from pathlib import Path from typing import Optional -from pydantic import Field, model_validator +from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -57,8 +57,30 @@ class Settings(BaseSettings): action_batch_size: int = 4 action_num_workers: int = 2 + #: 非空时由 fish_api 在后台持续扫描该目录中的新 MP4 并跑 FishAction(与 ingest 共用 app_state) + action_watch_dir: Optional[Path] = None + action_watch_poll_interval: float = Field(default=2.0, ge=0.1) + action_watch_stable_polls: int = Field(default=3, ge=1) + action_watch_recursive: bool = False + #: 默认:/.fishaction_watch_processed.json + action_watch_state_file: Optional[Path] = None + action_watch_use_state_file: bool = True + default_fish_species: str = "大黄鱼" + @field_validator( + "action_watch_dir", + "action_watch_state_file", + mode="before", + ) + @classmethod + def _empty_str_path_none(cls, v: object) -> object: + if v is None: + return None + if isinstance(v, str) and not v.strip(): + return None + return v + @model_validator(mode="after") def _default_paths(self) -> "Settings": if not self.yolo_model: diff --git a/fish_api/pyproject.toml b/fish_api/pyproject.toml index 98d459d..f029ec3 100644 --- a/fish_api/pyproject.toml +++ b/fish_api/pyproject.toml @@ -12,3 +12,6 @@ dependencies = [ [dependency-groups] dev = ["httpx>=0.28.1"] + +[project.scripts] +fish-action-watch = "app.action_watch_cli:main"