[ICLR 2023] Token Merging: Your ViT But Faster

news2024/12/26 12:04:54

Contents

  • Introduction
  • Token Merging
  • Experiments
    • Image Experiments
      • Design choices
      • Model Sweep
      • Comparison to Other Works
      • Visualizations
    • Video Experiments
    • Audio Experiments
  • References

Introduction

  • 作者提出了一种 token 合并方法 Token Merging (ToMe),能够在不进行额外训练的情况下提高 ViT 推理速度,做到了真正的即插即用。通过使用轻量匹配算法逐步合并相似 tokens,ToMe 在运行速度上和 token pruning 方法一样快,并且还能使得模型精度更高 (prune 会损失掉被裁减 token 的信息,但 merge 只会在合并不相似 token 时损失信息)。当然,ToMe 也可以在训练时使用,从而加速训练过程,并进一步提高 ToMe 的推理精度
  • 在不进行额外训练的情况下 (i.e., Off-the-shelf),ToME 能够将 ViT-L @ 512 和 ViT-H @ 518 on images 的吞吐量加速到 2 × 2\times 2×,将 ViT-L on video 的吞吐量加速到 2.2 × 2.2\times 2.2× 并且只产生 0.2%~0.3% 的精度损失。当在训练时使用 ToMe 时,Tome 能将 ViT-B on audio 加速到 2 × 2\times 2×,并且只有 0.4% mAP 的精度损失

Token Merging

  • Strategy. 在每个 block 的 attention 层之后、MLP 层之前设置 ToMe 层用于 token 合并 (之所以设置在 block 中间而非每个 block 的开始处是因为可以利用 attention 中的信息帮助计算 token 间的相似度),每个 block 合并 r r r 次,即减少 r r r 个 tokens,假设总共有 L L L 个 blocks,则一共减少 r L rL rL 个 tokens. r r r 为超参,用于控制速度和精度的平衡。如下图 a 所示,狗的毛发对应的 token 最终被合并为了相同的 token
    在这里插入图片描述
  • Token Similarity. 衡量 token 间相似度最直接的方法就是计算 token embed 间的距离,但这并不是最优的,因为 ViT 的中间特征是 overparameterized 的,例如 ViT-B/16 的特征维度完全可以编码每个 patch 的 rgb pixel value ( 16 × 16 × 3 = 768 16\times16\times3=768 16×16×3=768),也就是说,ViT 的中间特征往往存在很多噪声,直接将其用于相似度计算并不能有效反映 token 间的相似度。为了解决上述问题,作者直接使用 self-attention 中 token 的 keys (K) 计算余弦相似度来估计 token 间的相似度 (the keys (K) already summarize the information contained in each token for use in dot product similarity)
  • Bipartite Soft Matching (二部图匹配). 作者提出了一种快速二部图匹配算法用于 token 合并:把所有 token 按照顺序交替分成两个集合,集合两两之间计算相似度,只保留集合 A A A 到集合 B B B 最相似的边,最终保留 r r r 条最相似的边,通过加权平均来合并相连的 r r r 对 tokens,权重为 token size s s s,代表该 token 是多少个原始 patch token 合并后的结果 (number of patches the token represents).
    在这里插入图片描述
def bipartite_soft_matching (k: torch . Tensor , r: int ) -> torch . Tensor :
	""" Input is k from attention , size [batch , tokens , channels ]. """
	k = k / k. norm ( dim =-1, keepdim = True )
	a, b = k[... , ::2, :], k[... , 1::2, :]
	scores = a @ b. transpose (-1, -2)
	
	scores [... , 0, :] = - math . inf # don ’t merge cls token
	
	node_max , node_idx = scores . max ( dim =-1)
	edge_idx = node_max . argsort ( dim =-1, descending = True )[... , None ]
	
	unm_idx = edge_idx [... , r:, :] # Unmerged Tokens
	src_idx = edge_idx [... , :r, :] # Merged Tokens
	dst_idx = node_idx [... , None ]. gather ( dim =-2, index = src_idx )
	
	unm_idx = unm_idx . sort (dim =-2)[0] # Sort cls token back to idx 0
	
	def merge (x: torch . Tensor ) -> torch . Tensor :
		""" Input is of shape [batch , tokens , channels ]. """
		src , dst = x[... , ::2, :], x[... , 1::2, :]
		n, t1 , c = src . shape
		unm = src. gather ( dim=-2, index = unm_idx . expand (n, t1 - r, c))
		src = src. gather ( dim=-2, index = src_idx . expand (n, r, c))
		dst = dst. scatter_add (-2, dst_idx . expand (n, r, c), src )
		return torch . cat([unm , dst ], dim=-2)
	
	return merge
  • Tracking Token Size. 作者认为 token size 越大,该 token 在 softmax attention 里的重要性也应该越大,为此,作者提出了 proportional attention.
    在这里插入图片描述softmax 里的 log ⁡ s \log s logs 相当于给 attention score 乘上一个系数 s s s,相当于是有 s s s 份相同的 keys

Experiments

Image Experiments

