Transformer、RNN和SSM的相似性探究:揭示看似不相关的LLM架构之间的联系

news2024/9/20 22:05:22

通过探索看似不相关的大语言模型(LLM)架构之间的潜在联系,我们可能为促进不同模型间的思想交流和提高整体效率开辟新的途径。

尽管Mamba等线性循环神经网络(RNN)和状态空间模型(SSM)近来备受关注,Transformer架构仍然是LLM的主要支柱。这种格局可能即将发生变化:像Jamba、Samba和Griffin这样的混合架构展现出了巨大的潜力。这些模型在时间和内存效率方面明显优于Transformer,同时在能力上与基于注意力的LLM相比并未显著下降。

近期研究揭示了不同架构选择之间的深层联系,包括Transformer、RNN、SSM和matrix mixers,这一发现具有重要意义,因为它为不同架构间的思想迁移提供了可能。本文将深入探讨Transformer、RNN和Mamba 2,通过详细的代数分析来理解以下几点:

  1. Transformer在某些情况下可以视为RNN(第2节)
  2. 状态空间模型可能隐藏在自注意力机制的掩码中(第4节)
  3. Mamba在特定条件下可以重写为掩码自注意力(第5节)

这些联系不仅有趣,还可能对未来的模型设计产生深远影响。

LLM中的掩码自注意力机制

首先,让我们回顾一下经典的LLM自注意力层的结构:

更详细的结构如下:

自注意力层的工作流程如下:

  1. 将查询矩阵Q和键矩阵K相乘,得到一个L×L的矩阵,包含查询和键的标量积。
  2. 对结果矩阵进行归一化。
  3. 将归一化后的矩阵与L×L的注意力掩码进行元素级乘法。图中展示了默认的因果掩码——左侧的0-1矩阵。这一步骤将较早查询与较晚键的乘积置零,防止注意力机制"看到未来"。
  4. 对结果应用softmax函数。
  5. 最后,将注意力权重矩阵A与值矩阵V相乘。输出的第t行可表示为:

这意味着第i个值是通过"第t个查询对第i个键的注意力权重"来加权的。

这种架构中的多个设计选择都可能被修改。接下来我们将探讨一些可能的变体。

线性化注意力

注意力公式中的Softmax函数确保了值是以和为1的正系数混合的。这种设计保持了某些统计特性,但同时也带来了限制。例如即使我们希望利用结合律,如(QK^T)V = Q(K^TV),也无法突破Softmax的限制。

为什么结合律如此重要?因为改变乘法顺序可能显著影响计算复杂度:

左侧公式需要计算一个L×L矩阵,如果这个矩阵完全显现在内存中,复杂度为O(L²d),内存消耗为O(L²)。右侧公式需要计算一个d×d矩阵,复杂度为O(Ld²),内存消耗为O(d²)。

随着上下文长度L的增加,左侧公式的计算成本rapidly become prohibitively非常的高。为了解决这个问题,我们可以考虑移除Softmax。详细展开带有Softmax的公式:

其中

是Softmax函数。指数函数是主要的障碍,它阻止了我们从中提取任何项。如果我们直接移除指数函数:

那么归一化因子

也随之消失。

这个简化后的公式存在一个问题:q_t^T k_s不能保证为正,这可能导致值以不同符号的系数混合,这在理论上是不合理的。更糟糕的是,分母可能为零,会导致计算崩溃。为了缓解这个问题,我们可以引入一个"良好的"元素级函数φ(称为核函数):

原始研究建议使用φ(x) = 1 + elu(x)作为核函数。

这种注意力机制的变体被称为线性化注意力。它的一个重要优势是允许我们利用结合律:

括号中M, K^T和V之间的关系现在变得相当复杂,不再仅仅是普通的矩阵乘法和元素级乘法。我们将在下一节详细讨论这个计算单元。

如果M是一个因果掩码,即对角线及以下为1,对角线以上为0:

那么计算可以进一步简化:

这可以通过一种简单的递归方式计算:

