ViT和SwinTransformer详解

news2025/1/10 16:55:50

ViT是Google brain发表于ICLR'21上的工作,开创性将transformer用在vision领域,且图像识别性能超CNN,至今引用3.8w+;原文:https://arxiv.org/pdf/2010.11929

SwinTransformer是微软亚洲研究院发表于ICCV'21上,获best paper,在多个视觉任务上获sota,打破CNN垄断vision backbone的现状,至今引用1.8w+;原文:https://openaccess.thecvf.com/content/ICCV2021/papers/Liu_Swin_Transformer_Hierarchical_Vision_Transformer_Using_Shifted_Windows_ICCV_2021_paper.pdf

建议读原文,这些文章优雅、简洁、深刻。

下面按照三部分进行,分别是Attention介绍、ViT详解、SwinTransformer详解。与常规文章讲解不同,我会多采用QA进行展开。

1. Attention介绍

这涉及到NeurIPS发表的“Attention is all you need”,这篇文章引用已经12w+,理解注意力机制是学习transformer的核心。

Q: general attention和self attention区别?

A: 相同点是均需要计算qkv,不同之处,self attention的input只有x,而general attention的input除了有x(映射得到kv),还有q(查询query)。

Self attention layer介绍

步骤:1. 输入x,通过映射矩阵Wq,Wk,Wv,得到qkv(D维)

           2. q和k进行对齐操作,如:q0会分别与不同的k进行点乘操作,得e0(e矩阵第一列)

           3. 注意力机制:softmax操作。如,从e0得a0(a矩阵第一列),为0~1之间的注意力权重

           4. 输出:v和注意力权重a的加权和,如:y0为a0和所有v的注意力加权和

=》不同于CNN的局部特性,此处的自注意力很好地体现了全局特性。

仔细观察,可以发现self-attention layer具备permutation invariant的性质(置换不变)

现实中,不管是语言token还是vision patch token,位置不同,显然我们应该得到不同的内容向量y才是合理的。

因此,有必要加入位置编码,将位置信息考虑进来进行自注意力学习。

对于每个输入xj,给出位置编码pj。使用位置编码函数pos,pj=pos(j),将位置j映射到D维向量(因为x是D维)。对于pos函数的选取此处不详细展开。

Multi-head self attention layer介绍

多头自注意力层就是transformer里核心模块。

Q:为什么要multi-head?

A:本质是为提取更好的特征,类似于CNN中卷积核也是多组,以得到多个特征谱。不同的是,CNN中卷积核小,计算量小,特征谱数量都是几十、几百。这儿的Multi-head不会很多,一般不超过10。

2. ViT详解

这篇文章的writing也可以当作范本,反复学习。

Q:标题两个keys,一个是an image is worth 16x16 words, 另一个是at scale,分别突出了什么?

A:前者突出将图像按照文字的处理方式,把一张图表示成了16x16 tokens。另一个关键点at scale,则与transformer的优势关联起来,也暗含了transformer要获的良好性能的前提。

Q:Transformer的天然优势是什么?

A:主要是excellent scalability,当模型和训练集增加时,并没有saturating performance。可以处理超大规模的训练数据。另一个是self-attention带来的computational efficiency,很多计算可以高度并行。

Q:CNN的天然优势是什么?

A:主要是inductive bias,在卷积的过程中,我们使用了translation equaivariance(平移不变性)、locality(局部性)来保留2D相邻结构。这些使得CNN在少量训练数据时候也能获得很好的性能。

Q:什么时候Transformer会比CNN更好?

A:通常,小训练数据集时候,convolutional inductive bias会很有用。当,数据集规模足够大的时候,最终large scaling training会比inductive bias表现好。这是合理的,因为Transformer学习中没有inductive bias,其特征时只能从大规模数据中学习。

Conclusion中提及的几个有前瞻性的点,现在均已经实现:)

1)self-supervised vs. supervised learning,之间的gap已经去掉;

2)scaling law,随着scaling提升,模型性能提升,现在已经是大模型发展遵循的发展规律;

3)transformer在segmentation、detection上的发展,现在已经横扫这些视觉任务。

# Transformer Encoder (depth x)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head,dropout=dropout),
                FeedForward(dim,mlp_dim,dropout=dropout)]))
    
    def forward(self, x):
        for att, ff in self.layers:
            x = attn(x)+x
            x = ff(x)+x
        return self.norm(x)
# Multi-Head Attention
# 与self-attention layer中的operation保持一致

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropot=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not(heads==1 and dim_head==dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)    
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)
        
        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q,k,v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # batch(b), sequence length(n), heads(h), dim(d) 
        dots = torch.matmul(q, k.transpose(-1,-2))*self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
    
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# FeedForward (Transformer Encoder第二个部分)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout))
        
        def forward(self, x):
            return self.net(x)