Design choices

  • Token Similarity. 下表对比了计算 token 间相似度的不同方法, X pre \text{X}_{\text{pre}} Xpre 为输入 block 的 token feature, X \text{X} X 为 attention 后的 token feature, K,Q,V \text{K,Q,V} K,Q,V 分别为计算相似度使用的自注意力层特征
    在这里插入图片描述下表对比了计算相似度的不同距离函数
    在这里插入图片描述为了模型更加高效,作者选择 average K \text{K} K over the attention heads instead of concatenating them
    在这里插入图片描述
  • Algorithmic Choices. 下表对比了合并 token 的不同方法
    在这里插入图片描述下表对比了将所有 token 划分为两个集合的不同方法
    在这里插入图片描述
  • Proportional Attention. 作者发现 proportional attention 对 supervised models (e.g., AugReg, SWAG, DeiT) 比较有用,但对 MAE 没用,这可能是因为 MAE 在训练时就会丢弃 tokens. 因此作者对除了 off-the-shelf MAE models 之外的模型使用了 proportional attention
    在这里插入图片描述
  • Comparing Matching Algorithms.
    在这里插入图片描述
  • Selecting a Merging Schedule. 作者将每层固定合并 r r r 次的策略和随机采样的 1500 种合并策略进行了比较,由下图可以发现固定合并 r r r 次的策略是接近最优的
    在这里插入图片描述此外,作者还发现 linearly decreasing schedule 和最好的随机采样的合并策略相比效果较好,并且能将模型吞吐量提高到 ∼ 3 × \sim3\times 3×,因此作者也定义了 “decreasing” schedule. 与 constant schedule 相比,它们最终合并的 token 数一样多,但 decreasing schedule 在模型早期合并的 token 数更多,因此吞吐量更大
    在这里插入图片描述

Model Sweep

在这里插入图片描述

  • Re-evaluating. 在 Fig. 3c 中,作者测试了使用 ToMe 进行微调后的模型精度。值得一提的是,我们并不需要给每组 r r r 都重新训练一次模型,而是只需要训练一组 r r r 然后在其他 r r r 值上重新测试模型即可,这样做相比 off-the-shelf 也能提升模型精度 (For instance, the baseline ViT-L model we train in Fig. 3c gets 85.7% accuracy. If we re-evaluate our r = 5 r = 5 r=5 trained model with r = 0 r = 0 r=0, we obtain 85.8% accuracy.)

Comparison to Other Works

在这里插入图片描述
在这里插入图片描述

Visualizations

在这里插入图片描述

Video Experiments

  • Results.
    在这里插入图片描述
  • Throughput.
    在这里插入图片描述
  • Clip Count.
    在这里插入图片描述
  • Visualization.
    在这里插入图片描述

Audio Experiments

  • Results.
    在这里插入图片描述

References

  • Bolya, Daniel, et al. “Token Merging: Your ViT But Faster.” (ICLR 2023).
  • code: https://github.com/facebookresearch/tome

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/462974.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Tasking_IDE】-1-如何让目录下的C文件不参与编译

案例背景: 当您在使用Tasking TriCore Eclipse IDE集成开发环境编译时,是不是有时遇到这样一个问题:导入了一个算法/驱动文件夹,但文件夹里面不是所有的C文件都要参与编译,于是您可能想到把这些“不参与编译的文件”删…

Kafka3.0.0版本——生产者 数据去重

目录 一、数据传递语义1.1、至少一次1.2、最多一次1.3、精确一次 二、幂等性2.1、幂等性原理2.2、重复数据的判断标准2.3、如何使用幂等性 三、生产者 事务3.1、Kafka事务原理3.2、Kafka事务注意事项3.3、Kafka事务的5个API3.3.1、初始化事务API3.3.2、开启事务API3.3.3、在事务…

CMake Tutorial Step1

CMake Tutorial Step1 参考资料:Step 1: A Basic Starting Point — CMake 3.26.3 Documentation Tutorial工程:官方Tutorial工程 开发环境:CLion CMake简介 方便起见直接问New Bing。 为什么要学习CMake? CMake的最大特点和…

微服务---分布式搜索引擎 elasticsearch基础

分布式搜索引擎 elasticsearch基础 0.学习目标 1.初识elasticsearch 1.1.了解ES 1.1.1.elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎,具备非常多强大功能,可以帮助我们从海量数据中快速找到需要的内容 例如: 在GitH…

centos7操作yum命令失败

前言设置网卡开机自动启动设置国内dns服务器系统修改CentOS-Base.repo中的地址 前言 刚安装完的CentOS7的系统,发现无法使用yum命令进行更新,在更新的时候会出现下面这种内容,为此问题有以下这些解决方案可以尝试。 One of the configured r…

两段视频合成一个视频用什么软件 怎么把两段视频合成一段看不出来

两段视频合成一个视频用什么软件?无论是两段视频的合成,还是三段视频的合成,用视频编辑软件都能轻松搞定。但怎么把两段视频合成一段看不出来?这就比较考验制作者的功力了,不过我们还是有捷径的,下面一起来…

new和delete

