Vision Permutator(TPAMI 2022)论文与代码解析

news2024/9/21 17:32:24

paper:Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition

official implementation:https://github.com/houqb/VisionPermutator

出发点

现有的MLP模型在编码空间信息时通常会将空间维度展开并沿着展平的维度进行线性投影,这样会丢失由二维特征表示携带的位置信息。 为了解决这个问题,本文提出了Vision Permutator,一种新的纯MLP结构的网络,它分别沿着高度和宽度维度进行线性投影,从而保留精确的位置信息并捕获长距离依赖关系。 在不依赖于空间卷积或注意力机制的情况下,达到或超过了大多数CNN和视觉Transformer的性能。

创新点

  • 新颖的Permute-MLP层:与现有的MLP模型不同,Vision Permutator提出了一种新的层结构,即Permute-MLP层。该层包含三个独立的分支,分别负责沿高度、宽度和通道维度编码特征。
  • 位置敏感的输出:通过分别沿高度和宽度维度进行线性投影,Vision Permutator生成的位置敏感输出以互补的方式聚合,从而形成对目标对象的有效表示。
  • 高效的性能:在不使用额外大规模训练数据的情况下,Vision Permutator在ImageNet上达到了81.5%的top-1准确率,并且参数量仅为25M。模型扩展到88M参数时,准确率进一步提升到83.2%。

方法介绍

Permutator block的结构如图1左所示,可以看到和Transformer block相似,只是将其中的self-attention换成了Permute-MLP层,Channel-MLP和Transformer block中的FFN类似,都是由两个全连接层和一个GELU激活函数组成。对于空间信息的编码,和最近的Mixer(具体介绍见MLP-Mixer(NeurIPS 2021, Google)论文与源码解读-CSDN博客)不同,它沿着空间维度对所有的token进行线性投影,而本文提出分别沿着宽度和高度维度来处理token。

具体来说,给定一个C维的输入token \(\mathbf{X}\in \mathbb{R}^{H\times W\times C}\),Permutator可以表示如下

其中LN指Layer Norm,输出 \(\mathbf{Z}\) 作为下一个Permutator block的输入。

Permute-MLP

Permute-MLP的过程如图2所示,与vision transformer和Mixer接收一个二维("tokens x channels",即 \(HW\times C\))的输入不同,Permute-MLP接收一个三维的输入。

如图2所示,Permute-MLP包括三个分支,分别负责沿高度、宽度、通道维度编码信息。通道信息编码很简单,只需要一个权重为 \(\mathbf{W}_C\in \mathbb{R}^{C\times C}\) 的全连接层就可以对输入 \(\mathbf{X}\) 进行线性投影得到输出 \(\mathbf{X}_C\)。接下来我们详细介绍下如何通过维度之间的permutation操作来编码空间信息。

