Files
FishServer/fish_api/app/settings.py
2026-04-15 09:01:45 +08:00

309 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from functools import lru_cache
from pathlib import Path
from typing import Optional
from pydantic import AliasChoices, Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
def _fish_api_env_file() -> Path:
"""fish_api/.env — 与启动 cwd 无关,避免从仓库根跑 uvicorn 时读不到 .env。"""
return Path(__file__).resolve().parents[1] / ".env"
def fish_repo_root() -> Path:
# fish_api/app/settings.py -> parent[2] = repo root (contains FishMeasure/, fish_api/)
return Path(__file__).resolve().parents[2]
def models_dir() -> Path:
"""仓库内统一权重目录YOLO / DGCNN / PointNet / X3D / SAM 等),与 FishMeasure 代码目录解耦。"""
return fish_repo_root() / "models"
def _default_stream_tmp() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "ingest"
def _default_media_root() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "media"
def _default_sqlite_path() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "app.db"
def _default_action_output_root() -> Path:
return fish_repo_root() / "fish_api" / ".data" / "action_output"
def _default_measure_debug_log_dir() -> Path:
"""DGCNN 体重推算过程等可调试文本(与终端一致的 calculation log"""
return fish_repo_root() / "fish_api" / ".data" / "logs" / "measure"
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=_fish_api_env_file(),
env_file_encoding="utf-8",
extra="ignore",
)
#: 对外可访问的 API 基址(无末尾 `/`),用于 biomass 等 JSON 里 `video_left` / `video_right` 的绝对 URL。环境变量**PUBLIC_BASE_URL**
public_base_url: str = Field(
default="http://127.0.0.1:8000",
validation_alias=AliasChoices("PUBLIC_BASE_URL", "public_base_url"),
)
ingest_api_key: str = ""
stream_tmp_dir: Path = Field(default_factory=_default_stream_tmp)
media_root: Path = Field(default_factory=_default_media_root)
sqlite_path: Path = Field(default_factory=_default_sqlite_path)
fish_measure_root: Path = fish_repo_root() / "FishMeasure"
fish_action_root: Path = fish_repo_root() / "FishAction"
#: FishMeasure 推理输出(与 SQLite、媒体缓存同属 fish_api/.data启动脚本默认保留设置 CLEAR_MEASURE_OUTPUT=1 可清空)
measure_output_root: Path = fish_repo_root() / "fish_api" / ".data" / "measure_output"
#: 体重推算过程等调试文本写入目录(默认 fish_api/.data/logs/measure。**MEASURE_DEBUG_LOG_DIR**
measure_debug_log_dir: Path = Field(
default_factory=_default_measure_debug_log_dir,
validation_alias=AliasChoices("MEASURE_DEBUG_LOG_DIR", "measure_debug_log_dir"),
)
#: 为 False 时不写入上述目录(仍打 logger。**MEASURE_DEBUG_LOG_WRITE**
measure_debug_log_write: bool = Field(
default=True,
validation_alias=AliasChoices("MEASURE_DEBUG_LOG_WRITE", "measure_debug_log_write"),
)
#: FishAction 侧预留目录(与 measure 对称;启动脚本默认保留,设置 CLEAR_ACTION_OUTPUT=1 可清空)
action_output_root: Path = Field(default_factory=_default_action_output_root)
python_fish_measure: str = ""
python_fish_action: str = ""
#: SAM/CUDA 设备cuda 或 cpu
sam_device: str = "cuda"
#: 为 True 时在视频右上角显示大型 weight/length 标签10倍字体
predict_show_large_labels_at_top_right: bool = False
#: FishMeasure 中 YOLO 置信度见 ``measure_yolo_conf`` / ``MEASURE_YOLO_CONF``;其余脚本内参数仍可在 FishMeasure 目录修改。
#: FishAction 核心参数见 ``action_checkpoint`` 等。
#: FishAction X3D 模型路径(不设则用 models/action_x3d/checkpoint_best.pt
action_checkpoint: Optional[str] = None
#: 为 True 时复用已有 cloud/*.ply传 --reuse-existing-clouds
#: 为 False 时强制重新生成点云(传 --no-reuse-existing-clouds
measure_reuse_existing_clouds: bool = True
#: YOLO 检测置信度,传给 ``predict_weigth_from_svo2.py --conf``(与 FishMeasure ``run_predict_from_svo2_fish9.sh`` 等使用的 0.8 对齐)。**MEASURE_YOLO_CONF**
measure_yolo_conf: float = Field(
default=0.8,
ge=0.0,
le=1.0,
validation_alias=AliasChoices("MEASURE_YOLO_CONF", "measure_yolo_conf"),
)
#: 传给 FishMeasure ``--filter-pointcloud``(默认开启,与 fish9 脚本对齐)。
measure_filter_pointcloud: bool = Field(
default=True,
validation_alias=AliasChoices(
"MEASURE_FILTER_POINTCLOUD", "measure_filter_pointcloud"
),
)
#: 传给 FishMeasure ``--use-density-filter``(默认开启,与 fish9 脚本对齐)。
measure_use_density_filter: bool = Field(
default=True,
validation_alias=AliasChoices(
"MEASURE_USE_DENSITY_FILTER", "measure_use_density_filter"
),
)
#: 传给 FishMeasure ``--use-pointcloud-classifier``(默认开启,与 fish9 脚本对齐)。
measure_use_pointcloud_classifier: bool = Field(
default=True,
validation_alias=AliasChoices(
"MEASURE_USE_POINTCLOUD_CLASSIFIER", "measure_use_pointcloud_classifier"
),
)
#: PointNet2 点云分类器阈值,传给 ``--pointcloud-classifier-threshold``。
measure_pointcloud_classifier_threshold: float = Field(
default=0.7,
ge=0.0,
le=1.0,
validation_alias=AliasChoices(
"MEASURE_POINTCLOUD_CLASSIFIER_THRESHOLD",
"measure_pointcloud_classifier_threshold",
),
)
#: 点云分类器模型路径,传给 ``--pointcloud-classifier``。
measure_pointcloud_classifier: Optional[Path] = Field(
default=None,
validation_alias=AliasChoices(
"MEASURE_POINTCLOUD_CLASSIFIER", "measure_pointcloud_classifier"
),
)
#: 传给 FishMeasure ``--use-flatness-filter``(默认开启,与 fish9 脚本对齐)。
measure_use_flatness_filter: bool = Field(
default=True,
validation_alias=AliasChoices(
"MEASURE_USE_FLATNESS_FILTER", "measure_use_flatness_filter"
),
)
#: 平整度阈值,传给 ``--flatness-threshold``。
measure_flatness_threshold: float = Field(
default=55.0,
validation_alias=AliasChoices(
"MEASURE_FLATNESS_THRESHOLD", "measure_flatness_threshold"
),
)
# ── 体重聚合规则(传给 predict_weigth_from_svo2.py → test_dgcnn_weight_estimator.py ──
#: DGCNN top-K 帧数,传给 ``--weight-top-k``。**MEASURE_WEIGHT_TOP_K**
measure_weight_top_k: int = Field(
default=5,
ge=1,
validation_alias=AliasChoices("MEASURE_WEIGHT_TOP_K", "measure_weight_top_k"),
)
#: 按长度选 top-K传给 ``--weight-top-by-length``。**MEASURE_WEIGHT_TOP_BY_LENGTH**
measure_weight_top_by_length: bool = Field(
default=True,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_TOP_BY_LENGTH", "measure_weight_top_by_length"
),
)
#: top-K 按长度选时,若 K 个平均长度 > 此值则切为按重量选,传给 ``--weight-length-switch-mm``。**MEASURE_WEIGHT_LENGTH_SWITCH_MM**
measure_weight_length_switch_mm: float = Field(
default=319.0,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_LENGTH_SWITCH_MM", "measure_weight_length_switch_mm"
),
)
#: 几何过滤length > 此值的帧排除,传给 ``--weight-max-length-mm``0 关闭)。**MEASURE_WEIGHT_MAX_LENGTH_MM**
measure_weight_max_length_mm: float = Field(
default=400.0,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_MAX_LENGTH_MM", "measure_weight_max_length_mm"
),
)
#: 几何过滤PCA 长/宽 < 此值的帧排除,传给 ``--weight-min-length-width-ratio``0 关闭)。**MEASURE_WEIGHT_MIN_LENGTH_WIDTH_RATIO**
measure_weight_min_length_width_ratio: float = Field(
default=1.5,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_MIN_LENGTH_WIDTH_RATIO", "measure_weight_min_length_width_ratio"
),
)
#: 全池均值模式,传给 ``--weight-average-all-after-filter``。**MEASURE_WEIGHT_AVERAGE_ALL_AFTER_FILTER**
measure_weight_average_all_after_filter: bool = Field(
default=False,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_AVERAGE_ALL_AFTER_FILTER", "measure_weight_average_all_after_filter"
),
)
#: 全池均值 > 此值时改用 max规则 A传给 ``--weight-average-all-fallback-max-if-mean-over-g``0 关闭)。**MEASURE_WEIGHT_AVG_ALL_FALLBACK_MAX_G**
measure_weight_avg_all_fallback_max_g: float = Field(
default=400.0,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_AVG_ALL_FALLBACK_MAX_G", "measure_weight_avg_all_fallback_max_g"
),
)
#: 全池 candidates 均值 > 此值时改用 max规则 B, 440g 保护),传给 ``--weight-mean-pool-fallback-max-if-over-g``0 关闭)。**MEASURE_WEIGHT_MEAN_POOL_FALLBACK_MAX_G**
measure_weight_mean_pool_fallback_max_g: float = Field(
default=440.0,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_MEAN_POOL_FALLBACK_MAX_G", "measure_weight_mean_pool_fallback_max_g"
),
)
#: 异常值剔除开关,传给 ``--weight-remove-outliers``。**MEASURE_WEIGHT_REMOVE_OUTLIERS**
measure_weight_remove_outliers: bool = Field(
default=False,
validation_alias=AliasChoices(
"MEASURE_WEIGHT_REMOVE_OUTLIERS", "measure_weight_remove_outliers"
),
)
#: 异常值剔除方法iqr / zscore传给 ``--weight-outlier-method``。**MEASURE_WEIGHT_OUTLIER_METHOD**
measure_weight_outlier_method: str = Field(
default="iqr",
validation_alias=AliasChoices(
"MEASURE_WEIGHT_OUTLIER_METHOD", "measure_weight_outlier_method"
),
)
#: 非空时由 fish_api 在后台持续扫描该目录中的新 MP4 并跑 FishAction与 ingest 共用 SQLite 最新结果)
action_watch_dir: Optional[Path] = None
action_watch_poll_interval: float = Field(default=2.0, ge=0.1)
action_watch_stable_polls: int = Field(default=3, ge=1)
action_watch_recursive: bool = False
#: 状态管理true=持久化到 SQLite重启后记住false=内存模式(重启后清空)
action_watch_use_state_file: bool = True
#: 优先作为「水上视频」源文件;未设置时在 ACTION_WATCH_DIR 取最新 .mp4FishAction 输入)。**BIOMASS_WATER_VIDEO_SOURCE**
biomass_water_video_source: Optional[Path] = None
#: 发布到 MEDIA_ROOT 的 H.264 文件名。**BIOMASS_WATER_VIDEO_MEDIA_NAME**
biomass_water_video_media_name: str = "biomass_water_surface.mp4"
#: 优先作为「声呐视频」源文件;未设置时在 BIOMASS_SONAR_VIDEO_DIR 取最新 .mp4。**BIOMASS_SONAR_VIDEO_SOURCE**
biomass_sonar_video_source: Optional[Path] = None
#: 声呐 MP4 目录(与 ACTION_WATCH_DIR 独立,避免与水面视频混用)。**BIOMASS_SONAR_VIDEO_DIR**
biomass_sonar_video_dir: Optional[Path] = None
#: 是否在 SONAR_VIDEO_DIR 中递归查找 .mp4。**BIOMASS_SONAR_VIDEO_RECURSIVE**
biomass_sonar_video_recursive: bool = False
#: 发布到 MEDIA_ROOT 的 H.264 文件名。**BIOMASS_SONAR_VIDEO_MEDIA_NAME**
biomass_sonar_video_media_name: str = "biomass_sonar.mp4"
#: 非空时后台持续扫描该目录中的新 .svo2 并跑 FishMeasure与 ingest 共用 SQLite 最新结果)
measure_watch_dir: Optional[Path] = None
measure_watch_poll_interval: float = Field(default=2.0, ge=0.1)
measure_watch_stable_polls: int = Field(default=3, ge=1)
measure_watch_recursive: bool = False
#: 状态管理true=持久化到 SQLite重启后记住false=内存模式(重启后清空)
measure_watch_use_state_file: bool = True
default_fish_species: str = "大黄鱼"
@field_validator(
"action_watch_dir",
"biomass_water_video_source",
"biomass_sonar_video_source",
"biomass_sonar_video_dir",
"measure_watch_dir",
mode="before",
)
@classmethod
def _empty_str_path_none(cls, v: object) -> object:
if v is None:
return None
if isinstance(v, str) and not v.strip():
return None
return v
@model_validator(mode="after")
def _default_paths(self) -> "Settings":
md = models_dir()
if not self.action_checkpoint:
object.__setattr__(
self, "action_checkpoint", str(md / "action_x3d" / "checkpoint_best.pt")
)
if self.measure_pointcloud_classifier is None:
object.__setattr__(
self,
"measure_pointcloud_classifier",
self.fish_measure_root
/ "pointcloud_classifier"
/ "Pointnet_Pointnet2_pytorch"
/ "log"
/ "classification"
/ "fish_pointnet2_finetune"
/ "checkpoints"
/ "best_model.pth",
)
return self
@lru_cache
def get_settings() -> Settings:
return Settings()