# ViT framework
# 包括输入图像的处理方式以及具体的任务

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        
        num_pathes = (image_height//patch_height)*(image_width//patch_width) # sequence length
        patch_dim = channels*patch_height*patch_width
        assert pool in {'cls','mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' #输入序列的第一个位置会添加一个特殊的标记,称为 [CLS] 标记

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
        self.cls_token = nn.Parameter(torch.randn(1,1,dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes) # classification task

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b,n,_ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:,:(n+1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool=='mean' else x[:,0]

        x = self.to_latent(x)
        return self.mlp_head(x)  

        

此处,pos_embedding是随机给的,transformer的输出后pool只能选cls或者mean中之一,然后进行MLP对任务的预测。

这里没有涉及到transformer decoder设计。

3. SwinTransformer详解

Q:ViT不好吗,SwinTransformer主要解决哪些关键问题?

A:如果图像分辨率变大,按照patch的size进行切分,这时候图像块的数量会增加,相应的计算复杂度quadratic增加,除此外切分的patch也相对较大(下采样倍数高),特征提取信息不准。不能很好处理高分辨率图。除此外,ViT这种固定的图像块切分方法对于不同大小的视觉实体而言不是很合理,当物体远小于或者大于patch时候,很难有效提取特征。不同物体的尺寸和比例差异很大,不像单词的长度相对固定。不能很好处理大小变化的视觉物体

个人感觉SwinTransformer中窗口概念,类似与CNN中卷积核,窗口shift类似于CNN中stride,不同这个shift(向右、向下)更加灵活。不同在于,CNN针对局部图像感受野直接去求W,而SwinTransformer则是利用Self-attention更高级的方式去求局部图像的特征。

很多技巧都是用于减少(分窗口、窗口移动)运算量。

Q:SwinTransformer主要贡献?

A:第一,层级的特征谱方式使计算复杂度对于图像尺寸而言是linear而不是quadratic,可以处理高分辨率的图像。第二、shifted window很好解决了视觉物体大小变化的特点。

Q:SwinTransformer主要设计思想?

A:全局的注意力机制只在小范围内做,然后在不同层级上提特征(W-MSA,提出窗口的概念,窗口内进行多头注意力机制)。此外,利用shifted window将各个窗口之间的信息进行通信,完美达到捕获全局的上下文信息的优势(SW-MSA,此处就是滑动窗口的多头注意力机制)。这两部分就是Swin Transformer blocks的主要组成部分。

对比:

1)SwinTransformer有很多窗口(红色框),且在不同的层级上,窗口的划分是不同的。ViT将整图作为一个窗口,一直进行全局注意力机制计算。

2)SwinTransformer先进行4x下采样,将4*4个pixels作为一个小patch,在划定的窗口内进行注意力计算,然后是8x下采样,最后是16x下采样。ViT直接下采样16倍,后面保持相同的下采样规律。

SwinTransformer Blocks

1. W-MSA介绍(窗口间不涉及信息传递)

这个提出的目的是在窗口内进行kqv的求解,既能减少计算复杂度,也能使用更小的patch size,使下采样倍数不用很大。

Q:具体减少了多少运算量?

A:运算量主要分为三部分:1)to kqv,2)qk对齐,3)与v加权和。Att(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d}})V

1) X^{hw\times C}通过矩阵运算W^{C\times C}生成Q^{hw\times C },K^{hw\times C},V^{hw\times C}.总运算量为3hwC^{2}

2) qk对齐,运算量为(hw)^{2}C,得A^{hw\times hw}

3) 与v加权,得B^{hw\times C},运算量为(hw)^{2}C

4) 多头注意力机制,多了一个融合矩阵W,B^{hw\times C}\cdot W^{C\times C}=O^{hw\times C},计算量hwC^{2}

总计,4hwC^{2}+2(hw)^{2}C 公式一

假设W-MSA的窗口长和宽为M,代入上面公式为,

4M^{2}C^{2}+2M^{4}C

\frac{h}{M}\times \frac{w}{M}窗口,所以为,4hwC^{2}+2M^{2}hwC 公式二

缺点:减少了运算量,但窗口之间由于没有任何通信,导致确实全局感受野。

2.SW-MSA介绍(窗口间进行信息传递)

两层之间发生了窗口的移动(Shift),偏移的量是:往右、往下偏移M/2个像素。移动后,划分出的第二列3个窗口能够完成相邻窗口的信息交流。

缺点:原来4个窗口,移动后变成9个窗口,且大小不一。总之,移动后窗口的数量增多,从\frac{h}{M}\times \frac{w}{M}变成(\frac{h}{M}+1)\times (\frac{w}{M}+1),有些窗口会变小。

