70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
|
|
# SAM segmentation module for fish detection
|
|||
|
|
# Uses SAM (Segment Anything Model) for segmentation with boxes from YOLO detection
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
import os
|
|||
|
|
import argparse
|
|||
|
|
import urllib.request
|
|||
|
|
from segment_anything import sam_model_registry, SamPredictor
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
|
|||
|
|
|
|||
|
|
def download_models():
|
|||
|
|
"""
|
|||
|
|
下载所需的模型文件
|
|||
|
|
"""
|
|||
|
|
# 下载 SAM 模型
|
|||
|
|
sam_model_urls = {
|
|||
|
|
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
|||
|
|
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
|||
|
|
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 下载 SAM 模型
|
|||
|
|
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
|||
|
|
if not os.path.exists(sam_checkpoint):
|
|||
|
|
print(f"正在下载SAM模型...")
|
|||
|
|
urllib.request.urlretrieve(sam_model_urls["vit_h"], sam_checkpoint)
|
|||
|
|
print(f"SAM模型下载完成")
|
|||
|
|
|
|||
|
|
def init_models(device="cpu", seg_model="sam", yolo_seg_weights=None):
|
|||
|
|
"""
|
|||
|
|
初始化模型
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
device: 运行设备
|
|||
|
|
seg_model: 分割模型类型 ("sam" 或 "yolov8_seg")
|
|||
|
|
yolo_seg_weights: YOLOv8分割模型权重路径(当seg_model="yolov8_seg"时必需)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
sam_predictor 或 yolo_seg_model: 分割模型
|
|||
|
|
"""
|
|||
|
|
if seg_model == "yolov8_seg":
|
|||
|
|
# 初始化YOLOv8分割模型
|
|||
|
|
if yolo_seg_weights is None:
|
|||
|
|
raise ValueError("使用yolov8_seg时必须提供yolo_seg_weights参数")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
except ImportError:
|
|||
|
|
raise ImportError("未找到 ultralytics,请先安装: pip install ultralytics")
|
|||
|
|
|
|||
|
|
print(f"正在加载YOLOv8分割模型: {yolo_seg_weights}")
|
|||
|
|
yolo_seg_model = YOLO(yolo_seg_weights)
|
|||
|
|
print("YOLOv8分割模型加载成功")
|
|||
|
|
return yolo_seg_model
|
|||
|
|
else:
|
|||
|
|
# 默认使用SAM
|
|||
|
|
# 下载模型
|
|||
|
|
download_models()
|
|||
|
|
|
|||
|
|
# 初始化SAM
|
|||
|
|
print("正在加载SAM模型...")
|
|||
|
|
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
|||
|
|
sam.to(device=device)
|
|||
|
|
sam_predictor = SamPredictor(sam)
|
|||
|
|
print("SAM模型加载成功")
|
|||
|
|
return sam_predictor
|