【CVPR2023】TPS详解:联合令牌剪枝与压缩以实现视觉变形器更积极的压缩

news2024/12/24 20:34:22

【CVPR2023】TPS详解:联合令牌剪枝与压缩以实现视觉变形器更积极的压缩

  • 0. 引言
  • 1. 为什么要使用TPS?
  • 2. TPS介绍
  • 3. TPS 详解
    • 3.1 重要性计算
    • 3.2 令牌压缩
      • 3.2.1 匹配
      • 3.2.2 融合
  • 4. 简化版理解
  • 5. 总结

0. 引言

虽然 Vision Transformers (ViTs)近年来在各种计算机视觉任务中展示出良好的效果,但是 Transformers 的高复杂度给计算机资源带来了沉重的负担。ViTs 方面的讲解:ViT 和 基于知识蒸馏的ViT(DeiT)。为了克服 Transformers 存在的问题,众多学者提出了自己的见解。其中主要包括以下几个方面:

  1. 最简单的方法(减少Transformers模块比重,增加CNN模块)------MobileViT详解
  2. 通过减少模型输入(正确的说:通过Mask的方法减少模型输入,然后通过Encoder-Decoder重构原始图形)。何凯明大神佳作 MAE
  3. 通过改变全局注意力计算的方式(Transformers模块复杂度过高往往是由于全局注意力的计算方式)。Swin-Transformer详解、CSWin Transformer详解
  4. 通过对令牌进行修剪和合并(通过减少Token的数量进而减少模型复杂度)。DiffRate详解

而本篇文章所提出的新的联合令牌修剪和压缩模块(TPS) ,是为了解决 由修剪策略引起的错误可能导致重大的信息丢失 的问题。首先,TPS通过剪枝得到保留子集和剪枝子集。其次,TPS通过单向最近邻匹配基于相似性融合步骤,将被修剪的令牌信息压缩为部分保留令牌。

论文名称:Joint Token Pruning and Squeezing Towards More Aggressive Compression of Vision Transformers
论文地址:https://arxiv.org/abs/2304.10716
代码地址:https://github.com/megvii-research/tps-cvpr2023
注意:截止当前,代码中只有dTPS部分,作者仍在更新完善项目。

1. 为什么要使用TPS?

与传统直接进行令牌修剪相比,联合令牌修剪和压缩在某种程度上保存了所有信息。从而防止因手动设置剪切率导致删除重要信息的情况。

在这里插入图片描述

令牌修剪范式(第二行)与联合令牌修剪和压缩(第三行)之间的比较

在上图中,上下文信息(例如示例中的sod)有助于预测,但会被令牌修剪范式丢弃。然而,TPS 方法可以将修剪过的令牌压缩到保留的令牌中,从而减轻了信息丢失。通过这种设计,我们可以应用更积极的令牌修剪同时减少性能下降示例结果来自ImageNet1K,为了可视化的清晰度,将实际的补丁网格从 14 × 14 14 × 14 14×14 减少到 7 × 7 7 × 7 7×7

为了更好地解释 TPS 的操作流程,这里采用图片 对比传统修剪、重组方法和 TPS 方法的区别。
在这里插入图片描述
如上图所示,图(a)表示令牌修剪的方法,通过计算各个 token 的重要性,选择其中最为重要的 k k k 个进行保留,删除余下的 token ;图(b)表示令牌重组的方法,在计算各个 token 的重要性后,将最重要的 k k k 个进行保留的同时,将需要删除的 token 合并成第 k + 1 k+1 k+1 个 token 进行保存;图(c)表示 TPS 方法,TPS 采用令牌修剪和压缩两步来压缩 ViTs。在TPS 方法中,在计算各个 token 的重要性后,将需要删除的 token 与保留的 token 计算相似性,将需要删除的 token 中存在的信息压缩最相似的保留的 token 中。
因此,从上述介绍中可知:TPS 方法可以与任意 令牌修剪 的方法相合并,从而得到保留子集 S r S^r Sr 和修剪子集 S p S^p Sp

2. TPS介绍

