DeepMind: 用ReLU取代Softmax可以让Transformer更快

news2025/1/10 16:43:41

注意力是人类认知功能的重要组成部分,当面对海量的信息时,人类可以在关注一些信息的同时,忽略另一些信息。当计算机使用神经网络来处理大量的输入信息时,也可以借鉴人脑的注意力机制,只选择一些关键的信息输入进行处理,来提高神经网络的效率。

2017年,谷歌团队的Vaswani等人发表的《Attention Is All You Need》利用注意力机制,提出Transformer机器学习框架。到目前为止,该论文已经被引用9万多次,显示出Transformer构架和注意力机制在现代机器学习领域中得到了广泛应用。

注意力机制的一个核心步骤中包含了一个 softmax函数,其作用是产生 token 的一个概率分布。数学上来讲,Softmax函数的定义很简单,就是将一个任意序列的数组转换成区间为(0,1)的数组(图1)。因为这种归一化,后者数组可以被解释成前者数组发生的概率。

因为它涉及到指数计算和对序列长度进行求和计算,执行softmax往往有较高的成本,有时候使得并行化难以执行。

图1,softmax函数的定义和说明

最近,Google DeepMind团队在Arxiv上发表一篇预印本论文,《Replacing softmax with ReLU in Vision Transformers》。该论文发现:利用某种不一定会输出概率分布的新方法,即序列长度归一化的ReLU函数,来替代 softmax 运算,可以使得注意力运算得到可以接近或匹敌传统的 softmax 注意力。这一结果为并行化带来了新方案,因为 ReLU 注意力可以在序列长度维度上并行化,其所需的求和运算少于传统的基于softmax注意力。

图2,谷歌DeepMind新论文

方法和原理

注意力机制:

虽然注意力机制有许多种实现方式,最常用的还是“点积标度注意力”机制。

点积标度注意力机制通过一个两步式流程对一个 d-维的数组 {q_i,k_i,v_i} 进行变换。其中 q, k, v 分别表示查询(query)、键(key),和值(value)。

第一步,通过下式方程(1)计算注意力矩阵【注:原文作者把下列方程中的 alpha 叫做注意力权重(attention weight)。其实 alpha 并不是注意力训练的权重。权重矩阵(weight matrix,w)是隐含在单个 q, k, v 的向量矩阵中,即 q=w_q*H, k=w_k*H, v=w_v*H。这里 H 是嵌入向量】:

它表示第 个 query 向量与第 j 个 key 向量之间的关联程度。其中的 phi 就是通常所说的softmax函数。

第二步,将注意力矩阵与对应的 v 向量相乘,得到第 i 个 query 向量更新后的矩阵,其形式化表示为

其中 Q, K, V 分别是 query、key、value 向量序列。如果忽略 softmax 激活函数,实际上它就是三个维数为 m x d_k, d_k x n, n x d_v 的矩阵相乘,得到一个维数为 m x d_v 的矩阵,也就是将维数为 m x d_k 的序列 Q 编码成了一个新的维数为 n x d_v 的序列。

这篇论文探索了使用逐点式计算的方案来替代 phi=saftmax函数的可能性。

ReLU注意力机制

在深度学习理论中,ReLU(rectified linear unit,线性整流函数)是指如下‘整流’变换:

DeepMind团队观察到,可以利用简单的被序列长度 (L) 归一化的线性整流函数,L^(-1)ReLU,替代 softmax,可以产生更加快速有效的结果。他们称这种注意力为 ‘ReLU-attention' (线性整流函数注意力机制)。


图3,各种不同转换函数的比较。softmax类似于左上的Sigmoid函数;ReLU对应于左下的曲线。

广义上来讲,我们可以定义一大类逐点注意力函数,phi=L^(-a)h,其中 a 在 [0,1] 之间取值,h 可以是 ReLU, ReLU**2, GeLU, softplus, identity, ReLU6 和 sigmoid 中的任何一种函数。

序列长度归一化

因为 Transformer 机制要求所有的注意力矩阵元素在某一指标(j)的求和等于1,这意味着注意力矩阵元素的平均量级应该是~1/L,或者说L^(-1)。其中 L 是序列的长度。因此,在上面方程(1)中的 phi 函数就可以是 phi~L^(-1)ReLU。

本文的结果显示,L^(-1) 的归一化对于模型的训练精度至关重要。然而,在以往类似的工作中,其他研究者并没有注意到这个归一化因子的重要性。

实验与结果

作者在不改变原模型参数的情况下,对BigVision库中的两个程序(ImageNet-21k and ImageNet-1k)进行了测试。作者对这两个模型分别进行了30和300个epoch的训练。

主要结果

图 4 的结果显示出,在 ImageNet-21k 训练方面,ReLU 注意力与 softmax 注意力有着类似的模型训练精度。但是,ReLU 注意力的一大优势是能在序列长度维度上实现并行化,其所需的收集操作比 softmax 注意力更少。

图 4:sofmax注意力和ReLU注意力机制的比较。

序列长度扩展的效果

