Monarch Mixer:一种性能比Transformer更强的网络架构

news2024/11/27 16:44:41

六年前,谷歌团队在arXiv上发表了革命性的论文《Attention is all you need》。作为一种优势的机器学习网络架构,Transformer技术迅速席卷全球。Transformer一直是现代基础模型背后的主力架构,并且在不同的应用程序中取得了令人印象深刻的成功:包括像BERT、ChatGPT和Flan-T5这样的预训练语言模型,到像SAM和stable diffusion这样的图像模型。

尽管如此,Transformer架构中的自注意力机制和MLP在处理长度很长的序列或者维数很大的模型的时候,速度和效率会打折扣。这主要是因为tranformer构架的时空复杂性随序列的长度和训练模型的维数按平方的依赖关系生长,即所谓“二次元”(quadratic)。

最近,斯坦福大学和纽约州立大学布法罗分校的一个研究团队,在arXiv上发表题为《Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture》论文。该论文提出了一种新的transformer的替代技术:Monarch Mixer(M2)。该方法去掉了 Transformer 中高成本的自注意力和 MLP,代之以富有表现力的 Monarch 矩阵,使之在语言和图像实验中以更低的成本取得了更优的表现。其复杂度随序列长度和模型维度的增长是低于二次元,即所谓次二次元的(sub-quadratic)。

斯坦福大学和布法罗大学在axXiv上发表论文截图

该论文已入选 NeurIPS 2023 并获得 Oral Presentation 资格。

算法与原理

该论文的研究灵感来自 MLP-mixer 和 ConvMixer;这两项研究观察到:许多机器学习模型的运作方式都是沿序列和模型维度轴对信息进行混合,并且它们往往对两个轴使用了单个算子。

寻找表现力强、次二次元且硬件效率高的混合算子的难度很大。举个例子,MLP-mixer 中的 MLP 和 ConvMixer 中的卷积都颇具表现力,但它们都会随输入维度二次扩展。近期有一些研究提出了一些次二次元的序列混合方法,这些方法使用了较长的卷积或状态空间模型,而且它们都会用到快速傅里叶变换( FFT),但这些模型的 FLOP 利用率很低并且在模型维度方面依然是二次扩展。与此同时,不损质量的稀疏密集 MLP 层方面也有一些颇具潜力的进展,但由于硬件利用率较低,某些模型实际上可能还比密集模型更慢。

基于这些灵感,该论文研究团队提出了 Monarch Mixer (M2),其使用到了一类富有表现力的次二次结构化矩阵:Monarch 矩阵。

Monarch矩阵

Monarch 矩阵是一类泛化了FFT的结构化矩阵,并且研究表明其涵盖了范围广泛的线性变换,包括哈达玛变换、托普利兹矩阵、AFDF 矩阵和卷积。它们可通过分块对角矩阵的积进行参数化,这些参数被称为 Monarch 因子,与排列交织。

它们的计算是次二次扩展的:如果将因子的数量设为 p,则当输入长度为 N 时,计算复杂度为O(pN^(p+1)/p),从而让计算复杂度可以位于 p = log N 时的 O (N log N) 与 p = 2 时的 O(N^3/2)之间。

M2 使用了 Monarch 矩阵来沿序列和模型维度轴混合信息。这种方法不仅易于实现,而且硬件效率也很高:使用支持 GEMM(广义矩阵乘法算法)的现代硬件就能高效地计算分块对角 Monarch 因子。

图1:Monarch 矩阵是一种简单、富有表现力且硬件效率高的次二次结构矩阵。Monarch Mixer (M2)使用Monarch=矩阵来混合输入:首先沿着序列维度,然后沿着模型维度。

图 2:Monarch 乘法可以解释为多项式求值和插值。