假设隐藏维度C为384,输入图像的分辨率为224x224。为了沿高度维度对空间信息进行编码,我们首先进行一个height-channel维度的permutation操作。给定输入 \(\mathbf{X}\in \mathbb{R}^{H\times W\times C}\),我们首先沿通道维度将其均分成 \(S\) 份,得到 \([\mathbf{X}_{H_1},\mathbf{X}_{H_2},...,\mathbf{X}_{H_S}]\),且满足 \(C=N*S\)(本文\(N=H=W\))。如果patch大小设置为14x14,则 \(N=16\) 且 \(\mathbf{X}_{H_i}\in \mathbb{R}^{H\times W\times N}, \ (i\in \{1,...,S\})\)。然后我们对每个 \(\mathbf{X}_{H_i}\) 进行height-channel的permutation操作(就是转换第一个高度维度和第三个通道维度,\((H, W, C)\rightarrow(C,W,H)\),得到输出 \([\mathbf{X}_{H_1}^{\top}, \mathbf{X}_{H_2}^{\top}, \cdots, \mathbf{X}_{H_S}^{\top}]\),然后沿通道维度拼接。接着一个权重为 \(\mathbf{W}_H\in \mathbb{R}^{C\times C}\) 的全连接层用来混合高度信息。为了恢复到原始维度,只需要再执行一次height-channel permutation操作即可,得到最终输出 \(\mathbf{X}_{H}\)。类似的,在第二个分支,我们执行width-channel的permuation操作,然后得到输出 \(\mathbf{X}_{W}\)。然后将三个分支的输出进行element-wise summation,再通过一个全连接层得到Permute-MLP层的输出,如下

其中 \(FC(\cdot)\) 表示一个权重为 \(\mathbf{W}_P\in \mathbb{R}^{C\times C}\) 的全连接层。Permute-MLP的PyTorch代码如下所示

Weighted Permute-MLP

在式(3)中我们只是简单地将三个分支的输出进行相加,作者进一步提出了Weighted Permute-MLP来重新校正三个分支的重要性。具体采用ResNeSt中的split attention(具体介绍见ResNeSt-CSDN博客)来得到加权权重,区别在于ResNeSt中的split attention是在每个cardinal group内进行的,对每个radix group求一个权重。而这里是对 \(\mathbf{X}_{H},\mathbf{X}_{W},\mathbf{X}_{C}\) 进行的,求得三个权重。

实验结果

作者设计了5个不同大小的ViP,具体配置如下

作者比较了ViP和CNN、Transformer以及MLP类模型在ImageNet数据集上性能,首先是MLP类的性能如下表所示,可以看到ViP在MLP类的backbone中取得了最优的性能。

下表是和CNN、Transformer的代表网络的性能对比,可以看到和一些经典的卷积网络例如ResNet、RegNet相比在相似的模型大小下取得了更好的结果。和一些transformer模型例如DeiT、Swin Transformer相比效果也更好。但和一些最新的SOTA模型例如NFNet(86.5%)、CaiT(86.5%)相比,还有较大的差距。

代码解析

具体实现非常简单,官方实现的weighted permute mlp的代码如下,其中self.segment_dim就是文章中的N,即沿通道划分成S份后每份的维度。然后就是将H维度与channel维度调换,执行MLP;将W维度与channel维度调换,执行MLP;直接沿输入channel维度执行MLP。最后通过split attention得到三者的权重,最后加权求和得到最终输出结果。

class WeightedPermuteMLP(nn.Module):
    def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.segment_dim = segment_dim

        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
        self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)
        self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)

        self.reweight = Mlp(dim, dim // 4, dim * 3)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape

        S = C // self.segment_dim
        h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S)  # (B, seg_dim, W, H, S)->(B, seg_dim, W, H*S)
        h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)

        w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
        w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)

        c = self.mlp_c(x)

        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)  # (B,H,W,C)->(B,C,H,W)->(B,C,H*W)->(B,C)
        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
        # (B,3C)->(B,C,3)->(3,B,C)->(3,B,C)->(3,B,1,C)->(3,B,1,1,C)

        x = h * a[0] + w * a[1] + c * a[2]

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1943589.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《Java初阶数据结构》----3.<线性表---LinkedList与链表>

目录 前言 一、链表的简介 1.1链表的概念 1.2链表的八种结构 重点掌握两种 1.3单链表的常见方法 三、单链表的模拟实现 四、LinkedList的模拟实现(双链表) 4.1 什么是LinkedList 4.2LinkedList的使用 五、ArrayList和LinkedList的区别 前言 …

无法连接到internet怎么办?已连接但无internet访问,其实并不难

有时我们会遇到无法连接到Internet的问题,由多种原因引起,包括硬件故障、软件设置问题、网络供应商故障等。本文将介绍无法连接到Internet时可以采取的步骤。 简述 当你无法连接到Internet时,可以按照以下步骤进行检查和解决: 1…

数据结构C++——优先队列

文章目录 一、定义二、ADT三、优先队列的描述3.1 线性表3.2 堆3.2.1 最大堆的ADT3.2.2 最大堆的插入3.2.3 最大堆的删除3.2.4 最大堆的初始化3.3 左高树 LT3.3.1 高度优先左高树HBLT3.3.2 重量优先左高树WBLT3.3.3 最大HBLT的插入3.3.4 最大HBLT的删除3.3.5 合并两棵最大HBLT3.…

自用:磁传感器数据解算

协议格式: 详细计算磁场如下: 3字节数据的格式为有符号整型数,数据为补码格式,最高位为符号位。需要先将补码格式的数据转化为10进制的实际值,方法如下: 当数据小于时为正数,实际值为本身&…

Mac中maven配置安装路径

Mac中maven配置安装路径 没有下载maven的可以先下载:(这里建议maven版本不要下高了) 如果你的bash_profile中没有配置JAVA_HOME路径,可以按照下面的命令配置一下 获取JAVA的安装路径: /usr/libexec/java_home -V …

Nest.js 实战 (三):使用 Swagger 优雅地生成 API 文档

什么是 Swagger ? Swagger 是一组围绕 OpenAPI 规范构建的开源工具,可以帮助您设计、构建、记录和使用 REST API。主要的 Swagger 工具 包括: Swagger Editor:基于浏览器的编辑器,您可以在其中编写 OpenAPI 定义Swagger UI&…

