如何微调Segment Anything Model

news2024/10/2 18:21:39

文章目录

  • 什么是SAM?
  • 什么是模型微调?
  • 为什么要微调模型?
  • 如何微调 Segment Anything 模型
    • 背景与架构
    • 创建自定义数据集
    • 输入数据预处理
    • 训练设置
    • 循环训练
    • 保存检查点并从中启动模型
  • 下游应用程序的微调

随着 Meta 上周发布的 Segment Anything Model (SAM),计算机视觉迎来了它的 ChatGPT 时刻。SAM 训练了超过 110 亿个分割掩码,是预测 AI 用例而非生成 AI 的基础模型。虽然它在对广泛的图像模式和问题空间进行分割的能力方面表现出了令人难以置信的灵活性,但它在发布时没有“微调”功能。
本教程将概述使用掩码解码器微调 SAM 的一些关键步骤,特别是描述 SAM 中使用哪些函数对数据进行预处理/后处理,使其处于良好状态以进行微调。

什么是SAM?

Segment Anything Model (SAM) 是由 Meta AI 开发的一种分割模型。它被认为是计算机视觉的第一个基础模型。SAM 在包含数百万图像和数十亿掩码的庞大数据集上进行了训练,使其非常强大。顾名思义,SAM 能够为各种图像生成准确的分割掩码。Sam 的设计允许它考虑人类提示,这使得它对于 Human In The Loop 注释特别强大。这些提示可以是多模式的:它们可以是要分割区域上的点、要分割对象周围的边界框或关于应该分割什么的文本提示。

该模型分为 3 个组件:图像编码器、提示编码器和掩码解码器。

在这里插入图片描述

图像编码器为被分割的图像生成嵌入,而提示编码器为提示生成嵌入。图像编码器是模型中特别大的组件。这与轻量级掩码解码器形成对比,轻量级掩码解码器根据嵌入预测分割掩码。Meta AI 已将在 Segment Anything 10 亿掩码 (SA-1B) 数据集上训练的模型的权重和偏差用作模型检查点。在此处的解释器博客文章中了解有关 Segment Anything 工作原理的更多信息。

什么是模型微调?

公开可用的最先进模型具有自定义架构,通常提供预训练模型权重。如果这些架构在没有权重的情况下提供,那么用户将需要从头开始训练模型,用户将需要使用大量数据集来获得最先进的性能。

模型微调是采用预训练模型(架构+权重)并显示特定用例数据的过程。这通常是模型以前没有见过的数据,或者在其原始训练数据集中代表性不足的数据。

微调模型和从头开始的区别在于权重和偏差的起始值。如果我们从头开始训练,这些将根据某种策略随机初始化。在这样的初始配置中,模型对手头的任务“一无所知”并且表现不佳。通过使用预先存在的权重和偏差作为起点,我们可以“微调”权重和偏差,以便我们的模型在我们的自定义数据集上更好地工作。例如:学习识别猫的信息(边缘检测、计数爪子)将对识别狗有用。

为什么要微调模型?

微调模型的目的是在预训练模型以前没有见过的数据上获得更高的性能。例如,在从手机摄像头收集的广泛数据集上训练的图像分割模型将主要从水平视角看到图像。

如果我们尝试将此模型用于从垂直角度拍摄的卫星图像,它的性能可能不会那么好。如果我们试图分割屋顶,该模型可能不会产生最佳结果。预训练很有用,因为模型已经学会了一般如何分割对象,所以我们想利用这个起点来构建一个可以准确分割屋顶的模型。此外,我们的自定义数据集可能不会有数百万个示例,因此我们希望微调而不是从头开始训练模型。

微调是可取的,这样我们就可以在特定用例上获得更好的性能,而不必承担从头开始训练模型的计算成本。

如何微调 Segment Anything 模型

背景与架构

我们在介绍部分概述了 SAM 体系结构。图像编码器具有包含许多参数的复杂架构。为了微调模型,我们将重点放在轻量级的掩码解码器上是有意义的,因此微调更容易、更快、内存效率更高。

