Files
FishServer/FishMeasure/seg.py

70 lines
2.3 KiB
Python
Raw Normal View History

# SAM segmentation module for fish detection
# Uses SAM (Segment Anything Model) for segmentation with boxes from YOLO detection
import cv2
import numpy as np
import torch
import os
import argparse
import urllib.request
from segment_anything import sam_model_registry, SamPredictor
from tqdm import tqdm
def download_models():
"""
下载所需的模型文件
"""
# 下载 SAM 模型
sam_model_urls = {
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}
# 下载 SAM 模型
sam_checkpoint = "sam_vit_h_4b8939.pth"
if not os.path.exists(sam_checkpoint):
print(f"正在下载SAM模型...")
urllib.request.urlretrieve(sam_model_urls["vit_h"], sam_checkpoint)
print(f"SAM模型下载完成")
def init_models(device="cpu", seg_model="sam", yolo_seg_weights=None):
"""
初始化模型
Args:
device: 运行设备
seg_model: 分割模型类型 ("sam" "yolov8_seg")
yolo_seg_weights: YOLOv8分割模型权重路径当seg_model="yolov8_seg"时必需
Returns:
sam_predictor yolo_seg_model: 分割模型
"""
if seg_model == "yolov8_seg":
# 初始化YOLOv8分割模型
if yolo_seg_weights is None:
raise ValueError("使用yolov8_seg时必须提供yolo_seg_weights参数")
try:
from ultralytics import YOLO
except ImportError:
raise ImportError("未找到 ultralytics请先安装: pip install ultralytics")
print(f"正在加载YOLOv8分割模型: {yolo_seg_weights}")
yolo_seg_model = YOLO(yolo_seg_weights)
print("YOLOv8分割模型加载成功")
return yolo_seg_model
else:
# 默认使用SAM
# 下载模型
download_models()
# 初始化SAM
print("正在加载SAM模型...")
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam.to(device=device)
sam_predictor = SamPredictor(sam)
print("SAM模型加载成功")
return sam_predictor