From 5e1b2117c1c7ba6fa44ebd2efe6248a807bf3640 Mon Sep 17 00:00:00 2001 From: zaiun xu Date: Thu, 9 Apr 2026 11:54:30 +0800 Subject: [PATCH] =?UTF-8?q?feat(fish=5Fapi):=20SQLite=20=E5=BF=AB=E7=85=A7?= =?UTF-8?q?=E6=8A=95=E9=80=92=E3=80=81=E6=97=A5=E5=BF=97=E4=B8=8E=20watch?= =?UTF-8?q?=20=E7=A9=BA=E9=97=B2=E5=91=8A=E8=AD=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 SQLite:measure/health 快照、delivery_cursor 单消费者 pop;clear/start_fresh 可清空库 - biomass GET 仅返回约定 data 字段,X-Fish-Biomass-New 表示是否有新快照;poller 读响应头 - loguru 桥接 uvicorn,子进程 stdout 流式输出;format_json_pretty 与算法摘要日志 - measure/action watch 无新任务时限流 WARNING;watch_idle 共用逻辑 - 依赖 loguru;新增 db、logging_config、subprocess_run、watch_idle、启动脚本 FishMeasure: 更新 fish_video_weight_evaluation 与 predict_weigth_from_svo2;移除未用 refbox/segmentation 脚本 Made-with: Cursor --- .../detect_refbox/prepare_refbox_dataset.py | 317 ----------- .../detect_refbox/train_refbox_yolo.py | 162 ------ FishMeasure/fish_video_weight_evaluation.py | 452 ++++++++++++---- FishMeasure/predict_weigth_from_svo2.py | 55 ++ FishMeasure/segmentation/README.md | 99 ---- FishMeasure/segmentation/__init__.py | 6 - .../segmentation/prepare_yolo_seg_dataset.py | 490 ------------------ FishMeasure/segmentation/train_yolo_seg.py | 151 ------ .../segmentation/visualize_yolo_seg_labels.py | 215 -------- fish_api/app/action_watch_cli.py | 2 + fish_api/app/db.py | 421 +++++++++++++++ fish_api/app/logging_config.py | 82 +++ fish_api/app/main.py | 6 + fish_api/app/routers/biomass.py | 126 +++-- fish_api/app/routers/ingest.py | 31 +- fish_api/app/services/action.py | 23 +- fish_api/app/services/action_watch.py | 90 ++-- fish_api/app/services/measure.py | 79 ++- fish_api/app/services/measure_watch.py | 69 ++- fish_api/app/settings.py | 38 +- fish_api/app/state.py | 16 +- fish_api/app/subprocess_run.py | 41 ++ fish_api/app/watch_idle.py | 56 ++ fish_api/pyproject.toml | 3 +- fish_api/start.sh | 24 + fish_api/start_fresh.sh | 69 +++ scripts/biomass_poller.py | 28 +- scripts/run_fishserver.sh | 17 +- scripts/start_fishapi_fresh.sh | 10 + 29 files changed, 1464 insertions(+), 1714 deletions(-) delete mode 100755 FishMeasure/detect_refbox/prepare_refbox_dataset.py delete mode 100755 FishMeasure/detect_refbox/train_refbox_yolo.py delete mode 100755 FishMeasure/segmentation/README.md delete mode 100755 FishMeasure/segmentation/__init__.py delete mode 100755 FishMeasure/segmentation/prepare_yolo_seg_dataset.py delete mode 100755 FishMeasure/segmentation/train_yolo_seg.py delete mode 100755 FishMeasure/segmentation/visualize_yolo_seg_labels.py create mode 100644 fish_api/app/db.py create mode 100644 fish_api/app/logging_config.py create mode 100644 fish_api/app/subprocess_run.py create mode 100644 fish_api/app/watch_idle.py create mode 100644 fish_api/start.sh create mode 100755 fish_api/start_fresh.sh create mode 100755 scripts/start_fishapi_fresh.sh diff --git a/FishMeasure/detect_refbox/prepare_refbox_dataset.py b/FishMeasure/detect_refbox/prepare_refbox_dataset.py deleted file mode 100755 index 49c4779..0000000 --- a/FishMeasure/detect_refbox/prepare_refbox_dataset.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -从递归目录中的 LabelMe 风格 JSON 构建 YOLO 检测数据集(仅参考物 ref)。 - -- 仅使用同时满足:存在对应图像、JSON 内至少有一个可转换的 ref 矩形框 的样本。 -- 图像与 JSON 同目录,或通过 imagePath 解析;若仅有 imageData 则解码写出。 -- 输出扁平唯一文件名(相对路径转 __),避免不同子目录同名帧冲突。 -""" - -from __future__ import annotations - -import argparse -import base64 -import io -import json -import random -import shutil -import sys -from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple - -IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} -# 归一到单一类别 ref(class 0) -REF_LABELS: Set[str] = {"ref", "reference", "refbox", "参考", "参考物"} - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="准备 ref 框 YOLO 数据集") - p.add_argument( - "--source", - type=Path, - default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"), - help="含递归子目录与 JSON 的根目录", - ) - p.add_argument( - "--out", - type=Path, - default=None, - help="输出数据集根目录(默认:本仓库 detect_refbox/dataset)", - ) - p.add_argument("--val-ratio", type=float, default=0.2, help="验证集比例") - p.add_argument("--seed", type=int, default=42, help="划分随机种子") - p.add_argument( - "--copy-images", - action="store_true", - help="复制图像;默认硬链接(同盘失败时回退复制)", - ) - return p.parse_args() - - -def repo_root() -> Path: - return Path(__file__).resolve().parents[1] - - -def norm_bbox_yolo( - x1: float, y1: float, x2: float, y2: float, w: int, h: int -) -> Tuple[float, float, float, float]: - x_min, y_min = min(x1, x2), min(y1, y2) - x_max, y_max = max(x1, x2), max(y1, y2) - bw = max(0.0, x_max - x_min) - bh = max(0.0, y_max - y_min) - cx = x_min + bw / 2.0 - cy = y_min + bh / 2.0 - if w <= 0 or h <= 0: - raise ValueError("invalid image size") - return cx / w, cy / h, bw / w, bh / h - - -def load_json(path: Path) -> Optional[Dict[str, Any]]: - for enc in ("utf-8", "gbk", "gb2312", "latin-1"): - try: - with open(path, "r", encoding=enc) as f: - return json.load(f) - except (UnicodeDecodeError, json.JSONDecodeError): - continue - return None - - -def find_image_path(json_path: Path, data: Dict[str, Any]) -> Optional[Path]: - ip = (data.get("imagePath") or "").strip() - if ip: - cand = (json_path.parent / ip).resolve() - if cand.exists() and cand.suffix.lower() in IMG_EXTS: - return cand - name = Path(ip).name - for ext in IMG_EXTS: - c2 = json_path.parent / f"{Path(name).stem}{ext}" - if c2.exists(): - return c2 - stem = json_path.stem - for ext in IMG_EXTS: - c = json_path.parent / f"{stem}{ext}" - if c.exists(): - return c - return None - - -def image_from_image_data(data: Dict[str, Any]) -> Optional[bytes]: - raw = data.get("imageData") - if not raw or not isinstance(raw, str): - return None - try: - return base64.b64decode(raw) - except Exception: - return None - - -def shapes_to_yolo_lines( - data: Dict[str, Any], img_w: int, img_h: int -) -> List[str]: - shapes = data.get("shapes") or data.get("annotations") or [] - if not isinstance(shapes, list): - return [] - lines: List[str] = [] - for shp in shapes: - st = shp.get("shape_type") - if st not in (None, "rectangle", "bbox", "box"): - continue - label = str(shp.get("label", "")).strip() - if label not in REF_LABELS: - continue - pts = shp.get("points") - if not pts or len(pts) < 2: - continue - (x1, y1), (x2, y2) = pts[0], pts[1] - try: - xc, yc, bw, bh = norm_bbox_yolo( - float(x1), float(y1), float(x2), float(y2), img_w, img_h - ) - except Exception: - continue - if bw <= 0 or bh <= 0: - continue - lines.append(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}") - return lines - - -def ensure_image_size( - data: Dict[str, Any], img_path: Optional[Path] -) -> Tuple[int, int]: - w = int(data.get("imageWidth", 0) or 0) - h = int(data.get("imageHeight", 0) or 0) - if w > 0 and h > 0: - return w, h - if img_path and img_path.exists(): - try: - from PIL import Image - - with Image.open(img_path) as im: - return im.size - except Exception: - pass - blob = image_from_image_data(data) - if blob: - try: - from PIL import Image - - with Image.open(io.BytesIO(blob)) as im: - return im.size - except Exception: - pass - return 0, 0 - - -def unique_stem(json_path: Path, source_root: Path) -> str: - rel = json_path.parent.relative_to(source_root) - prefix = rel.as_posix().replace("/", "__") - return f"{prefix}__{json_path.stem}" - - -def write_dataset_yaml(out_root: Path) -> Path: - yaml_path = out_root / "refbox.yaml" - # Ultralytics:path 为数据集根;train/val 为相对 path 的图像目录 - text = ( - f"path: {out_root.resolve()}\n" - "train: images/train\n" - "val: images/val\n" - "nc: 1\n" - "names:\n" - " 0: ref\n" - ) - yaml_path.write_text(text, encoding="utf-8") - return yaml_path - - -def main() -> int: - args = parse_args() - source = args.source.expanduser().resolve() - if not source.is_dir(): - print(f"[错误] 数据目录不存在: {source}", file=sys.stderr) - return 1 - - out_root = ( - args.out.expanduser().resolve() - if args.out - else repo_root() / "detect_refbox" / "dataset" - ) - img_train = out_root / "images" / "train" - img_val = out_root / "images" / "val" - lbl_train = out_root / "labels" / "train" - lbl_val = out_root / "labels" / "val" - for d in (img_train, img_val, lbl_train, lbl_val): - d.mkdir(parents=True, exist_ok=True) - - records: List[Tuple[Path, Path, List[str], str]] = [] - # (json_path, src_image_path, yolo_lines, unique_stem) - - json_files = sorted(source.rglob("*.json")) - skipped = 0 - for jp in json_files: - data = load_json(jp) - if not data: - skipped += 1 - continue - - img_path = find_image_path(jp, data) - if not img_path: - blob = image_from_image_data(data) - if not blob: - skipped += 1 - continue - ext = Path((data.get("imagePath") or "img.png")).suffix.lower() - if ext not in IMG_EXTS: - ext = ".png" - stem = unique_stem(jp, source) - tmp_img = out_root / "_tmp_decode" / f"{stem}{ext}" - tmp_img.parent.mkdir(parents=True, exist_ok=True) - tmp_img.write_bytes(blob) - img_path = tmp_img - - iw, ih = ensure_image_size(data, img_path) - if iw <= 0 or ih <= 0: - skipped += 1 - continue - - lines = shapes_to_yolo_lines(data, iw, ih) - if not lines: - skipped += 1 - continue - - stem = unique_stem(jp, source) - ext = img_path.suffix.lower() - if ext not in IMG_EXTS: - ext = ".png" - records.append((jp, img_path, lines, stem + ext)) - - if not records: - print("[错误] 没有可用样本(需 JSON + 图像 + ref 矩形)", file=sys.stderr) - return 1 - - rng = random.Random(args.seed) - rng.shuffle(records) - n_val = int(round(len(records) * args.val_ratio)) - n_val = max(1, n_val) if len(records) >= 2 else 0 - n_val = min(n_val, len(records) - 1) if len(records) >= 2 else 0 - # 仅 1 张:训练与验证共用同一张(写入两个目录),避免 YOLO 无 val - single_dup = len(records) == 1 - val_set = set(range(len(records) - n_val, len(records))) if not single_dup else set() - - n_tr = 0 - n_va = 0 - - def materialize( - src_img: Path, lines: List[str], fname: str, is_val: bool - ) -> bool: - nonlocal n_tr, n_va, skipped - idir = img_val if is_val else img_train - ldir = lbl_val if is_val else lbl_train - dst_img = idir / fname - stem = Path(fname).stem - dst_lbl = ldir / f"{stem}.txt" - dst_img.parent.mkdir(parents=True, exist_ok=True) - try: - if args.copy_images: - shutil.copy2(src_img, dst_img) - else: - if dst_img.exists(): - dst_img.unlink() - try: - dst_img.hardlink_to(src_img) - except OSError: - shutil.copy2(src_img, dst_img) - except Exception as e: - print(f"[跳过] 复制图像失败 {src_img}: {e}", file=sys.stderr) - skipped += 1 - return False - dst_lbl.write_text("\n".join(lines) + "\n", encoding="utf-8") - if is_val: - n_va += 1 - else: - n_tr += 1 - return True - - for i, (_, src_img, lines, fname) in enumerate(records): - if single_dup: - materialize(src_img, lines, fname, is_val=False) - materialize(src_img, lines, fname, is_val=True) - break - is_val = i in val_set - materialize(src_img, lines, fname, is_val=is_val) - - yaml_path = write_dataset_yaml(out_root) - tmp = out_root / "_tmp_decode" - if tmp.exists(): - shutil.rmtree(tmp, ignore_errors=True) - - print( - f"完成: 训练 {n_tr} / 验证 {n_va}(跳过 {skipped} 个 JSON)\n" - f"数据集: {out_root}\n" - f"YAML: {yaml_path}" - ) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/FishMeasure/detect_refbox/train_refbox_yolo.py b/FishMeasure/detect_refbox/train_refbox_yolo.py deleted file mode 100755 index eb9fd23..0000000 --- a/FishMeasure/detect_refbox/train_refbox_yolo.py +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -使用 Ultralytics YOLO 训练参考物(ref)单类检测模型。 - -默认权重:仓库根目录 yolo26n.pt -默认数据:先运行 prepare_refbox_dataset.py 生成 detect_refbox/dataset/refbox.yaml - -小样本建议:较轻增强、适中 epoch、较小 batch。 -""" - -from __future__ import annotations - -import argparse -import subprocess -import sys -from datetime import datetime -from pathlib import Path - - -def repo_root() -> Path: - return Path(__file__).resolve().parents[1] - - -def parse_args() -> argparse.Namespace: - root = repo_root() - p = argparse.ArgumentParser(description="训练 ref 框 YOLO(yolo26n)") - p.add_argument( - "--model", - type=Path, - default=root / "yolo26n.pt", - help="初始权重路径", - ) - p.add_argument( - "--data", - type=Path, - default=root / "detect_refbox" / "dataset" / "refbox.yaml", - help="数据集 YAML", - ) - p.add_argument( - "--source", - type=Path, - default=Path("/home/ubuntu/data/fish/2016-1-22-last_images"), - help="原始 JSON 图像根目录(仅 --prepare 时)", - ) - p.add_argument( - "--dataset-out", - type=Path, - default=root / "detect_refbox" / "dataset", - help="prepare 输出目录", - ) - p.add_argument("--prepare", action="store_true", help="训练前先执行数据集准备") - p.add_argument("--val-ratio", type=float, default=0.2, help="验证比例(prepare)") - p.add_argument("--copy-images", action="store_true", help="prepare 时复制图像") - p.add_argument("--epochs", type=int, default=200) - p.add_argument("--batch", type=int, default=8) - p.add_argument("--imgsz", type=int, default=640) - p.add_argument("--device", type=str, default="", help="如 0 或 cpu,空则自动") - p.add_argument( - "--project", - type=Path, - default=root / "detect_refbox" / "runs", - help="Ultralytics project 目录", - ) - p.add_argument("--name", type=str, default="", help="运行名,默认带时间戳") - p.add_argument("--workers", type=int, default=4) - p.add_argument("--patience", type=int, default=80) - p.add_argument("--seed", type=int, default=42) - p.add_argument("--exist-ok", action="store_true") - return p.parse_args() - - -def run_prepare(args: argparse.Namespace) -> int: - prep = Path(__file__).resolve().parent / "prepare_refbox_dataset.py" - cmd = [ - sys.executable, - str(prep), - "--source", - str(args.source), - "--out", - str(args.dataset_out), - "--val-ratio", - str(args.val_ratio), - "--seed", - str(args.seed), - ] - if args.copy_images: - cmd.append("--copy-images") - return subprocess.call(cmd) - - -def main() -> int: - args = parse_args() - if args.prepare: - rc = run_prepare(args) - if rc != 0: - return rc - - data_yaml = args.data.expanduser().resolve() - if not data_yaml.is_file(): - print( - f"[错误] 未找到 {data_yaml}。请先运行:\n" - f" python3 {Path(__file__).parent / 'prepare_refbox_dataset.py'} " - f"--source {args.source}", - file=sys.stderr, - ) - return 1 - - model_path = args.model.expanduser().resolve() - if not model_path.is_file(): - print(f"[错误] 未找到权重: {model_path}", file=sys.stderr) - return 1 - - try: - from ultralytics import YOLO - except ImportError as e: - print("[错误] 需要 ultralytics: pip install ultralytics", file=sys.stderr) - print(e, file=sys.stderr) - return 1 - - name = args.name or f"refbox_y26n_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - args.project = args.project.expanduser().resolve() - args.project.mkdir(parents=True, exist_ok=True) - - # 小数据:关闭强 mosaic/mixup,减轻过拟合风险 - model = YOLO(str(model_path)) - model.train( - data=str(data_yaml), - epochs=args.epochs, - imgsz=args.imgsz, - batch=args.batch, - device=args.device if args.device else None, - project=str(args.project), - name=name, - workers=args.workers, - patience=args.patience, - seed=args.seed, - exist_ok=args.exist_ok, - verbose=True, - mosaic=0.0, - mixup=0.0, - copy_paste=0.0, - degrees=5.0, - translate=0.08, - scale=0.25, - shear=0.0, - perspective=0.0, - flipud=0.0, - fliplr=0.5, - hsv_h=0.01, - hsv_s=0.4, - hsv_v=0.3, - close_mosaic=0, - ) - - save_dir = args.project / name - print(f"\n完成。权重目录: {save_dir / 'weights'}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/FishMeasure/fish_video_weight_evaluation.py b/FishMeasure/fish_video_weight_evaluation.py index 0af9a48..87e033f 100755 --- a/FishMeasure/fish_video_weight_evaluation.py +++ b/FishMeasure/fish_video_weight_evaluation.py @@ -5,6 +5,7 @@ Pure OpenCV + YOLO + SAM for viewing images from a folder. """ import argparse +import re import cv2 import json import numpy as np @@ -268,6 +269,7 @@ def run_weight_estimation( top_k: int = 5, top_by_length: bool = True, length_switch_to_weight_mm: float = 319.0, + ply_files: Optional[List[Path]] = None, ) -> Optional[Dict]: """ Run weight estimation on point clouds in a folder using DGCNN (test_dgcnn_weight_estimator). @@ -301,15 +303,19 @@ def run_weight_estimation( output_dir = Path(output_dir).expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) - # Find PLY files - ply_files = sorted(cloud_folder.glob("*.ply")) - if not ply_files: + # Find PLY files (optional explicit list for subsets / per-minute windows) + if ply_files is not None: + ply_list = sorted({Path(p).expanduser().resolve() for p in ply_files}) + ply_list = [p for p in ply_list if p.is_file()] + else: + ply_list = sorted(cloud_folder.glob("*.ply")) + if not ply_list: print(f" No PLY files found in: {cloud_folder}") return None if verbose: print(f"\n{'='*60}") - print(f"Running DGCNN weight estimation on {len(ply_files)} point clouds...") + print(f"Running DGCNN weight estimation on {len(ply_list)} point clouds...") print(f"{'='*60}") try: @@ -336,6 +342,7 @@ def run_weight_estimation( outlier_field="length_input", iqr_factor=1.5, zscore_threshold=2.5, + ply_files=ply_list, ) # Check CV threshold @@ -356,7 +363,7 @@ def run_weight_estimation( length_str = f"{length:.1f}mm" if np.isfinite(length) else "nan" print(f" {ply_name}: len={length_str} | {g:.2f}g{outlier_marker}") - print(f"\n Files processed: {summary.get('num_files_predicted', len(ply_files))}") + print(f"\n Files processed: {summary.get('num_files_predicted', len(ply_list))}") if summary.get('num_outliers_removed', 0) > 0: print(f" Outliers removed: {summary['num_outliers_removed']}") if cv_pct is not None and np.isfinite(cv_pct): @@ -464,6 +471,174 @@ def draw_detections(image, results, class_names, depth_stats_list=None): return image +def draw_fish_boxes_from_arrays( + image: np.ndarray, + boxes: np.ndarray, + class_ids: np.ndarray, + track_ids: Optional[np.ndarray], + class_names: Dict[int, str], + weights_by_track: Optional[Dict[int, float]] = None, +) -> np.ndarray: + """Draw boxes from numpy arrays; labels show DGCNN weight (g), not YOLO conf or depth.""" + if boxes is None or len(boxes) == 0: + return image + for i, box in enumerate(boxes): + x1, y1, x2, y2 = map(int, box) + cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) + cls_id = int(class_ids[i]) if i < len(class_ids) else 0 + cname = class_names.get(cls_id, "fish") if class_names else "fish" + tid = int(track_ids[i]) if track_ids is not None and i < len(track_ids) else -1 + wg: Optional[float] = None + if weights_by_track is not None and tid >= 0 and tid in weights_by_track: + wg = weights_by_track[tid] + if wg is not None and np.isfinite(wg): + label = f"ID:{tid} {cname} weight: {wg:.0f} g" + else: + label = f"ID:{tid} {cname} weight: -- g" + (text_w, text_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 1) + cv2.rectangle(image, (x1, y1 - text_h - baseline - 5), (x1 + text_w, y1), (0, 255, 0), -1) + cv2.putText( + image, label, (x1, y1 - baseline - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 1, cv2.LINE_AA + ) + return image + + +def draw_fish_boxes_with_weight( + image: np.ndarray, + results, + class_names: Dict[int, str], + weights_by_track: Optional[Dict[int, float]] = None, +) -> np.ndarray: + """Draw YOLO boxes with fish weight (g) per track; no confidence, no depth.""" + if results is None or results.boxes is None: + return image + boxes = results.boxes.xyxy.cpu().numpy() + class_ids = ( + results.boxes.cls.cpu().numpy().astype(int) + if results.boxes.cls is not None + else np.zeros(len(boxes), dtype=int) + ) + tid = None + if hasattr(results.boxes, "id") and results.boxes.id is not None: + tid = results.boxes.id.cpu().numpy().astype(int) + return draw_fish_boxes_from_arrays(image, boxes, class_ids, tid, class_names, weights_by_track) + + +def draw_overlay_header(image: np.ndarray, lines: List[str]) -> None: + y = 22 + for line in lines: + if not line: + continue + cv2.putText(image, line, (10, y), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 255), 2, cv2.LINE_AA) + y += 22 + + +def _parse_tid_from_ply_name(name: str) -> Optional[int]: + m = re.search(r"_tid(\d+)", name) + if not m: + return None + return int(m.group(1)) + + +def _parse_frame_num_from_ply_name(name: str) -> Optional[int]: + m = re.search(r"frame_(\d+)", name) + if not m: + return None + return int(m.group(1)) + + +def build_track_weights_minute_top5( + per_file: List[Dict[str, Any]], + *, + fps: float, + minute_interval_sec: float, +) -> Tuple[Dict[int, float], Dict[int, float], List[float]]: + """track_id -> max weight g; minute_bucket -> mean weight in window; global top-5 weights (desc).""" + tid_max: Dict[int, float] = {} + bucket_vals: Dict[int, List[float]] = {} + for it in per_file: + ply = str(it.get("ply", "")) + g = float(it.get("predicted_weight_g", float("nan"))) + if not np.isfinite(g): + continue + tid = _parse_tid_from_ply_name(Path(ply).name) + if tid is not None: + tid_max[tid] = max(tid_max.get(tid, float("-inf")), g) + fn = _parse_frame_num_from_ply_name(Path(ply).name) + if fn is None or fps <= 0: + continue + t_sec = (fn - 1) / fps + b = int(t_sec // minute_interval_sec) + bucket_vals.setdefault(b, []).append(g) + minute_avg: Dict[int, float] = {} + for b, vals in bucket_vals.items(): + if vals: + minute_avg[b] = float(np.mean(vals)) + top5 = sorted(tid_max.values(), reverse=True)[:5] + return tid_max, minute_avg, top5 + + +def finalize_preview_video_with_weights( + video_buffer: List[Dict[str, Any]], + *, + fps_video: float, + fps_timeline: float, + minute_interval_sec: float, + weights_by_track: Dict[int, float], + minute_avg: Dict[int, float], + top5: List[float], + class_names: Dict[int, str], + svo_name: str, + output_images_folder: Path, +) -> None: + """Redraw buffered frames with DGCNN weights + top-5 / per-minute lines; write side-by-side mp4.""" + if not video_buffer: + return + out_frames: List[np.ndarray] = [] + for entry in video_buffer: + left_raw = entry["left_raw"] + right = entry["right"] + frame_idx = int(entry["frame_idx"]) + boxes = np.asarray(entry["boxes"], dtype=np.float32) + cls_ids = np.asarray(entry["class_ids"], dtype=np.int64) + _tid = entry.get("track_ids") + if _tid is None: + tids = np.zeros(len(boxes), dtype=np.int64) + else: + tids = np.asarray(_tid, dtype=np.int64) + left_disp = draw_fish_boxes_from_arrays( + left_raw.copy(), + boxes, + cls_ids, + tids, + class_names, + weights_by_track=weights_by_track, + ) + bucket = int((frame_idx / max(fps_timeline, 1e-6)) // minute_interval_sec) + mav = minute_avg.get(bucket) + mav_s = f"{mav:.0f} g" if mav is not None and np.isfinite(mav) else "--" + top5_s = ", ".join(f"{w:.0f}" for w in top5 if np.isfinite(w)) if top5 else "--" + lines = [ + f"Top-5 weights (g, all fish so far): {top5_s}", + f"This {int(minute_interval_sec)}s window ~min {bucket + 1}: avg {mav_s}", + ] + draw_overlay_header(left_disp, lines) + info = f"[{frame_idx + 1}] Detections" + cv2.putText( + left_disp, info, (10, left_disp.shape[0] - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 2, cv2.LINE_AA + ) + combined = np.hstack([left_disp, right]) + out_frames.append(combined) + video_path = output_images_folder / f"{svo_name}_preview.mp4" + h, w = out_frames[0].shape[:2] + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + vw = cv2.VideoWriter(str(video_path), fourcc, float(fps_video), (w, h)) + for fr in out_frames: + vw.write(fr) + vw.release() + print(f"✓ Saved video (weight overlay): {video_path.name} ({len(out_frames)} frames @ {fps_video} fps)") + + def segment_with_sam(sam_predictor, image_bgr, boxes_xyxy, device): """Segment fish using SAM with YOLO detection boxes. Returns list of individual masks for each detection.""" @@ -1057,7 +1232,9 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de save_raw_pointclouds=False, correct_tail_rotation=False, tail_rotation_distance_threshold=5.0, tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0, - max_cv_length=None, output_dir_stem=None): + max_cv_length=None, output_dir_stem=None, + weight_overlay_video: bool = False, + minute_interval_sec: float = 60.0): """Process a single SVO2 file with pre-loaded YOLO and SAM models. Args: @@ -1090,6 +1267,10 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de svo_name = svo_path.stem class_names = yolo_model.names if hasattr(yolo_model, 'names') else {} + try: + class_names = {int(k): v for k, v in class_names.items()} + except (TypeError, ValueError): + class_names = {} out_stem = output_dir_stem if output_dir_stem else svo_name # Setup output folders @@ -1168,7 +1349,22 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de STATIONARY_THRESHOLD = 10 MOVEMENT_THRESHOLD = 5.0 - video_frames = [] + try: + cam_cfg = zed_reader.zed.get_camera_information().camera_configuration + _fps_raw = getattr(cam_cfg, "fps", None) + try: + fps_timeline = float(_fps_raw) if _fps_raw is not None else 30.0 + except (TypeError, ValueError): + fps_timeline = 30.0 + if fps_timeline <= 0: + fps_timeline = 30.0 + except Exception: + fps_timeline = 30.0 + print(f" Timeline FPS (for per-minute buckets): {fps_timeline:.2f}") + + defer_video = bool(do_weight_estimation and weight_overlay_video and not save_images) + video_frames: List[np.ndarray] = [] + video_defer_buffer: List[Dict[str, Any]] = [] idx = 0 # List to track point clouds that passed PointNet++ classifier (if enabled) @@ -1318,8 +1514,34 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de previous_boxes = None depth_stats_list = [] - # Draw detections with depth info - left_display = draw_detections(img.copy(), results, class_names, depth_stats_list) + # Left panel: DGCNN mass (g) after finalize — not YOLO conf, not depth (depth is mm to camera, not mass) + if defer_video and num_dets > 0 and results is not None and results.boxes is not None: + bx = results.boxes.xyxy.cpu().numpy() + cls_np = ( + results.boxes.cls.cpu().numpy().astype(int) + if results.boxes.cls is not None + else np.zeros(len(bx), dtype=int) + ) + tid_np = ( + results.boxes.id.cpu().numpy().astype(int) + if results.boxes.id is not None + else np.zeros(len(bx), dtype=int) + ) + video_defer_buffer.append( + { + "left_raw": img.copy(), + "right": right_display.copy(), + "frame_idx": idx, + "boxes": bx, + "class_ids": cls_np, + "track_ids": tid_np, + } + ) + left_display = img.copy() + else: + left_display = draw_fish_boxes_with_weight( + img.copy(), results, class_names, weights_by_track=None + ) info = f"[{idx + 1}] {frame_name} | Detections: {num_dets}" cv2.putText(left_display, info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA) cv2.putText(left_display, "Detection", (10, left_display.shape[0] - 20), @@ -1336,8 +1558,8 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de cv2.imwrite(str(image_path), combined_display) if idx % 30 == 0: print(f" Saved image: {image_path.name}") - else: - # Collect for video + elif not defer_video: + # Collect for video (immediate composite; no deferred weight overlay) video_frames.append(combined_display.copy()) # Save point clouds @@ -1494,10 +1716,14 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de print(f" Point cloud {fish_idx + 1}: WARNING - Tail rotation correction failed: {e}") # Continue with original points if correction fails - # Save point cloud (passed all checks) + # Save point cloud (passed all checks); include track id for DGCNN→video weight mapping filtered_count = len(points) postfix = f"_{fish_idx + 1}" if len(individual_masks) > 1 else "" - ply_path = output_cloud_folder / f"cloud_{idx+1:04d}_{frame_name}{postfix}.ply" + track_id = int(active_track_ids[fish_idx]) + ply_path = ( + output_cloud_folder + / f"cloud_{idx+1:04d}_{frame_name}_tid{track_id}{postfix}.ply" + ) write_ply_file(ply_path, points, colors) # Track point clouds that passed PointNet++ classifier (if enabled) @@ -1517,17 +1743,68 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de idx += 1 - # Create video (only if not saving individual images) - if not save_images and video_frames: - video_path = output_images_folder / f"{svo_name}_preview.mp4" - h, w = video_frames[0].shape[:2] - fps = 10.0 - fourcc = cv2.VideoWriter_fourcc(*'mp4v') - video_writer = cv2.VideoWriter(str(video_path), fourcc, fps, (w, h)) - for frame in video_frames: - video_writer.write(frame) - video_writer.release() - print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames)") + # DGCNN on saved clouds first (needed for deferred weight overlay on video) + if do_weight_estimation and kept_pointclouds: + weight_output_dir = output_base / "weight_estimation" + weight_output_dir.mkdir(parents=True, exist_ok=True) + wres = run_weight_estimation( + cloud_folder=output_cloud_folder, + output_dir=weight_output_dir, + topk_length=weight_topk_length, + remove_outliers=weight_remove_outliers, + outlier_method=weight_outlier_method, + max_cv_length=max_cv_length, + verbose=True, + top_k=weight_top_k, + top_by_length=weight_top_by_length, + length_switch_to_weight_mm=weight_length_switch_mm, + ) + if wres is None: + print(f" WARNING: Weight estimation failed") + elif do_weight_estimation and not kept_pointclouds: + if use_pointcloud_classifier: + print(f" Warning: Weight estimation requested but no point clouds passed PointNet++ classifier") + else: + print(f" Warning: Weight estimation requested but no point clouds were saved") + + # Preview video after weights (so labels show mass in g, not depth mm) + if not save_images: + if defer_video and video_defer_buffer: + wdict: Dict[int, float] = {} + mavg: Dict[int, float] = {} + top5: List[float] = [] + wjson = output_base / "weight_estimation" / "weight_estimation_results.json" + if wjson.is_file(): + try: + wr = json.loads(wjson.read_text(encoding="utf-8")) + per = wr.get("per_file") or [] + wdict, mavg, top5 = build_track_weights_minute_top5( + per, fps=fps_timeline, minute_interval_sec=minute_interval_sec + ) + except Exception as e: + print(f" WARNING: Could not parse weight results for video overlay: {e}") + finalize_preview_video_with_weights( + video_defer_buffer, + fps_video=10.0, + fps_timeline=fps_timeline, + minute_interval_sec=minute_interval_sec, + weights_by_track=wdict, + minute_avg=mavg, + top5=top5, + class_names=class_names, + svo_name=svo_name, + output_images_folder=output_images_folder, + ) + elif video_frames: + video_path = output_images_folder / f"{svo_name}_preview.mp4" + h, w = video_frames[0].shape[:2] + fps_v = 10.0 + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + video_writer = cv2.VideoWriter(str(video_path), fourcc, fps_v, (w, h)) + for frame in video_frames: + video_writer.write(frame) + video_writer.release() + print(f"✓ Saved video: {video_path.name} ({len(video_frames)} frames)") elif save_images: print(f"✓ Saved {idx} frames as individual images") @@ -1572,34 +1849,6 @@ def process_single_svo2(svo_path, output_base, yolo_model, sam_predictor, sam_de else: print(f" Note: No point clouds were saved") - # Run weight estimation if requested - if do_weight_estimation and kept_pointclouds: - # Create output directory for weight estimation - weight_output_dir = output_base / "weight_estimation" - weight_output_dir.mkdir(parents=True, exist_ok=True) - - # Run weight estimation - results = run_weight_estimation( - cloud_folder=output_cloud_folder, - output_dir=weight_output_dir, - topk_length=weight_topk_length, - remove_outliers=weight_remove_outliers, - outlier_method=weight_outlier_method, - max_cv_length=max_cv_length, - verbose=True, - top_k=weight_top_k, - top_by_length=weight_top_by_length, - length_switch_to_weight_mm=weight_length_switch_mm, - ) - - if results is None: - print(f" WARNING: Weight estimation failed") - elif do_weight_estimation and not kept_pointclouds: - if use_pointcloud_classifier: - print(f" Warning: Weight estimation requested but no point clouds passed PointNet++ classifier") - else: - print(f" Warning: Weight estimation requested but no point clouds were saved") - print(f"✓ Processed {idx} frames from {svo_path.name}") return True @@ -1626,7 +1875,9 @@ def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor save_raw_pointclouds=False, correct_tail_rotation=False, tail_rotation_distance_threshold=5.0, tail_rotation_min_tail_ratio=0.7, tail_rotation_min_angle=5.0, - max_cv_length=None, batch_svo_recursive=False): + max_cv_length=None, batch_svo_recursive=False, + weight_overlay_video: bool = False, + minute_interval_sec: float = 60.0): """Process all SVO2 files in a folder with pre-loaded YOLO and SAM models. Args: @@ -1739,6 +1990,8 @@ def process_batch_svo2_folder(svo_folder, output_base, yolo_model, sam_predictor tail_rotation_min_angle=tail_rotation_min_angle, max_cv_length=max_cv_length, output_dir_stem=output_dir_stem, + weight_overlay_video=weight_overlay_video, + minute_interval_sec=minute_interval_sec, ) if success: @@ -1865,6 +2118,18 @@ def main(): help="Remove outliers before computing average weight (default: True)") parser.add_argument("--weight-outlier-method", type=str, default="iqr", choices=["iqr", "zscore"], help="Outlier detection method for weight estimation (default: iqr)") + parser.add_argument( + "--weight-overlay-video", + action="store_true", + help="With --run-weight-estimation: defer preview mp4 until DGCNN runs, then burn in fish weight (g) per " + "track, top-5 masses, and per-window average (no YOLO confidence, no depth median on the preview).", + ) + parser.add_argument( + "--minute-interval-sec", + type=float, + default=60.0, + help="Length of each time bucket (seconds) for the on-video per-window average line (default: 60).", + ) parser.add_argument("--save-raw-pointclouds", action="store_true", help="Save point clouds to raw_pc folder before passing to PointNet++ classifier. Useful for debugging why some point clouds fail classification.") parser.add_argument("--correct-tail-rotation", action="store_true", @@ -1982,6 +2247,8 @@ def main(): tail_rotation_min_angle=args.tail_rotation_min_angle, max_cv_length=args.max_cv_length, batch_svo_recursive=args.batch_svo_recursive, + weight_overlay_video=args.weight_overlay_video, + minute_interval_sec=args.minute_interval_sec, ) return @@ -2097,47 +2364,44 @@ def main(): if args.save_raw_pointclouds: print(f"✓ Output raw point cloud folder (before classifier): {output_raw_pc_folder.resolve()}") - # Check if output folder already exists and contains point clouds - # If so, skip data generation and directly run weight estimation - if output_base.exists() and output_cloud_folder.exists(): - # Check if there are point cloud files - point_cloud_files = list(output_cloud_folder.glob("*.ply")) - if point_cloud_files and args.run_weight_estimation: - print(f"\n{'='*60}") - print(f"Output folder already exists with {len(point_cloud_files)} point cloud files") - print(f"Skipping data generation, directly running weight estimation...") - print(f"{'='*60}") - - # Load weight estimator if not already loaded - if _weight_estimator_model is None: - ckpt_path = Path(args.weight_estimator_checkpoint).expanduser().resolve() - if not load_weight_estimator(str(ckpt_path), device=args.sam_device): - print("ERROR: Failed to load weight estimator.") - return - - # Run weight estimation directly - weight_output_dir = output_base / "weight_estimation" - results = run_weight_estimation( - cloud_folder=output_cloud_folder, - output_dir=weight_output_dir, - topk_length=args.weight_topk_length, - remove_outliers=args.weight_remove_outliers, - outlier_method=args.weight_outlier_method, - max_cv_length=args.max_cv_length, - verbose=True, - top_k=args.weight_top_k, - top_by_length=args.weight_top_by_length, - length_switch_to_weight_mm=args.weight_length_switch_mm, - ) - return - elif point_cloud_files and not args.run_weight_estimation: - print(f"\n{'='*60}") - print(f"Output folder already exists with {len(point_cloud_files)} point cloud files") - print(f"Weight estimation not requested (--run-weight-estimation not set)") - print(f"Skipping processing...") - print(f"{'='*60}") - return - # If folder exists but no point clouds, continue with normal processing + if use_svo: + process_single_svo2( + str(svo_path), + str(Path(args.save_output).expanduser().resolve()), + yolo_model, + sam_predictor, + sam_device, + conf=args.conf, + imgsz=args.imgsz, + max_frames=args.max_frames, + frame_stride=args.frame_stride, + save_images=args.save_images, + filter_pointcloud=args.filter_pointcloud, + use_clustering_filter=args.use_clustering_filter, + use_density_filter=args.use_density_filter, + pointcloud_classifier=pointcloud_classifier, + use_pointcloud_classifier=args.use_pointcloud_classifier, + pointcloud_classifier_threshold=args.pointcloud_classifier_threshold, + flatness_threshold=args.flatness_threshold, + use_flatness_filter=args.use_flatness_filter, + do_weight_estimation=args.run_weight_estimation, + weight_topk_length=args.weight_topk_length, + weight_top_k=args.weight_top_k, + weight_top_by_length=args.weight_top_by_length, + weight_length_switch_mm=args.weight_length_switch_mm, + weight_remove_outliers=args.weight_remove_outliers, + weight_outlier_method=args.weight_outlier_method, + save_raw_pointclouds=args.save_raw_pointclouds, + correct_tail_rotation=args.correct_tail_rotation, + tail_rotation_distance_threshold=args.tail_rotation_distance_threshold, + tail_rotation_min_tail_ratio=args.tail_rotation_min_tail_ratio, + tail_rotation_min_angle=args.tail_rotation_min_angle, + max_cv_length=args.max_cv_length, + output_dir_stem=None, + weight_overlay_video=args.weight_overlay_video, + minute_interval_sec=args.minute_interval_sec, + ) + return else: window_name = "Fish Detection & Segmentation Preview" cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) diff --git a/FishMeasure/predict_weigth_from_svo2.py b/FishMeasure/predict_weigth_from_svo2.py index 7fd24b7..971b457 100755 --- a/FishMeasure/predict_weigth_from_svo2.py +++ b/FishMeasure/predict_weigth_from_svo2.py @@ -102,6 +102,19 @@ def _run_fish_video_evaluation_subprocess(args: argparse.Namespace, *, batch_fol cmd.append("--use-flatness-filter") cmd.extend(["--flatness-threshold", str(args.flatness_threshold)]) + if getattr(args, "fish_video_weight_overlay", False): + wck = Path(args.weight_checkpoint).expanduser().resolve() + cmd.extend( + [ + "--run-weight-estimation", + "--weight-estimator-checkpoint", + str(wck), + "--weight-overlay-video", + "--minute-interval-sec", + str(getattr(args, "minute_interval_sec", 60.0)), + ] + ) + print(f"Invoking fish_video_weight_evaluation.py:\n {' '.join(cmd)}") proc = subprocess.run(cmd, cwd=str(REPO_ROOT)) if proc.returncode != 0: @@ -277,6 +290,7 @@ def run_weight_prediction_for_svo( weight_outlier_method: str, weight_xyz_scale: float, weight_labels_json: Optional[str], + force_dgcnn_subprocess: bool = False, ) -> Dict[str, Any]: svo_path = svo_path.expanduser().resolve() if not svo_path.exists(): @@ -293,6 +307,28 @@ def run_weight_prediction_for_svo( f"fish_video_weight_evaluation.py first." ) + fish_wj = out_dir / "weight_estimation" / "weight_estimation_results.json" + if not force_dgcnn_subprocess and fish_wj.is_file(): + try: + dgcnn_data = json.loads(fish_wj.read_text(encoding="utf-8")) + if dgcnn_data.get("summary") is not None or dgcnn_data.get("per_file"): + print(f"Using existing DGCNN results from fish_video: {fish_wj}") + result = _merge_weight_prediction_json( + svo_path=svo_path, + svo_name=svo_name, + out_dir=out_dir, + cloud_dir=cloud_dir, + dgcnn_json_path=fish_wj, + dgcnn_data=dgcnn_data, + ) + (out_dir / "weight_prediction.json").write_text( + json.dumps(_sanitize_for_json(result), indent=2, ensure_ascii=False), + encoding="utf-8", + ) + return result + except Exception as e: + print(f"WARNING: Could not merge {fish_wj}, falling back to test_dgcnn subprocess: {e}") + dgcnn_path, dgcnn_data = _run_test_dgcnn_weight_estimator_subprocess( cloud_dir=cloud_dir, ply_list_file=None, @@ -468,6 +504,24 @@ def main() -> None: ) parser.add_argument("--no-reuse-existing-clouds", action="store_false", dest="reuse_existing_clouds") + parser.add_argument( + "--fish-video-weight-overlay", + action="store_true", + help="Run fish_video with DGCNN + preview video overlay (fish weight g, top-5, per-window avg). " + "Avoids a duplicate test_dgcnn pass when weight_estimation_results.json is present.", + ) + parser.add_argument( + "--minute-interval-sec", + type=float, + default=60.0, + help="Passed to fish_video --minute-interval-sec for on-video bucket stats (default: 60).", + ) + parser.add_argument( + "--force-dgcnn-subprocess", + action="store_true", + help="Always run test_dgcnn_weight_estimator.py subprocess even if fish_video left weight_estimation_results.json.", + ) + args = parser.parse_args() if args.frame_stride < 1: raise SystemExit("--frame-stride must be >= 1") @@ -551,6 +605,7 @@ def main() -> None: weight_outlier_method=args.weight_outlier_method, weight_xyz_scale=args.weight_xyz_scale, weight_labels_json=args.weight_labels_json, + force_dgcnn_subprocess=args.force_dgcnn_subprocess, ) ) except Exception as e: diff --git a/FishMeasure/segmentation/README.md b/FishMeasure/segmentation/README.md deleted file mode 100755 index 1a902e5..0000000 --- a/FishMeasure/segmentation/README.md +++ /dev/null @@ -1,99 +0,0 @@ -## Fish Body Segmentation (YOLOv8-seg) - -This folder provides a quick pipeline to train a **body-only** fish segmentation model using **Labelme polygon annotations**. - -The goal is to produce a mask that **excludes fins and tail** (as much as possible), so the depth->pointcloud becomes cleaner for weight estimation. - -### 1) Labeling in Labelme - -- Use `Labelme` polygon tool. -- Recommended class name: `body` (you can use other names; see `--classes` below). -- Each image produces a `.json` annotation file. - -### 2) Convert Labelme JSON -> YOLOv8-seg dataset - -This will create a YOLO dataset folder: - -``` -/ - images/train, images/val, images/test - labels/train, labels/val, labels/test - dataset.yaml -``` - -Example: - -```bash -python3 segmentation/prepare_yolo_seg_dataset.py \ - --source_dir /path/to/labelme_export \ - --out_dir ./datasets/fish_body_seg \ - --classes body \ - --train_ratio 0.8 --val_ratio 0.1 --test_ratio 0.1 \ - --seed 42 \ - --copy -``` - -Notes: -- YOLOv8-seg label format is: ` x1 y1 x2 y2 ... xn yn` (all normalized to [0,1]). -- If an image has no valid polygons, an empty label file will be written (you can change this later if desired). - -### 2b) Filter existing prepared dataset (if only some images are labeled) - -If you already have a prepared YOLO-seg dataset but only some images have labels, use this mode to filter and keep only labeled images: - -```bash -python3 segmentation/prepare_yolo_seg_dataset.py \ - --prepared_dataset /home/ubuntu/data/fish/fish_measure_intermediates/yolo_seg \ - --out_dir ./datasets/fish_body_seg_filtered \ - --classes body \ - --copy -``` - -This will: -- Scan `images/train/`, `images/val/`, `images/test/` for images -- Check for corresponding `.txt` label files in `labels/train/`, `labels/val/`, `labels/test/` -- Only copy/symlink images that have labels -- Generate a clean `dataset.yaml` for training - -### 3) Visualize labels (optional - verify conversion correctness) - -Before training, you can visualize the converted labels to verify they're correct: - -```bash -python3 segmentation/visualize_yolo_seg_labels.py \ - --dataset ./datasets/fish_body_seg_filtered \ - --output ./visualizations/yolo_seg_labels \ - --split train \ - --max_images 50 \ - --classes fishbody \ - --alpha 0.5 -``` - -This will: -- Load images and their corresponding `.txt` label files -- Draw polygon masks (semi-transparent overlay) on images -- Save visualized images to the output directory -- Useful for checking that Labelme → YOLO conversion preserved polygon shapes correctly - -### 4) Train YOLOv8 segmentation - -```bash -python3 segmentation/train_yolo_seg.py \ - --data ./datasets/fish_body_seg/dataset.yaml \ - --model yolov8s-seg.pt \ - --epochs 200 \ - --batch 16 \ - --imgsz 640 \ - --project runs/seg \ - --name fish_body_seg_$(date +%Y%m%d_%H%M%S) -``` - -Outputs: -- `runs/seg//weights/best.pt` - -### 5) Next step (pipeline integration) - -After training, you can run segmentation on: -- full image, or -- **cropped image from detector bbox** (often better when fish is small in the frame). - diff --git a/FishMeasure/segmentation/__init__.py b/FishMeasure/segmentation/__init__.py deleted file mode 100755 index 15f4787..0000000 --- a/FishMeasure/segmentation/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Segmentation training/inference utilities. - -Currently focused on YOLOv8-seg (Ultralytics) and Labelme polygon annotations. -""" - diff --git a/FishMeasure/segmentation/prepare_yolo_seg_dataset.py b/FishMeasure/segmentation/prepare_yolo_seg_dataset.py deleted file mode 100755 index b52ce39..0000000 --- a/FishMeasure/segmentation/prepare_yolo_seg_dataset.py +++ /dev/null @@ -1,490 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Prepare a YOLOv8-seg dataset from Labelme JSON polygon annotations OR filter an existing prepared dataset. - -Mode 1: Convert from Labelme JSONs - Input (Labelme): - - one JSON per image - - JSON contains: imagePath (recommended), imageHeight, imageWidth, shapes[] - - each shape is a polygon with: label, points[[x,y],...], shape_type="polygon" - - Example: - python3 segmentation/prepare_yolo_seg_dataset.py \ - --source_dir /data/labelme \ - --out_dir ./datasets/fish_body_seg \ - --classes body \ - --train_ratio 0.8 --val_ratio 0.1 --test_ratio 0.1 \ - --seed 42 --copy - -Mode 2: Filter existing prepared dataset (keep only images with labels) - Input: Existing YOLO-seg dataset with images/ and labels/ folders - - Only images that have corresponding .txt label files are kept - - Useful when only some images in the dataset are labeled - - Example: - python3 segmentation/prepare_yolo_seg_dataset.py \ - --prepared_dataset /home/ubuntu/data/fish/fish_measure_intermediates/yolo_seg \ - --out_dir ./datasets/fish_body_seg_filtered \ - --classes body \ - --copy - -Output (Ultralytics YOLO segmentation dataset): -/ - images/{train,val,test}/xxx.jpg - labels/{train,val,test}/xxx.txt - dataset.yaml - -Label format (YOLOv8-seg): - ... -where coordinates are normalized to [0,1] by (x/img_w, y/img_h). -""" - -from __future__ import annotations - -import argparse -import base64 -import json -import os -import random -import shutil -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import cv2 -import numpy as np - -IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"} - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="Prepare YOLOv8-seg dataset from Labelme JSONs or filter existing prepared dataset") - mode = p.add_mutually_exclusive_group(required=True) - mode.add_argument( - "--source_dir", - type=str, - default="", - help="Folder containing Labelme JSONs (and images) - use this for Labelme conversion mode", - ) - mode.add_argument( - "--prepared_dataset", - type=str, - default="", - help="Path to existing prepared YOLO-seg dataset (images/ and labels/ folders) - use this to filter/validate existing dataset", - ) - p.add_argument("--out_dir", type=str, required=True, help="Output dataset directory") - p.add_argument("--train_ratio", type=float, default=0.8) - p.add_argument("--val_ratio", type=float, default=0.1) - p.add_argument("--test_ratio", type=float, default=0.1) - p.add_argument("--seed", type=int, default=42) - - # classes - p.add_argument( - "--classes", - type=str, - default="body", - help="Comma-separated class names, e.g. 'body' or 'body,fin,tail' (order defines class_id)", - ) - p.add_argument( - "--allow_unknown_labels", - action="store_true", - help="If set, unknown labels will be ignored (default behavior is also ignore).", - ) - - # image placing - g = p.add_mutually_exclusive_group() - g.add_argument("--copy", action="store_true", help="Copy images into output dataset") - g.add_argument("--symlink", action="store_true", help="Symlink images into output dataset") - # default: hardlink - - p.add_argument( - "--skip_non_polygon", - action="store_true", - default=True, - help="Ignore non-polygon shapes (default: True)", - ) - p.add_argument( - "--drop_empty", - action="store_true", - help="Drop images with no valid polygons (default: keep with empty label file)", - ) - return p.parse_args() - - -def ensure_dirs(root: Path) -> None: - for sub in [ - "images/train", - "images/val", - "images/test", - "labels/train", - "labels/val", - "labels/test", - ]: - (root / sub).mkdir(parents=True, exist_ok=True) - - -def place_image(src: Path, dst: Path, mode: str) -> None: - dst.parent.mkdir(parents=True, exist_ok=True) - if mode == "copy": - shutil.copy2(src, dst) - elif mode == "symlink": - if dst.exists(): - dst.unlink() - os.symlink(src, dst) - else: # hardlink - if dst.exists(): - dst.unlink() - try: - os.link(src, dst) - except OSError: - shutil.copy2(src, dst) - - -def write_label(label_path: Path, lines: List[str]) -> None: - label_path.parent.mkdir(parents=True, exist_ok=True) - label_path.write_text("\n".join(lines) + ("\n" if lines else ""), encoding="utf-8") - - -def _load_image_size(image_path: Path) -> Optional[Tuple[int, int]]: - img = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED) - if img is None: - return None - h, w = img.shape[:2] - if w <= 0 or h <= 0: - return None - return int(w), int(h) - - -def _decode_labelme_image_data(image_data_b64: str) -> Optional[np.ndarray]: - try: - raw = base64.b64decode(image_data_b64.encode("utf-8")) - arr = np.frombuffer(raw, dtype=np.uint8) - img = cv2.imdecode(arr, cv2.IMREAD_COLOR) - return img - except Exception: - return None - - -def resolve_labelme_image_path(source_dir: Path, json_path: Path, meta: Dict) -> Optional[Path]: - # 1) preferred: imagePath from JSON - image_path = meta.get("imagePath", "") or "" - if image_path: - p = (json_path.parent / image_path).resolve() if not os.path.isabs(image_path) else Path(image_path) - if p.exists() and p.suffix.lower() in IMG_EXTS: - return p - - # sometimes imagePath has only basename but image lives elsewhere under source_dir - b = Path(image_path).name - found = list(source_dir.rglob(b)) - for fp in found: - if fp.exists() and fp.suffix.lower() in IMG_EXTS: - return fp - - # 2) fallback: same stem with common image extensions next to json - for ext in sorted(IMG_EXTS): - p = json_path.with_suffix(ext) - if p.exists(): - return p - - return None - - -def _normalize_polygon(points_xy: List[List[float]], w: int, h: int) -> Optional[List[Tuple[float, float]]]: - if w <= 0 or h <= 0: - return None - if not points_xy or len(points_xy) < 3: - return None - - pts: List[Tuple[float, float]] = [] - for p in points_xy: - if not isinstance(p, (list, tuple)) or len(p) != 2: - continue - x, y = float(p[0]), float(p[1]) - xn = x / float(w) - yn = y / float(h) - # clip (labelme can slightly exceed bounds) - xn = 0.0 if xn < 0.0 else (1.0 if xn > 1.0 else xn) - yn = 0.0 if yn < 0.0 else (1.0 if yn > 1.0 else yn) - pts.append((xn, yn)) - - # remove duplicated last==first (optional) - if len(pts) >= 4 and pts[0] == pts[-1]: - pts = pts[:-1] - - # ensure at least 3 unique points - uniq = list(dict.fromkeys(pts)) - if len(uniq) < 3: - return None - - return pts - - -def generate_yaml(out_dir: Path, names: List[str]) -> None: - yaml_path = out_dir / "dataset.yaml" - content = ( - f"path: {out_dir.resolve()}\n" - f"train: images/train\n" - f"val: images/val\n" - f"test: images/test\n" - f"names: {names}\n" - ) - yaml_path.write_text(content, encoding="utf-8") - print(f"[OK] wrote: {yaml_path}") - - -def is_labelme_json(meta: Dict) -> bool: - return isinstance(meta.get("shapes", None), list) - - -def find_labeled_images_in_prepared_dataset(prepared_dir: Path) -> Dict[str, List[Tuple[Path, Path]]]: - """ - Scan a prepared YOLO-seg dataset and return only images that have corresponding label files. - Supports both .txt (YOLO format) and .json (Labelme format) label files. - Returns: {"train": [(img_path, label_path), ...], "val": [...], "test": [...]} - """ - prepared_dir = prepared_dir.expanduser().resolve() - if not prepared_dir.exists(): - raise SystemExit(f"prepared_dataset not found: {prepared_dir}") - - result: Dict[str, List[Tuple[Path, Path]]] = {"train": [], "val": [], "test": []} - - for split in ["train", "val", "test"]: - img_dir = prepared_dir / "images" / split - lbl_dir = prepared_dir / "labels" / split - - if not img_dir.exists(): - print(f"[info] {split}: images directory not found: {img_dir}") - continue - - if not lbl_dir.exists(): - print(f"[info] {split}: labels directory not found: {lbl_dir}") - continue - - # find all images - img_count = 0 - lbl_count = 0 - for img_path in img_dir.iterdir(): - if img_path.suffix.lower() not in IMG_EXTS: - continue - img_count += 1 - - # check for corresponding label (.txt or .json) - lbl_path_txt = lbl_dir / f"{img_path.stem}.txt" - lbl_path_json = lbl_dir / f"{img_path.stem}.json" - - if lbl_path_txt.exists(): - result[split].append((img_path, lbl_path_txt)) - lbl_count += 1 - elif lbl_path_json.exists(): - result[split].append((img_path, lbl_path_json)) - lbl_count += 1 - # else: image has no label, skip it - - print(f"[info] {split}: found {img_count} images, {lbl_count} with labels") - - return result - - -def process_prepared_dataset(prepared_dir: Path, out_dir: Path, place_mode: str, classes: List[str]) -> None: - """Filter and copy/symlink only labeled images from a prepared dataset.""" - labeled = find_labeled_images_in_prepared_dataset(prepared_dir) - - ensure_dirs(out_dir) - - total_kept = 0 - for split in ["train", "val", "test"]: - items = labeled.get(split, []) - print(f"{split}: {len(items)} images with labels") - - for img_src, lbl_src in items: - dst_img = out_dir / f"images/{split}/{img_src.name}" - - # If source label is .json, convert to .txt format; otherwise keep as-is - if lbl_src.suffix.lower() == ".json": - # Convert Labelme JSON to YOLO .txt format - try: - meta = json.loads(lbl_src.read_text(encoding="utf-8")) - img_w = int(meta.get("imageWidth", 0) or 0) - img_h = int(meta.get("imageHeight", 0) or 0) - if img_w <= 0 or img_h <= 0: - # Try to load from image - wh = _load_image_size(img_src) - if wh is None: - print(f"[warn] cannot determine size for {img_src.name}, skipping") - continue - img_w, img_h = wh - - lines: List[str] = [] - name2id = {n: i for i, n in enumerate(classes)} - for sh in meta.get("shapes", []): - label = (sh.get("label", "") or "").strip() - if label not in name2id: - continue - shape_type = (sh.get("shape_type", "polygon") or "polygon").lower() - if shape_type != "polygon": - continue - pts = _normalize_polygon(sh.get("points", []), w=img_w, h=img_h) - if pts is None: - continue - cls_id = name2id[label] - flat = " ".join([f"{x:.6f} {y:.6f}" for x, y in pts]) - lines.append(f"{cls_id} {flat}") - - dst_lbl = out_dir / f"labels/{split}/{img_src.stem}.txt" - write_label(dst_lbl, lines) - except Exception as e: - print(f"[warn] failed to convert {lbl_src.name}: {e}") - continue - else: - # Already .txt format, just copy - dst_lbl = out_dir / f"labels/{split}/{lbl_src.name}" - try: - shutil.copy2(lbl_src, dst_lbl) - except Exception as e: - print(f"[warn] failed to copy label {lbl_src.name}: {e}") - continue - - try: - place_image(img_src, dst_img, place_mode) - total_kept += 1 - except Exception as e: - print(f"[warn] failed to place {img_src.name}: {e}") - - generate_yaml(out_dir, classes) - print(f"[done] kept={total_kept} labeled images out={out_dir}") - - -def main() -> None: - args = parse_args() - random.seed(args.seed) - - out_dir = Path(args.out_dir).expanduser().resolve() - - classes = [c.strip() for c in (args.classes or "").split(",") if c.strip()] - if not classes: - raise SystemExit("No classes provided. Example: --classes body") - - # Mode 1: Process prepared dataset (filter to only labeled images) - if args.prepared_dataset: - prepared_dir = Path(args.prepared_dataset).expanduser().resolve() - place_mode = "copy" if args.copy else ("symlink" if args.symlink else "hardlink") - process_prepared_dataset(prepared_dir, out_dir, place_mode, classes) - return - - # Mode 2: Convert from Labelme JSONs (original behavior) - source_dir = Path(args.source_dir).expanduser().resolve() - if not source_dir.exists(): - raise SystemExit(f"source_dir not found: {source_dir}") - - name2id = {n: i for i, n in enumerate(classes)} - - json_files = sorted(source_dir.rglob("*.json")) - if not json_files: - raise SystemExit(f"No .json found under: {source_dir}") - - items: List[Tuple[Path, Path, Dict]] = [] - bad = 0 - for jp in json_files: - try: - meta = json.loads(jp.read_text(encoding="utf-8")) - except Exception: - bad += 1 - continue - if not is_labelme_json(meta): - continue - img_path = resolve_labelme_image_path(source_dir, jp, meta) - if img_path is None: - # allow imageData-only workflows: decode and write next to json - if meta.get("imageData", None): - img = _decode_labelme_image_data(meta["imageData"]) - if img is not None: - # choose png - img_path = jp.with_suffix(".png") - cv2.imwrite(str(img_path), img) - else: - bad += 1 - continue - else: - bad += 1 - continue - - items.append((jp, img_path, meta)) - - if not items: - raise SystemExit(f"No valid Labelme JSON found under: {source_dir} (bad_json={bad})") - - # split - idx = list(range(len(items))) - random.shuffle(idx) - n = len(idx) - n_train = int(n * args.train_ratio) - n_val = int(n * args.val_ratio) - n_test = n - n_train - n_val - train_set = set(idx[:n_train]) - val_set = set(idx[n_train : n_train + n_val]) - test_set = set(idx[n_train + n_val :]) - - print(f"total={n} train={len(train_set)} val={len(val_set)} test={len(test_set)} bad_json={bad}") - - ensure_dirs(out_dir) - place_mode = "copy" if args.copy else ("symlink" if args.symlink else "hardlink") - - kept = 0 - dropped_empty = 0 - for i, (json_path, img_path, meta) in enumerate(items): - if i in train_set: - split = "train" - elif i in val_set: - split = "val" - else: - split = "test" - - # size - w = int(meta.get("imageWidth", 0) or 0) - h = int(meta.get("imageHeight", 0) or 0) - if w <= 0 or h <= 0: - wh = _load_image_size(img_path) - if wh is None: - continue - w, h = wh - - # shapes -> yolo seg lines - lines: List[str] = [] - for sh in meta.get("shapes", []): - label = (sh.get("label", "") or "").strip() - if label not in name2id: - # ignore unknown labels - continue - shape_type = (sh.get("shape_type", "polygon") or "polygon").lower() - if args.skip_non_polygon and shape_type != "polygon": - continue - pts = _normalize_polygon(sh.get("points", []), w=w, h=h) - if pts is None: - continue - cls_id = name2id[label] - flat = " ".join([f"{x:.6f} {y:.6f}" for x, y in pts]) - lines.append(f"{cls_id} {flat}") - - if args.drop_empty and not lines: - dropped_empty += 1 - continue - - dst_img = out_dir / f"images/{split}/{img_path.name}" - dst_lbl = out_dir / f"labels/{split}/{img_path.with_suffix('.txt').name}" - - try: - place_image(img_path, dst_img, place_mode) - except Exception: - continue - - write_label(dst_lbl, lines) - kept += 1 - - generate_yaml(out_dir, classes) - print(f"[done] kept={kept} dropped_empty={dropped_empty} out={out_dir}") - - -if __name__ == "__main__": - main() - diff --git a/FishMeasure/segmentation/train_yolo_seg.py b/FishMeasure/segmentation/train_yolo_seg.py deleted file mode 100755 index e95a528..0000000 --- a/FishMeasure/segmentation/train_yolo_seg.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Ultralytics YOLOv8 segmentation training script. - -Example (using filtered dataset): - python3 segmentation/train_yolo_seg.py \ - --data ./datasets/fish_body_seg_filtered/dataset.yaml \ - --model yolo26s-seg.pt \ - --epochs 100 \ - --batch 16 \ - --imgsz 640 \ - --project runs/seg \ - --name fish_body_seg_$(date +%Y%m%d_%H%M%S) - -Example (with more options): - python3 segmentation/train_yolo_seg.py \ - --data ./datasets/fish_body_seg_filtered/dataset.yaml \ - --model yolov8s-seg.pt \ - --epochs 300 \ - --batch 32 \ - --imgsz 640 \ - --device 0 \ - --workers 8 \ - --patience 50 \ - --pretrained \ - --cache \ - --project runs/seg \ - --name fish_body_seg_yolov8s_$(date +%Y%m%d_%H%M%S) - -Dependency: - pip install ultralytics -""" - -from __future__ import annotations - -import argparse -import os -import sys -from datetime import datetime - - -def parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="Ultralytics YOLOv8-seg training") - p.add_argument( - "--data", - type=str, - default="./datasets/fish_body_seg_filtered/dataset.yaml", - help="dataset.yaml path (default: ./datasets/fish_body_seg_filtered/dataset.yaml)", - ) - p.add_argument( - "--model", - type=str, - default="yolo26l-seg.pt", - help="model weights/arch, e.g. yolov8n-seg.pt/yolov8s-seg.pt or your .pt", - ) - p.add_argument("--epochs", type=int, default=100) - p.add_argument("--batch", type=int, default=16) - p.add_argument("--imgsz", type=int, default=640) - p.add_argument("--device", type=str, default="", help="CUDA device like '0' or '0,1'. Empty=auto") - p.add_argument("--project", type=str, default="runs/seg", help="output project dir") - p.add_argument("--name", type=str, default="", help="run name (default: model + timestamp)") - p.add_argument("--workers", type=int, default=8) - p.add_argument("--patience", type=int, default=50) - p.add_argument("--lr0", type=float, default=0.01) - p.add_argument("--pretrained", action="store_true", help="use pretrained weights") - p.add_argument("--cache", action="store_true") - p.add_argument("--seed", type=int, default=0) - p.add_argument("--exist-ok", action="store_true") - p.add_argument("--resume", action="store_true") - p.add_argument("--export", action="store_true", help="export ONNX/TorchScript after training") - return p.parse_args() - - -def main() -> None: - args = parse_args() - - try: - from ultralytics import YOLO - except Exception as e: - print("[error] ultralytics not found. Install with: pip install ultralytics") - print(f"details: {e}") - sys.exit(1) - - if not os.path.exists(args.data): - print(f"[error] dataset yaml not found: {args.data}") - sys.exit(1) - - if not args.name: - model_stem = os.path.splitext(os.path.basename(args.model))[0] - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - args.name = f"{model_stem}_{timestamp}" - - os.makedirs(args.project, exist_ok=True) - - print("======== YOLOv8-seg Train ========") - print(f"data : {args.data}") - print(f"model : {args.model}") - print(f"epochs : {args.epochs}") - print(f"batch : {args.batch}") - print(f"imgsz : {args.imgsz}") - print(f"device : {args.device or 'auto'}") - print(f"project : {args.project}") - print(f"name : {args.name}") - print("=================================") - - model = YOLO(args.model) - model.train( - data=args.data, - epochs=args.epochs, - imgsz=args.imgsz, - batch=args.batch, - device=args.device if args.device else None, - project=args.project, - name=args.name, - pretrained=args.pretrained, - cache=args.cache, - workers=args.workers, - patience=args.patience, - lr0=args.lr0, - seed=args.seed, - exist_ok=args.exist_ok, - resume=args.resume, - verbose=True, - ) - - save_dir = os.path.join(args.project, args.name) - best_pt = os.path.join(save_dir, "weights", "best.pt") - last_pt = os.path.join(save_dir, "weights", "last.pt") - print("\n======== Train done ========") - print(f"save_dir : {save_dir}") - if os.path.exists(best_pt): - print(f"best.pt : {best_pt}") - if os.path.exists(last_pt): - print(f"last.pt : {last_pt}") - - if args.export and os.path.exists(best_pt): - try: - exp = YOLO(best_pt) - onnx_path = exp.export(format="onnx", imgsz=args.imgsz) - ts_path = exp.export(format="torchscript", imgsz=args.imgsz) - print(f"export onnx : {onnx_path}") - print(f"export torchscript: {ts_path}") - except Exception as e: - print(f"[warn] export failed: {e}") - - -if __name__ == "__main__": - main() - diff --git a/FishMeasure/segmentation/visualize_yolo_seg_labels.py b/FishMeasure/segmentation/visualize_yolo_seg_labels.py deleted file mode 100755 index 4c0e94e..0000000 --- a/FishMeasure/segmentation/visualize_yolo_seg_labels.py +++ /dev/null @@ -1,215 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Visualize YOLOv8-seg labels by drawing polygon masks on images. - -This helps verify that label conversion (e.g., from Labelme JSON to YOLO .txt) is correct. - -Example: - python3 segmentation/visualize_yolo_seg_labels.py \ - --dataset ./datasets/fish_body_seg_filtered \ - --output ./visualizations/yolo_seg_labels \ - --max_images 50 \ - --split train -""" - -from __future__ import annotations - -import argparse -import random -from pathlib import Path -from typing import List, Tuple - -import cv2 -import numpy as np - - -def parse_yolo_seg_label(label_path: Path, img_w: int, img_h: int) -> List[np.ndarray]: - """ - Parse YOLO segmentation label file. - Returns list of polygons (each as Nx2 numpy array in pixel coordinates). - """ - if not label_path.exists(): - return [] - - polygons: List[np.ndarray] = [] - lines = label_path.read_text(encoding="utf-8").strip().split("\n") - - for line in lines: - line = line.strip() - if not line: - continue - - parts = line.split() - if len(parts) < 7: # class_id + at least 3 points (x,y pairs) - continue - - try: - _class_id = int(parts[0]) - coords = [float(x) for x in parts[1:]] - if len(coords) % 2 != 0: - continue - - # Convert normalized [0,1] to pixel coordinates - points = [] - for i in range(0, len(coords), 2): - x_norm = coords[i] - y_norm = coords[i + 1] - x_px = int(x_norm * img_w) - y_px = int(y_norm * img_h) - points.append([x_px, y_px]) - - if len(points) >= 3: - polygons.append(np.array(points, dtype=np.int32)) - except (ValueError, IndexError): - continue - - return polygons - - -def draw_polygons_on_image( - img: np.ndarray, polygons: List[np.ndarray], class_colors: List[Tuple[int, int, int]], alpha: float = 0.5 -) -> np.ndarray: - """ - Draw polygons as semi-transparent masks on image. - Returns a new image with overlays. - """ - overlay = img.copy() - mask = np.zeros(img.shape[:2], dtype=np.uint8) - - for i, poly in enumerate(polygons): - color_idx = i % len(class_colors) - color = class_colors[color_idx] - cv2.fillPoly(mask, [poly], 255) - cv2.fillPoly(overlay, [poly], color) - # Also draw outline - cv2.polylines(overlay, [poly], isClosed=True, color=color, thickness=2) - - # Blend overlay with original - result = cv2.addWeighted(overlay, alpha, img, 1.0 - alpha, 0) - return result - - -def visualize_dataset( - dataset_dir: Path, - output_dir: Path, - split: str = "train", - max_images: int = 50, - class_names: List[str] = None, - alpha: float = 0.5, -) -> None: - """ - Visualize YOLO segmentation labels on images. - """ - dataset_dir = dataset_dir.expanduser().resolve() - if not dataset_dir.exists(): - raise SystemExit(f"dataset_dir not found: {dataset_dir}") - - img_dir = dataset_dir / "images" / split - lbl_dir = dataset_dir / "labels" / split - - if not img_dir.exists(): - raise SystemExit(f"images directory not found: {img_dir}") - if not lbl_dir.exists(): - raise SystemExit(f"labels directory not found: {lbl_dir}") - - # Collect image-label pairs - pairs: List[Tuple[Path, Path]] = [] - for img_path in sorted(img_dir.iterdir()): - if img_path.suffix.lower() not in {".jpg", ".jpeg", ".png", ".bmp"}: - continue - lbl_path = lbl_dir / f"{img_path.stem}.txt" - if lbl_path.exists(): - pairs.append((img_path, lbl_path)) - - if not pairs: - raise SystemExit(f"No image-label pairs found in {split} split") - - # Limit number of images - if max_images > 0 and len(pairs) > max_images: - random.seed(42) - pairs = random.sample(pairs, max_images) - - print(f"Visualizing {len(pairs)} images from {split} split...") - - # Generate colors for classes (BGR format for OpenCV) - if class_names is None: - class_names = ["class0", "class1", "class2"] - colors = [ - (0, 255, 0), # green - (255, 0, 0), # blue - (0, 0, 255), # red - (255, 255, 0), # cyan - (255, 0, 255), # magenta - (0, 255, 255), # yellow - ] - class_colors = colors[: len(class_names)] - - output_dir = output_dir.expanduser().resolve() - output_dir.mkdir(parents=True, exist_ok=True) - - for img_path, lbl_path in pairs: - # Load image - img = cv2.imread(str(img_path)) - if img is None: - print(f"[warn] failed to load: {img_path}") - continue - - h, w = img.shape[:2] - - # Parse labels - polygons = parse_yolo_seg_label(lbl_path, w, h) - - if not polygons: - print(f"[warn] no polygons found in: {lbl_path}") - # Still save the original image for reference - out_path = output_dir / f"{img_path.stem}_no_labels{img_path.suffix}" - cv2.imwrite(str(out_path), img) - continue - - # Draw polygons - vis_img = draw_polygons_on_image(img, polygons, class_colors, alpha=alpha) - - # Save visualization - out_path = output_dir / f"{img_path.stem}_vis{img_path.suffix}" - cv2.imwrite(str(out_path), vis_img) - - print(f"[done] saved {len(pairs)} visualizations to: {output_dir}") - - -def main() -> None: - parser = argparse.ArgumentParser(description="Visualize YOLOv8-seg labels on images") - parser.add_argument("--dataset", type=str, required=True, help="Path to YOLO-seg dataset directory") - parser.add_argument("--output", type=str, required=True, help="Output directory for visualizations") - parser.add_argument( - "--split", type=str, default="train", choices=["train", "val", "test"], help="Dataset split to visualize" - ) - parser.add_argument("--max_images", type=int, default=50, help="Maximum number of images to visualize (0=all)") - parser.add_argument( - "--classes", - type=str, - default="fishbody", - help="Comma-separated class names (for color assignment, e.g. 'fishbody' or 'body,fin,tail')", - ) - parser.add_argument( - "--alpha", type=float, default=0.5, help="Transparency of mask overlay (0.0=transparent, 1.0=opaque)" - ) - args = parser.parse_args() - - class_names = [c.strip() for c in (args.classes or "").split(",") if c.strip()] - if not class_names: - class_names = ["class0"] - - visualize_dataset( - dataset_dir=Path(args.dataset), - output_dir=Path(args.output), - split=args.split, - max_images=args.max_images, - class_names=class_names, - alpha=args.alpha, - ) - - -if __name__ == "__main__": - main() diff --git a/fish_api/app/action_watch_cli.py b/fish_api/app/action_watch_cli.py index a3a4dc8..254fccf 100644 --- a/fish_api/app/action_watch_cli.py +++ b/fish_api/app/action_watch_cli.py @@ -5,12 +5,14 @@ from __future__ import annotations import asyncio import sys +from app.db import init_db from app.services.action_watch import run_action_watch_loop from app.settings import get_settings def main() -> None: s = get_settings() + init_db(s) if s.action_watch_dir is None: print( "未配置监控目录:请在 .env 中设置 ACTION_WATCH_DIR=/你的/mp4/目录", diff --git a/fish_api/app/db.py b/fish_api/app/db.py new file mode 100644 index 0000000..d537b28 --- /dev/null +++ b/fish_api/app/db.py @@ -0,0 +1,421 @@ +"""仅 FastAPI 进程使用 SQLite:落库测量/健康结果与 watch 已处理路径。 + +FishMeasure / FishAction 子进程不连接、不依赖本库;它们只读写各自文件(如 output 下 +weight_prediction.json、临时 pred.json 等),由 fish_api 在子进程结束后读文件并写入本表。 +视频仍使用 measure_output_root、media_root 等原路径。 +""" + +from __future__ import annotations + +import json +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Optional, Set, Tuple + +from app.settings import Settings +from app.state import HealthSnapshot, MeasureSnapshot + + +def _connect(path: Path) -> sqlite3.Connection: + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(path), check_same_thread=False, isolation_level=None) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + +def init_db(settings: Settings) -> None: + conn = _connect(settings.sqlite_path) + try: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS measure_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TEXT NOT NULL, + result_json TEXT NOT NULL, + video_left TEXT NOT NULL DEFAULT '', + video_right TEXT NOT NULL DEFAULT '', + error TEXT, + raw_prediction_path TEXT, + source_path TEXT + ); + + CREATE TABLE IF NOT EXISTS health_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_at TEXT NOT NULL, + behavior_result TEXT NOT NULL DEFAULT '', + health_result TEXT NOT NULL DEFAULT '', + raw_class_en TEXT NOT NULL DEFAULT '', + error TEXT, + source_path TEXT + ); + + CREATE TABLE IF NOT EXISTS watch_processed ( + path TEXT NOT NULL, + kind TEXT NOT NULL CHECK (kind IN ('measure', 'action')), + PRIMARY KEY (path, kind) + ); + + CREATE TABLE IF NOT EXISTS delivery_cursor ( + kind TEXT PRIMARY KEY CHECK (kind IN ('measure', 'health')), + last_delivered_id INTEGER NOT NULL DEFAULT 0 + ); + """ + ) + _ensure_delivery_cursors(conn) + finally: + conn.close() + + +def _ensure_delivery_cursors(conn: sqlite3.Connection) -> None: + """为每条流插入一行游标;首次插入时 last_delivered_id=当前 MAX(id),避免升级后逐条投递历史快照。""" + for kind, table in ( + ("measure", "measure_snapshots"), + ("health", "health_snapshots"), + ): + row = conn.execute( + "SELECT 1 FROM delivery_cursor WHERE kind = ?", (kind,) + ).fetchone() + if row is None: + mid = conn.execute( + f"SELECT COALESCE(MAX(id), 0) FROM {table}" + ).fetchone()[0] + conn.execute( + "INSERT INTO delivery_cursor (kind, last_delivered_id) VALUES (?, ?)", + (kind, int(mid)), + ) + conn.commit() + + +def save_measure_snapshot( + settings: Settings, + snap: MeasureSnapshot, + source_path: Optional[str] = None, +) -> None: + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + ts = ( + snap.updated_at.isoformat() + if snap.updated_at + else datetime.now(timezone.utc).isoformat() + ) + conn.execute( + """ + INSERT INTO measure_snapshots ( + created_at, result_json, video_left, video_right, + error, raw_prediction_path, source_path + ) VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + ts, + json.dumps(snap.result, ensure_ascii=False), + snap.video_left, + snap.video_right, + snap.error, + snap.raw_prediction_path, + source_path, + ), + ) + finally: + conn.close() + + +def save_health_snapshot( + settings: Settings, + snap: HealthSnapshot, + source_path: Optional[str] = None, +) -> None: + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + ts = ( + snap.updated_at.isoformat() + if snap.updated_at + else datetime.now(timezone.utc).isoformat() + ) + conn.execute( + """ + INSERT INTO health_snapshots ( + created_at, behavior_result, health_result, + raw_class_en, error, source_path + ) VALUES (?, ?, ?, ?, ?, ?) + """, + ( + ts, + snap.behavior_result, + snap.health_result, + snap.raw_class_en, + snap.error, + source_path, + ), + ) + finally: + conn.close() + + +def _parse_dt(s: Optional[str]) -> Optional[datetime]: + if not s: + return None + try: + return datetime.fromisoformat(s.replace("Z", "+00:00")) + except ValueError: + return None + + +def get_latest_measure(settings: Settings) -> MeasureSnapshot: + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + row = conn.execute( + """ + SELECT created_at, result_json, video_left, video_right, + error, raw_prediction_path + FROM measure_snapshots + ORDER BY id DESC + LIMIT 1 + """ + ).fetchone() + if row is None: + return MeasureSnapshot(result=[], video_left="", video_right="") + data: Any = json.loads(row["result_json"]) + if not isinstance(data, list): + data = [] + return MeasureSnapshot( + result=data, + video_left=row["video_left"] or "", + video_right=row["video_right"] or "", + updated_at=_parse_dt(row["created_at"]), + error=row["error"], + raw_prediction_path=row["raw_prediction_path"], + ) + finally: + conn.close() + + +def get_latest_health(settings: Settings) -> HealthSnapshot: + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + row = conn.execute( + """ + SELECT created_at, behavior_result, health_result, + raw_class_en, error + FROM health_snapshots + ORDER BY id DESC + LIMIT 1 + """ + ).fetchone() + if row is None: + return HealthSnapshot(behavior_result="", health_result="") + return HealthSnapshot( + behavior_result=row["behavior_result"] or "", + health_result=row["health_result"] or "", + updated_at=_parse_dt(row["created_at"]), + error=row["error"], + raw_class_en=row["raw_class_en"] or "", + ) + finally: + conn.close() + + +def _last_delivered_id( + conn: sqlite3.Connection, kind: str, snapshots_table: str +) -> int: + row = conn.execute( + "SELECT last_delivered_id FROM delivery_cursor WHERE kind = ?", (kind,) + ).fetchone() + if row is not None: + return int(row["last_delivered_id"]) + mid = conn.execute( + f"SELECT COALESCE(MAX(id), 0) FROM {snapshots_table}" + ).fetchone()[0] + conn.execute( + "INSERT INTO delivery_cursor (kind, last_delivered_id) VALUES (?, ?)", + (kind, int(mid)), + ) + return int(mid) + + +def pop_next_measure( + settings: Settings, +) -> Tuple[MeasureSnapshot, bool, Optional[int]]: + """取队首未投递的 measure 快照并推进游标;无未投递时 has_new=False。""" + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + conn.execute("BEGIN IMMEDIATE") + last_id = _last_delivered_id(conn, "measure", "measure_snapshots") + + next_row = conn.execute( + """ + SELECT id, created_at, result_json, video_left, video_right, + error, raw_prediction_path + FROM measure_snapshots + WHERE id > ? + ORDER BY id ASC + LIMIT 1 + """, + (last_id,), + ).fetchone() + + if next_row is None: + conn.commit() + return MeasureSnapshot(result=[], video_left="", video_right=""), False, None + + nid = int(next_row["id"]) + conn.execute( + "UPDATE delivery_cursor SET last_delivered_id = ? WHERE kind = ?", + (nid, "measure"), + ) + conn.commit() + + data: Any = json.loads(next_row["result_json"]) + if not isinstance(data, list): + data = [] + snap = MeasureSnapshot( + result=data, + video_left=next_row["video_left"] or "", + video_right=next_row["video_right"] or "", + updated_at=_parse_dt(next_row["created_at"]), + error=next_row["error"], + raw_prediction_path=next_row["raw_prediction_path"], + ) + return snap, True, nid + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def pop_next_health(settings: Settings) -> Tuple[HealthSnapshot, bool, Optional[int]]: + """取队首未投递的 health 快照并推进游标;无未投递时 has_new=False。""" + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + conn.execute("BEGIN IMMEDIATE") + last_id = _last_delivered_id(conn, "health", "health_snapshots") + + next_row = conn.execute( + """ + SELECT id, created_at, behavior_result, health_result, + raw_class_en, error + FROM health_snapshots + WHERE id > ? + ORDER BY id ASC + LIMIT 1 + """, + (last_id,), + ).fetchone() + + if next_row is None: + conn.commit() + return HealthSnapshot(behavior_result="", health_result=""), False, None + + nid = int(next_row["id"]) + conn.execute( + "UPDATE delivery_cursor SET last_delivered_id = ? WHERE kind = ?", + (nid, "health"), + ) + conn.commit() + + snap = HealthSnapshot( + behavior_result=next_row["behavior_result"] or "", + health_result=next_row["health_result"] or "", + updated_at=_parse_dt(next_row["created_at"]), + error=next_row["error"], + raw_class_en=next_row["raw_class_en"] or "", + ) + return snap, True, nid + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +def _load_json_processed_set(path: Path) -> Set[str]: + if not path.is_file(): + return set() + try: + with open(path, encoding="utf-8") as f: + data: Any = json.load(f) + if isinstance(data, list): + return set(str(x) for x in data) + if isinstance(data, dict) and "processed" in data: + return set(str(x) for x in data["processed"]) + except (json.JSONDecodeError, OSError): + pass + return set() + + +def load_watch_processed(settings: Settings, state_file: Path, kind: str) -> Set[str]: + """从 SQLite 读取已处理路径;若存在旧版 JSON 状态文件则合并导入(幂等)。""" + assert kind in ("measure", "action") + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + for p in _load_json_processed_set(state_file): + conn.execute( + "INSERT OR IGNORE INTO watch_processed (path, kind) VALUES (?, ?)", + (p, kind), + ) + conn.commit() + cur = conn.execute( + "SELECT path FROM watch_processed WHERE kind = ?", (kind,) + ) + return {r[0] for r in cur} + finally: + conn.close() + + +def add_watch_processed(settings: Settings, path: str, kind: str) -> None: + assert kind in ("measure", "action") + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + conn.execute( + "INSERT OR IGNORE INTO watch_processed (path, kind) VALUES (?, ?)", + (path, kind), + ) + conn.commit() + finally: + conn.close() + + +def remove_sqlite_database_files(settings: Settings) -> None: + """删除 SQLite 主库及 WAL/SHM 副文件;不存在则忽略。下次 init_db 会重建空库。""" + base = settings.sqlite_path.resolve() + for p in (base, Path(str(base) + "-wal"), Path(str(base) + "-shm")): + try: + if p.is_file(): + p.unlink() + except OSError: + pass + + +def clear_watch_cache_and_snapshots(settings: Settings) -> None: + """清空 watch 已处理路径与对应快照,便于重新跑推理(与 measure/action_watch 的 use_state_file 开关一致)。""" + init_db(settings) + conn = _connect(settings.sqlite_path) + try: + if settings.measure_watch_use_state_file: + conn.execute("DELETE FROM watch_processed WHERE kind = ?", ("measure",)) + conn.execute("DELETE FROM measure_snapshots") + conn.execute( + "UPDATE delivery_cursor SET last_delivered_id = 0 WHERE kind = ?", + ("measure",), + ) + if settings.action_watch_use_state_file: + conn.execute("DELETE FROM watch_processed WHERE kind = ?", ("action",)) + conn.execute("DELETE FROM health_snapshots") + conn.execute( + "UPDATE delivery_cursor SET last_delivered_id = 0 WHERE kind = ?", + ("health",), + ) + conn.commit() + finally: + conn.close() diff --git a/fish_api/app/logging_config.py b/fish_api/app/logging_config.py new file mode 100644 index 0000000..fac46f6 --- /dev/null +++ b/fish_api/app/logging_config.py @@ -0,0 +1,82 @@ +"""将 loguru 与标准 logging 桥接,使 uvicorn / FastAPI / Starlette 日志走同一套格式。""" + +from __future__ import annotations + +import json +import logging +import os +import sys +from typing import Any + +from loguru import logger + +_loguru_sink_configured = False + + +def format_json_pretty( + data: Any, + *, + indent: int = 2, + max_chars: int | None = 24_000, +) -> str: + """将对象格式化为带缩进、保留中文的 JSON 字符串,供 loguru 多行输出。 + + ``max_chars`` 用于避免单条日志过大;超长时截断并标注。 + """ + try: + s = json.dumps(data, ensure_ascii=False, indent=indent, default=str) + except (TypeError, ValueError): + return repr(data) + if max_chars is not None and len(s) > max_chars: + return s[:max_chars] + "\n... (truncated, max_chars=" + str(max_chars) + ")" + return s + + +class InterceptHandler(logging.Handler): + """将标准 logging 记录转发到 loguru。""" + + def emit(self, record: logging.LogRecord) -> None: + try: + level: str | int = logger.level(record.levelname).name + except ValueError: + level = record.levelno + logger.opt(depth=6, exception=record.exc_info).log(level, record.getMessage()) + + +def setup_logging() -> None: + """配置 loguru sink,并把 uvicorn / fastapi 等 logger 接到 loguru。 + + 可在模块 import 时与 lifespan 启动时各调用一次;后者用于覆盖 uvicorn 启动后写入的 handler。 + """ + global _loguru_sink_configured + level = os.environ.get("LOG_LEVEL", "INFO").upper() + if not _loguru_sink_configured: + logger.remove() + logger.add( + sys.stderr, + level=level, + colorize=sys.stderr.isatty(), + format=( + "{time:YYYY-MM-DD HH:mm:ss.SSS} | " + "{level: <8} | " + "{message}" + ), + ) + _loguru_sink_configured = True + + intercept = InterceptHandler() + logging.root.handlers = [intercept] + logging.root.setLevel(logging.DEBUG) + + for name in ( + "uvicorn", + "uvicorn.error", + "uvicorn.access", + "uvicorn.asgi", + "fastapi", + "starlette", + "starlette.requests", + ): + lg = logging.getLogger(name) + lg.handlers.clear() + lg.propagate = True diff --git a/fish_api/app/main.py b/fish_api/app/main.py index 76febdd..dded7e2 100644 --- a/fish_api/app/main.py +++ b/fish_api/app/main.py @@ -6,15 +6,21 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from app.logging_config import setup_logging +from app.db import init_db from app.routers import biomass, ingest from app.services.action_watch import run_action_watch_loop from app.services.measure_watch import run_measure_watch_loop from app.settings import get_settings +setup_logging() + @asynccontextmanager async def lifespan(app: FastAPI): + setup_logging() s = get_settings() + init_db(s) s.media_root.mkdir(parents=True, exist_ok=True) s.stream_tmp_dir.mkdir(parents=True, exist_ok=True) tasks: list[asyncio.Task[None]] = [] diff --git a/fish_api/app/routers/biomass.py b/fish_api/app/routers/biomass.py index cdb9a6a..20260b4 100644 --- a/fish_api/app/routers/biomass.py +++ b/fish_api/app/routers/biomass.py @@ -1,55 +1,101 @@ from __future__ import annotations -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from starlette.responses import JSONResponse -from app.state import app_state +from app.db import pop_next_health, pop_next_measure +from app.settings import Settings, get_settings router = APIRouter(prefix="/api/v1/biomass", tags=["biomass"]) +# 是否有新快照被本次 GET 消费(1/0);body 保持与客户端约定字段一致,不写入 has_new。 +HEADER_BIOMASS_NEW = "X-Fish-Biomass-New" + + +def _new_headers(has_new: bool) -> dict[str, str]: + return {HEADER_BIOMASS_NEW: "1" if has_new else "0"} + @router.get("/real/camera/") -async def get_real_camera(): - """双目实时结果(轮询最新一次 FishMeasure 完成快照)。""" - m = app_state.last_measure - if m.error: - return { - "code": 500, - "msg": m.error, - "data": { - "result": [], - "video_left": "", - "video_right": "", +async def get_real_camera(settings: Settings = Depends(get_settings)): + """双目实时结果:每次 GET 投递下一条未消费的 FishMeasure 快照(SQLite 游标)。""" + m, has_new, _ = pop_next_measure(settings) + if not has_new: + return JSONResponse( + content={ + "code": 200, + "msg": "成功", + "data": { + "result": [], + "video_left": "", + "video_right": "", + }, + }, + headers=_new_headers(False), + ) + if m.error: + return JSONResponse( + content={ + "code": 500, + "msg": m.error, + "data": { + "result": [], + "video_left": "", + "video_right": "", + }, + }, + headers=_new_headers(True), + ) + return JSONResponse( + content={ + "code": 200, + "msg": "成功", + "data": { + "result": m.result, + "video_left": m.video_left, + "video_right": m.video_right, }, - } - return { - "code": 200, - "msg": "成功", - "data": { - "result": m.result, - "video_left": m.video_left, - "video_right": m.video_right, }, - } + headers=_new_headers(True), + ) @router.get("/health/result/") -async def get_health_result(): - """行为 / 健康结果(轮询最新一次 FishAction 完成快照)。""" - h = app_state.last_health - if h.error: - return { - "code": 500, - "msg": h.error, - "data": { - "behavior_result": "", - "health_result": "", +async def get_health_result(settings: Settings = Depends(get_settings)): + """行为 / 健康结果:每次 GET 投递下一条未消费的 FishAction 快照(SQLite 游标)。""" + h, has_new, _ = pop_next_health(settings) + if not has_new: + return JSONResponse( + content={ + "code": 200, + "msg": "成功", + "data": { + "behavior_result": "", + "health_result": "", + }, + }, + headers=_new_headers(False), + ) + if h.error: + return JSONResponse( + content={ + "code": 500, + "msg": h.error, + "data": { + "behavior_result": "", + "health_result": "", + }, + }, + headers=_new_headers(True), + ) + return JSONResponse( + content={ + "code": 200, + "msg": "成功", + "data": { + "behavior_result": h.behavior_result, + "health_result": h.health_result, }, - } - return { - "code": 200, - "msg": "成功", - "data": { - "behavior_result": h.behavior_result, - "health_result": h.health_result, }, - } + headers=_new_headers(True), + ) diff --git a/fish_api/app/routers/ingest.py b/fish_api/app/routers/ingest.py index 35d235a..39794fe 100644 --- a/fish_api/app/routers/ingest.py +++ b/fish_api/app/routers/ingest.py @@ -5,6 +5,7 @@ from pathlib import Path from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, Response +from app.db import save_health_snapshot, save_measure_snapshot from app.deps import require_ingest_auth from app.services import action as action_svc from app.services import measure as measure_svc @@ -27,14 +28,18 @@ async def _measure_job_serial(svo_path: Path, settings: Settings) -> None: snap = await asyncio.to_thread( measure_svc.run_full_measure, svo_path, settings ) - app_state.last_measure = snap + save_measure_snapshot(settings, snap, source_path=str(svo_path.resolve())) app_state.measure_status = "idle" except Exception as e: - app_state.last_measure = MeasureSnapshot( - result=[], - video_left="", - video_right="", - error=str(e), + save_measure_snapshot( + settings, + MeasureSnapshot( + result=[], + video_left="", + video_right="", + error=str(e), + ), + source_path=str(svo_path.resolve()), ) app_state.measure_status = "error" @@ -46,13 +51,17 @@ async def _action_job_serial(mp4_path: Path, settings: Settings) -> None: snap = await asyncio.to_thread( action_svc.run_full_action, mp4_path, settings ) - app_state.last_health = snap + save_health_snapshot(settings, snap, source_path=str(mp4_path.resolve())) app_state.action_status = "idle" except Exception as e: - app_state.last_health = HealthSnapshot( - behavior_result="", - health_result="", - error=str(e), + save_health_snapshot( + settings, + HealthSnapshot( + behavior_result="", + health_result="", + error=str(e), + ), + source_path=str(mp4_path.resolve()), ) app_state.action_status = "error" diff --git a/fish_api/app/services/action.py b/fish_api/app/services/action.py index 7cfcae4..135d307 100644 --- a/fish_api/app/services/action.py +++ b/fish_api/app/services/action.py @@ -2,14 +2,16 @@ from __future__ import annotations import json import os -import subprocess import sys import tempfile from datetime import datetime, timezone from pathlib import Path +from app.logging_config import format_json_pretty from app.settings import Settings from app.state import HealthSnapshot +from app.subprocess_run import run_subprocess_with_log +from loguru import logger BEHAVIOR_EN_TO_ZH = { "feeding": "吃饵", @@ -67,15 +69,14 @@ def run_action_subprocess(mp4_path: Path, settings: Settings) -> str: "--log_interval", "0", ] - proc = subprocess.run( + proc = run_subprocess_with_log( cmd, cwd=str(settings.fish_action_root), env=os.environ.copy(), - capture_output=True, - text=True, + log_name="FishAction", ) if proc.returncode != 0: - err = (proc.stderr or "") + (proc.stdout or "") + err = proc.stdout or "" raise RuntimeError( f"predict_video_x3d_3class.py failed ({proc.returncode}): {err[-4000:]}" ) @@ -86,6 +87,10 @@ def run_action_subprocess(mp4_path: Path, settings: Settings) -> str: if not rows: raise RuntimeError("Empty prediction JSON") pred_en = str(rows[0].get("pred_3class", "")).strip().lower() + logger.info( + "[FishAction] prediction row:\n{}", + format_json_pretty(rows[0]), + ) if pred_en not in BEHAVIOR_EN_TO_ZH: raise RuntimeError(f"Unexpected pred_3class: {pred_en!r}") return pred_en @@ -95,9 +100,17 @@ def run_action_subprocess(mp4_path: Path, settings: Settings) -> str: def run_full_action(mp4_path: Path, settings: Settings) -> HealthSnapshot: + logger.info("[FishAction] start mp4={}", mp4_path.resolve()) pred_en = run_action_subprocess(mp4_path, settings) zh = BEHAVIOR_EN_TO_ZH[pred_en] health = behavior_to_health(pred_en) + logger.info( + "[FishAction] done mp4={} pred_3class={} behavior_zh={} health={}", + mp4_path.name, + pred_en, + zh, + health, + ) return HealthSnapshot( behavior_result=zh, health_result=health, diff --git a/fish_api/app/services/action_watch.py b/fish_api/app/services/action_watch.py index 272ec76..26222bc 100644 --- a/fish_api/app/services/action_watch.py +++ b/fish_api/app/services/action_watch.py @@ -1,14 +1,20 @@ from __future__ import annotations import asyncio -import json -import traceback from pathlib import Path -from typing import Any, Dict, Set +from typing import Dict, Set +from loguru import logger + +from app.db import add_watch_processed, load_watch_processed, save_health_snapshot from app.services import action as action_svc from app.settings import Settings from app.state import HealthSnapshot, app_state +from app.watch_idle import IdleWatchWarnState, idle_warn_interval_sec, maybe_warn_idle_watch + +_ACTION_IDLE_WARN_INTERVAL_SEC = idle_warn_interval_sec( + "FISH_ACTION_WATCH_IDLE_WARN_INTERVAL_SEC" +) def _state_path(settings: Settings) -> Path: @@ -18,29 +24,6 @@ def _state_path(settings: Settings) -> Path: return settings.action_watch_dir / ".fishaction_watch_processed.json" -def load_processed(path: Path) -> Set[str]: - if not path.is_file(): - return set() - try: - with open(path, encoding="utf-8") as f: - data: Any = json.load(f) - if isinstance(data, list): - return set(str(x) for x in data) - if isinstance(data, dict) and "processed" in data: - return set(str(x) for x in data["processed"]) - except (json.JSONDecodeError, OSError): - pass - return set() - - -def save_processed(path: Path, processed: Set[str]) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(path.suffix + ".tmp") - with open(tmp, "w", encoding="utf-8") as f: - json.dump(sorted(processed), f, indent=0, ensure_ascii=False) - tmp.replace(path) - - def iter_mp4(watch_dir: Path, recursive: bool) -> list[Path]: if recursive: return sorted( @@ -64,27 +47,30 @@ async def _run_inference_and_state( key = str(mp4.resolve()) if key in processed: return - print(f"[action-watch] inference: {mp4}", flush=True) + logger.info("[action-watch] inference: {}", mp4) async with app_state.action_lock: app_state.action_status = "running" try: snap = await asyncio.to_thread(action_svc.run_full_action, mp4, settings) - app_state.last_health = snap + save_health_snapshot(settings, snap, source_path=key) app_state.action_status = "idle" processed.add(key) if settings.action_watch_use_state_file: - save_processed(state_file, processed) + add_watch_processed(settings, key, "action") pred = (snap.raw_class_en or "").strip() - print(f"[action-watch] done: {mp4.name} -> {pred}", flush=True) + logger.info("[action-watch] done: {} -> {}", mp4.name, pred) except Exception as e: - app_state.last_health = HealthSnapshot( - behavior_result="", - health_result="", - error=str(e), + save_health_snapshot( + settings, + HealthSnapshot( + behavior_result="", + health_result="", + error=str(e), + ), + source_path=key, ) app_state.action_status = "error" - print(f"[action-watch] error on {mp4}: {e}", flush=True) - traceback.print_exc() + logger.exception("[action-watch] error on {}: {}", mp4, e) raise @@ -136,24 +122,36 @@ async def run_action_watch_loop(settings: Settings) -> None: assert settings.action_watch_dir is not None wd = settings.action_watch_dir if not wd.is_dir(): - print(f"[action-watch] skip: not a directory: {wd}", flush=True) + logger.warning("[action-watch] skip: not a directory: {}", wd) return state_file = _state_path(settings) processed: Set[str] = ( - load_processed(state_file) if settings.action_watch_use_state_file else set() + load_watch_processed(settings, state_file, "action") + if settings.action_watch_use_state_file + else set() ) stability: Dict[str, tuple[int, int]] = {} - print( - f"[action-watch] watching {wd} " - f"(poll={settings.action_watch_poll_interval}s, " - f"stable_polls={settings.action_watch_stable_polls}, " - f"state={'on' if settings.action_watch_use_state_file else 'off'} " - f"{state_file if settings.action_watch_use_state_file else ''})", - flush=True, + logger.info( + "[action-watch] watching {} (poll={}s, stable_polls={}, state={} {})", + wd, + settings.action_watch_poll_interval, + settings.action_watch_stable_polls, + "on" if settings.action_watch_use_state_file else "off", + state_file if settings.action_watch_use_state_file else "", ) + idle_warn_state = IdleWatchWarnState() while True: - await watch_tick(settings, processed, stability, state_file) + did = await watch_tick(settings, processed, stability, state_file) + maybe_warn_idle_watch( + did_work=did, + log_tag="action-watch", + algo_name="FishAction", + idle_hint="目录内无 .mp4、已全部处理完毕,或文件尚未达到稳定轮询次数", + watch_dir=wd, + state=idle_warn_state, + interval_sec=_ACTION_IDLE_WARN_INTERVAL_SEC, + ) await asyncio.sleep(max(settings.action_watch_poll_interval, 0.1)) diff --git a/fish_api/app/services/measure.py b/fish_api/app/services/measure.py index 2bc4390..59c9606 100644 --- a/fish_api/app/services/measure.py +++ b/fish_api/app/services/measure.py @@ -7,16 +7,61 @@ import subprocess import sys from datetime import date, datetime, timezone from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple +from app.logging_config import format_json_pretty from app.settings import Settings from app.state import MeasureSnapshot +from app.subprocess_run import run_subprocess_with_log +from loguru import logger def _py_exe(settings: Settings) -> str: return settings.python_fish_measure or sys.executable +def _predict_weigth_from_svo2_extra_args(settings: Settings) -> List[str]: + """Flags aligned with FishMeasure/predict_weigth_from_svo2.py CLI.""" + out: List[str] = [] + if settings.predict_filter_pointcloud: + out.append("--filter-pointcloud") + if settings.predict_use_density_filter: + out.append("--use-density-filter") + if settings.predict_use_clustering_filter: + out.append("--use-clustering-filter") + if ( + settings.predict_use_pointcloud_classifier + and settings.predict_pointcloud_classifier + and Path(settings.predict_pointcloud_classifier).is_file() + ): + out.extend( + [ + "--pointcloud-classifier", + settings.predict_pointcloud_classifier, + "--use-pointcloud-classifier", + "--pointcloud-classifier-threshold", + str(settings.predict_pointcloud_classifier_threshold), + ] + ) + if settings.predict_use_flatness_filter: + out.append("--use-flatness-filter") + out.extend(["--flatness-threshold", str(settings.predict_flatness_threshold)]) + out.extend(["--weight-top-k", str(settings.measure_weight_top_k)]) + if settings.measure_weight_top_by_length: + out.append("--weight-top-by-length") + else: + out.append("--no-weight-top-by-length") + if settings.predict_fish_video_weight_overlay: + out.extend( + [ + "--fish-video-weight-overlay", + "--minute-interval-sec", + str(settings.predict_minute_interval_sec), + ] + ) + return out + + def run_measure_subprocess(svo_path: Path, settings: Settings) -> None: script = settings.fish_measure_root / "predict_weigth_from_svo2.py" if not script.is_file(): @@ -46,16 +91,16 @@ def run_measure_subprocess(svo_path: Path, settings: Settings) -> None: "--frame-stride", str(settings.predict_frame_stride), ] + cmd.extend(_predict_weigth_from_svo2_extra_args(settings)) - proc = subprocess.run( + proc = run_subprocess_with_log( cmd, cwd=str(settings.fish_measure_root), env=os.environ.copy(), - capture_output=True, - text=True, + log_name="FishMeasure", ) if proc.returncode != 0: - err = (proc.stderr or "") + (proc.stdout or "") + err = proc.stdout or "" raise RuntimeError( f"predict_weigth_from_svo2.py failed ({proc.returncode}): {err[-4000:]}" ) @@ -171,9 +216,28 @@ def build_measure_snapshot(svo_path: Path, settings: Settings) -> MeasureSnapsho "date": today, } + logger.info( + "[FishMeasure] parsed {}\navg_length_mm={} avg_weight_g={}\nweight_summary:\n{}", + svo_path.name, + length_mm, + weight_g, + format_json_pretty(summary if summary else {}), + ) + logger.info( + "[FishMeasure] API result_item:\n{}", + format_json_pretty(result_item), + ) + out_dir = Path(data.get("output_dir", settings.measure_output_root / svo_path.stem)) lv, rv = _find_preview_videos(out_dir) v_left, v_right = _publish_media(lv, rv, settings) + logger.info( + "[FishMeasure] media preview_paths={} {} | published_left={} published_right={}", + lv, + rv, + v_left or "(none)", + v_right or "(none)", + ) return MeasureSnapshot( result=[result_item], @@ -187,5 +251,8 @@ def build_measure_snapshot(svo_path: Path, settings: Settings) -> MeasureSnapsho def run_full_measure(svo_path: Path, settings: Settings) -> MeasureSnapshot: + logger.info("[FishMeasure] start svo={}", svo_path.resolve()) run_measure_subprocess(svo_path, settings) - return build_measure_snapshot(svo_path, settings) + snap = build_measure_snapshot(svo_path, settings) + logger.info("[FishMeasure] done svo={} result_len={}", svo_path.name, len(snap.result)) + return snap diff --git a/fish_api/app/services/measure_watch.py b/fish_api/app/services/measure_watch.py index 89c1f53..0cd1c0a 100644 --- a/fish_api/app/services/measure_watch.py +++ b/fish_api/app/services/measure_watch.py @@ -1,16 +1,22 @@ -"""后台轮询目录中的 .svo2,跑 FishMeasure,写入 app_state.last_measure(与 ingest 共用状态)。""" +"""后台轮询目录中的 .svo2,跑 FishMeasure,写入 SQLite(与 ingest 共用)。""" from __future__ import annotations import asyncio -import traceback from pathlib import Path from typing import Dict, Set +from loguru import logger + +from app.db import add_watch_processed, load_watch_processed, save_measure_snapshot from app.services import measure as measure_svc -from app.services.action_watch import load_processed, save_processed from app.settings import Settings from app.state import MeasureSnapshot, app_state +from app.watch_idle import IdleWatchWarnState, idle_warn_interval_sec, maybe_warn_idle_watch + +_MEASURE_IDLE_WARN_INTERVAL_SEC = idle_warn_interval_sec( + "FISH_MEASURE_WATCH_IDLE_WARN_INTERVAL_SEC" +) def _state_path(settings: Settings) -> Path: @@ -43,29 +49,32 @@ async def _run_measure_and_state( key = str(svo.resolve()) if key in processed: return - print(f"[measure-watch] inference: {svo}", flush=True) + logger.info("[measure-watch] inference: {}", svo) async with app_state.measure_lock: app_state.measure_status = "running" try: snap = await asyncio.to_thread(measure_svc.run_full_measure, svo, settings) - app_state.last_measure = snap + save_measure_snapshot(settings, snap, source_path=key) app_state.measure_status = "idle" processed.add(key) if settings.measure_watch_use_state_file: - save_processed(state_file, processed) + add_watch_processed(settings, key, "measure") r0 = snap.result[0] if snap.result else {} w = r0.get("weight", "") - print(f"[measure-watch] done: {svo.name} weight={w!r}", flush=True) + logger.info("[measure-watch] done: {} weight={!r}", svo.name, w) except Exception as e: - app_state.last_measure = MeasureSnapshot( - result=[], - video_left="", - video_right="", - error=str(e), + save_measure_snapshot( + settings, + MeasureSnapshot( + result=[], + video_left="", + video_right="", + error=str(e), + ), + source_path=key, ) app_state.measure_status = "error" - print(f"[measure-watch] error on {svo}: {e}", flush=True) - traceback.print_exc() + logger.exception("[measure-watch] error on {}: {}", svo, e) raise @@ -116,24 +125,36 @@ async def run_measure_watch_loop(settings: Settings) -> None: assert settings.measure_watch_dir is not None wd = settings.measure_watch_dir if not wd.is_dir(): - print(f"[measure-watch] skip: not a directory: {wd}", flush=True) + logger.warning("[measure-watch] skip: not a directory: {}", wd) return state_file = _state_path(settings) processed: Set[str] = ( - load_processed(state_file) if settings.measure_watch_use_state_file else set() + load_watch_processed(settings, state_file, "measure") + if settings.measure_watch_use_state_file + else set() ) stability: Dict[str, tuple[int, int]] = {} - print( - f"[measure-watch] watching {wd} " - f"(poll={settings.measure_watch_poll_interval}s, " - f"stable_polls={settings.measure_watch_stable_polls}, " - f"state={'on' if settings.measure_watch_use_state_file else 'off'} " - f"{state_file if settings.measure_watch_use_state_file else ''})", - flush=True, + logger.info( + "[measure-watch] watching {} (poll={}s, stable_polls={}, state={} {})", + wd, + settings.measure_watch_poll_interval, + settings.measure_watch_stable_polls, + "on" if settings.measure_watch_use_state_file else "off", + state_file if settings.measure_watch_use_state_file else "", ) + idle_warn_state = IdleWatchWarnState() while True: - await watch_tick(settings, processed, stability, state_file) + did = await watch_tick(settings, processed, stability, state_file) + maybe_warn_idle_watch( + did_work=did, + log_tag="measure-watch", + algo_name="FishMeasure", + idle_hint="目录内无 .svo2、已全部处理完毕,或文件尚未达到稳定轮询次数", + watch_dir=wd, + state=idle_warn_state, + interval_sec=_MEASURE_IDLE_WARN_INTERVAL_SEC, + ) await asyncio.sleep(max(settings.measure_watch_poll_interval, 0.1)) diff --git a/fish_api/app/settings.py b/fish_api/app/settings.py index 3641781..be1631b 100644 --- a/fish_api/app/settings.py +++ b/fish_api/app/settings.py @@ -21,6 +21,10 @@ def _default_media_root() -> Path: return fish_repo_root() / "fish_api" / ".data" / "media" +def _default_sqlite_path() -> Path: + return fish_repo_root() / "fish_api" / ".data" / "app.db" + + class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=".env", @@ -34,6 +38,7 @@ class Settings(BaseSettings): stream_tmp_dir: Path = Field(default_factory=_default_stream_tmp) media_root: Path = Field(default_factory=_default_media_root) + sqlite_path: Path = Field(default_factory=_default_sqlite_path) fish_measure_root: Path = fish_repo_root() / "FishMeasure" fish_action_root: Path = fish_repo_root() / "FishAction" @@ -52,12 +57,28 @@ class Settings(BaseSettings): predict_max_frames: int = 0 predict_frame_stride: int = 1 + #: 传给 predict_weigth_from_svo2.py 的点云/权重选项(与命令行一致,可用 .env 覆盖) + predict_filter_pointcloud: bool = True + predict_use_density_filter: bool = True + predict_use_clustering_filter: bool = False + #: 留空则在 _default_paths 中设为 FishMeasure 下默认 PointNet++ 权重(若文件存在) + predict_pointcloud_classifier: Optional[str] = None + predict_use_pointcloud_classifier: bool = True + predict_pointcloud_classifier_threshold: float = 0.7 + predict_use_flatness_filter: bool = True + predict_flatness_threshold: float = 55.0 + measure_weight_top_k: int = 5 + measure_weight_top_by_length: bool = True + #: 为 True 时 fish_video 内联 DGCNN + 预览叠加(更重;需 fish_video 已支持) + predict_fish_video_weight_overlay: bool = False + predict_minute_interval_sec: float = 60.0 + action_checkpoint: Optional[str] = None action_clips_per_video: int = 8 action_batch_size: int = 4 action_num_workers: int = 2 - #: 非空时由 fish_api 在后台持续扫描该目录中的新 MP4 并跑 FishAction(与 ingest 共用 app_state) + #: 非空时由 fish_api 在后台持续扫描该目录中的新 MP4 并跑 FishAction(与 ingest 共用 SQLite 最新结果) action_watch_dir: Optional[Path] = None action_watch_poll_interval: float = Field(default=2.0, ge=0.1) action_watch_stable_polls: int = Field(default=3, ge=1) @@ -66,7 +87,7 @@ class Settings(BaseSettings): action_watch_state_file: Optional[Path] = None action_watch_use_state_file: bool = True - #: 非空时后台持续扫描该目录中的新 .svo2 并跑 FishMeasure(与 ingest 共用 app_state) + #: 非空时后台持续扫描该目录中的新 .svo2 并跑 FishMeasure(与 ingest 共用 SQLite 最新结果) measure_watch_dir: Optional[Path] = None measure_watch_poll_interval: float = Field(default=2.0, ge=0.1) measure_watch_stable_polls: int = Field(default=3, ge=1) @@ -117,6 +138,19 @@ class Settings(BaseSettings): "action_checkpoint", str(self.fish_action_root / "checkpoints/ptv_x3d_m/checkpoint_best.pt"), ) + if not self.predict_pointcloud_classifier: + _pc = ( + self.fish_measure_root + / "pointcloud_classifier" + / "Pointnet_Pointnet2_pytorch" + / "log" + / "classification" + / "fish_pointnet2_finetune" + / "checkpoints" + / "best_model.pth" + ) + if _pc.is_file(): + object.__setattr__(self, "predict_pointcloud_classifier", str(_pc)) return self diff --git a/fish_api/app/state.py b/fish_api/app/state.py index 94b6b8c..015b957 100644 --- a/fish_api/app/state.py +++ b/fish_api/app/state.py @@ -2,8 +2,8 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any, List, Optional +from datetime import datetime +from typing import List, Optional @dataclass @@ -30,17 +30,7 @@ class AppState: measure_lock: asyncio.Lock = field(default_factory=asyncio.Lock) action_lock: asyncio.Lock = field(default_factory=asyncio.Lock) - last_measure: MeasureSnapshot = field( - default_factory=lambda: MeasureSnapshot(result=[], video_left="", video_right="") - ) - last_health: HealthSnapshot = field( - default_factory=lambda: HealthSnapshot( - behavior_result="", - health_result="", - ) - ) - - # job status for optional polling + # job status for optional polling(业务结果见 SQLite) measure_status: str = "idle" action_status: str = "idle" diff --git a/fish_api/app/subprocess_run.py b/fish_api/app/subprocess_run.py new file mode 100644 index 0000000..c369aba --- /dev/null +++ b/fish_api/app/subprocess_run.py @@ -0,0 +1,41 @@ +"""子进程运行并把 stdout/stderr 流式写入 loguru,便于查看 FishMeasure / FishAction 中间输出。""" + +from __future__ import annotations + +import os +import subprocess +from typing import Dict, List, Optional + +from loguru import logger + + +def run_subprocess_with_log( + cmd: List[str], + *, + cwd: str, + env: Optional[Dict[str, str]] = None, + log_name: str, +) -> subprocess.CompletedProcess[str]: + """运行子进程,合并 stderr 到 stdout,按行输出到 loguru。 + + 返回 CompletedProcess,stdout 为完整输出,便于失败时拼进异常信息。 + """ + proc = subprocess.Popen( + cmd, + cwd=cwd, + env=env if env is not None else os.environ.copy(), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + lines: List[str] = [] + if proc.stdout is not None: + for line in proc.stdout: + lines.append(line) + s = line.rstrip() + if s: + logger.info("[{}] {}", log_name, s) + rc = proc.wait() + out = "".join(lines) + return subprocess.CompletedProcess(cmd, rc, out, "") diff --git a/fish_api/app/watch_idle.py b/fish_api/app/watch_idle.py new file mode 100644 index 0000000..fa29ec3 --- /dev/null +++ b/fish_api/app/watch_idle.py @@ -0,0 +1,56 @@ +"""后台 watch 无新任务时的限流 warning(FishMeasure / FishAction 共用)。""" + +from __future__ import annotations + +import os +import time +from pathlib import Path + +from loguru import logger + + +def idle_warn_interval_sec(override_env: str) -> float: + """优先读 override_env,否则 FISH_WATCH_IDLE_WARN_INTERVAL_SEC,默认 120。""" + v = os.environ.get(override_env) + if v is not None and str(v).strip() != "": + return float(v) + return float(os.environ.get("FISH_WATCH_IDLE_WARN_INTERVAL_SEC", "120")) + + +class IdleWatchWarnState: + """记录上次告警时间;有推理成功时重置计时。""" + + __slots__ = ("_last_mono",) + + def __init__(self) -> None: + self._last_mono = time.monotonic() + + def on_did_work(self) -> None: + self._last_mono = time.monotonic() + + +def maybe_warn_idle_watch( + *, + did_work: bool, + log_tag: str, + algo_name: str, + idle_hint: str, + watch_dir: Path, + state: IdleWatchWarnState, + interval_sec: float, +) -> None: + """本轮未跑推理时按 interval 限流打 warning;interval<=0 则每轮无任务都打。""" + if did_work: + state.on_did_work() + return + now = time.monotonic() + if interval_sec > 0 and now - state._last_mono < interval_sec: + return + state._last_mono = now + logger.warning( + "[{}] 本轮无新数据可供 {} 处理({}) | dir={}", + log_tag, + algo_name, + idle_hint, + watch_dir, + ) diff --git a/fish_api/pyproject.toml b/fish_api/pyproject.toml index dad87b2..68f0c76 100644 --- a/fish_api/pyproject.toml +++ b/fish_api/pyproject.toml @@ -6,12 +6,13 @@ readme = "README.md" requires-python = ">=3.11" dependencies = [ "fastapi>=0.115.0", + "loguru>=0.7.0", "uvicorn[standard]>=0.32.0", "pydantic-settings>=2.6.0", ] [dependency-groups] -dev = ["httpx>=0.28.1", "loguru>=0.7.0"] +dev = ["httpx>=0.28.1"] [project.scripts] fish-action-watch = "app.action_watch_cli:main" diff --git a/fish_api/start.sh b/fish_api/start.sh new file mode 100644 index 0000000..d6aaea7 --- /dev/null +++ b/fish_api/start.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash +# 一键启动 Fish API(在 fish_api 目录下执行 uvicorn,读取同目录 .env) +# +# bash fish_api/start.sh +# PORT=8001 HOST=0.0.0.0 bash fish_api/start.sh +# +# 首次使用请先:cd fish_api && uv sync +# +set -euo pipefail + +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$DIR" + +export PUBLIC_BASE_URL="${PUBLIC_BASE_URL:-http://127.0.0.1:8000}" +unset PYTHON_FISH_MEASURE PYTHON_FISH_ACTION 2>/dev/null || true + +PORT="${PORT:-8000}" +HOST="${HOST:-0.0.0.0}" + +if command -v uv >/dev/null 2>&1; then + exec uv run uvicorn app.main:app --host "$HOST" --port "$PORT" +else + exec uvicorn app.main:app --host "$HOST" --port "$PORT" +fi diff --git a/fish_api/start_fresh.sh b/fish_api/start_fresh.sh new file mode 100755 index 0000000..d308c1a --- /dev/null +++ b/fish_api/start_fresh.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash +# 一键启动 Fish API:删除整个 SQLite 库文件(含 -wal/-shm),并删除旧版 watch JSON 状态文件,再启动服务。 +# +# bash fish_api/start_fresh.sh +# PORT=8001 HOST=0.0.0.0 bash fish_api/start_fresh.sh +# +# 首次使用请先:cd fish_api && uv sync +# +set -euo pipefail + +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$DIR" + +export PUBLIC_BASE_URL="${PUBLIC_BASE_URL:-http://127.0.0.1:8000}" +unset PYTHON_FISH_MEASURE PYTHON_FISH_ACTION 2>/dev/null || true + +if command -v uv >/dev/null 2>&1; then + PY=(uv run python) +else + PY=(python3) +fi + +"${PY[@]}" - <<'PY' +from pathlib import Path + +from app.db import remove_sqlite_database_files +from app.settings import get_settings + +s = get_settings() + + +def _rm(path: Path | None) -> None: + if path is None: + return + try: + if path.is_file(): + path.unlink() + print(f"[start-fresh] removed legacy JSON {path}", flush=True) + except OSError as e: + print(f"[start-fresh] skip {path}: {e}", flush=True) + + +remove_sqlite_database_files(s) +print(f"[start-fresh] removed SQLite database at {s.sqlite_path} (and -wal/-shm if present).", flush=True) + +# 旧版 JSON 若仍存在,启动时会被 load_watch_processed 合并进 SQLite,必须一并删除 +if s.measure_watch_use_state_file: + if s.measure_watch_state_file is not None: + _rm(s.measure_watch_state_file) + elif s.measure_watch_dir is not None: + _rm(s.measure_watch_dir / ".fishmeasure_watch_processed.json") + +if s.action_watch_use_state_file: + if s.action_watch_state_file is not None: + _rm(s.action_watch_state_file) + elif s.action_watch_dir is not None: + _rm(s.action_watch_dir / ".fishaction_watch_processed.json") + +print("[start-fresh] done.", flush=True) +PY + +PORT="${PORT:-8000}" +HOST="${HOST:-0.0.0.0}" + +if command -v uv >/dev/null 2>&1; then + exec uv run uvicorn app.main:app --host "$HOST" --port "$PORT" +else + exec uvicorn app.main:app --host "$HOST" --port "$PORT" +fi diff --git a/scripts/biomass_poller.py b/scripts/biomass_poller.py index 6ecf59e..e0556de 100755 --- a/scripts/biomass_poller.py +++ b/scripts/biomass_poller.py @@ -2,6 +2,8 @@ """ 独立进程:轮询 Fish API 的两个结果接口,用 loguru 输出响应。 +接口每次 GET 会「消费」一条未投递快照;仅当响应头 X-Fish-Biomass-New: 1(或 JSON code!=200)时打日志。 + BIOMASS_API_BASE=http://127.0.0.1:8000 POLL_INTERVAL=5 \\ python scripts/biomass_poller.py @@ -29,24 +31,16 @@ def _fmt_body(resp: httpx.Response) -> Any: return resp.text -_last_camera: dict | None = None -_last_health: dict | None = None - - -def _has_data(body: Any) -> bool: - """非空业务数据:code==200 且 data 里至少有一个非空字段。""" +def _should_log(resp: httpx.Response, body: Any) -> bool: + """有新投递的快照(X-Fish-Biomass-New: 1)或业务错误(JSON code!=200)时打日志。""" if not isinstance(body, dict): - return False + return bool(body) if body.get("code") != 200: - return True # 错误也算"新信息" - data = body.get("data", {}) - if not isinstance(data, dict): - return bool(data) - return any(bool(v) for v in data.values()) + return True + return resp.headers.get("X-Fish-Biomass-New", "").strip() == "1" async def poll_once(client: httpx.AsyncClient, base: str) -> None: - global _last_camera, _last_health base = base.rstrip("/") camera_url = f"{base}/api/v1/biomass/real/camera/" health_url = f"{base}/api/v1/biomass/health/result/" @@ -54,14 +48,10 @@ async def poll_once(client: httpx.AsyncClient, base: str) -> None: r2 = await client.get(health_url) b1 = _fmt_body(r1) b2 = _fmt_body(r2) - changed1 = b1 != _last_camera - changed2 = b2 != _last_health - if changed1 and _has_data(b1): + if _should_log(r1, b1): logger.info("[real/camera/] HTTP {} | {}", r1.status_code, json.dumps(b1, ensure_ascii=False)) - if changed2 and _has_data(b2): + if _should_log(r2, b2): logger.info("[health/result/] HTTP {} | {}", r2.status_code, json.dumps(b2, ensure_ascii=False)) - _last_camera = b1 - _last_health = b2 async def poll_loop(base: str, interval: float) -> None: diff --git a/scripts/run_fishserver.sh b/scripts/run_fishserver.sh index 670124e..58fd304 100755 --- a/scripts/run_fishserver.sh +++ b/scripts/run_fishserver.sh @@ -1,20 +1,11 @@ #!/usr/bin/env bash -# 在「单一 Python 环境」下启动网关:FishMeasure / FishAction 子进程默认使用同一解释器。 +# 仓库根目录入口:与 fish_api/start.sh 等价 # -# conda activate fishserver -# export PUBLIC_BASE_URL=http://<本机对外IP>:8001 # 可选,填 video URL 前缀 +# conda activate fishserver # 若不用 uv +# export PUBLIC_BASE_URL=http://<本机对外IP>:8001 # PORT=8001 bash scripts/run_fishserver.sh # set -euo pipefail ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -cd "$ROOT/fish_api" - -export PUBLIC_BASE_URL="${PUBLIC_BASE_URL:-http://127.0.0.1:8000}" -# 单一环境:勿设置 PYTHON_FISH_MEASURE / PYTHON_FISH_ACTION,使用当前 uvicorn 的 Python -unset PYTHON_FISH_MEASURE PYTHON_FISH_ACTION 2>/dev/null || true - -PORT="${PORT:-8000}" -HOST="${HOST:-0.0.0.0}" - -exec uvicorn app.main:app --host "$HOST" --port "$PORT" +exec bash "$ROOT/fish_api/start.sh" diff --git a/scripts/start_fishapi_fresh.sh b/scripts/start_fishapi_fresh.sh new file mode 100755 index 0000000..551018c --- /dev/null +++ b/scripts/start_fishapi_fresh.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# 仓库根目录入口:删除 SQLite 库文件(及旧 JSON 状态文件)后启动 Fish API +# +# bash scripts/start_fishapi_fresh.sh +# PORT=8001 bash scripts/start_fishapi_fresh.sh +# +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +exec bash "$ROOT/fish_api/start_fresh.sh"