TPS 方法存在两种变体:dTPSeTPS,分别指 块间(在两个 Transformer Block 之间压缩 token)块内(在 Transformer Block中间压缩 token)令牌压缩。其中,块间压缩的 Class Token Attention 的理解可以看 DiffRate详解:高效Vision Transformers的可微压缩率。
在这里插入图片描述
具体而言:

  • dTPS 采用dynamicViT 中的可学习令牌分数预测头,通过直通Gumbel Softmax对二值决策掩码进行可微性采样利用Gumbel Softmax,可以使目标函数对于该mask参数可微);
  • eTPS使用类令牌关注值来衡量令牌作为EViT的重要性
  • 在两种变体的推理阶段,基于token分数,使用给定固定token压缩比 ρ ρ ρ 的 Top-k 操作设计token选择策略;
  • 这两种变体都保证了恒定的形状,从而从计算图的推理优化中获益。

3. TPS 详解

3.1 重要性计算

论文中作者没有详述重要性计算公式。结合作者给出的代码,相关代码如下所示。

pred_score = self.score_predictor[p_count](
    spatial_x, prev_decision).reshape(B, -1, 2)
if self.training:
    # use gumbel-softmax and mask-attention with policy
    hard_keep_decision = gumbel_softmax(pred_score, hard=True)[
        :, :, 0:1] * prev_decision
    # TODO: dTPS and eTPS
    current_pruned_decision = (
        1-hard_keep_decision) * prev_decision
    spatial_x = self.tps(
        spatial_x, None, hard_keep_decision, current_pruned_decision)
    x = F.concat([x[:, :1, :], spatial_x], axis=1)
    hard_decision_list.append(
        hard_keep_decision.reshape(B, init_n))
    cls_policy = F.ones(
        (B, 1, 1), dtype=hard_keep_decision.dtype, device=hard_keep_decision.device)
    policy = F.concat([cls_policy, hard_keep_decision], axis=1)
    x = blk(x, policy=policy)
    prev_decision = hard_keep_decision
else:
    score = pred_score[:, :, 0]
    num_keep_node = int(init_n * self.keep_ratio_list[p_count])
    sort_idxs = F.argsort(score, descending=True)
    keep_idxs = sort_idxs[:, :num_keep_node]
    drop_idxs = sort_idxs[:, num_keep_node:]
    spatial_x = self.tps(batch_index_select(
        spatial_x, keep_idxs), batch_index_select(spatial_x, drop_idxs), None, None)
    x = F.concat([x[:, :1, :], spatial_x], axis=1)
    x = blk(x)
p_count += 1

上述代码为 dTPS 模型计算重要性, eTPS作者暂未给出。在上述计算过程中,当模型训练的时候使用可学习的分数,然后使用Gumbel Softmax 进行二值决策。当模型训练完成后,采用令牌压缩机制进行操作(类似于DeiT中的知识蒸馏,也许这就是为什么模型文件叫做 tps_deit.py 的原因)。

3.2 令牌压缩

考虑到保留令牌贡献了大部分正确的预测,作者的目的是设计一个过程,在保留大多数注意令牌的同时压缩来自删除令牌的信息,从而保持模型的整体性能。为了避免生成额外的令牌,作者将修剪过的令牌注入到类似的保留令牌中。因此,作者以多对一的方式应用了从 S p S^p Sp S r S^r Sr单向最近邻匹配算法。然后,作者采用一种基于相似性的融合方法将信息从被修剪的令牌中吸收到部分保留令牌中。
将上述过程概括为两个步骤:匹配融合

3.2.1 匹配

