音视频开发之旅(90)-Vision Transformer论文解读与源码分析

news2025/1/12 1:07:44

目录

1.背景和问题

2.Vision Transformer(VIT)模型结构

3.Patch Embedding

4.实现效果

5.代码解析

6.资料

一、背景和问题

上一篇我们学习了Transformer的原理,主要介绍了在NLP领域上的应用,那么在CV(图像视频)领域该如何使用?

最直观的想法就是把每一个像素像NLP中一个文字一样处理,理论上可行,但是这样做有什么不足吗?

Transformer的自注意力机制的计算复杂度是O(n^2),其中n是序列长度,一张720*1280的图片就需要921600个token,这将导致巨大的计算开销,使得模型的训练和推理非常缓慢。图像不同像素之间存在很多冗余信息(编码时会进行帧内压缩),是否可以采用类似编码压缩技术中的宏块方案呐(把图像分割为固定大小的16x16、8x8、4x4的的块)。

二、VIT模型结构

VIT的思路和视频编码的宏块思想类似,把图像分割为固定大小pathchs,然后通过线性变换得到patch embedding,将图像的patch embeddings送入transformer的Encoder进行特征提取,在根据不同任务添加不同的Head。ViT模型原理如下图所示:

图片

模型由三个模块组成:

  • Linear Projection of Flattened Patches(该网络的前处理,把图像分割为patch,然后进行Embedding)

  • Transformer Encoder(该网络的backbone,用于特征提取)

  • MLP Head(该网络的head,用于分类任务)

主要的公式如下:

图片

图片

可以看到VIT只用到了Transfomer的Encoder作为backbone进行特征提取,TransfomerEncoderLayer也是使用Multi-head Attention,不同的是LayerNormalation放在了Multi-head Attention的前面。和Transfromer的结构主要区别在于Embedding的过程,如果对于注意力机制还不太清楚,建议复习下上一篇。

三、Patch Embedding

图片

关键点包括:

  1. 图像被分割成固定大小的patches。

  2. 每个patch通过线性投影映射到嵌入空间。

  3. 添加一个特殊的分类token。

  4. 加入位置编码以保留空间信息。

将2D图像转换为一个1D序列,使得标准Transformer架构可以直接处理图像数据,允许ViT像处理文本序列一样处理图像,充分利用了Transformer的自注意力机制来捕捉图像中的全局依赖关系。

下面我们用一个示例来说明PatchEmbedding的过程。

输入一张:256x256的rgb图像,然后把它分割为64个32x32的patchs,对patchs进行线性投影得到序列长度为64,dim为1024的Embedding,然后加上用于分类的可训练的classToken(随机初始化),最后在加上相同形状的PosEmbedding 作为TransformEncodeer的输入。

图片

图片来自:详解 Vision Transformer

图片

不同于Transfromer的PositionEmbedding(采用sin和cos固定编码),VIT中的PositionEmbedding采用了符合正态分布随机初始化,可训练的方案(bert也采用了类似方式)

论文中对学习到的positional embedding进行了可视化,发现相近的patchs的positional embedding比较相似,而且同行或同列的positional embedding也相近:

图片

需要注意的是:如果改变图像的输入大小,ViT不会改变patchs的大小,patchs的数量会发生变化,之前学习的pos_embed就维度对不上了,通常ViT采用插值的方式来解决这个问题,但效果不好,另外一篇论文给出了说明和解决措施 https://arxiv.org/pdf/2102.10882,有兴趣可以进一步研究下。

四、实验效果

ViT的训练策略:先在大数据集上做预训练,然后在小数据集上做迁移使用。

图片

如果在小数据集ImageNet上做预训练时,VIT的模型架构效果普遍低于ResNet搭建的BiT网络;当在中等数据集ImageNet-21k上做预训练时,VIT的模型架构基本位于BiT最好和最差的之间;而当在大数据集JFT-300M上做预训练时,VIT的模型架构最好的效果已经超过了BiT。