该论文研究团队实现了一个 M2 层来进行概念验证。代码完全使用 PyTorch 来编写,代码行数不到 40(包括 import 软件包),而且其只需依赖矩阵乘法、转置、改造和逐元素乘积(见图 1 中部的伪代码)。结果,对于大小为 64k 的输入,这些代码在一台 A100 GPU 上实现了 25.6% 的 FLOP 利用率。在 RTX 4090 等更新的架构上,对于同样大小的输入,一个简单的 CUDA 实现就能实现 41.4% 的 FLOP 利用率。

表1:RTX 4090 上各种混频器层的 FLOP 成本和利用率(输入维数64K)

实验测试结果

该研究团队在 Transformer 已占主导地位的三个任务上对 Monarch Mixer 和 Transformer 进行了比较:(1)BERT 风格的非因果掩码语言建模任务;(2)ViT 风格的图像分类任务;(3)GPT 风格的因果语言建模任务。

在每个任务上,实验结果表明新提出的方法在不使用注意力和 MLP 的前提下均能达到与 Transformer 相媲美的水平。他们还在 BERT 设置中评估了新方法相较于强大 Transformer 基准模型的加速情况。

(1)非因果语言建模

对于非因果语言建模任务,论文作者构建了一种基于 M2 的架构:M2-BERT。M2-BERT 可以直接替代 BERT 风格的语言模型,而 BERT 是 Transformer 架构的一大主力应用。对于 M2-BERT 的训练,使用了在 C4 上的掩码语言建模,token 化器则是 bert-base-uncased。 

M2-BERT 基于 Transformer 骨干,但其中的注意力层和 MLP 被 M2 层替换,如图 3 所示。

图 3:M2-BERT 使用 Monarch 矩阵在序列混合器中创建双向门控长卷积,并使用 Monarch 矩阵替换维度混合器中的线性层。

在序列混合器中,注意力被带残差卷积的双向门控卷积替代(见图 3 左侧)。为了恢复卷积,论文作者将 Monarch 矩阵设置为 DFT 和逆 DFT 矩阵。他们还在投射步骤之后添加了逐深度的卷积。

在维度混合器中,MLP 中两个密集矩阵被替换成了学习得到的分块对角矩阵(1 阶 Monarch 矩阵,b = 4)。

作者预训练了 4 个 M2-BERT 模型:其中两个是大小分别为 80M 和 110M 的 M2-BERT-base 模型,另外两个是大小分别为 260M 和 341M 的 M2-BERT-large 模型。它们分别相当于 BERT-base 和 BERT-large。

表 3 给出了相当于 BERT-base 的模型的性能表现。

表 3:M2-BERT-base 与 BERT-base相比的平均 GLUE 分数,以及参数和 GLUE 分数的变化

表 4 给出了相当于 BERT-large 的模型的性能表现。

表 4:M2-BERT-large 与 BERT-large 相比的平均 GLUE 得分,以及变化参数和 GLUE 分数

从这些表中结果可以看到,在 GLUE 基准上,M2-BERT-base 的表现可以媲美 BERT-base,同时参数还少了 27%;而当两者参数数量相当时,M2-BERT-base 胜过 BERT-base 1.3 分。类似地,参数少 24% 的 M2-BERT-large 与 BERT-large 表现相当,而参数数量一样时,M2-BERT-large 有 0.7 分的优势。

表 5 给出了相当于 BERT-base 的模型的前向吞吐量情况。其中报告的是在 A100-40GB GPU 上每毫秒处理的 token 数,这能反映推理时间。

表 5:M2-BERT-base (80M) 吞吐量与 BERT-base 比较结果(以token/毫秒为单位)

可以看到,M2-BERT-base 的吞吐量甚至超过了经过高度优化的 BERT 模型;相较于在 4k 序列长度上的标准 HuggingFace 实现,M2-BERT-base 的吞吐量可达其 9.1 倍!

表 6 则报告了 M2-BERT-base (80M) 和 BERT-base 的 CPU 推理时间 —— 结果是直接运行这两个模型的 PyTorch 实现得到的。