这是在2020年ICML上首次提出线性化注意力的论文"Transformers are RNNs"。在这个公式中,我们有两个隐藏状态:向量z_t和矩阵h_t(φ(k_t)^T v_t是列向量乘以行向量,得到一个d×d矩阵。

而近期的研究often以更简化的形式呈现线性化注意力,去除了φ函数和分母:

线性化注意力具有两个主要优势:

  1. 作为递归机制,它在推理时相对于序列长度L具有线性复杂度。
  2. 作为Transformer模型,它可以高效地并行训练。

但是你可能会问:如果线性化注意力如此优秀,为什么它没有在所有LLM中广泛应用?我们在讨论注意力的二次复杂度问题?实际上基于线性化注意力的LLM在训练过程中stability较低,且capability略逊于标准自注意力。这可能是因为固定的d×d形状的瓶颈比可调整的L×L形状的瓶颈能传递的信息更少。

进一步探索

RNN和线性化注意力之间的联系在近期的多项研究中得到了重新发现和深入探讨。一个common pattern是使用具有如下更新规则的矩阵隐藏状态:

其中k_t和v_t可以视为某种"键"和"值",RNN层的输出形式为:

这本质上等同于线性注意力。下面两篇论文提供了有趣的一些样例:

1、xLSTM (2024年5月): 该论文提出了对著名的LSTM递归架构的改进。其mLSTM块包含一个矩阵隐藏状态,更新方式如下:

输出通过将这个状态与一个"查询"相乘得到。(注意:该论文的线性代数设置与我们的相反,查询、键和值是列向量而非行向量,因此v_t k_t^T的顺序看起来可能有些奇怪。)

2、Learning to (learn at test time) (2024年7月): 这是另一种具有矩阵隐藏状态的RNN架构,它的隐藏状态W是一个函数的参数,在t的迭代过程中通过梯度下降优化:

这里的设置也是转置的,因此顺序看起来有些不同。尽管数学表达比W_t = W_{t-1} + v_t k_t^T更复杂,但可以简化为这种形式。

以上两篇论文我们都详细介绍过,有兴趣的可以自行搜索

注意力掩码

在简化了掩码注意力机制后,我们可以开始探索其潜在的发展方向。一个明显的研究方向是选择不同的下三角矩阵(确保不会"看到未来")作为掩码M,而不是简单的0-1因果掩码。在进行这种探索之前,我们需要解决由此带来的效率问题。

在前一节中,我们使用了一个简单的0-1因果掩码M,这使得递归计算成为可能。但在一般情况下,这种递归技巧不再适用:

系数m_ts不再相同,也不存在将y_3与y_2关联的简单递归公式。因此,对于每个t我们都需要从头开始计算总和,这使得计算复杂度再次变为L的二次方而不是线性的。

解决这个问题的关键在于我们不能使用任意的掩码M,而应该选择特殊的、"良好"的掩码。我们需要那些可以快速与其他矩阵相乘(注意不是元素级乘法)的掩码。为了理解如何从这种特性中获益,让我们详细分析如何高效计算:

首先明确这个表达式的含义:

如果深入到单个索引级别:

为了便于后续讨论,可以用不同的颜色标记索引,而不是块:

现在我们可以提出一个四步算法:

步骤1. 利用K和V创建一个三维张量Z,其中:

(每个轴都标注了其长度。)这一步骤需要O(Ld²)的时间和内存复杂度。值得注意的是,如果我们在洋红色轴t上对这个张量求和,我们将得到矩阵乘积K^T V:

步骤2. 将M乘以这个张量(注意不是元素级乘法)。M乘以Z沿着洋红色轴t的每个"列"。

这正好得到:

将这个结果记为H。接下来只需要将所有内容乘以q,这将在接下来的两个步骤中完成。

步骤3a. 取Q并与H的每个j = const层进行元素级乘法:

这将得到:

这一步骤需要O(Ld²)的时间和内存复杂度。

步骤3b. 沿i轴对结果张量求和:

这一步骤同样需要O(Ld²)的时间和内存复杂度。最终得到了所需的结果:

在这个过程中,最关键的是第二步,我们故意省略了其复杂度分析。一个简单的估计是:

每次矩阵乘法需要O(L²)的复杂度,重复d²次

这将导致一个巨大的O(L²d²)复杂度。但是我们的目标是选择特殊的M,使得将M乘以一个向量的复杂度为O(RL),其中R是某个不太大的常数

例如如果M是0-1因果矩阵,那么与它相乘实际上就是计算累积和,这可以在O(L)时间内完成。但还存在许多其他具有快速向量乘法特性的结构化矩阵选项。

在下一节中将讨论这种矩阵类型的一个重要例子——半可分离矩阵,它与状态空间模型有着密切的联系。

半可分离矩阵与状态空间模型

让我们回顾一下(离散化的)状态空间模型(SSM)的定义。SSM是一类连接1维输入x_t、r维隐藏状态h_t和1维输出u_t的序列模型,其数学表达式如下:

在离散形式中,SSM本质上是一个带有跳跃连接的复杂线性RNN。为了简化后续讨论,我们甚至可以通过设置D_t = 0来忽略跳跃连接。

让我们将SSM表示为单个矩阵乘法:

其中

M是一个下三角矩阵,类似于我们之前讨论的注意力掩码。

这种类型的矩阵具有一个重要的优势:

一个L × L的下三角矩阵,如果其元素可以以这种方式表示,则可以使用O(rL)的内存存储,并且具有O(rL)的矩阵-向量乘法复杂度,而不是默认的O(L²)。

这意味着每个状态空间模型都对应一个结构化的注意力掩码M,可以在具有线性化注意力的高效Transformer模型中使用。

即使没有周围的查询-键-值机制,半可分离矩阵M本身已经相当复杂和富有表现力。它本身可能就是一个掩码注意力机制。我们将在下一节中详细探讨这一点。

状态空间对偶性

在这里,我们将介绍Mamba 2论文中的一个核心结果。

让我们再次考虑y = Mu,其中u = u(x)是输入的函数,M是一个可分离矩阵。如果我们考虑一个非常特殊的情况,其中每个A_t都是一个标量矩阵:A_t = a_t I。在这种情况下公式变得特别简单:

这里的

只是一个标量。还可以将C_i和B_i堆叠成矩阵B和C,使得:

现在我们还需要定义矩阵

然后就可以很容易地验证:

这个表达式是否看起来很熟悉?这实际上是一个掩码注意力机制,其中:

  • G作为掩码
  • C作为查询矩阵Q
  • B作为转置的键矩阵K^T
  • u作为值矩阵V

在经典的SSM中,B和C是常量。但在Mamba模型中,它们被设计为依赖于数据,这进一步强化了与注意力机制的对应关系。这种特定状态空间模型与掩码注意力之间的对应关系在Mamba 2论文中被称为状态空间对偶性

进一步探索

使用矩阵混合器而不是更复杂的架构并不是一个全新的idea。一个早期的例子是是MLP-Mixer,它在计算机视觉任务中使用MLP而不是卷积或注意力来进行空间混合。

尽管当前研究主要集中在大语言模型(LLM)上,但也有一些论文提出了用于编码器模型的非Transformer、矩阵混合架构。例如:

  1. 来自Google研究的FNet,其矩阵混合器M基于傅里叶变换。
  2. Hydra,除了其他创新外,还提出了半可分离矩阵在非因果(非三角)工作模式下的适应性方案。

总结

本文深入探讨了Transformer、循环神经网络(RNN)和状态空间模型(SSM)之间的潜在联系。文章首先回顾了传统的掩码自注意力机制,然后引入了线性化注意力的概念,解释了其计算效率优势。接着探讨了注意力掩码的优化,引入了半可分离矩阵的概念,并阐述了其与状态空间模型的关系。最后介绍了状态空间对偶性,揭示了特定状态空间模型与掩码注意力之间的对应关系。通过这些分析,展示了看似不同的模型架构之间存在深层联系,为未来模型设计和跨架构思想交流提供了新的视角和可能性。

https://avoid.overfit.cn/post/cc1b1bb7816b412790e9224484cd5b56

作者:Stanislav Fedotov

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2118047.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

红队C2工具Sliver探究与免杀

吉祥知识星球http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247485367&idx1&sn837891059c360ad60db7e9ac980a3321&chksmc0e47eebf793f7fdb8fcd7eed8ce29160cf79ba303b59858ba3a6660c6dac536774afb2a6330&scene21#wechat_redirect 《网安面试指南》…

【QCA(定性比较分析)组态研究】01 基础入门

【目录】 1.理论入门1.1和个案分析的区别1.2 QCA的特点因果非对称:殊途同归:1.3 什么时候用到QCA2.QCA的一般步骤3.QCA论文精读1.理论入门 QCA(定性比较分析)是一种探索性研究方法,旨在通过系统地比较不同案例的条件组合,识别出影响结果的因果关系。它结合了定性和定量分…

HarmonyOS Next系列之实现一个左右露出中间大两边小带缩放动画的轮播图(十二)

系列文章目录 HarmonyOS Next 系列之省市区弹窗选择器实现(一) HarmonyOS Next 系列之验证码输入组件实现(二) HarmonyOS Next 系列之底部标签栏TabBar实现(三) HarmonyOS Next 系列之HTTP请求封装和Token…

ASP.NET Core 中间件

一、什么是中间件? 中间件 是一种装配到 ASP.NET Core 应用程序请求处理管道中的软件组件,用于处理 HTTP 请求和响应。 每个中间件组件可以: 选择是否将请求传递到下一个中间件:通过调用 next() 或者不调用 next() 来决定是否将…

HTML5中的数据存储sessionStorage、localStorage

第8章 HTML5中的数据存储 之前通常使用Cookie存储机制将数据保存在用户的客户端。 H5增加了两种全新的数据存储方式:Web Stroage和Web SQL Database. 前者用于临时或永久保存客户端少量数据,后者是客户端本地化的一套数据库系统。 8.1 Web Storage存…

日本“大米荒”持续!政府再次拒绝投放储备米

KlipC报道:日本多地从7月开始出现“大米荒”,有部分新米上市,但是许多超市的大米仍然存在断购或限购的情况,并且部分新米价格上涨至去年同期的两倍。大阪府官员再次呼吁日本中央政府尽快投放储备米以缓解供应紧张,但遭…

Dynamics CRM Ribbon Workbench-the solution contains non-entity components

今天在一个低版本的环境里准备用Ribbon Workbench去编辑一个按钮时,遇到了如下错误 一开始没当回事,以为是我的解决方案问题,去检查了下,只有一个组件,并且哪怕我把组件换成了某个实体也不行,尝试了其他任何…

开源NAS系统-OpenMediaVault(OMV)共享存储网盘搭建和使用(保姆级教程)

1、OpenMediaVault简介 OpenMediaVault,简称:OMV,是由原 FreeNAS 核心开发成员 Volker Theile 发起的基于 Debian Linux 的开源 NAS 操作系统,主要面向家庭用户和小型办公环境。 OpenMediaVault是一款基于Debian Linux的开源网络附加存储(NAS)操作系统,它提供了强大的存…

酒店智能轻触开关:智慧化的创新实践

在追求高品质住宿体验的今天,酒店智能轻触开关作为智慧酒店建设的关键一环,正逐步成为提升酒店服务品质、优化运营效率、增强顾客满意度的有力工具。本文将深入探讨酒店智能轻触开关如何助力酒店实现智慧化管理,以及它所带来的多重变革。 一、…

大模型时代下,nlp初学者需要怎么入门?

前言 自从 ChatGPT 横空出世以来,自然语言处理(Natural Language Processing,NLP)研究领域就出现了一种消极的声音,认为大模型技术导致 NLP “死了”。 有人认为 NLP 的市场肯定有,但 NLP 的研究会遇到麻…

图片产生3D模型

HyperHuman 上传图片,点击生成 可以多生成几次,点击应用 让效果再好一点 生成完成之后可以导出为fbx格式

实战|等保2.0 Oracle数据库测评过程

以下等保测评过程以Oracle 11g为例,通过PL/SQL进行管理,未进行任何配置、按照等保2.0标准,2021报告模板,三级系统要求进行测评。 一、身份鉴别 a) 应对登录的用户进行身份标识和鉴别,身份标识具有唯一性,…

