CvT(ICCV 2021)论文与代码解读

news2024/11/13 9:36:59

paper:CvT: Introducing Convolutions to Vision Transformers

official implementation:https://github.com/microsoft/CvT

出发点

该论文的出发点是改进Vision Transformer (ViT) 的性能和效率。传统的ViT在处理图像分类任务时虽然表现出色,但在数据量较小的情况下,其表现不如同等规模的卷积神经网络(CNN)。研究人员认为这是因为ViT缺乏CNN固有的一些有利特性,如对局部空间信息的捕捉能力。本文提出通过在ViT结构中引入卷积操作来弥补这一不足,以获得更好的性能和鲁棒性。

创新点

本文解决了如何在保持ViT优点(如动态注意力机制、全局上下文建模和更好的泛化能力)的同时,引入卷积神经网络的优点(如局部感受野、权重共享和空间下采样)。具体来说,论文通过引入卷积的方式来增强ViT的局部信息捕捉能力和计算效率,从而在各种图像分类任务中取得更好的表现。具体如下

  1. 卷积token embedding层:在ViT的结构中引入卷积embedding层,通过卷积操作将图像转换为token,同时保留局部空间信息。这种方法使模型能够在多个阶段逐步减少令token序列长度,同时增加token特征维度,类似于CNN的设计。
  2. 卷积projection:标准Transformer模块中的线性投影替换为卷积投影。通过深度可分离卷积操作,进一步捕捉局部空间上下文,并减少注意力机制中的语义模糊性。此外,卷积投影的步幅可用于对key和value矩阵进行下采样,从而显著提高计算效率。
  3. 无需位置编码:实验表明,CvT模型可以在不使用位置编码的情况下取得良好的性能,这简化了模型设计,尤其适用于处理高分辨率图像任务。

方法介绍

CvT的整体pipeline如图2所示。作者将两种基于卷积的operation引入Vision Transformer中,即Convolutional Token Embedding和Convolutional Projection。如图2(a)所示,借鉴了CNN采用了一个多个stage的层级设计,本文一共包含三个stage。每个stage包括两部分,首先输入图片(或reshape后的二维token map)经过Convolutional Token Embedding层的处理,具体是通过一个重叠的卷积实现。这使得每个stage可以逐渐减少token的数量(即特征分辨率)并增加token的宽度(即特征的维度),从而实现空间降采样并增加特征表示的丰富性。和之前的各种视觉Transformer不同,本文在这里并没有加上一个位置编码。接下来是堆叠的多个本文提出的Convolutional Transformer Block如图2(b)所示, 其中一个深度可分离卷积作为卷积投影分别作用于query、key和value。class token只在最后一个stage添加,最后通过一个MLP head得到最终的输出预测类别。 

Convolutional Token Embedding

给定一张图片或前一个stage输出并reshape成二维的token map \(x_{i-1}\in \mathbb{R}^{H_{i-1}\times W_{i-1}\times C_{i-1}}\) 作为当前stage \(i\) 的输入,我们学习一个卷积 \(f(\cdot)\) 将 \(x_{i-1}\) 映射到新的token \(f(x_{i-1})\),卷积核大小为 \(s\times s\),步长为 \(s-o\),padding为 \(p\)。新的token map \(f(x_{i-1})\in \mathbb{R}^{H_i\times W_i\times C_i}\) 的高和宽分别为

\(f(x_{i-1})\) 然后展平成 \(H_iW_i\times C_i\) 的shape并经过layer normalization处理,然后作为输入到stage \(i\) 的后续transformer block中。

Convolution Token Embedding层使得我们可以通过调整卷积的参数来调整每个stage的token特征维度和数量。通过这种方式,每个stage我们逐渐减少token序列的长度同时增加token特征的维度,使得token能够在越来越大的空间中表示越来越复杂的视觉模式,类似于CNN的特征层。

Convolutional Projection for Attention

本文提出的卷积映射层的目的是实现对局部context的额外建模,并通过对 \(K\) 和 \(V\) 矩阵降采样来提高效率。

图3(a)展示了ViT中使用的position-wise线性投影,图3(b)展示了本文提出的 \(s\times s\) 卷积投影。如图3(b)所示,tokens首先reshape成一个2D token map,然后通过一个深度可分离卷积实现卷积投影。最后再将projected tokens展平成1D作为后续的输入,如下

其中 \(x_i^{q/k/v}\) 是 \(i\) 层 \(Q/K/V\) 矩阵的token输入,\(conv2d\) 是一个深度可分离卷积具体实现为:\(Depthwise\ Con2d\rightarrow BatchNorm2d\rightarrow Pointwise\ Conv2d\),\(s\) 表示卷积核大小。原始的position-wise线性投影可以通过1x1卷积实现,因此这里新的卷积投影可以看作是一种推广。