为了微调 SAM,我们需要提取其架构的底层部分(图像和提示编码器、掩码解码器)。我们不能使用SamPredictor.predict (链接) 有两个原因:

  • 我们只想微调掩码解码器
  • 这个函数调用SamPredictor.predict_torch,它有 @torch.no_grad()装饰器(链接),它阻止我们计算梯度。

因此,我们需要检查SamPredictor.predict函数并调用适当的函数,并在我们想要微调的部分(掩码解码器)上启用梯度计算。这样做也是了解更多有关 SAM 工作原理的好方法。

创建自定义数据集

我们需要三件事来微调我们的模型:

  • 在其上绘制分割的图像
  • 分割真值掩码
  • 提示输入模型,我正在使用边界框

我选择了印章验证数据集(链接),因为它包含 SAM 在其训练中可能没有看到的数据(即文档上的印章)。我可以通过使用预先训练的权重运行推理来验证它在这个数据集上表现良好,但并不完美。ground truth masks 也非常精确,这将使我们能够计算准确的损失。最后,该数据集包含分割掩码周围的边界框,我们可以将其用作 SAM 的提示。示例图像如下所示。这些边界框与人工注释者在寻找生成分段时所经历的工作流程非常吻合。

在这里插入图片描述

输入数据预处理

我们需要预处理从 numpy 数组到 pytorch 张量的扫描。为此,我们可以关注SamPredictor.set_image (链接) 和SamPredictor.set_torch_image (链接) 内部发生的事情,它们对图像进行预处理。首先,我们可以使用utils.transform.ResizeLongestSide来调整图像大小,因为这是预测器内部使用的转换器 ( link )。然后我们可以将图像转换为pytorch张量,并使用 SAM 预处理方法(链接)完成预处理。

训练设置

我们下载vit_b模型的模型检查点并将它们加载到:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')
我们可以使用默认设置 Adam 优化器,并指定要调整的参数是掩码解码器的参数:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())
同时,我们可以设置我们的损失函数,例如Mean Squared Error

loss_fn = torch.nn.MSELoss()

循环训练

在主训练循环中,我们将迭代我们的数据项,生成掩码并将它们与我们的真实掩码进行比较,以便我们可以根据损失函数优化模型参数。

在此示例中,我们使用 GPU 进行训练,因为它比使用 CPU 快得多。在适当的张量上使用.to(device)非常重要,以确保我们不会在 CPU 上使用某些张量而在 GPU 上使用其他张量。

我们希望通过将编码器包装在torch.no_grad()上下文管理器中来嵌入图像,否则我们将遇到内存问题,以及我们不希望微调图像编码器的事实。

with torch.no_grad():
	image_embedding = sam_model.image_encoder(input_image)

我们还可以在 no_grad 上下文管理器中生成提示嵌入。我们使用我们的边界框坐标,转换为 pytorch 张量。

with torch.no_grad():
      sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
          points=None,
          boxes=box_torch,
          masks=None,
      )

最后,我们可以生成掩码。请注意,这里我们处于单一掩码生成模式(与通常输出的 3 个掩码形成对比)。

low_res_masks, iou_predictions = sam_model.mask_decoder(
  image_embeddings=image_embedding,
  image_pe=sam_model.prompt_encoder.get_dense_pe(),
  sparse_prompt_embeddings=sparse_embeddings,
  dense_prompt_embeddings=dense_embeddings,
  multimask_output=False,
)

这里的最后一步是将蒙版放大回原始图像大小,因为它们的分辨率很低。我们可以使用Sam.postprocess_masks来实现这一点。我们还希望根据预测的掩码生成二进制掩码,以便我们可以将它们与我们的基本事实进行比较。为了不破坏反向传播,使用torch功能很重要。

upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

from torch.nn.functional import threshold, normalize

binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

最后我们可以计算损失并运行优化步骤:

loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

通过在多个时期和批次上重复此操作,我们可以微调 SAM 解码器。

保存检查点并从中启动模型

一旦我们完成训练并对性能提升感到满意,我们可以使用:

torch.save(model.state_dict(), PATH)

保存调优模型的状态字典。当我们想要对与我们用于微调模型的数据相似的数据执行推理时,我们可以加载这个状态字典。