结论:VIT模型需要在大数据集上进行预训练,在大数据集上预训练的效果会比卷积神经网络的上限高

例如下图先在有3亿张图像的JFT大数据集上预训练,然后在ImageNet上进行微调,准确率达到88.55%

图片

ViT 还可根据 Attention Map 来可视化,得知模型具体关注图像的哪个部分,

图片

五、代码解析

源码地址:https://github.com/lucidrains/vit-pytorch

图片

图片来自:Vision Transformer详解

3.1、调用

import torchfrom vit_pytorch import ViT
def test():    #VIT的具体实现在vit.py中    v = ViT(        #原始图像尺寸        image_size = 256,        #切割的每个图像块的尺寸        patch_size = 32,        #类别数量        num_classes = 1000,        #Transformer隐变量维度大小        dim = 1024,        #Transformer Encoder层的个数        depth = 6,        #Multi-Head Attention 头的个数        heads = 16,        #mlp层 hid层的维度        mlp_dim = 2048,        dropout = 0.1,        emb_dropout = 0.1    )
    img = torch.randn(1, 3, 256, 256)
    preds = v(img)

3.2、Attention和FFN的实现

# helpers#确保t为元组def pair(t):    return t if isinstance(t, tuple) else (t, t)
# classes#前馈网络class FeedForward(nn.Module):    def __init__(self, dim, hidden_dim, dropout = 0.):        super().__init__()        self.net = nn.Sequential(            nn.LayerNorm(dim),            nn.Linear(dim, hidden_dim),            nn.GELU(),            nn.Dropout(dropout),            nn.Linear(hidden_dim, dim),            nn.Dropout(dropout)        )
    def forward(self, x):        return self.net(x)
#VIT中的self-Attention实现,这里也是多头注意力机制class Attention(nn.Module):    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):        super().__init__()        inner_dim = dim_head *  heads #多头的个数heads:16 * 每个头的维度:64 =1024        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads        self.scale = dim_head ** -0.5 # dim_head =64, scale=1/8
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim = -1)        self.dropout = nn.Dropout(dropout)        #to_qkv线性变化,将输入映射到一个三维空间,以便在多头注意力机制中生成QKV 输入特征维度为dim (1024),输出维度为inner_dim*3 (1024*3)        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) #dim:1024,inner_dim:1024
        self.to_out = nn.Sequential(            nn.Linear(inner_dim, dim),            nn.Dropout(dropout)        ) if project_out else nn.Identity()
    def forward(self, x):        x = self.norm(x)        #将输入数据x映射到三维空间,x.shape为[1,65,1024],to_qkv经过线性变换后输出维度为[1,65,1024*3]; chunk(3,-1)将最后一个维度分割为3个子张量,生成qkv元组        qkv = self.to_qkv(x).chunk(3, dim = -1)        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) #进行形状转换,生成[batchsize,heads,squcelen,dim] 值为[1,16,65,64]        #经典的attention计算, 把q和K的转置相乘除以缩放系数,得到相似性系数        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale        #沿最后一维度进行softmax归一化        attn = self.attend(dots)        attn = self.dropout(attn)        #attn[1, 16, 65, 65]点乘V [1, 16, 65, 64]输出[1, 16, 65, 64]        out = torch.matmul(attn, v)        out = rearrange(out, 'b h n d -> b n (h d)') #对多头进行concate,得到[1, 65, 1024]        return self.to_out(out)

3.3、Transfromer Encoder层的实现

#VIT中Transfromer的实现,用到了Transformer的Encoder层. 和原始的Transfromer稍微有些差异,主要是layernormalization的位置class Transformer(nn.Module):    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):#dim:1024,depth:6;heads:16;dim_head:64;mlp_dim:2048;dropout:0.1        super().__init__()        self.norm = nn.LayerNorm(dim)        self.layers = nn.ModuleList([])        for _ in range(depth):            self.layers.append(nn.ModuleList([                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),                FeedForward(dim, mlp_dim, dropout = dropout)            ]))
    def forward(self, x):        for attn, ff in self.layers:            x = attn(x) + x #Attention进行残差            x = ff(x) + x #MLP进行残差
        return self.norm(x)