E212: Can‘t open file for writing

如图 1. 查看当前用户的用户名和所属组 如果你只想查看当前登录用户的用户名和所属组,可以使用以下命令: whoami groups 检查文件和目录权限: ls -ld /private/var/log/wyhy ls -l /private/var/log/wyhy/market.log 修改文件权限&#…

RAKsmart美国大带宽服务器租用体验怎么样?

RAKsmart是一家提供全球服务器租用服务的知名供应商,其在美国的服务器产品种类多样,包括大带宽服务器、多IP站群服务器以及高防御服务器等,以适应不同业务的需求。rak小编为您整理发布。 下面是对RAKsmart美国大带宽服务器租用的具体介绍&…

计算机毕业设计 大学志愿填报系统 Java+SpringBoot+Vue 前后端分离 文档报告 代码讲解 安装调试

🍊作者:计算机编程-吉哥 🍊简介:专业从事JavaWeb程序开发,微信小程序开发,定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事,生活就是快乐的。 🍊心愿:点…

随机分类,保持均衡水平Python

1、目的: 10000个样本有4个指标,按照逾期金额分10类,确保每类别逾期金额均衡。 2、数据: 3、思路: 将10000个样本按照逾期金额排序, 等距分箱为2500个类别 增加一列随机数 根据类别和随机数升序排列 增加…