图 5 对比了序列长度扩展方法与其它多种替代 softmax 的逐点式方案的结果。具体来说,就是用 relu、relu²、gelu、softplus、identity 等方法替代 softmax。X 轴是 α。Y 轴则是 S/32、S/16 和 S/8 视觉 Transformer 模型的准确度。最佳结果通常是在 α 接近 1 时得到。由于没有明确的最佳非线性,所以他们在主要实验中使用了 ReLU,因为它速度更快。

图5:用L^(−α)h 替换 softmax函数,其中 h ∈ {relu, relu2 , gelu, softplus, Identity, relu6, sigmoid},  L 是序列长度。 

qk-layernorm 的效果

此前的研究中,Dehghani等人提出一种叫做qk-归一化的训练机制。在该算法中,和 k 矩阵会通过 LayerNorm传递。本文的作者表示,默认使用 qk-layernorm 的原因是在扩展模型大小时有必要防止不稳定情况发生。图 6 展示了移除 qk-layernorm 的影响。这一结果表明 qk-layernorm 对这些模型的影响不大,但当模型规模变大时,情况可能会不一样。

图 6:qk-layernorm对ReLU和ReLU**2的影响。

添加gate的效果

此前也有关于移除 softmax 但是添加一个门控单元(gated unit)的做法,但这种方法无法随序列长度而扩展。具体来说,在门控注意力单元中,会有一个额外的投影产生输出,该输出是在输出投影之前通过矩阵元素相乘得到的。图 7 探究了gate的存在是否可消除对序列长度扩展的需求。总体而言,本文作者观察到,不管有没有gate,通过序列长度扩展都可以得到最佳准确度。也要注意,对于使用 ReLU 的 S/8 模型,这种门控机制会将实验所需的核心时间增多大约 9.3%。

图 4:使用门控注意力单元对 ReLU 和 ReLU**2  注意力机制的影响,其中 L 是序列长度。

小结

Softmax函数是Transformer学习机制的一个核心函数。因为它涉及到指数求和运算,该函数不利于并行化计算。此前曾有研究人员试图利用ReLU或者ReLU**2来取代softmax,但是效果并不理想。

谷歌DeepMind团队的这份研究报告显示,ReLU加上序列长度归一化,可以取得和传统softmax近似的模型训练精度。但是ReLU注意力的速度更快,更有利于并行化运算。

尽管如此,正如作者所指出的,这篇报告留下了许多悬而未决的问题。 特别是,他们不确定为什么这个L^(-1)因子可以提高模型的训练性能,或者这个因子能否通过学习获得。很显然,可能有更好的激活函数等待我们去发现。

参考文献:

M Wortsman, J Lee, J Gilmer, S Kornblith, Google DeepMind, Replacing softmax with ReLU in Vision Transformers. arXiv:2309.08586v1 [cs.CV] 15 Sep 2023. https://arxiv.org/pdf/2309.08586.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1037700.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

常见限流算法学习

文章目录 常见限流算法学习前言限流算法基本介绍固定窗口计数器限流算法计数器限流算法相关介绍计数器限流算法的实现(基于共享变量)计数器限流算法的实现(基于Redis) 滑动窗口计数器算法滑动时间窗口算法相关介绍介绍滑动时间窗口…

【软件设计师-从小白到大牛】上午题基础篇:第五章 结构化开发方法

文章目录 前言结构化设计1、基本原则真题链接2、内聚与耦合真题链接3、系统结构/模块结构真题链接用户界面设计的黄金原则(补充)真题链接数据流图(补充)真题链接系统文档(补充)真题链接 前言 ​ 本系列文章…

如何使用 Git 进行多人协作开发(全流程图解)

文章目录 分支管理策略1.什么是Feature Branching?2.Feature Branching如何工作? 多人协作一:单分支1.准备工作2.创建分支3.在分支上开发4.分支合并5.清理 多人协作二:多分支1.创建分支2.在分支上开发3. pull request4.清理 在软件…

/usr/bin/ld: cannot find -lmysqlcllient

文章目录 1. question: /usr/bin/ld: cannot find -lmysqlcllient2. solution 1. question: /usr/bin/ld: cannot find -lmysqlcllient 2. solution 在 使用编译命令 -lmysqlclient时,如果提示这个信息。 先确认一下 有没有安装mysql-devel 执行如下命令 yum inst…

js对象属性

在面向对象的语言中有一个标志,那就是都有类,通过类可以创建任意多个相同属性、方法的对象。在js中没有类的存在,所以js中的对象,相对于类语言中对象有所不同。 js中定义对象为:“无序属性的集合,其属性可…

新版绿豆视频APP视频免授权源码 V6.6插件版

新版绿豆视频APP视频免授权源码 V6.6插件版 简介: 新版绿豆视频APP视频免授权源码 插件版 后端插件开源,可直接反编译修改方便 对接苹果cms,自定义DIY页面布局! 绿豆影视APP对接苹果cms 所有页面皆可通过后端自由定制 此版本后端源码 前…

二叉树创建、前序遍历、中序遍历、后序遍历、层序遍历

