【论文源码实战】轻量化MobileSAM,分割一切大模型出现,模型缩小60倍,速度提高40倍

news2025/2/26 6:01:48

前言

MobileSAM模型是在2023年发布的,其对之前的SAM分割一切大模型进行了轻量化的优化处理,模型整体体积缩小了60倍,运行速度提高40倍,但分割效果却依旧很好。

MobileSAM在使用方法上沿用了SAM模型的接口,因此可以与SAM模型的使用方法几乎可以无缝对接,真的是非常Nice。唯一的区别就是在模型加载的时候需要修改一点点。

一、环境配置

创建专属环境

conda create -n MobileSAM python=3.9

​​​​​

激活环境

conda activate MobileSAM

 

安装 Pytorch 环境

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "torch-1.13.0+cu116-cp39-cp39-win_amd64.whl" 
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "torchvision-0.14.0+cu116-cp39-cp39-win_amd64.whl"

二、代码测试

网页版使用

安装相关库

pip install -r requirements.txt

pip install gradio -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install timm -i https://pypi.tuna.tsinghua.edu.cn/simple

代码运行

python app/app.py

点击链接进入下方的网页界面

在此就可以就行在网页上进行分割操作,下方是一些分割的图片:

Instructions for point mode(点模式说明)

  1. Restart by click the Restart button(单击“重新启动”按钮重新启动)

  2. Select a point with Add Mask for the foreground (Must)(选择具有“添加遮罩”的点作为前景(必须))

  3. Select a point with Remove Area for the background (Optional)(选择具有“删除区域”的点作为背景(可选))

  4. Click the Start Segmenting.(单击“开始分割”)

纯代码实现

Predictor 方法【提示点分割代码】

加载模型
def load_sam():
    # Selecting objects with SAM
    sam_checkpoint = "./weights/mobile_sam.pt"
    model_type = "vit_t"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    sam.eval()
    return SamPredictor(sam)
绘制结果
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
单点得到掩码
# 使用 MobileSAM 从提示中得到掩码对象
# 选择需要分割的图像上的一个点。点以 (x,y) 格式输入到模型中,标签为 1(前景点)或 0(背景点)。
input_point = np.array([[400, 400]])
input_label = np.array([1])

