Initial commit: FishServer monorepo (FishAction, FishMeasure, fish_api)
Made-with: Cursor
This commit is contained in:
69
FishMeasure/seg.py
Executable file
69
FishMeasure/seg.py
Executable file
@@ -0,0 +1,69 @@
|
||||
# SAM segmentation module for fish detection
|
||||
# Uses SAM (Segment Anything Model) for segmentation with boxes from YOLO detection
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import argparse
|
||||
import urllib.request
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def download_models():
|
||||
"""
|
||||
下载所需的模型文件
|
||||
"""
|
||||
# 下载 SAM 模型
|
||||
sam_model_urls = {
|
||||
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
||||
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||||
}
|
||||
|
||||
# 下载 SAM 模型
|
||||
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
||||
if not os.path.exists(sam_checkpoint):
|
||||
print(f"正在下载SAM模型...")
|
||||
urllib.request.urlretrieve(sam_model_urls["vit_h"], sam_checkpoint)
|
||||
print(f"SAM模型下载完成")
|
||||
|
||||
def init_models(device="cpu", seg_model="sam", yolo_seg_weights=None):
|
||||
"""
|
||||
初始化模型
|
||||
|
||||
Args:
|
||||
device: 运行设备
|
||||
seg_model: 分割模型类型 ("sam" 或 "yolov8_seg")
|
||||
yolo_seg_weights: YOLOv8分割模型权重路径(当seg_model="yolov8_seg"时必需)
|
||||
|
||||
Returns:
|
||||
sam_predictor 或 yolo_seg_model: 分割模型
|
||||
"""
|
||||
if seg_model == "yolov8_seg":
|
||||
# 初始化YOLOv8分割模型
|
||||
if yolo_seg_weights is None:
|
||||
raise ValueError("使用yolov8_seg时必须提供yolo_seg_weights参数")
|
||||
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except ImportError:
|
||||
raise ImportError("未找到 ultralytics,请先安装: pip install ultralytics")
|
||||
|
||||
print(f"正在加载YOLOv8分割模型: {yolo_seg_weights}")
|
||||
yolo_seg_model = YOLO(yolo_seg_weights)
|
||||
print("YOLOv8分割模型加载成功")
|
||||
return yolo_seg_model
|
||||
else:
|
||||
# 默认使用SAM
|
||||
# 下载模型
|
||||
download_models()
|
||||
|
||||
# 初始化SAM
|
||||
print("正在加载SAM模型...")
|
||||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||||
sam.to(device=device)
|
||||
sam_predictor = SamPredictor(sam)
|
||||
print("SAM模型加载成功")
|
||||
return sam_predictor
|
||||
Reference in New Issue
Block a user