目录 malloc: 开辟失败:返回值为空指针 new: 内置类型: 申请一个int对象(开辟一块存储int类型数据的空间,只能存储一个int数据): 申请5个int对象(开辟一块存储int类型数据的空间&#xff…

Blender3.5 边的操作

目录 1. 边操作1.1 边的细分 Subdivide1.2 边的滑移 Edge Slide1.3 边的删除1.4 边的溶解 Dissolve1.5 边线倒角 Bevel1.6 循环边 Loop Edges1.7 并排边 Ring Edges1.8 桥接循环边 1. 边操作 1.1 边的细分 Subdivide 在边选择模式,选中一条边,右键&…

JVM系列(十一) 垃圾收集器之 Concurrent Mark Sweep 并发标记清除

垃圾收集器之 Concurrent Mark Sweep 并发标记清除 上几篇文章我们讲解了单线程垃圾收集器 Serial/SerialOld ,多线程垃圾收集器 Parallel Scavenge/Old, 本文我们讲解下 Concurrent Mark Sweep 简称CMS垃圾收集器 垃圾收集器 新生代收集器: Serial、ParNew、Par…

图解 | 原来这就是网络

​​ 你是一台电脑,你的名字叫 A 很久很久之前,你不与任何其他电脑相连接,孤苦伶仃。 ​ 直到有一天,你希望与另一台电脑 B 建立通信,于是你们各开了一个网口,用一根网线连接了起来。 ​ 用一根网线连接起来…

[晕事]今天做了件晕事7

今天在使用iptables与grep的时候碰到一件晕事; 第一步添加了一条rule到OUTPUT: iptables -A OUTPUT --source 10.87.51.2 --destination 10.87.51.10 -p tcp --sport 5060 -j DROP 第二步使用:iptables -nL | grep DROP 发现这条记录跑到了FO…

玩转ESP32 PWM输出,制作炫酷呼吸灯效果

文章目录 什么是PWM软硬件使用ESP32实现PWM输出代码讲解结语 什么是PWM PWM(Pulse Width Modulation)是一种常用的模拟信号产生技术,它通过对一个定时器的计数值进行调整来改变输出信号的占空比,从而控制输出信号的平均电压值&am…

idea使用 ( 二 ) 创建java项目并导入依赖jar

3.创建java项目 3.1.创建普通java项目 3.1.1.打开创建向导 接 2.3.1.创建新的项目 也可以 从菜单选择建立项目 会打开下面的选择界面 3.1.2.不使用模板 3.1.3.设置项目名 Project name : 项目名 Project location : 项目存放的位置 确认创建 3.1.4.关闭tips 将 Dont s…

Spring Boot集成ShardingSphere实现数据分片(一) | Spring Cloud 40

一、背景 传统的将数据集中存储至单一节点的解决方案,在性能、可用性和运维成本这三方面已经难于满足海量数据的场景。 从性能方面来说,由于关系型数据库大多采用 B 树类型的索引,在数据量超过阈值的情况下,索引深度的增加也将使…

Mail 邮件服务

~ Postfix ~ sdskill.com 的邮件发送服务器 ~~ 支持smtps(465)协议连接,使用Rserver颁发的证书,证书路径/CA/cacert.pem ~ 创建邮箱账户“user1~user99”(共99个用户),密码为Chinaskill20!; ~ Dovecot ~ sdskill.com 的邮件接收服务器; ~ 支持imap…

6.微服务项目实战---Sleuth--链路追踪

6.1 链路追踪介绍 在大型系统的微服务化构建中,一个系统被拆分成了许多模块。这些模块负责不同的功能,组合成 系统,最终可以提供丰富的功能。在这种架构中,一次请求往往需要涉及到多个服务。互联网应用构建在不同的软件模块集上…

Docker compose-实现多服务、nginx负载均衡、--scale参数解决端口冲突问题

Docker compose-实现多服务、nginx负载均衡、--scale参数解决端口冲突问题 问题:scale参数端口冲突解决方法:nginx实现多服务、负载均衡修改docker-compose.yml配置新增nginx本地配置文件验证启动容器查看容器状态访问web应用 问题:scale参数…

《二》HTTP 请求报文和响应报文、请求方法、状态码

请求报文和响应报文: 请求报文: 客户端向服务器发送的请求信息,就叫做请求报文。 客户端发送一个 HTTP 请求到服务器,请求信息包含四部分:请求行、请求头、空行、请求体。 请求行:包含三部分,分别是请…

查看库文件是32位还是64位|查看lib是静态库还是导入库|判断是debug模式还是release模式

文章目录 dll位数查看lib位数查看查看lib库是静态库还是导入库dll库文件信息查看lib库文件内容查看dll库查看编译模式是debug还是release方法一方法二方法三 lib静态库查看编译模式是debug还是release方法一方法二 lib导入库查看编译模式是debug还是release查看Linux下的.a库&a…

ROS学习第十五节——常用API(C++)

由于时间问题,从这一节开始只记录C实现效果,加油 以下附上这一节调试用的程序 https://download.csdn.net/download/qq_45685327/87708069 1.初始化函数 void init(int &argc, char **argv, const std::string& name, uint32_t options 0); …