NSSCTF[堆][tcache]

1. [CISCN 2021 初赛]lonelywolf 题目地址:[CISCN 2021 初赛]lonelywolf | NSSCTF 思路: 修开tcache结构,伪造一个0x91的chunk,伪造0x91chunk的数量(填满tcache),再将其释放free进入unsortedb…

Linux中,MySQL数据库基础

21 世纪,人类迈入了“信息爆炸时代”,大量的数据、信息在不断产生,伴随而来的就是如何安全、有效地存储、检索和管理它们。对数据的有效存储、高效访问、方便共享和安全控制已经成为信息时代亟待解决的问题。 数据库简介 使用数据库的必要性…

MATLAB--文件操作相关指令

文章目录 文件操作相关指令前言 M文件创建MATLAB文件操作指令MATLAB文件流控制 文件操作相关指令 前言 记录一下M文件创建、操作、获取信息等相关资料。   MATLAB的M文件是用来代替MATLAB命令行窗口输入指令的文件。因此所有的MATLAB指令都可以再MATLAB的M文件中调用. M文件…

算法力扣刷题记录 五十七【236. 二叉树的最近公共祖先】和【235. 二叉搜索树的最近公共祖先】

前言 公共祖先解决。二叉树和二叉搜索树条件下的最近公共祖先。 二叉树篇继续。 一、【236. 二叉树的最近公共祖先】题目阅读 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q&#xff…

Spring Bean介绍

目录 1.什么是bean 2.获取bean 3.bean的作用域 4.第三方bean 5.Bean的生命周期 6.Bean的种类 7.为什么使用Bean? 1.什么是bean Bean是Java世界中的一种组件,用于封装数据和逻辑,以便在应用程序中重用和维护。它不仅可以装在数据&#x…

Redis哨兵模式实践

本次环境为Centos7.6,redis-7.0.4 1:主备模式:即主节点的数据自动同步到从节点,但当主节点挂了,从节点需要手动设置为主节点,比较麻烦。 2:哨兵模式:当主节点挂了,自动投…

PCL-基于SAC_IA和NDT结合的点云配准算法

一、原理概述1.点云配准流程图2.快速点特征直方图FPFH3.采样一致性SAC_IA粗配准4.正态分布变换NDT精配准 二、实验代码三、实验结果四、总结五、参考 一、原理概述 1.点云配准流程图 2.快速点特征直方图FPFH 快速点特征直方图(Fast Point Feature Histogram&#…

Oracle SQL:了解执行计划和性能调优

查询优化类似于制作完美食谱的艺术——它需要对成分(数据)、厨房(数据库系统)和使用的技术(查询优化器)有深入的了解。每个数据库系统都有自己的处理和运行 SQL 查询的方式,“解释”计划向我们展…

Mysql注意事项(一)

Mysql注意事项(一) 最近回顾了一下MySQL,发现了一些MySQL需要注意的事项,同时也作为学习笔记,记录下来。–2020年05月13日 1、通配符* 检索所有的列。 不建议使用 通常,除非你确定需要表中的每个列&am…

每日刷题记录(codetop版)

7.21 7.22 7.23 复习7.21和7.22

每日OJ_牛客DD1 连续最大和

目录 牛客DD1 连续最大和 解析代码 牛客DD1 连续最大和 连续最大和_牛客题霸_牛客网 解析代码 本题是一个经典的动规问题,简称dp问题,但这个问题是非常简单的dp问题,而且经常会考察,所以一定要把这个题做会。本题题意很简单&am…

探寻安全新时代:叉车AI智能影像防撞系统,守护生命之光

在繁忙的工业现场,叉车司机常常面临着视线受阻的困境,那些被货物遮挡的盲区,仿佛隐藏着无法预知的危险。然而,这样的隐患在一次惨痛的事故中暴露无遗,一名无辜的行人因叉车司机的视线受阻而不幸被撞身亡。这起悲剧让我…

机械设计基础B(学习笔记)

绪论 机构:是一些具备各自特点的和具有确定的相对运动的基本组合的统称。 组成机构的各个相对运动部分称为构件。构件作为运动单元,它可以是单一的整体,也可以是由几个最基本的事物(通常称为零件)组成的刚性结构。 构件…

python·数据分析基础知识

numpy 一个数值计算包 python列表与numpy矩阵区别 python中修改列表元素和列表相加 for循环 :[x1 for x in a] 多个元素需要用zip捆绑:[xy for(x,y) in zip(a,b)] numpy矩阵自动进行相应元素计算 np.array()1各元素1 ab各元素相加 a*b矩阵相乘或者是…