DeepViT 论文与代码解析

news2024/11/14 12:20:47

paper:DeepViT: Towards Deeper Vision Transformer

official implementation:https://github.com/zhoudaquan/dvit_repo

出发点

尽管浅层ViTs在视觉任务中表现优异,但随着网络深度增加,性能提升变得困难。研究发现,这种性能饱和的主要原因是注意力崩溃问题,即在深层变压器中,attention map逐渐变得相似,导致feature map在顶层趋于一致,从而限制了模型的表示学习能力。本文旨在研究如何有效地加深ViT模型,并提出了一种新的自注意力机制Re-attention来解决这个问题。

创新点

  • 注意力崩溃问题的提出与分析:首次提出并深入分析了注意力崩溃问题,发现这是导致深层ViT模型性能饱和的主要原因。
  • Re-attention机制:提出了一种简单但有效的Re-attention机制,通过在不同注意力头之间交换信息,以增加不同层的注意力图的多样性。该方法在计算和内存开销上几乎可以忽略不计。
  • 性能提升:通过替换现有ViT模型中的多头自注意力(MHSA)模块,成功训练了具有32个Transformer block的深层ViT模型,在ImageNet上的Top-1分类准确率提高了1.6%。

方法介绍

由于deep CNNs的成功,作者也系统研究了随深度变化ViT性能的变化,其中hidden dimension和head数量分别固定为384和12,然后堆叠不同数量的Transformer block(从12到32),结果如图1所示,可以看到,随着模型深度的增加,分类精度提升缓慢,饱和速度较快,且达到24个block后,性能不再有提升。

之前在CNN中也存在这个问题,但随着残差连接的提出,该问题得到了解决。而ViT和CNN的最大区别就在于self-attention机制,因此作者研究了自注意力或者更具体的说是生成的attention map随着网络深度的增加是如何变化的。作者计算了不同层的attention map之间的相似性来衡量注意力图的变化,如下

 

其中 \(M^{p,q}\) 是 \(p\) 层和 \(q\) 层注意力图之间的余弦相似度矩阵,每个元素 \(M^{p,q}_{h,t}\) 表示head \(h\) 和 token \(t\) 的相似度。

根据式(2),作者在ImageNet上训练了一个包含32个block的ViT,并研究了attention map之间的相似度,结果如图3(a)所示,可以看到,在第17个block之后,注意力图之间的相似度超过了比例超过了90%。这表示后面学习到的attention map是相似度,Transformer block可能退化为一个MLP。

为了理解attention collapse是如何影响ViT的性能的,作者进一步研究了它是如何影响更深层网络的特征学习的。因此作者也绘制出了随网络深度变化feature map之间的相似度变化曲线,如图4(left)所示,可以看到feature map的变化曲线和attention map的变化曲线比较相似,这一结果表明,注意力崩溃是导致ViT模型non-scalable的原因。

 

Re-Attention

在实验过程中,作者发现来自同一block不同head之间的attention map的相似度很小,如图3(c)所示。这表明来自同一自注意力层的不同head关注输入token的不同方面。基于此观察,作者提出建立cross-head通信来重新生成attention map。

具体来说,通过动态地聚合来自不同head的注意力图来生成一组新的注意力图。作者定义了一个可学习的变换矩阵 \(\Theta \in\mathbb{R}^{H\times H}\) 并用它来混合不同head的注意力图,具体如下

其中 \(\Theta\) 和注意力图 \(\mathbf{A}\) 沿head维度相乘,Norm是归一化函数用来减少层之间的方差,\(\Theta\) 是端到端可学习的。

实验结果

如图1所示,在将ViT中的self-attention换成Re-Attention后得到的DeepViT,随着网络深度的增加并没有像ViT那样过早的出现性能饱和,而是继续提升。

如图8(a)所示,Re-Attention的相邻block注意力图的相似度显著降低。

作者定义了DeepViT-S和DeepViT-L,具体配置如下,其中split ratio表示不用Re-Attention和使用Re-Attention的block数的比例,如图3(a)所示,只有在网络的深层注意力图和特征图之间的相似度才会变高,因此没必要在所有层的block中都使用Re-Attention。 