# 使用 `SamPredictor.predict`进行预测
# 返回值:掩码、这些掩码的质量预测值和低分辨率掩码数值,这些数据可传递给下一次进行迭代预测。
# 当 `multimask_output=True`(默认设置)时,SAM 会输出 3 个掩码,其中 `scores` 给出了模型对这些掩码质量的估计值。
# 此设置用于模棱两可的输入提示,帮助模型区分与提示一致的不同对象。如果设置为 "false",则将返回单一掩码。
# 对于模棱两可的提示(如单点),建议使用 `multimask_output=True`,即使只需要单个掩码;可以通过选择在 `scores` 中返回的分数最高的掩码来选择最佳的单个掩码,这通常会产生更好的掩码。
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()
多点得到掩码
# Specifying a specific object with additional points(指定具有附加点的特定对象)
# 单个输入点模棱两可,模型返回了多个与之一致的对象。要获得单一对象,可以提供多个点。
# 如果有上一次迭代的掩码,也可以提供给模型以帮助预测。
# 在使用多个提示指定单个对象时,可以通过设置 `multimask_output=False` 来得到单个掩码。
input_point = np.array([[400, 400], [450, 350]])
input_label = np.array([1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
多点得到掩码前景和背景
input_point = np.array([[400, 400], [100, 500]])
input_label = np.array([0, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
通过方框得到掩码
input_box = np.array([190, 70, 460, 280])

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

SamAutomaticMaskGenerator 方法【一键全景分割代码】

Automatic mask generator (得到全部图像掩码)
# Automatic mask generation(自动生成掩码)
mask_generator = load_sam()

# To generate masks
masks = mask_generator.generate(image)

# Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask.
# These keys are:
# segmentation: the mask
# area: the area of the mask in pixels
# bbox: the boundary box of the mask in XYWH format
# predicted_iou: the model's own prediction for the quality of the mask
# point_coords: the sampled input point that generated this mask
# stability_score: an additional measure of mask quality
# crop_box: the crop of the image used to generate this mask in XYWH format

# Show all the masks overlayed on the image.
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
Automatic mask generation options
# Automatic mask generation options
# There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
# 在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,生成可以在图像的裁剪上自动运行,以提高较小对象的性能,后处理可以去除杂散像素和孔洞。
# 以下是一个示例配置,用于对更多掩码进行采样:
mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)
masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()

Onnx 推理【提示点分割代码】

模型转换(.pt 转 .onnx)
# 得到多个输出
python scripts/export_onnx_model.py --checkpoint ./weights/mobile_mul_sam.pt --model-type vit_t --output ./weights/mobile_mul_sam.onnx

# 得到单个输出
python scripts/export_onnx_model.py --checkpoint ./weights/mobile_single_sam.pt --model-type vit_t --return-single-mask --output ./weights/mobile_single_sam.onnx
量化onnx模型
onnx_model_path = 'mobile_single_sam.onnx'

onnx_model_quantized_path = "mobile_single_sam_quantized.onnx"
# 通过对模型进行量化和优化。我们发现,这显著改善了web运行时,而质量性能的变化可以忽略不计。
quantize_dynamic(
    model_input=onnx_model_path,
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path
使用onnx 模型
# Using an ONNX model
ort_session = onnxruntime.InferenceSession(onnx_model_path)

checkpoint = "../weights/mobile_sam.pt"
model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cpu')
predictor = SamPredictor(sam)
predictor.set_image(image)

image_embedding = predictor.get_image_embedding().cpu().numpy()
Onnx 推理参数

The ONNX model has a different input signature than SamPredictor.predict. The following inputs must all be supplied. Note the special cases for both point and mask inputs. All inputs are np.float32.

  1. image_embeddings: The image embedding from predictor.get_image_embedding(). Has a batch index of length 1.
  2. point_coords: Coordinates of sparse input prompts, corresponding to both point inputs and box inputs. Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner. Coordinates must already be transformed to long-side 1024. Has a batch index of length 1.
  3. point_labels: Labels for the sparse input prompts. 0 is a negative input point, 1 is a positive input point, 2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point. If there is no box input, a single padding point with label -1 and coordinates (0.0, 0.0) should be concatenated.
  4. mask_input: A mask input to the model with shape 1x1x256x256. This must be supplied even if there is no mask input. In this case, it can just be zeros.
  5. has_mask_input: An indicator for the mask input. 1 indicates a mask input, 0 indicates no mask input.
  6. orig_im_size: The size of the input image in (H,W) format, before any transformation.

Additionally, the ONNX model does not threshold the output mask logits. To obtain a binary mask, threshold at sam.mask_threshold (equal to 0.0).

单点得到掩码
input_point = np.array([[250, 375]])
input_label = np.array([1])

# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

# Package the inputs to run in the onnx model
ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

# Predict a mask and threshold it.
masks, _, low_res_logits = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
多点得到掩码
input_point = np.array([[250, 375], [490, 380], [375, 360]])
input_label = np.array([1, 1, 0])

# Use the mask output from the previous run. It is already in the correct form for input to the ONNX model.
onnx_mask_input = low_res_logits

# Transform the points as in the previous example.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

# The `has_mask_input` indicator is now 1.
onnx_has_mask_input = np.ones(1, dtype=np.float32)

# Package inputs, then predict and threshold the mask.
ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
points和box得到掩码
input_box = np.array([210, 200, 350, 500])
input_point = np.array([[275, 400]])
input_label = np.array([0])

# Add a batch index, concatenate a box and point inputs, add the appropriate labels for the box corners, and transform. There is no padding point since the input includes a box input.
onnx_box_coords = input_box.reshape(2, 2)
onnx_box_labels = np.array([2,3])

onnx_coord = np.concatenate([input_point, onnx_box_coords], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, onnx_box_labels], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

# Package inputs, then predict and threshold the mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

三、总结

从结果来看,MobileSAM相比于SAM,模型整体体积缩小了60倍,运行速度提高40倍,但分割效果却保持相当水平。个人认为,这对于视觉大模型在移动端的部署与应用是具有里程碑意义的。

关于MobileSAM模型的相关代码、论文PDF、预训练模型、使用方法等,我都已打包好,供需要的小伙伴交流研究,获取方式如下:

关注公众号,回复:MobileSAM,即可获取MobileSAM相关代码、论文、预训练模型、使用方法示例

四、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1608875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

matlab学习003-绘制由差分方程表示的离散系统图像

目录 1,题目 2,使用函数求解差分方程 1)基础知识 ①filter函数和impz函数 ②zeros函数 ☀ 2)绘制图像 ​☀ 3)对应代码 如果连简单的信号都不会的,建议先看如下文章👇,之…

UE4 相机围绕某点旋转

关卡(一个相机CameraActor,一个Cube(名叫Target)): 关卡蓝图里的逻辑(为了大家看得清楚,特意连得很紧凑,也比较乱,不然一张截图放不下): 只对Yaw 只Pitch: 同样对Roll: 围绕任…

Java maven项目打包自动测试并集成jacoco生成代码测试覆盖度报告

引入Junit 引入 junit5 单元测试依赖 <properties><junit.version>5.10.2</junit.version><jacoco.version>0.8.12</jacoco.version></properties><dependencies><!-- 单元测试 --><dependency><groupId>org.jun…

墨子web3时事周报

蚂蚁集团Web3研发进展与布局 国内Web3赛道的领军企业——蚂蚁集团&#xff0c;凭借其在前沿科技领域的深耕不辍&#xff0c;已在Web3技术研发疆域缔造了卓越战绩。特别是在引领行业革新的关键时刻&#xff0c;集团于今年四月末震撼推出了颠覆性的Web3全套解决方案&#xff0c…

easyui datagrid单元格点击进入编辑时,行会自动向上错位

现象描述&#xff0c;点击第20行可编辑的单元格进入编辑状态时&#xff0c;滚动条自动滚动到第19行了。导致第20行被分页遮挡&#xff0c;看不到无法编辑。 排查了一天百度AI说是滚动定位问题&#xff0c;最后发现是自己设置的列有问题&#xff0c;表格总共五列&#xff0c;全…

如何搭建线下陪玩系统(本地伴游、多玩圈子)APP小程序H5多端前后端源码交付,支持二开!

一、卡顿的优化方法 1、对陪玩系统源码中流媒体传输的上行进行优化&#xff0c;通过提升推流端的设备性能配置、推流边缘CDN节点就近选择等方式解决音视频数据源流的卡顿。 2、对陪玩系统源码中音视频数据的下载链路进行优化&#xff0c;通过选择更近更优质的CDN边缘节点来减少…

软航H5 PDF签章产品经nginx代理之后浏览器中PDF盖章时提示:签章失败:网络错误 的问题排查及解决办法

目录 问题现象 问题排查思路 问题处理办法 附&#xff1a;软航H5 PDF签章产品介绍 软航电子签章系统 软航版式文档签批系统 问题现象 问题描述&#xff1a;在系统中集成了软航H5 PDF签章产品&#xff0c;软航H5 PDF签章产品的对应服务是通过nginx代理的&#xff0c;在奇安…

FPGA_verilog语法整理

FPGA_verilog语法整理 verilog的逻辑值 verilog的常数表达 位宽中指定常数的宽度&#xff08;表示成二进制数的位数&#xff09;&#xff0c;单引号加表示该常数为几进制的底数符号。 二进制底数符号为b&#xff0c;八进制为 o&#xff0c;十进制为d&#xff0c;十六进制为 h…

Windows 下使用 CMake + Visual Studio 2022 编译 OpenCV 4.8.1 及其扩展模块

一. 背景 目前维护的某个项目是在 Windows 下运行的&#xff0c;并且使用了 OpenCV 4.5.2 版本。 我本地的开发环境是 Mac 并使用了比较新的 OpenCV 4.8.1 版本。为了和本地开发环境保持一致&#xff0c;我打算对项目中使用的 OpenCV 进行升级&#xff0c;因为该项目还是用了扩…

BI数据分析有什么优势?引入BI工具为何能加快企业数字化进程?

进化是人类社会永恒不变的主题。从早期的猿人到现在的人类&#xff0c;从久远的石器时代到现在的信息时代&#xff0c;人类社会历经一次次的进化才积攒了今日的科技与智慧。人类的文明史&#xff0c;实质是科学和信息的进化史。 如今&#xff0c;数字化浪潮席卷全球&#xff0…

React Hooks(常用)笔记

一、useState&#xff08;保存组件状态&#xff09; 1、基本使用 import { useState } from react;function Example() {const [initialState, setInitialState] useState(default); } useState(保存组件状态) &#xff1a;React hooks是function组件(无状态组件) &#xf…

ruby 配置代理 ip(核心逻辑)

在 Ruby 中配置代理 IP&#xff0c;可以通过设置 Net::HTTP 类的 Proxy 属性来实现。以下是一个示例&#xff1a; require net/http// 获取代理Ip&#xff1a;https://www.kuaidaili.com/?refrg3jlsko0ymg proxy_address 代理IP:端口 uri URI(http://www.example.com)Net:…

【EI会议征稿】2024年先进机械电子、电气工程与自动化国际学术会议(ICAMEEA 2024)

2024 International Conference on Advanced Mechatronic, Electrical Engineering and Automation ●会议简介 2024年先进机械电子、电气工程与自动化国际学术会议&#xff08;ICAMEEA 2024&#xff09;将汇聚全球机械电子、电气工程与自动化领域的专家学者&#xff0c;共同…

网际互联及OSI七层模型

1什么是OSI七层模型 2OSI每一个Layer的定义 及用途 3如何使用OSI参考模型分析网络通信过程 一、网际互联 &#xff08;一&#xff09;OSI的概念&#xff1a; open system interconnect开放系统互联参考模型&#xff0c;是有ISO&#xff08;国际标准化组织&#xff09;定义…

科技的崛起:国内机器视觉蓬勃发展

文 | BFT机器人 在工业4.0的浪潮下&#xff0c;随着科技的蓬勃发展&#xff0c;机器视觉逐渐走入大众视野&#xff0c;机器视觉产品的普及范围也越来越广。 大家知道机器视觉的由来吗&#xff1f; 机器视觉的由来可以追溯到20世纪70年代&#xff0c;美国麻省理工学院&#xff…

【Leetcode每日一题】 动态规划 - 地下城游戏(难度⭐⭐⭐)(61)

1. 题目解析 题目链接&#xff1a;174. 地下城游戏 这个问题的理解其实相当简单&#xff0c;只需看一下示例&#xff0c;基本就能明白其含义了。 2.算法原理 一、状态表定义 在解决地下城游戏问题时&#xff0c;我们首先需要对状态进行恰当的定义。一个直观的想法是&#x…

Python 将PowerPoint (PPT/PPTX) 转为HTML格式

PPT是传递信息、进行汇报和推广产品的重要工具。然而&#xff0c;有时我们需要将这些精心设计的PPT演示文稿发布到网络上&#xff0c;以便于更广泛的访问和分享。本文将介绍如何使用Python将PowerPoint文档转换为网页友好的HTML格式。包含两个示例&#xff1a; 目录 Python 将…

如何用idm下载迅雷文件 idm怎么安装到浏览器 idm怎么设置中文

如果不是vip用户使用迅雷下载数据文件&#xff0c;其下载速度是很慢的&#xff0c;有的时候还会被限速&#xff0c;所以很多小伙们就开始使用idm下载迅雷文件&#xff0c;idm这款软件最大的优势就是下载速度快&#xff0c;还有就是具备网页捕获功能&#xff0c;能够下载网页上的…

C++ - STL详解—vector类

一. vector的概念 向量&#xff08;Vector&#xff09;是一个封装了动态大小数组的顺序容器&#xff08;Sequence Container&#xff09;。跟任意其它类型容器一样&#xff0c;它能够存放各种类型的对象。可以简单的认为&#xff0c;向量是一个能够存放任意类型的动态数组。 …

探索MATLAB在计算机视觉与深度学习领域的实战应用

随着人工智能技术的快速发展&#xff0c;计算机视觉与深度学习已成为科技领域中最热门、最具挑战性的研究方向之一。 它们的应用范围从简单的图像处理扩展到了自动驾驶、医疗影像分析、智能监控行业等多个领域。 在这样的背景下&#xff0c;《MATLAB计算机视觉与深度学习实战…