您可以在此处找到 Colab Notebook,其中包含微调 SAM 所需的所有代码。如果您想要一个开箱即用的完整解决方案,请继续阅读!

下游应用程序的微调

虽然 SAM 目前不提供开箱即用的微调,但我们正在构建与 Encord 平台集成的自定义微调器。如本文所示,我们微调解码器以实现此目的。这可以在 Web 应用程序中作为开箱即用的一键式程序使用,其中会自动设置超参数。
在这里插入图片描述
原SAM 预测结果:

在这里插入图片描述

由模型的微调版本生成的掩码:

在这里插入图片描述
我们可以看到这个面具比原来的面具更紧。这是对邮票验证数据集中的一小部分图像进行微调,然后在以前未见过的示例上运行调优模型的结果。通过进一步的训练和更多的例子,我们可以获得更好的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/428439.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DAY 33 shell编程 常用的文本命令

sort命令####排序 sort将文件的每一行作为一个单位相互比较,比较原则是从首字符向后依次按ASCII码进行比较,最后将它们按升序输出。(以行为单位来排序输出) sort [选项] 参数​cat file | sort 选项常用选项: 常用选…

计算机综合题汇总

(数学计算题) 把6个相同的球分到三个不同的学生身上,允许有的学生没有球,请问有多少种不同的方法? C(8,2)=28。 典型的插板问题,直接套公式,C(n+m-1,m-1)。6个球,本身5个空,有同学可以不分,再加3个空,共8个空;插入2个板。 (软件选择题) 软件质量保证是什么? A. 确保…

超外差收音机的制作-电子线路课程设计-实验课

超外差收音机的制作 一、原理部分: 超外差收音机:超外差式收音机是将接收到的不同频率的高频信号全部变成一个固定的中频信号进行放大,因而电路对各种电台信号的放大量基本是相同的,这样可以使中放电路具有优良的频率特性。 超…

Adobe考证

在数字化时代,Adobe软件已成为许多人工作和创造的必备工具。为了证明自己在使用Adobe软件方面的专业能力,许多人选择参加Adobe认证考试并获取Adobe认证证书。 本文将从以下几个方面介绍Adobe考证的相关内容...... 什么是Adobe认证考试? Ado…

我的面试八股(JAVA并发)

重点AQS以及几种锁的底层需要补充!!!!! 程序计数器为什么是线程私有的? 程序计数器主要有下面两个作用: 字节码解释器通过改变程序计数器来依次读取指令,从而实现代码的流程控制&#xff0c…

VS Code配置C/C++开发环境

一、VS Code安装C/C++插件 二、配置MinGW 进入下载页面mingw-w64,找到winlibs-x86_64-mcf-seh-gcc-13.0.1-snapshot20230402-mingw-w64ucrt-11.0.0-r1.7z,点击下载。将文件放到自己想要放置的盘符下面,然后解压,将里面的mingw64目录剪切到最外层。 拷贝目录,将目录添加到…

【静态Web架构】静态站点生成器概述 Gatsby ,Hugo 和Jekyll对比

在本文中,您将看到三种最好的静态站点生成器的比较,它们的优点、缺点以及您应该使用它们的原因。网站统治着网络,无论是静态的还是动态的。虽然现在很多网站都是动态的,但是静态的仍然很受欢迎。事实上,静态网站的使用…

高频丙类谐振功率放大器【Multisim】【高频电子线路】

目录 一、实验目的与要求 二、实验仪器 三、实验内容与测试结果 1、观察输入、输出波形 2、观察不同工作状态下的集电极电流波形 3、测试负载特性 4、测试集电极调制特性 四、实验结果分析 五、参考资料 一、实验目的与要求 1、通过实验加深理解高频谐振功率放大器电路…

R -- 层次聚类和划分聚类

brief 聚类分析是一种数据归约技术,旨在揭漏一个数据集中观测值的子类。子类内部之间相似度最高,子类之间差异性最大。至于这个相似度是一个个性化的定义了,所以有很多聚类方法。 最常用的聚类方法包括层次聚类和划分聚类。 层次聚类&#…

vscode连接linux