解决办法:

naive方案,把所有变小的窗口pad后,计算attention时候把pad数值掩膜。但这样,存在很多没必要的运算。

Efficient batch computation approach by cyclic-shifting toward the top-left direction

主要思想:将移动模式作为flag,只对有相邻关系的子窗口计算,不相邻的,减去100,使得softmax计算后概率接近0。

示意图传递了跟之前的窗口类似的计算,对于不相关的信息加上了mask,softmax后得到的概率接近0,使其达到mask的作用。最后再通过reverse cyclic shift移动回去。

SwinTransformer Framework

整体架构的思想:4个阶段,每个阶段构建不同大小的特征图,不断缩小分辨率,类似CNN逐渐增大感受野。

  • Patch partition, 本质就是矩阵的reshape,以4*4为一个图像块,对输入图片进行分块,然后在channel方向上进行拼接
  • Linear embedding, 经过线性变换,通道数从48变成C
  • Patch merging, 本质就是降采样,只是比pooling的方式来得更复杂一些,有学习的参数

关于relative position bias这里不展开,因为其在图像分类上提高了,但是在目标检测任务上降低了,具体理解可以参考[4]。

Application:下一篇会介绍SwinIR,揭开该方法如何在底层视觉的图像修复上施展魔法。

参考:

[1] cs231n课件

[2] vit-pytorch/vit_pytorch/vit.py at main · lucidrains/vit-pytorch · GitHub

[3] Swin Transformer:屠榜各大CV任务的视觉Transformer模型 (high-level介绍)

[4] Swin Transformer 详解(detail-level理解,很不错)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1986670.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用python CodeGeeX 辅助数据处理xml

1 背景:手头上有N 张算是开发完成的报表,但是由于每个报表是不同的人开发的,每个人不同的编码风格,准备看看报表是否都定义了Title,是否都定义了报表的描述,是否有不带where条件的前台查询,是否同一个参数定义一致.现在AI 代码助手功能据说很强大了,试试描述需求让机器来辅助编…

【中项】系统集成项目管理工程师-第10章 项目整合管理-10.3指导与管理项目工作

前言:系统集成项目管理工程师专业,现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试,全称为“全国计算机与软件专业技术资格(水平)考试”&…

常用在线 Webshell 查杀工具推荐

一、简介 这篇文章将介绍几款常用的在线 Webshell 查杀工具,包括长亭牧云、微步在线云沙箱、河马和VirusTotal。每个工具都有其独特的特点和优势,用于帮助用户有效检测和清除各类恶意 Webshell,保障网站和服务器的安全。文章将深入探讨它们的…

实现自定义QDateEdit可删除日期值

在Qt框架中,QDateEdit是一个用于编辑日期的控件,如果想要删除QDateEdit不是特别好做,如果直接获取QDateEdit中的QLineEdti并设置显示删除按钮(代码如下所示),删除按钮会一直显示,效果并不好&…

SIP 消息的路由和 7 个相关的 Header IMS-HSS 中的透明数据及非透明数据(VoNR、VoLTE均用)

目录 1. SIP 消息的路由和 7 个相关的 Header 1.1 SIP 消息路由相关的7个Header 1.2 理解 Record-Route 和 Route 1.3 Record-Route 和 Route 流程举例 1.4 SIP 请求消息的路由原则 1.5 SIP 请求消息路由举例 1.6 SIP 请求消息路由原则和流程举例 2. IMS-HSS 中的透明数…

【C++指南】命名空间

💓 博客主页:倔强的石头的CSDN主页 📝Gitee主页:倔强的石头的gitee主页 ⏩ 文章专栏:《C指南》 期待您的关注 目录 一、命名空间的重要性 1. C语言中没有命名空间而存在的问题 2. C引入了命名空间解决的问题 3.…

【论文速读】《LLM4CP: Adapting Large Language Models for Channel Prediction》

论文地址: https://ieeexplore.ieee.org/document/10582829 前言:之前就想,大语言模型是否可以通过微调用于通信系统的无线空口应用,这篇文章给出了答案。通过讲信道状态信息进行嵌入和注意力操作,变成大语言模型可以…

Map和Set及哈希--的奥秘(详解)

目录: 一 搜索树: 二. 搜索相关概念 三.Map 的说明 四. Set 的说明 五.哈希表: 一 搜索树: 1.概念: 二叉搜索树又称二叉排序树,它或者是一棵空树,或者是具有以下性质的二叉树: (1). 若它的左子树不为空&a…

常见中间件漏洞复现之【Jboss】!

