Segment Anything模型部分结构和代码解析

news2024/11/20 3:23:41

0x0. 前言

上个月Meta提出的Segment Anything Model(SAM)希望在视觉领域通过Prompt+基础大模型的套路来解决目标分割的问题。经过实测,在大多数场景中SAM的表现都足够惊艳,并且基于SAM的各种二创工作也爆发了比如 检测一切的Grounded-Segment-Anything(https://github.com/IDEA-Research/Grounded-Segment-Anything),将Segment Anything扩展到医学图像领域 。但目前中文社区似乎并没有怎么对SAM的模型做细致的解析,所以这里 fork了SAM仓库并且对模型实现部分做了详细的代码解析,fork仓库的地址如下:https://github.com/Oneflow-Inc/segment-anything

本文会对照论文的结构图和fork的SAM仓库中的代码注释尝试梳理一下SAM模型部分的代码。最后,我也会介绍一下如果你想用oneflow来跑SAM应该怎么做,实际上在预测脚本里面加2行代码就可以了:

import oneflow.mock_torch as mock
mock.enable(lazy=True, extra_dict={"torchvision": "flowvision"})

最后汇总一下这个fork的SAM仓库做的事情:

  • 对 https://github.com/Oneflow-Inc/segment-anything/tree/main/notebooks 下面的推理脚本进行汉化。
  • 对 https://github.com/Oneflow-Inc/segment-anything/blob/main/README_zh.md 进行汉化。
  • 对 https://github.com/Oneflow-Inc/segment-anything/tree/main/segment_anything/modeling SAM的模型实现进行全面解析,为每个函数代码实现添加中文注释。
  • 基于oneflow的mock torch技术一键切换 oneflow 后端运行SAM模型推理,方便基于oneflow做二次开发以及性能优化。

欢迎点击star: https://github.com/Oneflow-Inc/segment-anything

在这里插入图片描述

0x1. 模型+代码解析

在这里插入图片描述
实际上模型实现部分就对应了这张图。

其中绿色部分表示将原始图像编码为向量,SAM中使用VIT来实现图像编码器。原始图像被等比和 padding 的缩放到1024大小(对应https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/sam.py#L131),然后采用kernel size16stride也为16的卷积将图像离散化为batch_size x 64x64X768的向量(对应 https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L482-L518,),向量在W和C上被顺序展平(https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L208)后再进入多层的transformer encoder,vit输出的向量再通过两层的卷积(kernel size分别为13,每层输出接LayerNorm2d)压缩到特征维度为256https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L98-L114)。

image encoder部分的详细代码细节的解释请查看:https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

接下来紫色部分表示prompt encoder,prompt encoder的输出包括对点,box和text进行编码组成的sparse_embeddings以及对输入mask进行编码的dense_embeddings (对应https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L251)。最后,sparse_embeddings的输出shape是batch_sizexNx(embed_dim=256),其中N由输入的点和框的个数决定。而dense_embeddings的输出shape是batch_sizex(embed_dim=256)x(embed_H)x(embed_W),其中embed_H和embed_H都等于64。(https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/build_sam.py#L73)。注意图上的对mask的卷积操作对应 https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py#L64-L71

prompt encoder部分的详细代码细节的解释请查看:https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py

最后我们看一下 Mask Decoder部分,也就是图中的橙色部分。Mask Decoder的细节可以用下图来表示:

在这里插入图片描述

这里的image embedding(256x64x64)就是上面的image decoder的输出,因为输入到Mask Decoder的时候是对batch维度进行遍历然后处理,所以这里的维度没有Batch。然后左下角的output tokens+prompt tokens( N t o k e n s × 256 N_{tokens}\times 256 Ntokens×256)分别表示iou token embedding和3个分割结果 token的embedding(sparse_embeddings+dense_embeddings)。(对应:https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py#LL171C9-L173C1)。这里还需要注意的一个细节是prompt embedding部分的dense embedding直接叠加到了image embedding上。(对应https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py#L175-L177C18)。

接着在 Transformer 实现中每一层都做

  • token embedding 做self attention 计算。
  • token embedding 和src 之间做cross attention 计算。
  • src 和 token embedding 之间做 cross attention 计算。
  • 第 2 和 3 之间有前馈 MLP 网络;cross attention的结果通过残差方式相加并norm。

详细的代码解释请看:https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/transformer.py#L133-L244

这个Transformer 块里面的右上角有一个 x2 ,这里的意思是Transformer的层数为2。然后这里的紫色和绿色的箭头表示当前的Attention模块的query, key, value的来源,每层有1个self attention和2个cross attention模块。transform最终输出前,token embedding 还需要和src 做一次cross attention,也就是图中的token to image attn。

最后,Transform 返回的3个 mask token 的 embedding 经过3层mlp后,与对齐后的图像embedding点积得到 3 个最终的分割结果;iou token 经过mlp得到3个分割结果置信度得分。(对应:https://github.com/Oneflow-Inc/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py#L182-L199

0x2. 切换SAM的后端

SAM的推理脚本默认是使用PyTorch运行,如果你想使用oneflow来执行并尝试获得推理加速,可以在运行脚本之前加上:

import oneflow.mock_torch as mock
mock.enable(lazy=True, extra_dict={"torchvision": "flowvision"})

OneFlow版本需要安装nightly,这样就可以用上OneFlow作为后端来推理SAM了。关于mock torch 黑魔法可以查看 https://docs.oneflow.org/master/cookies/oneflow_torch.html 这个官方文档。

oneflow nightly版本的安装方法如下:https://github.com/Oneflow-Inc/oneflow#install-with-pip-package

遗憾的是,我们还未来得及做调优工作,如果对使用OneFlow对SAM做推理加速感兴趣的读者可以自行尝试活着联系我本人一起讨论和实施。

0x3. 总结

本文介绍了 https://github.com/Oneflow-Inc/segment-anything 做的一些事情并解析了SAM的结构和代码实现。对于SAM来说,相比于模型更重要的是最数据进行处理,关于这方面大家可以参考:https://zhuanlan.zhihu.com/p/620355474

0x4. 后续工作

后面有时间的话会继续汉化onnx导出的jupyet notebook,并且做一下相关的性能调优工作以及剩余的SamAutomaticMaskGenerator的解析。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/502346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Prometheus监控报警-web,域名,端口, 钉钉版本

Prometheus监控报警-web,域名,端口, 钉钉版本 采用文章 https://www.yuque.com/youngfit/qok2pe/nypstd#616a2e58https://www.jianshu.com/p/eae67b770c3ehttps://blog.csdn.net/Sebastien23/article/details/113645177https://www.cnblogs…

Unix套接字(UDS,Unix Domain Socket)

【知识简介】 在​​Linux​​系统中,有很多进程间通信方式,套接字(Socket)就是其中的一种。但传统的套接字的用法都是基于TCP/IP协议栈的,需要指定IP地址。如果不同主机上的两个进程进行通信,当然这样做没…

BEVFusion A Simple and Robust LiDAR-Camera Fusion Framework 论文学习

论文地址:BEVFusion: A Simple and Robust LiDAR-Camera Fusion Framework 论文学习 Github 地址:BEVFusion: A Simple and Robust LiDAR-Camera Fusion Framework 论文学习 1. 解决了什么问题? 将相机和 LiDAR 融合已经成为 3D 检测任务事…

【MySQL】数据库的增删改查二:CURD

目录 ​ 需要知道 🌟一、增加数据 🌈1、语法 🌈2、单行数据,全列插入 🌈3、多行数据,全列插入 🌟二、查询数据 🌈1、全列查询 🌈2、指定列查询 &#…

Nacos集群部署配置Nginx负载均衡

Nacos集群部署配置Nginx负载均衡 1|新建nacos文件夹 mkdir nacos 新建文件夹 cd nacos 进入文件夹2|下载Nacos安装包(前提是云服务器,有网。也可以在windows下载好再上传) wget https://github.com/alibaba/nacos/releases/download/2…

新品发布 | 12通道CAN FD转USB接口卡全新上市!

新品发布 ON 05.05 TC1018是同星智能开发的一款12路CAN FD总线转USB接口卡,配合我们的TSMaster软件可以监控、分析和仿真CAN FD总线数据。广泛应用于汽车、工业、特种机械和其他行业,用于CAN总线测试与分析、UDS诊断和ECU刷写等方面。 TC1018-产品简介…

weka3.8.6的安装与使用

目录 背景 一、安装 二、使用explorer 1. 介绍 2.打开自带的数据集(Preprocess) 1.打开步骤 2.查看属性和数据编辑 3.classify 4.Cluster 5.Associate 6.Select attributes 7.Visualize 待补充 背景 Weka的全名是怀卡托智能分析环境(Waikato Environme…

进程调度/页面置换/磁盘调度算法

进程调度算法 进程调度算法也称 CPU 调度算法,毕竟进程是由 CPU 调度的。 当 CPU 空闲时,操作系统就选择内存中的某个「就绪状态」的进程,并给其分配 CPU。 什么时候会发生 CPU 调度呢?通常有以下情况: 当进程从运…

AIGC:【LLM(二)】——LangChain:由LLMs驱动的应用开发框架

文章目录 一.背景介绍二.LangChain简介2.1 常见应用场景 三.LangChain特点3.1 优点3.2 不足 四.LangChain功能4.1 基础功能4.2 功能模块4.2.1 LLM和Prompts4.2.2 Chain4.2.3 Agent4.2.4 Memory4.2.5 Embedding4.2.6 Models4.2.7 Indexes 五.实战案例5.1 背景需求5.2 数据准备5.…

抖音seo矩阵系统源码是什么?

抖音SEO矩阵系统源码是一款功能强大的营销工具,能够帮助用户进行抖音视频的SEO优化,使其在抖音平台上获得更高的曝光度和流量。该系统结合了SEO的相关算法和技巧,提供了完整的优化方案,可帮助用户提高视频的曝光率、获得更多的点赞…

阻塞队列原理及Java实现

目录 1.阻塞队列 1.举例:包饺子 1.通过多线程来实现 2.通过阻塞队列来实现 2.消息队列 1.解耦 2.削峰填谷 用消息队列来解决 3.异步操作 3.实现一个阻塞队列 使用循环数组 4.实现生产者和消费者模型 完整代码 5.虚假唤醒 1.概念及原因 2.解决方法 1…

关于GD32替换STM32(pin to pin)搭载rt-thread操作系统,需要注意的问题总结

1、SystemInit()函数 该函数位于启动文件中的Reset_Handler中(具体实现在GD32位于system_gd32f4xx.c,STM32位于system_stm32f4xx.c中,几乎所有的文件,你只要把gd换成st就能找到对应的文件),gd的叫startup_gd32Fxxx.s,…

4.HIVE函数

1.hive函数 1.1 空值替换 两个输入:nvl(col,default_num) : 如果colum不为null,返回col.否则返回default_num 多个输入:coalesce(col1, col2, col3, ....) :从左到右找第一个不为null的值 例如:求所有员工的平均薪…

【操作系统】总结

依旧是小林coding 的内容 存储架构 现代 CPU 都是多核心的,线程可能在不同 CPU 核心来回切换执行,这对 CPU Cache 不是有利的,虽然 L3 Cache 是多核心之间共享的,但是 L1 和 L2 Cache 都是每个核心独有的,如果一个线…

VMWare安装windows7虚拟机提示Operating System not found

前提:下载windows7 Gost并创建虚拟机,启动报错:Operating System not found 解决办法 用微PE工具制作iso系统,对虚拟机进行分区 下载地址:https://www.wepe.com.cn/ 制作方法,双击安装程序,选…

最困难的也是最简单的,做好这两点不盈利天理难容

投资者应该时刻记住,在外汇交易中复杂的方法并不总是最好的。Forexclub发现交易中最困难的是正确识别进场点和出场点。 从技术上来说,进入交易是非常容易的,你只需要点击一个按钮,你就在那里交易。但是你会从中获利吗?没人能回答…

【Linux Network】网络编程套接字(代码练习)—UDP

目录 1. 常用接口 2. C/S 回声模拟 3. C/S myshell 的制作 Linux网络编程✨ 1. 常用接口 socket:创建套接字: // 创建 socket 文件描述符 int socket(int domain, int type, int protocol); 返回值: 套接字创建成功返回一个文件描述符 &…

GAMMA电源维修直流高压电源模块RR300-1P

美国GAMMA高压电源维修参数(RR分离式): 输入:220VAC 或 380VAC(视型号而定) 输出电压:550KV,功率:0-10KW或定制 纹波率0.01 ;稳定度0.01/1H 控制部分19英…

(只需两步)让ChatGPT帮你制作出漂亮的PPT

目录 第一步:生成 PPT 代码 第二步:将代码转化为 PPT 还在为制作PPT而烦恼吗? 让ChatGPT来帮您! 本篇文章介绍如何利用ChatGPT一键生成PPT文字和样式,省时省力又专业! (真的只需两步&#xf…

案例实践|云智慧ITSM产品在利星行汽车的运维实践

ITSM(信息技术服务管理)是一种以客户为中心的方法,旨在提高信息技术的效率和效果。在传统零售行业,ITSM可以帮助连锁零售企业提升客户服务水平,通过IT服务台提供快速响应和解决客户的问题和需求。同时, ITS…