实验结果

作者设计三种不同size的模型如表2所示,其中CvT-X中的X表示模型总共的transformer block的数量。CvT-224中的W表示Wide。

表3是在ImageNet数据集上和其它SOTA模型的对比。

代码解析

这里的代码是官方实现,convolutional token embedding的代码如下,在每个stage的开始都会首先经过ConvEmbed,以cvt-13为例,一共3个stage,patch_size=[7, 3, 3],patch_stride=[4, 2, 2],patch_padding=[2, 1, 1]。

class ConvEmbed(nn.Module):
    """ Image to Conv Embedding

    """

    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x

Attention的代码如下,在forward函数中会首先调用forward_conv得到q、k、v,这里的forward_conv就是本文提出的conv projection,在函数_build_projection中method='dw_bn',因此三个投影都是通过深度可分离卷积实现的。在self.forward_conv后就是普通的计算attention的过程了。

class Attention(nn.Module):
    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 with_cls_token=True,
                 **kwargs
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        # head_dim = self.qkv_dim // num_heads
        self.scale = dim_out ** -0.5
        self.with_cls_token = with_cls_token

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q,
            stride_q, 'linear' if method == 'avg' else method
        )
        self.conv_proj_k = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )
        self.conv_proj_v = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self,
                          dim_in,
                          dim_out,
                          kernel_size,
                          padding,
                          stride,
                          method):
        if method == 'dw_bn':
            proj = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(
                    dim_in,
                    dim_in,
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    bias=False,
                    groups=dim_in
                )),
                ('bn', nn.BatchNorm2d(dim_in)),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'avg':
            proj = nn.Sequential(OrderedDict([
                ('avg', nn.AvgPool2d(
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    ceil_mode=True
                )),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'linear':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, h, w):
        if self.with_cls_token:
            cls_token, x = torch.split(x, [1, h*w], 1)

        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(x)
        else:
            q = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(x)
        else:
            k = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(x)
        else:
            v = rearrange(x, 'b c h w -> b (h w) c')

        if self.with_cls_token:
            q = torch.cat((cls_token, q), dim=1)
            k = torch.cat((cls_token, k), dim=1)
            v = torch.cat((cls_token, v), dim=1)

        return q, k, v

    def forward(self, x, h, w):
        if (
            self.conv_proj_q is not None
            or self.conv_proj_k is not None
            or self.conv_proj_v is not None
        ):
            q, k, v = self.forward_conv(x, h, w)

        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1808837.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

xilinx的Aurora8B10B的IP仿真及上板测试(高速收发器十七)

前文讲解了Aurora8B10B协议原理及xilinx相关IP,本文讲解如何设置该IP,并且通过示例工程完成该IP的仿真和上板。 1、生成Aurora8B10B IP 如下图所示,首先在vivado的IP catalog中输入Aurora 8B10B,双击该IP。 图1 查找Aurora 8B10…

【Python】成功解决SyntaxError: invalid syntax

【Python】成功解决SyntaxError: invalid syntax 下滑即可查看博客内容 🌈 欢迎莅临我的个人主页 👈这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地!🎇 🎓 博主简介:985高校的普通本硕&am…

cuda学习笔记(3)

一 CPU和GPU的区别 衡量处理器优劣的重要的两个指标: 延时性:同量的数据,所需要的处理时间 吞吐性:处理速度不快,但是每次处理量很大 GPU设计理念是最大化吞吐量,使用很小的控制单元对应很小的内存 cpu的设…

类和对象(下+)_const成员、初始化列表、友元、匿名对象

类和对象(下) 文章目录 类和对象(下)前言一、const成员二、友元1.友元函数2.友元类 三、初始化列表四、explicit关键字五、匿名对象总结 前言 static成员、内部类、const成员、初始化列表、友元、匿名对象 一、const成员 将cons…

cleanmymac清理时要一直输入密码 CleanMyMac X一直提示输入密码的解决方案

CleanMyMac X是一款专业的Mac清理软件,可智能清理mac磁盘垃圾和多余语言安装包,快速释放电脑内存,轻松管理和升级Mac上的应用。同时CleanMyMac X可以强力卸载恶意软件,修复系统漏洞,一键扫描和优化Mac系统。 在使用Cle…

LeetCode | 2022.将一维数组转变为二维数组

这道题思路比较简单,比较容易想到的是先判断m和n构成的二维数组在形式上是否可以由原来的数组转变而成,若不可以返回空数组,若可以直接用一个二重循环遍历一遍即可,时间复杂度 O ( n 2 ) O(n^2) O(n2) class Solution(object):de…

数据结构初阶 · 链式二叉树的部分问题

目录 前言: 1 链式二叉树的创建 2 前序 中序 后序遍历 3 树的节点个数 4 树的高度 5 树的叶子节点个数 6 树的第K层节点个数 前言: 链式二叉树我们在C语言阶段已经实现了,这里介绍的是涉及到的部分问题,比如求树的高度&am…

三、安全工程练习题(CISSP)

1.三、安全工程练习题(CISSP)

找素数第二、三种方法

文章目录 第一种 :使用标签第二种:本质是方法的分装 第一种 :使用标签 没有使用信号量。break和continue作用范围只是最近的循环,无法控制外部循环。 此时使用标签 对外部循环进行操作。 package com.zhang; /* 找素数 第二种方…

【已解决】FileNotFoundError: [Errno 3] No such file or directory: ‘xxx‘

😎 作者介绍:我是程序员行者孙,一个热爱分享技术的制能工人。计算机本硕,人工制能研究生。公众号:AI Sun,视频号:AI-行者Sun 🎈 本文专栏:本文收录于《AI实战中的各种bug…

【C语言】03.分支结构

本文用以介绍分支结构,主要的实现方式为if语句和switch语句。 一、if语句 1.1 if语句 if (表达式)语句表达式为真则执行语句,为假就不执行。在C语言中,0表示假,非0表示真.下图表示if的执行过程: 1.2 else语句 当…

数字孪生概念、数字孪生技术架构、数字孪生应用场景,深度长文学习

一、数字孪生起源与发展 1.1 数字孪生产生背景 数字孪生的概念最初由Grieves教授于2003年在美国密歇根大学的产品全生命周期管理课程上提出,并被定义为三维模型,包括实体产品、虚拟产品以及二者间的连接,如下图所示: 2011年&…

32位和64位的Windows7均不支持UEFI启动方式?试试看!

前言 今天小白突然想起:自己已经接近8年没有安装过32位的Windows系统了,这8年装的上百台电脑都是用的64位Windows。 今天 闲来无事 嗯……应该算是有小伙伴提出了个问题: 这位小伙伴表示:自己无论安装32位还是64位的Windows7都…

OSPF LSA头部详解

LSA概述 LSA是OSPF的本质 , 对于网工来说能否完成OSPF的排错就是基于OSPF的LSDB掌握程度 . 其中1/2类LAS是负责区域内部的 类似于设备的直连路由 . 加上对端的设备信息 3 类LSA是区域间的 指的是Area0和其他Area的区域间关系 , 设计多区域的初衷就是避免大型OSPF环境LSA太多…

14-特殊函数——静态函数、递归函数、函数指针、回调函数、内联函数、变参函数

14-特殊函数——静态函数、递归函数、函数指针、回调函数、内联函数、变参函数 文章目录 14-特殊函数——静态函数、递归函数、函数指针、回调函数、内联函数、变参函数一、静态函数1.1 语法 二、递归函数2.1 示例:输出n个自然数2.2 内存变化 三、函数指针四、指针函…

C++必修:探索C++的内存管理

✨✨ 欢迎大家来到贝蒂大讲堂✨✨ 🎈🎈养成好习惯,先赞后看哦~🎈🎈 所属专栏:C学习 贝蒂的主页:Betty’s blog 1. C/C的内存分布 我们首先来看一段代码及其相关问题 int globalVar 1; static…

软件测试--Mysql快速入门

文章目录 软件测试-mysql快速入门sql主要划分mysql常用的数据类型sql基本操作常用字段的约束:连接查询mysql内置函数存储过程视图事务索引 软件测试-mysql快速入门 sql主要划分 sql语言主要分为: DQL:数据查询语言,用于对数据进…

作业-day-240607

思维导图 C编程 要求: 搭建一个货币的场景,创建一个名为 RMB 的类,该类具有整型私有成员变量 yuan(元)、jiao(角)和 fen(分),并且具有以下功能:…

---java 抽象类 和 接口---

抽象类 再面向对对象的语言中,所以的对象都是通过类来描述的,但如果这个类无法准确的描述对象的 话,那么就可以把这个类设置为抽象类。 实例 这里用到abstract修饰,表示这个类或方法是抽象方法 因为会重写motifs里的show方法…

某药监局后缀(第一部分)

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 本文章未经许可禁止转载&#xff…