Jboss介绍 JBoss是⼀个基于J2EE的开发源代码的应⽤服务器。JBoss代码遵循LGPL许可,可以在任何商业应⽤中免费使⽤。JBoss是⼀个管理EJB的容器和服务器,⽀持EJB1.1、EJB 2.0和EJB3的规范。但JBoss核⼼服务不包括⽀持servlet/JSP的WEB容器,⼀般…

61 函数参数——可变长度参数

可变长度参数在定义函数时主要有两种形式:*parameter 和 **parameter,前者主要用来接收任意多个实参并将其放在一个元组中,后者接收类似于关键参数一样显示赋值形式的多个实参并将其放入字典中。 # 无论调用该函数时传递了多少实参&#xff…

鸿蒙Harmony开发:onFrame逐帧回调规范

通过返回应用onFrame逐帧回调的方式,让开发者在应用侧的每一帧都可以设置属性值,从而实现设置了该属性值对应组件的动画效果。 使用animator实现动画效果 使用如下步骤可以创建一个简单的animator,并且在每个帧回调中打印当前插值。 引入相…

萌新的Java入门日记18

一、mybatis范围筛选 1.第一种表示方法 <!--resultType 查出来的结果自贡每一行都要映射到该类型的对象--><select id"getStaff" resultType"com.easy.bean.Staff">select * from staff<!--根据参数不同组合出不同的SQL语句 动态SQL语句…

java之IO篇——工具包Commons-io和Hutool

前言 结束了IO篇的File、基本流和高级流。还要认识IO流的一些工具包Commons-io和hutool&#xff0c;不算是框架&#xff0c;但是非常实用。 目录 前言 一、Commons-io 1.来历及作用 2.使用 二、Hutool 1.简介 2.使用 一、Commons-io 1.来历及作用 Commons-io是apache…

C++第七篇 模板初阶和STL简介

目录 一&#xff0c;模板初阶 1.泛型编程 2.函数模板 2.1 函数模板概念 2.2 函数模板格式 2.3 函数模板的原理 2.4 函数模板的实例化 2.5 模板参数的匹配原则 3.类模板(模板类&#xff0c;模板函数) 3.1 类模板定义格式 二&#xff0c;STL简介 1. 什么是STL 2. ST…

【JUC】并发编程与源码分析 1-7章

1 线程基础知识复习 1把锁&#xff1a;synchronized&#xff08;后面细讲&#xff09; 2个并&#xff1a; 并发&#xff08;concurrent&#xff09;&#xff1a;是在同一实体上的多个事件&#xff0c;是在一台机器上“同时”处理多个任务&#xff0c;同一时刻&#xff0c;其…

【学习笔记】A2X通信的协议(三)- A2X PC5通信(一)

目录 6. A2X通信 6.1 A2X PC5通信 6.1.1 一般说明 6.1.2 通过NR-PC5的单播模式A2X通信 6.1.2.1 概述 6.1.2.2 A2X PC5单播链路建立程序 6.1.2.2.1 一般说明 6.1.2.2.2 发起UE启动A2X PC5单播链路建立程序 6.1.2.2.3 目标UE接受的A2X PC5单播链路建立程序 6.1.2.2.5 目…

学单片机怎么在3-5个月内找到工作?

每个初学者&#xff0c;都如履薄冰&#xff0c;10几年前&#xff0c;我自学单片机时&#xff0c;也一样。 想通过学习&#xff0c;找一份体面点的工作&#xff0c;又害怕辛辛苦苦学出来&#xff0c;找不到工作。 好在&#xff0c;当初执行力&#xff0c;还算可以&#xff0c;自…

WebLogic

二、WebLogic 2.1 后台弱口令GetShell 漏洞描述 通过弱口令进入后台界面&#xff0c;上传部署war包&#xff0c;getshell 影响范围 全版本(前提后台存在弱口令) 漏洞复现 默认账号密码:weblogic/Oracle123weblogic常用弱口令: Default Passwords | CIRT.net这里注意&am…

设计模式--结构型

类适配器 #include <queue> #include <iostream> #include <algorithm> #include <iterator>using namespace std;// 目标接口 class Target {public:virtual ~Target() {}virtual void method() 0; };// 适配者类 class Adaptee {public:void spec_…

CHIESI凯西医药:外企入职测评综合能力及性格测试SHL题库测评真题解析

CHIESI凯西医药是一家意大利国际制药集团&#xff0c;以研发为核心&#xff0c;专注于呼吸道健康、罕见疾病和专科治疗的创新治疗方案。集团总部位于意大利帕尔马市&#xff0c;拥有超过85年的历史&#xff0c;业务遍及全球31个国家和地区&#xff0c;拥有7,000多名员工。2023年…