Files
FishServer/fish_api/app/settings.py

160 lines
5.5 KiB
Python
Raw Normal View History

from __future__ import annotations
from functools import lru_cache
from pathlib import Path
from typing import Optional
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
def fish_repo_root() -> Path:
# fish_api/app/settings.py -> parent[2] = repo root (contains FishMeasure/, fish_api/)
return Path(__file__).resolve().parents[2]
def _default_stream_tmp() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "ingest"
def _default_media_root() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "media"
def _default_sqlite_path() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "app.db"
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
public_base_url: str = "http://127.0.0.1:8000"
ingest_api_key: str = ""
stream_tmp_dir: Path = Field(default_factory=_default_stream_tmp)
media_root: Path = Field(default_factory=_default_media_root)
sqlite_path: Path = Field(default_factory=_default_sqlite_path)
fish_measure_root: Path = fish_repo_root() / "FishMeasure"
fish_action_root: Path = fish_repo_root() / "FishAction"
measure_output_root: Path = fish_repo_root() / "FishMeasure" / "output_weight_estimator"
python_fish_measure: str = ""
python_fish_action: str = ""
yolo_model: Optional[str] = None
weight_checkpoint: Optional[str] = None
sam_device: str = "cuda"
predict_conf: float = 0.5
predict_imgsz: int = 640
predict_max_frames: int = 0
predict_frame_stride: int = 1
#: 传给 predict_weigth_from_svo2.py 的点云/权重选项(与命令行一致,可用 .env 覆盖)
predict_filter_pointcloud: bool = True
predict_use_density_filter: bool = True
predict_use_clustering_filter: bool = False
#: 留空则在 _default_paths 中设为 FishMeasure 下默认 PointNet++ 权重(若文件存在)
predict_pointcloud_classifier: Optional[str] = None
predict_use_pointcloud_classifier: bool = True
predict_pointcloud_classifier_threshold: float = 0.7
predict_use_flatness_filter: bool = True
predict_flatness_threshold: float = 55.0
measure_weight_top_k: int = 5
measure_weight_top_by_length: bool = True
#: 为 True 时 fish_video 内联 DGCNN + 预览叠加(更重;需 fish_video 已支持)
predict_fish_video_weight_overlay: bool = False
predict_minute_interval_sec: float = 60.0
action_checkpoint: Optional[str] = None
action_clips_per_video: int = 8
action_batch_size: int = 4
action_num_workers: int = 2
#: 非空时由 fish_api 在后台持续扫描该目录中的新 MP4 并跑 FishAction与 ingest 共用 SQLite 最新结果)
action_watch_dir: Optional[Path] = None
action_watch_poll_interval: float = Field(default=2.0, ge=0.1)
action_watch_stable_polls: int = Field(default=3, ge=1)
action_watch_recursive: bool = False
#: 默认:<action_watch_dir>/.fishaction_watch_processed.json
action_watch_state_file: Optional[Path] = None
action_watch_use_state_file: bool = True
#: 非空时后台持续扫描该目录中的新 .svo2 并跑 FishMeasure与 ingest 共用 SQLite 最新结果)
measure_watch_dir: Optional[Path] = None
measure_watch_poll_interval: float = Field(default=2.0, ge=0.1)
measure_watch_stable_polls: int = Field(default=3, ge=1)
measure_watch_recursive: bool = False
measure_watch_state_file: Optional[Path] = None
measure_watch_use_state_file: bool = True
default_fish_species: str = "大黄鱼"
@field_validator(
"action_watch_dir",
"action_watch_state_file",
"measure_watch_dir",
"measure_watch_state_file",
mode="before",
)
@classmethod
def _empty_str_path_none(cls, v: object) -> object:
if v is None:
return None
if isinstance(v, str) and not v.strip():
return None
return v
@model_validator(mode="after")
def _default_paths(self) -> "Settings":
if not self.yolo_model:
object.__setattr__(
self,
"yolo_model",
str(
self.fish_measure_root
/ "runs/train/fish_detection_20251127_104658/weights/best.pt"
),
)
if not self.weight_checkpoint:
object.__setattr__(
self,
"weight_checkpoint",
str(
self.fish_measure_root
/ "weight_estimator/runs/dgcnn_20260312_171043/best.pt"
),
)
if not self.action_checkpoint:
object.__setattr__(
self,
"action_checkpoint",
str(self.fish_action_root / "checkpoints/ptv_x3d_m/checkpoint_best.pt"),
)
if not self.predict_pointcloud_classifier:
_pc = (
self.fish_measure_root
/ "pointcloud_classifier"
/ "Pointnet_Pointnet2_pytorch"
/ "log"
/ "classification"
/ "fish_pointnet2_finetune"
/ "checkpoints"
/ "best_model.pth"
)
if _pc.is_file():
object.__setattr__(self, "predict_pointcloud_classifier", str(_pc))
return self
@lru_cache
def get_settings() -> Settings:
return Settings()