3.4、ViT的实现

#入口Module,这里的posEmbedding没有使用固定编码,而是像bert一样可训练的. 把image切分成多个patch,展平进行to_patch_embedding处理class ViT(nn.Module):    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):        super().__init__()        image_height, image_width = pair(image_size)        patch_height, patch_width = pair(patch_size)
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'        # num_patches =(256//32)*(256//32)=64;  patch_dim:3*32*32=3072; dim=1024        num_patches = (image_height // patch_height) * (image_width // patch_width)        patch_dim = channels * patch_height * patch_width        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'        #使用einops的Rearrange优雅地处理张量维度        self.to_patch_embedding = nn.Sequential(            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),#这里(h p1) (w p2)就相当于h与w变为原来的1/p1,1/p2            nn.LayerNorm(patch_dim),            nn.Linear(patch_dim, dim),#patch_dim3072,dim 1024 线性变换            nn.LayerNorm(dim),        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 创建一个形状为 (1, 65, 1024) 的随机张量,VIT中PE和Transformer中positionEmbedding的定义不同,这里是一个可以训练的模块        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))#创建一个随机的张量(1,1,1024)的cls_token        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool        self.to_latent = nn.Identity()
        self.mlp_head = nn.Linear(dim, num_classes)
    def forward(self, img):        x = self.to_patch_embedding(img)        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)        x = torch.cat((cls_tokens, x), dim=1)        x += self.pos_embedding[:, :(n + 1)]        x = self.dropout(x)        #输入和输出的形状都是 torch.Size([1, 65, 1024])        x = self.transformer(x)         #这里的pool为cls分类,所以沿dim=1,取第1个数据        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]        #这里的to_latent目前就是一个恒等变换层nn.Identity(),即输入和输出每个任何变化,可以去掉,这里起到占位的作用        x = self.to_latent(x)        return self.mlp_head(x)

六、资料

1.论文VIT:https://arxiv.org/pdf/2010.11929

2.源码:https://github.com/lucidrains/vit-pytorch

3.timm/models/vision_transformer.py: https://github.com/huggingface/pytorch-image-4.models/blob/main/timm/models/vision_transformer.py

5.ViT论文逐段精读【论文精读】https://www.bilibili.com/video/BV15P4y137jb

6.Vision Transformer(vit)网络详解 https://www.bilibili.com/video/BV1Jh411Y7WQ

7.李宏毅-Transformer 

https://www.bilibili.com/video/av56239558

8.详解VisionTransformer

 https://blog.csdn.net/qq_39478403/article/details/118704747

9.Vision Transformer详解  https://blog.csdn.net/qq_37541097/article/details/118242600

10.ViT代码超详细解读 https://blog.csdn.net/weixin_43334693/article/details/131836233

11.ViT PyTorch代码全解析(附图解)

https://blog.csdn.net/weixin_44966641/article/details/118733341

12.Vision Transformer(VIT)代码分析 https://blog.csdn.net/qq_38683460/article/details/127346916

13.ViT:视觉Transformer backbone网络ViT论文与代码详解 https://mp.weixin.qq.com/s/Nok5UQ2nzex94GXyrltiBg

14.可视化VIT中的注意力 https://mp.weixin.qq.com/s/O-56hxVa6Fgiz2YpjXTodQ

15."未来"的经典之作 ViT:transformer is all you need! https://www.cvmart.net/community/detail/4461

16.搞懂 Vision Transformer 原理和代码 https://mp.weixin.qq.com/s/ozUHHGMqIC0-FRWoNGhVYQ

17.3W字长文带你轻松入门视觉transformer https://zhuanlan.zhihu.com/p/308301901

18.Vision Transformer, LLM, Diffusion Model 超详细解读 (原理分析+代码解读) https://zhuanlan.zhihu.com/p/348593638

19.einops.repeat, rearrange, reduce优雅地处理张量维度 https://blog.csdn.net/qq_37297763/article/details/120348764

