Files

70 lines
2.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SAM segmentation module for fish detection
# Uses SAM (Segment Anything Model) for segmentation with boxes from YOLO detection
import cv2
import numpy as np
import torch
import os
import argparse
import urllib.request
from segment_anything import sam_model_registry, SamPredictor
from tqdm import tqdm
def download_models():
"""
下载所需的模型文件
"""
# 下载 SAM 模型
sam_model_urls = {
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
# 下载 SAM 模型
sam_checkpoint = "sam_vit_h_4b8939.pth"
if not os.path.exists(sam_checkpoint):
print(f"正在下载SAM模型...")
urllib.request.urlretrieve(sam_model_urls["vit_h"], sam_checkpoint)
print(f"SAM模型下载完成")
def init_models(device="cpu", seg_model="sam", yolo_seg_weights=None):
"""
初始化模型
Args:
device: 运行设备
seg_model: 分割模型类型 ("sam""yolov8_seg")
yolo_seg_weights: YOLOv8分割模型权重路径当seg_model="yolov8_seg"时必需)
Returns:
sam_predictor 或 yolo_seg_model: 分割模型
"""
if seg_model == "yolov8_seg":
# 初始化YOLOv8分割模型
if yolo_seg_weights is None:
raise ValueError("使用yolov8_seg时必须提供yolo_seg_weights参数")
try:
from ultralytics import YOLO
except ImportError:
raise ImportError("未找到 ultralytics请先安装: pip install ultralytics")
print(f"正在加载YOLOv8分割模型: {yolo_seg_weights}")
yolo_seg_model = YOLO(yolo_seg_weights)
print("YOLOv8分割模型加载成功")
return yolo_seg_model
else:
# 默认使用SAM
# 下载模型
download_models()
# 初始化SAM
print("正在加载SAM模型...")
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device=device)
sam_predictor = SamPredictor(sam)
print("SAM模型加载成功")
return sam_predictor