和其它SOTA模型在ImageNet上的性能对比如下所示

 

代码解析

Re-Attention的实现如下,其中 \(\Theta\) 是通过卷积定义的,归一化采用的BN。

class ReAttention(nn.Module):
    """
    It is observed that similarity along same batch of data is extremely large. 
    Thus can reduce the bs dimension when calculating the attention map.
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., expansion_ratio=3,
                 apply_transform=True, transform_scale=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.apply_transform = apply_transform
        
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if apply_transform:
            self.reatten_matrix = nn.Conv2d(self.num_heads, self.num_heads, 1, 1)
            self.var_norm = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
            self.reatten_scale = self.scale if transform_scale else 1.0
        else:
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, atten=None):
        B, N, C = x.shape
        # x = self.fc(x)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        if self.apply_transform:
            attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
        attn_next = attn
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_next

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1970854.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

共享打印机0x0000011b错误解决方法

日打印机故障一直是一个热门话题,特别是共享打印机0x0000011b错误特别头疼,有很多网友经常遇到共享打印机0x0000011b错误。0x0000011b有更新补丁导致的、有访问共享打印机服务异常、有访问共享打印机驱动异常等问题导致的,针对共享打印机0x00…

问题易如反掌?5个常用的AI人工智能助手推荐

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 如今的人工智能技术正以惊人的速度改变着我们的生活方式和工作方式。作为这一变革的关键驱动力,人工智能不仅在科技…

一个方法解决看世界时区 做外贸和跨境电商的必备小工具

一个方法解决看世界时区 做外贸和跨境电商的必备小工具。做过外贸或跨境电商的伙伴们都知道,看世界各地时区是一个比较繁琐的事情。 很多公司都有自己专注的几个地区业务,经常要看业务地区的时间,这样方便和客户沟通。做生意的人都知道&…

uniapp - APP分享到微信,通过h5页面跳转至对应的app页面

目录 项目场景: 效果展示: 解决方案: 第一步: 第二步 : 1、微信跳转APP:wx-open-launch-app 第三步: 总结: 项目场景: uniapp框架开发的app(Android和ios)&…

Zilliz 2025届校园招聘正式启动,寻找向量数据库内核开发工程师

为了解决非结构化数据处理问题,我们构建了向量数据库-Milvus! Milvus 数据库不仅是顶级开源基金会 LF AI&Data 的毕业项目,还曾登上数据库顶会SIGMOD、VLDB,在全球首届向量检索比赛中夺冠。目前,Milvus 项目已获得超过 2.8w s…

算法工程师必知必会的数学基础之线性代数

1. 线性代数 线性代数是机器学习和深度学习中一个非常重要的数学基础。下面我将详细介绍线性代数中的一些基本概念,并使用 Python 的 NumPy 库来演示这些概念的应用。 1.1 向量(Vectors)与 矩阵(Matrices) 向量&…

Qt 登录界面

本文代码效果如下: 本文代码: https://download.csdn.net/download/Sakuya__/89607657https://download.csdn.net/download/Sakuya__/89607657 代码之路 LoginTitleBar.h 自定义的透明标题栏 #ifndef LOGINTITLEBAR_H #define LOGINTITLEBAR_H#in…

【书生大模型实战营】基础岛-8G 显存玩转书生大模型 Demo

8G 显存玩转书生大模型 Demo 【书生大模型实战营】基础岛-8G 显存玩转书生大模型 DemoInternLM2-Chat-1.8B 模型的部署代码运行StreamLit部署 InternLM-XComposer2-VL-1.8B 模型的部署InternVL2-2B 模型的部署 【书生大模型实战营】基础岛-8G 显存玩转书生大模型 Demo InternL…

“八股文“在现代编程面试中的角色重塑:助力、阻力还是桥梁?

🌈所属专栏:【其它】✨作者主页: Mr.Zwq✔️个人简介:一个正在努力学技术的Python领域创作者,擅长爬虫,逆向,全栈方向,专注基础和实战分享,欢迎咨询! 您的点…

【全网最全】文心智能体平台介绍和应用

什么是智能体平台? 文心智能体平台(Wenxin Intelligent Agent Platform)是由百度开发的一个全面集成多种人工智能技术的开放平台,旨在为企业和开发者提供强大的智能化服务和解决方案。支持广大开发者根据自身行业领域、应用场景&…

LoRa无线通讯,让光伏机器人实现无“线”管理

光伏清洁机器人,作为光伏电站运维的新兴关键设备,已跃升为继组件、支架、光伏逆变器之后的第四大核心组件,正逐步成为光伏电站的标准配置。鉴于光伏电站普遍坐落于偏远无人区或地形复杂之地,光伏清洁机器人必须具备远程操控能力、…

Charles怎么修改参数

Charles怎么修改参数 1、再【Structure】下,找到需要抓取的包,鼠标右键,点中断点。 2、在【Proxy】-点击【Breakpoint Settings…】 3、双击设置断点的接口 4、勾选后,点击【OK】。 5、再次刷新,重新发请求&#…

海思35XX系列(三)sensor(传感器)

刚开始接触这个概念的时候感觉比较模糊,简单记录一下吧 Sensor(传感器)是一种可以感知外部环境并将感知到的信息转化为可用的电信号或其他形式的工具。传感器广泛应用于电子设备、工业自动化、汽车、医疗器械等领域,用于测量、监…

【JVM】常见面试题

🥰🥰🥰来都来了,不妨点个关注叭! 👉博客主页:欢迎各位大佬!👈 文章目录 1. JVM 中的内存区域划分2. JVM 的类加载机制2.1 加载(Loading)✨双亲委派模型2.2 验证(Verification)2.3 准…

AI1-PaddleOCR2.8在VS2019编译运行基于C++引擎推理CPU版本

1、下载PaddleOCR-release-2.8开源项目 https://github.com/PaddlePaddle/PaddleOCR https://github.com/PaddlePaddle/PaddleOCR/releases https://gitee.com/paddlepaddle/PaddleOCR?_fromgitee_search 2、下载安装Windows预测库 https://paddleinference.paddlepaddle.o…

STL—vector—模拟实现【深度理解vector】【模拟实现vector基本接口】

STL—vector—模拟实现 经过了前面对于vector的初步了解,我们已经具备了使用vector的能力了,现在我们就来深度学习一下vector,并做到能模拟实现vector的基础功能。 1.vector深度解析 要想深度了解vector,我们就要去看它的源代码…

“常温”前端网站框架(四)-- 音乐播放器【附源码】

开篇(请大家看完):此网站写给挚爱,后续页面还会慢慢更新,大家敬请期待~ ~ ~ 此前端框架,主要侧重于前端页面的视觉效果和交互体验。通过运用各种前端技术和创意,精心打造了一系列引人入胜的页面…

高龙海洋增收不增利:毛利率有所下滑,产能利用率下降仍扩产?

《港湾商业观察》廖紫雯 日前,高龙海洋集团有限公司(以下简称:高龙海洋)递表港交所,保荐机构为越秀融资。高龙海洋国内运营主体为福建高龙海洋生物工程有限公司。 自2008年公司成立以来,高龙海洋一直从事…

vue3中 provide/inject用法详解

依赖注入:provide 和 inject 什么情况下推荐provide/inject使用:Prop 多层级数据透传 通常情况下,当我们需要从父组件向子组件传递数据时,会使用 props。想象一下这样的结构:有一些多层级嵌套的组件,形成了…

云HIS综合管理系统源码,云端SaaS服务,与监管系统有序对接,扩展性强

云HIS系统: 本套云HIS系统是一款适用于二级及以下医院、专科医院和社区卫生机构的综合性医院信息系统,它包含门诊预约挂号、收费结算、排班、医护协同、药房、药库、电子病历等10大功能模块,支持门诊、住院、医技、后勤各项核心业务。 采用…