感谢你的阅读

接下来我们继续学习输出AI相关内容,欢迎关注公众号“音视频开发之旅”,一起学习成长。

欢迎交流

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2094077.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在Diffusers中使用LoRA微调模型

在浏览稳定扩散模型共享网站(例如 CivitAI)时,你可能遇到过一些标记为“LoRA”的自定义模型。“LoRA”到底是什么—它与典型的模型检查点有何不同?LoRA 可以与Diffusers包一起使用吗?在本文中,我们将回答这…

计算机视觉基础 2. 滤波器

1. 简介 模糊滤波器是低通滤波器。它们从图像中去除高空间频率内容,只留下低频空间分量。结果是图像失去了细节,看起来很模糊。图像模糊在计算机图形学和计算机视觉中有许多应用。它可用于降低噪声(如图17.1所示),揭示…

代码时光机:Git基础速成

hello,家人们,今天咱们来介绍Git以及Git相关的操作,好啦,废话不多讲,开干. 1:Git初识 在介绍Git前,博主首先讲一个小故事. 我们学计算机的小伙伴们,在学校里头都有实验课,那么老师呢就会要求我们写实验报告并且要求我们交上去给老师检查.有一个学计算机的大学生,名字叫张三,然…

Datawhale X 李宏毅苹果书 AI夏令营-深度学习进阶task2:自适应学习率,分类

1.自适应学习率 临界点其实不一定是在训练一个网络的时候会遇到的最大的障碍。很多时候训练网络,损失不再下降,不是因为到了临界点,而是可能在山谷之间不停震荡。 以下为不同学习率对训练的影响,下图中左右平缓,上下陡…

C语言 | Leetcode C语言题解之第387题字符串中的第一个唯一字符

题目&#xff1a; 题解&#xff1a; struct hashTable {int key;int val;UT_hash_handle hh; };int firstUniqChar(char* s) {struct hashTable* position NULL;int que[26][2], left 0, right 0;int n strlen(s);for (int i 0; i < n; i) {int ikey s[i];struct has…

火语言RPA流程组件介绍--浏览选择文件夹

&#x1f6a9;【组件功能】&#xff1a;打开浏览文件夹选择对话框 配置预览 配置说明 对话框标题 支持T或# 打开浏览文件夹对话框时显示的标题。 默认打开文件夹 支持T或# 打开浏览文件夹对话框时&#xff0c;默认打开此文件夹。 取消后终止流程 “是”、“否”2种供选择…

一篇详细介绍常用第三方库的教程

作者&#xff1a;郭震 我们之前介绍过如何安装Python的各种常用第三方库.这些库为程序员提供了许多功能,能够大大简化我们的开发工作.本文将为你介绍一些最常用的第三方库,帮助你更好地理解它们的用途及基本概念. 1. NumPy NumPy是一个强大的科学计算库.它提供了多维数组对象以…

09.定时器02

#include "reg52.h"sbit led P3^6;void delay10ms() { //1. 配置定时器0工作模式位16位计时TMOD 0x01;//2. 给初值&#xff0c;定一个10ms出来TL00x00;TH00xDC;//3. 开始计时TR0 1;TF0 0; } void main() {int cnt 0;led 1;while(1){if(TF0 1)//当爆表的时候&a…

Git之2.9版本重要特性及用法实例(五十八)

简介&#xff1a; CSDN博客专家、《Android系统多媒体进阶实战》一书作者. 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a; 多媒体系统工程师系列…

非关系型数据库 Redis 的安装与配置

文章目录 一 . CentOS 7 安装 Redis【版本选择说明】一 . 安装 Redis二 . 配置 Redis2.1 针对可执行程序设置符号链接2.2 针对配置文件设置符号链接2.3 修改配置文件2.3.1 设置 IP 地址2.3.2 关闭保护模式2.3.3 启动守护进程2.3.4 设置工作目录2.3.5 设置日志目录 三 . 启动 Re…

Apache SeaTunnel Zeta 引擎源码解析(一)Server端的初始化