给定两个子集 S r S^r Sr S p S^p Sp , I r I^r Ir I p I^p Ip S r S^r Sr S p S^p Sp 对应的 token 序号。对于所有 i ∈ I p i∈I^p iIp j ∈ I r j∈I^r jIr相似度矩阵 c i , j c_{i,j} ci,j 表示匹配令牌之间的相互作用。对于每一个被删减的令牌 x i ∈ S p x_i∈S^p xiSp,从保留令牌集 S r S^r Sr 中找到距离最近的令牌 x ∗ h o s t ∈ S r x^{host}_∗∈S^r xhostSr 作为它的 host token:
x ∗ h o s t = a r g m a x    c i , j        ( 1 ) x j ∈ S r x_*^{host} = \mathop argmax\ \ {c_{i,j}} \ \ \ \ \ \ (1) \\ x_j \in {S^r} xhost=argmax  ci,j      (1)xjSr注意,由于令牌匹配步骤从 S p S^p Sp S r S^r Sr 是单向的,因此多个被修剪的令牌可以共享同一个主机令牌,而不是每个保留令牌都可以作为主机令牌。
然后,将匹配结果记录在mask 矩阵 M ∈ R N p × N r M∈R^{N^p×N^r} MRNp×Nr 中,其值由下式计算得出:
m i , j = { 1 , x j i s   t h e   h o s t   t o k e n   o f   x i , 0 , o t h e r w i s e ,      ( 2 ) m_{i,j}=\begin{cases} 1, \boldsymbol x_j{is \ the \ host \ token \ of \ }\boldsymbol x_i,\\ 0, otherwise, \end{cases} \ \ \ \ (2) mi,j={1xjis the host token of xi,0otherwise    (2)式中, N p N^p Np N p N^p Np 分别表示两个子集的令牌个数mask 有助于在排除不匹配对影响的同时,对 S r S^r Sr S p S^p Sp 进行规则的矩阵运算来进行以下融合步骤。
虽然注意图是衡量令牌之间相互作用的一种自然而自由的选择,但我们可以通过 S r S^r Sr S p S^p Sp 之间的余弦相似度获得更高的性能。因此,在文章的所有的实验中,相似度矩阵定义为:
c i , j = x i T x j ∥ x i ∥ ∥ x j ∥   , f o r   i ∈ I p , j ∈ I r     ( 3 ) c_{i,j} = \frac{{\boldsymbol x_i{^T }}{\boldsymbol x_j}}{ {\|} \boldsymbol x_i{\|\|}\boldsymbol x_j{\|}} \ , for \ i\in I^ p, j \in I^ r \ \ \ (3) ci,j=xi∥∥xjxiTxj ,for iIp,jIr   (3)由于相似矩阵 c i , j c_{i,j} ci,j 是直接由输入特征生成的,所以在匹配步骤中没有引入额外的参数

3.2.2 融合

由于不同标记之间的差异,简单地平均标记可能导致特性分散EViT 利用令牌重要性分数来重新加权聚合令牌。因此,作者使用基于相似性的加权方案。它扩大了 closer tokenshost tokens 的影响,同时也避免了 impact token 评分带来的潜在缺陷
如前所述,融合步骤包含来自两个子集的所有令牌,并由 mask M M M 控制,以确保只混合 host tokens已修剪令牌。这引入了一些冗余计算,但由于常规矩阵运算的效率,增加了实际训练和推理吞吐量。
具体来说,通过剪枝保留下来的令牌 x j ∈ S r x_j \in S^r xjSr 通过融合原始特征被修剪令牌的特征来更新,具体操作如下所示:
y j = w j x j + ∑ x i ∈ S p w i x i ,     ( 4 ) y_j = w_j x_j + \sum_{x_i \in S^{p}} w_ix_i , \ \ \ (4) yj=wjxj+xiSpwixi,   (4)其中, w i w_i wi 为每个被修剪令牌 x i ∈ S p x_i∈S^p xiSp 的权值, w j w_j wj 为保留令牌本身的权值, y j y_j yj 为更新后的令牌。融合权值 w i w_i wi 取决于掩码值 m i , j m_{i,j} mi,j相似度 c i , j c_{i,j} ci,j w i w_i wi 的具体计算公式如下:
w i = exp ⁡ ( c i , j ) m i , j ∑ x i ∈ S p exp ⁡ ( c i , j ) m i , j + e     ( 5 ) w_i= \frac{\exp (c_{i,j})m_{i,j}}{\sum_{\boldsymbol x_i \in S^p }\exp (c_{i,j})m_{i,j} + \mathrm e} \ \ \ (5) wi=xiSpexp(ci,j)mi,j+eexp(ci,j)mi,j   (5)在计算过程中,保留令牌总是具有最大的融合权值 w j w_j wj,因为 x j x_j xj 与自己的相似度等于1(即 exp ⁡ ( c i , j ) m i , j = e \exp (c_{i,j})m_{i,j}=\mathrm e exp(ci,j)mi,j=e),而其余令牌与之相似度小于1。因此, w j w_j wj 的计算公式如下所示:
w j = e ∑ x i ∈ S p exp ⁡ ( c i , j ) m i , j + e     ( 6 ) w_j= \frac{\mathrm e}{\sum_{\boldsymbol x_i \in S^p }\exp (c_{i,j})m_{i,j} + \mathrm e} \ \ \ (6) wj=xiSpexp(ci,j)mi,j+ee   (6)根据上述方程,未被选为 host token 的保留令牌保持不变,而被修剪过的令牌被压缩进 host token ,替换原有令牌。 可以看到,匹配和融合步骤确保处理令牌的数量等于保留令牌的数量,从而保持有效推理的恒定形状。

4. 简化版理解

可能看了上述的内容,大家对于 TPS 的整体还是不太理解。这里对文章内容进行口语式解答来帮助大家理解文章内容。
TPS 这篇文章总的来说通过将需要修剪的信息压缩融合到最近似无需修剪的信息部分(可能存在多个块融合进一个块的情况),既提升了模型的运算速度又不丢失所有信息。
具体而言:

  • 首先,确定哪些 token 的重要性较低会被删除,哪些重要性较高会保留。
  • 然后,依次匹配需要删除的token与保留的token中哪个最相似。
  • 最后,将所有需要删除的token与最相似的保留的token相融合。

注意:可能存在一个保留的token融合多个需要删除token的情况,也存在保留的token与任意一个需要删除toiken也不融合的情况。

5. 总结

作者的实验证明:与最先进的方法相比,TPS方法在所有令牌修剪强度下都优于它们。特别是当将小型计算预算缩减到35%时,与ImageNet分类的基线相比,它的准确率提高了1%-6%。该方法可将DeiT-small的吞吐量提高到超过DeiT-tiny,准确率比DeiT-tiny提高4.78%。在各种变压器上的实验证明了该方法的有效性,分析实验证明了该方法对令牌修剪策略的误差具有较高的鲁棒性。如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

到此,有关TPS的内容就基本讲完了。如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/617141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小文智能宣布接入ChatGPT,智能化客户服务,开创全新用户体验

小文智能是一家致力于用AI技术解放劳动力的公司,最近我们接入了ChatGPT技术,深度探索AI在智能对话机器人领域应用的更多可能,这将为我们的客户带来更为优质的人机对话服务和全新的用户体验。 ChatGPT是一种基于人工智能的自然语言处理技术&a…

案例31:基于Springboot企业员工薪酬关系系统开题报告设计

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

逍遥自在学C语言 | break-循环的中断与跳转

前言 在C语言中,break语句是一种控制流语句,它用于终止当前所在的循环结构(for、while、do-while)或者switch语句,从而跳出循环或者结束switch语句的执行。 一、人物简介 第一位闪亮登场,有请今后会一直…

ML算法——梯度下降随笔【机器学习】

文章目录 11、梯度下降 11、梯度下降 梯度下降如何帮助参数优化? 梯度下降是一种用于参数优化的常见方法。它的基本思想是通过迭代地更新参数,以减小损失函数|代价函数的值,从而找到一个最优解。 梯度方向:→|向右|正向 ←|向左|反…

PostGIS(1):PostGIS概述

作为对象关系型数据库PostGreSQL的拓展模块,PostGIS可用于存储GIS数据,并提供了对基于GiST的R树索引支持、以及面向GIS对象的分析和处理相关的函数。 以下是PostGIS官网对其特征的介绍, (1) 先看一下百度对PostGIS的介…

Langchain-ChatGLM:基于本地知识库问答

文章目录 ChatGLM与Langchain简介ChatGLM-6B简介ChatGLM-6B是什么ChatGLM-6B具备的能力ChatGLM-6B具备的应用 Langchain简介Langchain是什么Langchain的核心模块Langchain的应用场景 ChatGLM与Langchain项目介绍知识库问答实现步骤ChatGLM与Langchain项目特点 项目主体结构项目…

php7新特性详细介绍(二)

一、PHP 7 异常 PHP 7 异常用于向下兼容及增强旧的assert()函数。它能在生产环境中实现零成本的断言,并且提供抛出自定义异常及错误的能力。 assert() 配置 | 配置项默认值可选值zend.assertions11 - 生成和执行代码 (开发模式) 0 - 生成代码,但在执…

智警杯excel和sql实训盲点

目录 excel基础操作: excel函数:智警杯赛前学习1.2--excel统计函数_lulu001128的博客-CSDN博客知识点https://blog.csdn.net/lulu001128/article/details/130936259?spm1001.2014.3001.5501 excel报表实战: excel数据透视及绘图&#xff…

Amino框架无锁算法实现并发线程安装组件(一)

Amino是无锁并行框架,线程安装,该框架封装了无锁算法,提供了可用于线程安全的一些数据结构,同时还内置了一些多线程调度模式。使用Amino进行软件开发有以下的优势: 1.对死锁的问题免疫 2.确保系统并发的整体进度 3.降低高并发下无锁竞争带…

java设计模式之:建造者模式

文章目录 建造者模式介绍建造者模式适用场景案例场景一坨坨代码实现重构代码 与工厂模式区别建造者模式优缺点总结 该说不说几乎是程序员都知道或者了解设计模式,但大部分小伙伴写代码总是习惯于一把梭。好的代码不只为了完成现有功能,也会考虑后续扩展。…

springboot自动配置源码解析

概述 使用springboog的时候引入starter就自动为我们加载,例如我们引入 spring-boot-starter-web 之后,就自动引入了 Spring MVC 相关的 jar 包,从而自动配置 Spring MVC 。 自动装配原理 SpringBootApplication SpringBootApplication: Spri…

Java的引用

一、概述 其实java有4种引用,4种可分为强、软、弱、虚。我们将从这四个方面入手进行介绍。 二、强引用 首先看到我们有一个类叫M,在这个类里我重写了一个方法叫finalize(),我们可以看到这个方法是已经被废弃的方法,为什么要重写…

【jupyter】Jupyter Notebook如何导入导出文件

目录 0.系统:windows 1.打开 Jupyter Notebook 2.Jupyter Notebook导入文件 3.Jupyter Notebook导出文件 0.系统:windows 1.打开 Jupyter Notebook 1)下载【Anaconda】后,直接点击【Jupyter Notebook】即可在网页打开 Jupyte…

用户研究干货——这一篇就够啦

一、基本概念: ①工作内容:用户研究的首要目的是帮助企业定义产品目标用户群,明确、细化产品概念,并通过对用户的任务操作特性、知觉特征、认知心理特征的研究,使用户的实际需求成为产品设计的导向,使产品…

建面超72万㎡,南山红花岭旧改规划公示,配套近15万㎡宿舍

近日,深圳市南山区城市更新和土地整备局发布关于桃源街道红花岭工业南区更新单元(暂定名)03-01、02-02地块《建设工程规划许可证》及总平面图的公告。 此次批复的红花岭工业南区02-02、03-01块,总建面超72万㎡,用地单…

nginx+tomcat 负载均衡、动静分离集群

文章目录 一、NginxTomcat负载均衡的组合原因1.1 Nginx实现负载均衡的原理1.2 Nginx实现负载均衡的主要配置项1.3 NginxTomcat负载均衡的组合的优点1.4 NginxTomcat负载均衡的实验设计 二、动静分离部署2.1 部署TOMCAT后端服务器2.2部署nginx服务器2.3安装nginx动态服务器 一、…

java中try-with-resources自动关闭io流

在传统的输入输出流处理中,我们一般使用的结构如下所示,使用try - catch - finally结构捕获相关异常,最后不管是否有异常,我们都将流进行关闭处理: try {//todo } catch (IOException e) {log.error("read xxx f…

《Lua程序设计》--学习1

前言&#xff1a; --> 表示一条语句的输出或表达式求值的结果 -- 单行注释 > 标注 一些代码需要在交互模式下输入 如果需要打印表达式求值的结果&#xff0c;必须在每个表达式前加上一个等号 <--> 表示两者完全等价 语言基础 我们将Lua语言执行的每一…

html选择器

基本选择器 基本选择器 : 标签选择器 , 类选择器 , ID选择器 标签选择器 代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEed…

小企业软件项目流程六步法

小企业软件项目流程六步法&#xff0c;很有效 软件项目的沟通成本是巨大的 软件生产是非常特殊的一套流程 没有过程控制&#xff0c;最终一定失控或废弃 趣讲大白话&#xff1a;输入垃圾&#xff0c;输出也是垃圾 【趣讲信息科技188期】 **************************** 软件行业…