70 lines
2.3 KiB
Python
Executable File
70 lines
2.3 KiB
Python
Executable File
# SAM segmentation module for fish detection
|
||
# Uses SAM (Segment Anything Model) for segmentation with boxes from YOLO detection
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
import os
|
||
import argparse
|
||
import urllib.request
|
||
from segment_anything import sam_model_registry, SamPredictor
|
||
from tqdm import tqdm
|
||
|
||
|
||
def download_models():
|
||
"""
|
||
下载所需的模型文件
|
||
"""
|
||
# 下载 SAM 模型
|
||
sam_model_urls = {
|
||
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
||
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||
}
|
||
|
||
# 下载 SAM 模型
|
||
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
||
if not os.path.exists(sam_checkpoint):
|
||
print(f"正在下载SAM模型...")
|
||
urllib.request.urlretrieve(sam_model_urls["vit_h"], sam_checkpoint)
|
||
print(f"SAM模型下载完成")
|
||
|
||
def init_models(device="cpu", seg_model="sam", yolo_seg_weights=None):
|
||
"""
|
||
初始化模型
|
||
|
||
Args:
|
||
device: 运行设备
|
||
seg_model: 分割模型类型 ("sam" 或 "yolov8_seg")
|
||
yolo_seg_weights: YOLOv8分割模型权重路径(当seg_model="yolov8_seg"时必需)
|
||
|
||
Returns:
|
||
sam_predictor 或 yolo_seg_model: 分割模型
|
||
"""
|
||
if seg_model == "yolov8_seg":
|
||
# 初始化YOLOv8分割模型
|
||
if yolo_seg_weights is None:
|
||
raise ValueError("使用yolov8_seg时必须提供yolo_seg_weights参数")
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
except ImportError:
|
||
raise ImportError("未找到 ultralytics,请先安装: pip install ultralytics")
|
||
|
||
print(f"正在加载YOLOv8分割模型: {yolo_seg_weights}")
|
||
yolo_seg_model = YOLO(yolo_seg_weights)
|
||
print("YOLOv8分割模型加载成功")
|
||
return yolo_seg_model
|
||
else:
|
||
# 默认使用SAM
|
||
# 下载模型
|
||
download_models()
|
||
|
||
# 初始化SAM
|
||
print("正在加载SAM模型...")
|
||
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
|
||
sam.to(device=device)
|
||
sam_predictor = SamPredictor(sam)
|
||
print("SAM模型加载成功")
|
||
return sam_predictor
|