引入 本系列文章是基于 Apache SeaTunnel 2.3.6版本&#xff0c;围绕Zeta引擎给大家介绍其任务是如何从提交到运行的全流程&#xff0c;希望通过这篇文档&#xff0c;对刚刚上手SeaTunnel的朋友提供一些帮助。 我们整体的文章将会分成三篇&#xff0c;从以下方向给大家介绍&am…

掌握数据利器:AWS Glue与数据基盘概览

引言 随着数字化进程的不断推进&#xff0c;企业现在能够积累并分析海量且多样化的数据。这一优势使得许多企业开始采用数据驱动型经营&#xff08;即基于数据的经营策略&#xff09;。通过基于数据的客观判断&#xff0c;企业及其管理者可以获得诸多好处。 然而&#xff0c;…

DeepMind 机器人学习打乒乓球,朝着「专业运动员水平的速度和性能」发展

这几天全球各界最火热的话题非奥运会莫属&#xff0c;而其中乒乓球比赛更是引起了互联网的讨论热潮&#xff0c;无论是欢呼也好、争议也罢&#xff0c;在现实世界人类的乒乓球大赛风生水起的同时&#xff0c;AI已经偷偷在乒乓球上“出师”了—— ——DeepMind近日发布一项新工作…

机器学习 第7章 贝叶斯分类器

目录 7.1 贝叶斯决策论7.2 极大似然估计7.3 朴素贝叶斯分类器7.4 半朴素贝叶斯分类器7.5 贝叶斯网7.5.1 结构7.5.2 学习7.5.3 推断 7.6 EM算法 7.1 贝叶斯决策论 对分类任务来说&#xff0c;在所有相关概率都己知的理想情形下&#xff0c;贝叶斯决策论考虑如何基于这些概率和误…

如何删除浏览器每次登录自动保存的密码,以防自动登录泄露自己的隐私

今天小编以 Microsoft edge 浏览器为例&#xff0c;如何在自己离职或毕业以后留给他人的电脑是干净的&#xff0c;不会在任何网页登录时显示已保存的密码&#xff0c;让他人自动登录。 ①在电脑上打开 Microsoft edge 浏览器后&#xff0c;点击“设置” ②进入设置界面后&…

基于SSM的咖啡馆管理系统

基于SSM的咖啡馆管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringSpringMVCMyBatisJSP工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 前台界面 后台界面 摘要 在当前这个信息爆炸的时代&#xff0c;众多行业正经历着…

Python酷库之旅-第三方库Pandas(114)

目录 一、用法精讲 501、pandas.DataFrame.mode方法 501-1、语法 501-2、参数 501-3、功能 501-4、返回值 501-5、说明 501-6、用法 501-6-1、数据准备 501-6-2、代码示例 501-6-3、结果输出 502、pandas.DataFrame.pct_change方法 502-1、语法 502-2、参数 502…

[知识分享]华为铁三角工作法

在通信技术领域&#xff0c;尤其是无线通信和物联网领域&#xff0c;“华为铁三角”是华为公司内部的一种销售、交付和服务一体化的运作模式。这种模式强调的是以客户为中心&#xff0c;通过市场、销售、交付和服务三个关键环节的紧密协作&#xff0c;快速响应客户需求&#xf…

2.12 滑动条事件

目录 实验原理 实验代码 运行结果 实验原理 在 OpenCV 中&#xff0c;滑动条设计的主要目的是在视频播放帧中选择特定帧&#xff0c;而在调节图像参数时也会经常用到。在使用滑动条前&#xff0c;需要给滑动条赋予一个名字&#xff08;通常是一个字符串&#xff09;&#x…

Java | Leetcode Java题解之第388题文件的最长绝对路径

题目&#xff1a; 题解&#xff1a; class Solution {public int lengthLongestPath(String input) {int n input.length();int pos 0;int ans 0;int[] level new int[n 1];while (pos < n) {/* 检测当前文件的深度 */int depth 1;while (pos < n && inpu…