MCU6.用keil新建项目

1.新建项目 打开keil4 2.选择单片机的类型 STC并没有出现在其中,但兼容8051芯片,选Atmel的AT89C51或AT89C52均可 本文选AT89C52 弹出的窗口点否 3.查看项目 4.新建文件 5.保存文件 6.将文件添加到工程 双击Source Group 1 点击Add 7.添加已有的工程 如果要添加已有的工程 8…

Java并发编程实战 09 | 为什么需要

什么是守护线程? 守护线程(Daemon Thread)是Java中的一种特殊线程,那么相对于普通线程它有什么特别之处呢? 在了解守护线程之前,我们先来思考一个问题:JVM在什么情况下会正常退出?…

腾讯公众号种类这么多,为什么小程序能脱颖而出

在微信公众平台中,公众号和小程序是两种不同的功能实体,它们各自承担着不同的角色和使命。然而,随着小程序的崛起,它在众多功能中逐渐脱颖而出,成为商家和开发者的新宠。具体分析如下: 技术优势与用户体验 …

深入探索协同过滤:从原理到推荐模块案例

文章目录 前言一、协同过滤1. 基于用户的协同过滤(UserCF)2. 基于物品的协同过滤(ItemCF)3. 相似度计算方法 二、相似度计算方法1. 欧氏距离2. 皮尔逊相关系数3. 杰卡德相似系数4. 余弦相似度 三、推荐模块案例1.基于文章的协同过…