表 6:在不同输入序列长度下批量大小为 1 的 CPU 推理延迟(以毫秒为单位)。 在运行 Intel Cascade Lake 处理器的 GCP n2-standard-48 系列的 48 vCPU、96 GB RAM 实例上对 10 多个示例进行了平均测量

当序列较短时,数据局部性的影响依然主导着 FLOP 的减少情况,而过滤器生成(BERT 中没有)等操作的成本更高。而当序列长度超过 1K 时,M2-BERT-base 的加速优势就渐渐起来了,当序列长度达 8K 时,速度优势可达 6.5 倍。

(2) ViT风格的图像分类

在非因果建模方面,为了验证新方法在图像上也有在语言上一样的优势,该团队还评估了 M2 在图像分类任务上的表现。

表 7 给出了 Monarch Mixer、ViT-b、HyenaViT-b 和 ViT-b-Monarch(用 Monarch 矩阵替换了标准 ViT-b 中的 MLP 模块)在 ImageNet-1k 上的性能表现。

表 7:ImageNet-1k 上的准确性。 ResNet-152 提供供参考。

Monarch Mixer 优势非常明显:只需一半的参数量,其表现就能胜过原始 ViT-b 模型。而更让人惊讶的是,参数更少的 Monarch Mixer 很能胜过 ResNet-152;要知道,ResNet-152 可是专门针对 ImageNet 任务设计的。

(3)GPT风格因果语言建模

GPT 风格的因果语言建模是 Transformer 的一大关键应用。该团队为因果语言建模构建了一个基于 M2 的架构:M2-GPT。

对于序列混合器,M2-GPT 组合使用了来自 Hyena 的卷积过滤器、当前最佳的无注意力语言模型以及来自 H3 的跨多头参数共享。他们使用因果参数化替换了这些架构中的 FFT,并完全移除了 MLP 层。所得到的架构完全没有注意力,也完全没有 MLP。

他们在因果语言建模的标准数据集 PILE 上对 M2-GPT 进行了预训练。结果见表 8。

表 8:针对不同数量的标记进行训练时,PILE 上的困惑度。

可以看到,尽管基于新架构的模型完全没有注意力和 MLP,但其在预训练的困惑度指标上依然胜过 Transformer 和 Hyena。这些结果表明,与 Transformer 大不相同的模型也可能在因果语言建模取得出色表现。

下一步研究计划

在最近的博客中(hazyresearch.stanford.edu/blog/2023-07-25-m2-bert),作者列出了他们的下一步研究计划:

  1. 我们今天发布了 BERT 代码以及 80M 和 110M 模型的检查点代码,使用序列长度 128 的标准配方进行了预训练 - 请继续关注更长的序列! 查看我们的代码和检查点(80M、110M)。
  2. 在接下来的几周内,请留意进一步的发布,因为我们将训练长序列 BERT 并开始追溯 Transformers 的历史 - 在 ImageNet、因果语言建模、T5 风格模型以及对长序列功能的探索。
  3. 作为此版本的一部分,您将找到一些用于 M2 层前向传递的优化 CUDA 代码(我们将其用于基准测试)——我们将在未来几周内继续优化并发布更新。 当我们探索计算权衡空间时,期待有关这些的另一系列博客和材料!
  4. 当然,更完整的论文将在 arXiv 中推出!

小结

该论文的研究工作为机器学习领域带来了新的思路,挑战了传统Transformer模型的优越性。他们的研究不仅探索了Monarch Mixer的理论基础,还进行了一系列实验来验证其性能。这篇文章的发表为机器学习社区提供了一个全新的研究方向,也让人们重新思考了在自然语言处理和计算机视觉任务中的模型选择。

总的来说,Monarch Mixer(M2)是一种具有次二次复杂度的新型模型架构,能够在不使用传统Transformer中的注意力和MLP的情况下,在自然语言处理和计算机视觉任务中表现出色。它的硬件效率和参数效率使其成为一个有望取代传统Transformer的新选择,为深度学习研究领域带来了新的思考。

参考文献:

Daniel Y. Fu, Simran Arora, Jessica Grogan, Isys Johnson, Sabri Eyuboglu, Armin W. Thomas, Benjamin Spector, Michael Poli, Atri Rudra, Christopher Ré. “Monarch Mixer: A Simple Sub-Quadratic GEMM-Based Architecture”,Oct 18, 2023, https://arxiv.org/abs/2310.12109

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1175339.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[云原生1. ] Docker consul的详细介绍(容器服务的更新与发现)

文章目录 1. 服务注册与发现的概述1.1 cmp问题1.2 解决方法 2. Consul的概述2.1 简介2.2 为什么要使用Consul服务模块2.2 Consul的服务架构2.3 Consul的一些关键特性 3. consul服务部署3.1 前置准备3.2 Consul服务器3.2.1 建立 Consul 服务3.2.2 设置代理,在后台启动…

Linux开发工具的使用(vim、gcc/g++)

文章目录 vimvim基本概念vim的常用三种模式vim三种模式的相互转换vim命令模式下的命令集移动光标删除文字剪切/删除复制替换撤销和恢复跳转至指定行 vim底行模式下的命令集 gcc/ggcc/g的作用gcc/g的语法预处理编译汇编链接函数库动静态库动态链接的优缺点 静态链接的优缺点 vim…

注意,注意,weak_ptr有坑

class Test { public:Test(){cout << "构造函数\n";}~Test(){cout << "析构函数\n";} }; void *operator new(size_t nsize) {void *ptmp std::malloc(nsize);printf("申请内存:%d,%p\n",nsize, ptmp);return ptmp; }void operator…

【油猴脚本】学习笔记

目录 新建用户脚本模板源注释 测试代码获取图标 Tampermonkey v4.19.0 原教程&#xff1a;手写油猴脚本&#xff0c;几分钟学会新技能——王子周棋洛   Tampermonkey首页   面向 Web 开发者的文档   Greasy Fork 新建用户脚本 打开【管理面板】 点击【】&#xff0c;即…

微服务使用指南

微服务使用指南 1.初识微服务 微服务可以认为是一种分布式架构的解决方案&#xff0c;提供服务的独立性和完整性&#xff0c;做到服务的高内聚、低耦合。 目前服务架构主要包含&#xff1a;单体架构和分布式架构。 1.1 单体架构 单体架构&#xff1a;把所有业务功能模块都…

YoloV8目标检测与实例分割——目标检测onnx模型推理

一、模型转换 1.onnxruntime ONNX Runtime&#xff08;ONNX Runtime或ORT&#xff09;是一个开源的高性能推理引擎&#xff0c;用于部署和运行机器学习模型。它的设计目标是优化执行使用Open Neural Network Exchange&#xff08;ONNX&#xff09;格式定义的模型&#xff0c;…

微信怎么批量保存大量照片

8-2 本文要解决的问题是自动或者快速地保存微信收到的图片的事情&#xff0c;如果你的工作中有一个事情是需要每天或者经常保存大量的从微信收到的图片或者视频的&#xff0c;也许本文适合你&#xff0c;本文介绍的方法&#xff0c;可以自动保存各个群或者人发来的图片和视频。…

【LeetCode每日一题合集】2023.9.18-2023.9.24(⭐拓扑排序⭐设计数据结构:LRU缓存实现 LinkedHashMap⭐)

文章目录 337. 打家劫舍 III&#xff08;树形DP&#xff09;2560. 打家劫舍 IV&#xff08;二分查找动态规划&#xff09;LCP 06. 拿硬币&#xff08;简单贪心模拟&#xff09;2603. 收集树中金币⭐思路——拓扑排序删边 2591. 将钱分给最多的儿童&#xff08;分类讨论&#xf…

MATLAB_5MW风电永磁直驱发电机-1200V直流并网MATLAB仿真模型

仿真软件&#xff1a;matlab2016b 风机传动模块、PMSG模块、蓄电池模块、超级电容模块、无穷大电源、蓄电池控制、风机控制、逆变器控制等模块。 逆变器输出电压&#xff1a; 混合储能系统SOC&#xff1a; 威♥关注“电击小子程高兴的MATLAB小屋”获取更多精彩资料&#xff0…

