# SAM segmentation module for fish detection # Uses SAM (Segment Anything Model) for segmentation with boxes from YOLO detection import cv2 import numpy as np import torch import os import argparse import urllib.request from segment_anything import sam_model_registry, SamPredictor from tqdm import tqdm def download_models(): """ 下载所需的模型文件 """ # 下载 SAM 模型 sam_model_urls = { "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" } # 下载 SAM 模型 sam_checkpoint = "sam_vit_h_4b8939.pth" if not os.path.exists(sam_checkpoint): print(f"正在下载SAM模型...") urllib.request.urlretrieve(sam_model_urls["vit_h"], sam_checkpoint) print(f"SAM模型下载完成") def init_models(device="cpu", seg_model="sam", yolo_seg_weights=None): """ 初始化模型 Args: device: 运行设备 seg_model: 分割模型类型 ("sam" 或 "yolov8_seg") yolo_seg_weights: YOLOv8分割模型权重路径(当seg_model="yolov8_seg"时必需) Returns: sam_predictor 或 yolo_seg_model: 分割模型 """ if seg_model == "yolov8_seg": # 初始化YOLOv8分割模型 if yolo_seg_weights is None: raise ValueError("使用yolov8_seg时必须提供yolo_seg_weights参数") try: from ultralytics import YOLO except ImportError: raise ImportError("未找到 ultralytics,请先安装: pip install ultralytics") print(f"正在加载YOLOv8分割模型: {yolo_seg_weights}") yolo_seg_model = YOLO(yolo_seg_weights) print("YOLOv8分割模型加载成功") return yolo_seg_model else: # 默认使用SAM # 下载模型 download_models() # 初始化SAM print("正在加载SAM模型...") sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth") sam.to(device=device) sam_predictor = SamPredictor(sam) print("SAM模型加载成功") return sam_predictor