from __future__ import annotations from functools import lru_cache from pathlib import Path from typing import Optional from pydantic import Field, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict def fish_repo_root() -> Path: # fish_api/app/settings.py -> parent[2] = repo root (contains FishMeasure/, fish_api/) return Path(__file__).resolve().parents[2] def _default_stream_tmp() -> Path: return fish_repo_root() / "fish_api" / ".data" / "ingest" def _default_media_root() -> Path: return fish_repo_root() / "fish_api" / ".data" / "media" def _default_sqlite_path() -> Path: return fish_repo_root() / "fish_api" / ".data" / "app.db" class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", extra="ignore", ) public_base_url: str = "http://127.0.0.1:8000" ingest_api_key: str = "" stream_tmp_dir: Path = Field(default_factory=_default_stream_tmp) media_root: Path = Field(default_factory=_default_media_root) sqlite_path: Path = Field(default_factory=_default_sqlite_path) fish_measure_root: Path = fish_repo_root() / "FishMeasure" fish_action_root: Path = fish_repo_root() / "FishAction" measure_output_root: Path = fish_repo_root() / "FishMeasure" / "output_weight_estimator" python_fish_measure: str = "" python_fish_action: str = "" yolo_model: Optional[str] = None weight_checkpoint: Optional[str] = None sam_device: str = "cuda" predict_conf: float = 0.5 predict_imgsz: int = 640 predict_max_frames: int = 0 predict_frame_stride: int = 1 #: 传给 predict_weigth_from_svo2.py 的点云/权重选项(与命令行一致,可用 .env 覆盖) predict_filter_pointcloud: bool = True predict_use_density_filter: bool = True predict_use_clustering_filter: bool = False #: 留空则在 _default_paths 中设为 FishMeasure 下默认 PointNet++ 权重(若文件存在) predict_pointcloud_classifier: Optional[str] = None predict_use_pointcloud_classifier: bool = True predict_pointcloud_classifier_threshold: float = 0.7 predict_use_flatness_filter: bool = True predict_flatness_threshold: float = 55.0 measure_weight_top_k: int = 5 measure_weight_top_by_length: bool = True #: 为 True 时 fish_video 内联 DGCNN + 预览叠加(更重;需 fish_video 已支持) predict_fish_video_weight_overlay: bool = False predict_minute_interval_sec: float = 60.0 action_checkpoint: Optional[str] = None action_clips_per_video: int = 8 action_batch_size: int = 4 action_num_workers: int = 2 #: 非空时由 fish_api 在后台持续扫描该目录中的新 MP4 并跑 FishAction(与 ingest 共用 SQLite 最新结果) action_watch_dir: Optional[Path] = None action_watch_poll_interval: float = Field(default=2.0, ge=0.1) action_watch_stable_polls: int = Field(default=3, ge=1) action_watch_recursive: bool = False #: 默认:/.fishaction_watch_processed.json action_watch_state_file: Optional[Path] = None action_watch_use_state_file: bool = True #: 非空时后台持续扫描该目录中的新 .svo2 并跑 FishMeasure(与 ingest 共用 SQLite 最新结果) measure_watch_dir: Optional[Path] = None measure_watch_poll_interval: float = Field(default=2.0, ge=0.1) measure_watch_stable_polls: int = Field(default=3, ge=1) measure_watch_recursive: bool = False measure_watch_state_file: Optional[Path] = None measure_watch_use_state_file: bool = True default_fish_species: str = "大黄鱼" @field_validator( "action_watch_dir", "action_watch_state_file", "measure_watch_dir", "measure_watch_state_file", mode="before", ) @classmethod def _empty_str_path_none(cls, v: object) -> object: if v is None: return None if isinstance(v, str) and not v.strip(): return None return v @model_validator(mode="after") def _default_paths(self) -> "Settings": if not self.yolo_model: object.__setattr__( self, "yolo_model", str( self.fish_measure_root / "runs/train/fish_detection_20251127_104658/weights/best.pt" ), ) if not self.weight_checkpoint: object.__setattr__( self, "weight_checkpoint", str( self.fish_measure_root / "weight_estimator/runs/dgcnn_20260312_171043/best.pt" ), ) if not self.action_checkpoint: object.__setattr__( self, "action_checkpoint", str(self.fish_action_root / "checkpoints/ptv_x3d_m/checkpoint_best.pt"), ) if not self.predict_pointcloud_classifier: _pc = ( self.fish_measure_root / "pointcloud_classifier" / "Pointnet_Pointnet2_pytorch" / "log" / "classification" / "fish_pointnet2_finetune" / "checkpoints" / "best_model.pth" ) if _pc.is_file(): object.__setattr__(self, "predict_pointcloud_classifier", str(_pc)) return self @lru_cache def get_settings() -> Settings: return Settings()