String的几个常见面试题及其解析

String s3 new String("a") new String("b")会不会在常量池中创建对象&#xff1f; 答案&#xff1a;不会&#xff0c;首先需要解释“”字符串拼接的理解。 采用 运算符拼接字符串时&#xff1a; 如果拼接的都是字符串直接量&#xff0c;则在编译时编…

基于信号功率谱特征和GRNN广义回归神经网络的信号调制类型识别算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022a 3.部分核心程序 ................................................................ %调制识别 len1 func_f…

【代码】【5 二叉树】d3

关键字&#xff1a; 非叶子结点数、k层叶子结点数、层次遍历、找双亲结点、找度为1、叶子结点数

MySQL EXPLAIN查看执行计划

MySQL 执⾏计划是 MySQL 查询优化器分析 SQL 查询时⽣成的⼀份详细计划&#xff0c;包括表如何连 接、是否⾛索引、表扫描⾏数等。通过这份执⾏计划&#xff0c;我们可以分析这条 SQL 查询中存在的 问题&#xff08;如是否出现全表扫描&#xff09;&#xff0c;从⽽进⾏针对优化…

好用的MybatisX插件~

MybatisX插件&#xff1a; MyBatis-Plus为我们提供了强大的mapper和service模板&#xff0c;能够大大的提高开发效率。但是在真正开发过程中&#xff0c;MyBatis-Plus并不能为我们解决所有问题&#xff0c;例如一些复杂的SQL&#xff0c;多表联查&#xff0c;我们就需要自己去…

Web前端—网页制作(以“学成在线”为例)

版本说明 当前版本号[20231105]。 版本修改说明20231105初版 目录 文章目录 版本说明目录day07-学成在线01-项目目录02-版心居中03-布局思路04-header区域-整体布局HTML结构CSS样式 05-header区域-logo06-header区域-导航HTML结构CSS样式 07-header区域-搜索布局HTML结构CSS…

Gin学习笔记

Gin学习笔记 Gin文档&#xff1a;https://pkg.go.dev/github.com/gin-gonic/gin 1、快速入门 1.1、安装Gin go get -u github.com/gin-gonic/gin1.2、main.go package mainimport ("github.com/gin-gonic/gin""net/http" )func main() {// 创建路由引…

打通你学习C语言的任督二脉-函数栈帧的创建和销毁(上)

&#x1f308;个人主页: Aileen_0v0&#x1f525;系列专栏:C语言学习&#x1f4ab;个人格言:"没有罗马,那就自己创造罗马~" 待解决疑惑: 局部变量是怎么创建的? 为什么局部变量的值是随机值? 函数是怎么传参的?传参的顺序是怎样的? 形参和实参是什么关系? 函数调…

3.25每日一题(知线性常系数方程的特解求线性方程)

思路&#xff1a;通过特解可以知道特征根&#xff0c;通过特征根可以求出特征方程&#xff0c;通过特征方程可以求出线性方程

C语言strcat函数再学习

之前学习了strcat函数&#xff1b;下面继续学习此函数&#xff1b; 它的功能描述是&#xff0c; 功能 把src所指向的字符串&#xff08;包括“\0”&#xff09;复制到dest所指向的字符串后面&#xff08;删除*dest原来末尾的“\0”&#xff09;。要保证*dest足够长&#xff0…

【数智化人物展】觉非科技CEO李东旻:数据闭环,智能驾驶数智时代发展的新引擎...

李东旻 本文由觉非科技CEO李东旻投递并参与《2023中国企业数智化转型升级先锋人物》榜单/奖项评选。 大数据产业创新服务媒体 ——聚焦数据 改变商业 数智化的主要作用是帮助决策。它的核心是大数据&#xff0c;以大数据为基础&#xff0c;匹配合适的AI技术&#xff0c;促使数…