#define _CRT_SECURE_NO_WARNINGS #include<stdio.h> #include<malloc.h> #define N 100 typedef char data_t;typedef struct tree {data_t data;//存放本节点数据struct tree* l_child;//存放左孩子节点地址struct tree* r_child;//存放右孩子节点地址 }Tree;Tre…

Zig实现Hello World

1. 什么是zig 先列出一段官方的介绍: Zig is a general-purpose programming language and toolchain for maintaining robust, optimal, and reusable software. 大概意思就是说&#xff1a; Zig是一种通用编程语言和工具链&#xff0c;用于维护健壮、最佳和可重用的软件。 官…

电脑计算机xinput1_3.dll丢失的解决方法分享,四种修复手段解决问题

日常生活中可能会遇到的问题——xinput1_3.dll丢失的解决方法。我相信&#xff0c;在座的很多朋友都曾遇到过这个问题&#xff0c;那么接下来&#xff0c;我将分享如何解决这个问题的解决方法。 首先&#xff0c;让我们来了解一下xinput1_3.dll文件。xinput1_3.dll是一个动态链…

服务注册发现_高可用Eureka注册中心搭建

在微服务架构这样的分布式环境中&#xff0c;我们需要充分考虑发生故障的情况&#xff0c;所以在生产环境中必须对各个组件进行高可用部署&#xff0c;对于微服务如此&#xff0c;对于服务注册中心也一样。 问题&#xff1a; Spring-Cloud为基础的微服务架构&#xff0c;所有的…

vulhub venom

文章目录 靶场环境信息收集ftp服务二、信息利用三、任意文件上传三 sudo提权靶场环境 `vmware 靶场信息:https://www.vulnhub.com/entry/venom-1,701/ 下载地址:https://download.vulnhub.com/venom/venom.zip 新建虚拟机打开下载后的ovf文件 遇见导入失败合规性检查时,重试…

找不到d3dcompiler_43.dll,无法继续执行代码如何解决

d3dcompiler_47.dll 是一个与 DirectX 相关的动态链接库&#xff08;DLL&#xff09;&#xff0c;它包含了 DirectX 图形编译器的一些功能。当您的电脑出现 d3dcompiler_47.dll 丢失的情况时&#xff0c;可能会导致一些基于 DirectX 的游戏或应用程序无法正常运行。下面我们将介…

【刷题笔记9.24】LeetCode:只出现一次的数字

LeetCode&#xff1a;只出现一次的数字 一、题目描述 给你一个 非空 整数数组 nums &#xff0c;除了某个元素只出现一次以外&#xff0c;其余每个元素均出现两次。找出那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法来解决此问题&#xff0c;且该算法只使…

ImportError: Java package ‘edu‘ not found, requested by alias ‘edu‘

参考issue&#xff1a; https://github.com/ncbi-nlp/NegBio/issues/44 我目前的解决办法 pip uninstall jpype1 -y可以成功运行。

CCNP-OSPFv3

现在在企业中&#xff0c;用的IPv4居多&#xff0c;在我们的手机上&#xff0c;数据中心&#xff0c;运营商以及一些大企业用的都是IPv6&#xff1b; 为啥用IPv6啊&#xff0c;因为IPv4地址不够用&#xff0c;IPv4地址只有32bit&#xff0c;而IPv6足足有128bit&#xff1b; 那…

【23-24 秋学期】 NNDL 作业2

习题2-1 分析为什么平方损失函数不适用于分类问题&#xff0c;交叉熵损失函数不适用于回归问题 平方损失函数 平方损失函数&#xff08;Quadratic Loss Function&#xff09;经常用在预测标签&#x1d466;为实数值的任务中 表达式为&#xff1a; 交叉熵损失函数 交叉熵损失函…

RGB-D转3D点云原理及实现代码

在图像处理和计算机视觉领域&#xff0c;RGBD 是指结合图像颜色和深度信息的数据格式。文本介绍如何使用Python将RGBD数据转换为3D点云&#xff0c;可以使用 NSDT 3DConvert 在线查看3D点云或者进行格式转换&#xff1a; 1、RGBD 颜色深度 缩写 RGB 代表三基色通道&#xf…

Unity中Shader用到的向量的乘积

文章目录 前言一、向量的乘法1、点积2、差积 二、点积&#xff08;结果是一个标量&#xff09;1、数学表示法2、几何表示法 三、叉积1、向量叉积的结果 与 两个相乘的向量互相垂直2、判断结果正负方向的方法&#xff1a;右手法则 前言 Unity中Shader用到的向量的点积 一、向量…

华为OD机试 - 最小传输时延 - 深度优先搜索DFS(Java 2023 B卷 100分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明计算源节点1到目的节点5&#xff0c;符合要求的时延集合 华为OD机试 2023B卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试&…

Gnomon绑定基础(约束 IK 节点)

点约束 方向约束 父约束 目标约束 修改后 对象方向 IK控制柄 直的骨骼&#xff0c;指定IK怎么弯曲 直的骨骼&#xff0c;指定IK怎么弯曲 样条曲线 数学节点 乘除节点 混合节点 注意