vscode连接linux第一步:下载扩展第二步:打开左侧的那个类似小电脑的选项第三步:点击那个螺丝按钮第四步:选第一个第五步:配置config文件第六步:打开设置第七步:在搜索栏搜索:Always reveal the SSH login terminal第八步:重启vscode第八步:输入密码后,点击右上角号旁边的"…

什么是存算分离架构?

随着硬件技术的快速进步,尤其是网络和存储设备的性能迅速提升,以及云计算厂商推动软硬件协同加速的云存储服务,越来越多的企业开始基于云存储来构建数据存储服务,或数据湖,因此就需要单独再建设一个独立的计算层来提供…

C++ LinuxWebServer 2万7千字的面经长文(上)

⭐️我叫忆_恒心,一名喜欢书写博客的在读研究生👨‍🎓。 如果觉得本文能帮到您,麻烦点个赞👍呗! 前言 Linux Web Server项目虽然是现在C求职者的人手一个的项目,但是想要吃透这个项目&#xff…

不得不说的创建型模式-工厂方法模式

工厂方法模式是创建型模式之一,它定义了一个用于创建对象的接口,但将具体创建的过程延迟到子类中进行。换句话说,它提供了一种通过调用工厂方法来实例化对象的方法,而不是通过直接使用 new 关键字来实例化对象。 下面是一个使用 C…

[架构之路-167]-《软考-系统分析师》-4-据通信与计算机网络-3- 常见局域网与广域网

目录 4 . 3 局域网与广域网 4.3.1 局域网基础知识 1 . 星型结构 2 . 总线结构 3 . 环型结构 4 . 网状结构 4.3.2 以太网技术(接入网) 1 . 以太网基础 2 . 帧结构 3 . 以太网物理层规范 4.3.3 无线局域网(接入网) 1 . …

huggingface TRL是如何实现20B-LLM+Lora+RLHF

huggingface TRL实现20B-LLMLoraRLHFIntroductionWhat is TRL?Training at scale8-bit matrix multiplicationLoraWhat is PEFT?Fine-tuning 20B parameter models with Low Rank Adapter参考Introduction 作者首先表示RLHF在目前LLM的训练中是一种很powerful的方式&#xf…

SpringBoot 整合Quartz定时任务管理【SpringBoot系列18】

SpringCloud 大型系列课程正在制作中,欢迎大家关注与提意见。 程序员每天的CV 与 板砖,也要知其所以然,本系列课程可以帮助初学者学习 SpringBooot 项目开发 与 SpringCloud 微服务系列项目开发 Quartz是由Java语言编写,是OpenSym…

【环境搭建:onnx模型部署】onnxruntime-gpu安装与测试(python)

ONNX模型部署环境创建1. onnxruntime 安装2. onnxruntime-gpu 安装2.1 方法一:onnxruntime-gpu依赖于本地主机上cuda和cudnn2.2 方法二:onnxruntime-gpu不依赖于本地主机上cuda和cudnn2.2.1 举例:创建onnxruntime-gpu1.14.1的conda环境2.2.2 …

Spring整合MyBatis与JUnit

Spring整合 想必到现在我们已经对Spring有一个简单的认识了,Spring有一个容器,叫做IoC容器,里面保存bean。在进行企业级开发的时候,其实除了将自己写的类Spring管理之外,还有一部分重要的工作就是使用第三方的技术。前…

Spring —— Spring Boot 创建和使用

JavaEE传送门JavaEE Spring —— Spring简单的读取和存储对象 Ⅱ Spring —— Bean 作用域和生命周期 目录Spring Boot 创建和使用Spring BootSpring Boot 项目创建使用 IDEA 创建网页版创建Spring Boot 目录介绍运行 Spring Boothello world约定大于配置Spring Boot 创建和使…

关于SeaDAS的安装教程以及使用问题笔记

2022年硕士研究生最后半个学期,已经交完了毕业论文,因为觉得工作以后会用到SeaDAS就拿出了一些时间学习,现在已经工作快一年了,而工作中也并没有用到这个软件,估计以后也不会用到了吧。现在把当时学习整理的一些笔记分…