自称超越Transformer的新一代大模型RWKV是什么

news2024/9/24 17:21:25

论文地址:arxiv.org/pdf/2305.13048v2

项目地址:github

论文题目为:《RWKV: Reinventing RNNs for the Transformer Era》

自 Vaswani 等人于 2017 年首次提出 Attention Is All You Need 之后,基于 transformer 的强大的模型一直在不断地涌现,它们在 NLP 相关任务上的表现远远超过基于 RNN (Recurrent Neural Networks, 递归神经网络) 的 SoTA 模型,甚至多数认为 RNN 已死。而本文将介绍一个集 RNN 和 transformer 两者的优势于一身的全新网络架构 –RWKV!现已在 HuggingFace transformers 库中支持。

1 背景与动机

1.1 背景:简单介绍RNN和Transformer

RNN(Recurrent Neural Network,循环神经网络)

RNN(Recurrent Neural Network,循环神经网络)是一种专门用于处理序列数据的神经网络。它能够处理前后数据点之间的依赖关系,这使得RNN特别适合于处理时间序列数据或者任何形式的序列,如文本、语音或视频。RNN的核心特点是它具有循环连接,这允许网络在处理序列的每个元素时保持一定的“记忆”。

RNN的工作原理是,它在序列的每个时间步骤接收输入,并产生输出,同时将一部分信息传递到下一个时间步骤。这种传递的信息通常被称为“隐藏状态”(hidden state),它能够捕捉序列中过去的信息。RNN的这种结构使得它能够处理可变长度的序列,并且能够处理长期依赖关系。

Token

Token在自然语言处理(NLP)中是一个重要的概念,它指的是文本中的一个基本单位。在不同的上下文中,token可以有不同的含义:

  1. 单词或词汇单元:在一些NLP任务中,token可以是一个单词或者一个字符,这是文本分析的最小单位。
  2. 子词单元:在一些现代的NLP模型中,为了更好地处理词汇的变体和形态变化,会将单词进一步分割成更小的单元,这些单元被称为subword tokens。
  3. 标记化:在文本处理中,将原始文本转换成token序列的过程称为标记化(tokenization)。这是文本预处理的重要步骤,它使得机器学习模型能够理解和处理文本数据。

在RNN中,序列数据通常会先被标记化,转换成一系列的token,然后这些token被用作RNN的输入。RNN通过处理这些token序列,学习序列中的模式和依赖关系,进而用于各种NLP任务,如语言建模、机器翻译、情感分析等。

Transformer

Transformer 是一种深度学习模型,由 Vaswani 等人在 2017 年的论文《Attention Is All You Need》中首次提出。它主要用于处理序列数据,尤其在自然语言处理(NLP)领域取得了革命性的进展。Transformer 模型完全基于注意力机制(Attention Mechanism),摒弃了之前序列模型中常用的循环神经网络结构。

主要特点

  1. 自注意力机制(Self-Attention):Transformer 通过自注意力机制使模型能够在序列中的每个位置都同时考虑其他位置,这有助于捕捉序列内部的长距离依赖关系。自注意力机制的核心是计算序列中每个元素对其他所有元素的注意力分数,然后根据这些分数对元素进行加权求和。

  2. 并行化处理:由于自注意力机制不依赖于序列中元素之间的循环或递归调用,Transformer 可以高效地并行处理整个序列,这在传统的循环神经网络中是难以实现的。

  3. 编码器-解码器架构:标准的 Transformer 模型由编码器和解码器两个部分组成。编码器处理输入序列,解码器生成输出序列。两部分都由多个相同的层组成,每层都包含自注意力模块和前馈神经网络。

  4. 多头注意力(Multi-Head Attention):Transformer 通过多头注意力机制进一步提升模型的表达能力。它将查询(Query)、键(Key)和值(Value)通过不同的线性投影分割成多个头,然后并行计算每个头的注意力输出,最后将这些输出合并,提供更丰富的信息表示。

  5. 位置编码:由于 Transformer 本身无法捕捉序列中元素的顺序信息,因此需要加入位置编码来提供序列中每个元素的位置信息。位置编码通常是与输入嵌入相加的固定或可学习的向量。

  6. 层正规化(Layer Normalization)残差连接(Residual Connections):Transformer 在每个子层(自注意力层和前馈网络层)的输出上应用层正规化,并使用残差连接,有助于避免深层网络中的梯度消失问题,使得深层网络的训练成为可能。

1.2 动机

部分内容来自拥抱脸的介绍:RWKV -- transformer 与 RNN 的强强联合 (huggingface.co)

由于 RNN 在计算每一时刻的预测值时使用的都是同一组网络权重,因此 RNN 很难解决长距离序列信息的记忆问题,这一定程度上也是训练过程中梯度消失导致的。为解决这个问题,相继有新的网络架构被提出,如 LSTM 或者 GRU,其中 transformer 是已被证实最有效的架构。

在 transformer 架构中,不同时刻的输入 token 可以在 self-attention 模块中并行处理。首先 token 经过 Q、K、V 权重矩阵做线性变换投影到不同的空间,得到的 Q、K 矩阵用于计算注意力分数 (通过 softmax,如下图所示),然后乘以 V 的隐状态得到最终的隐状态,这种架构设计可以有效缓解长距离序列问题,同时具有比 RNN 更快的训练和推理速度。

但是:

  • Transformer的局限性:尽管Transformer在自然语言处理(NLP)任务中取得了革命性的进展,但其自注意力机制的计算复杂度随着序列长度呈二次方增长,这在处理长序列时会导致显著的内存和计算负担。
  • RNN的局限性:RNN虽然在内存和计算需求上呈线性增长,但因为难以并行处理和可扩展性差,通常无法达到与Transformer相同的性能。

因此提出RWKV,RWKV 的灵感来自于 Apple 公司的 Attention Free Transformer。RWKV 该架构经过精心简化和优化,可以转换为 RNN。除此此外,为使 RWKV 性能媲美 GPT,还额外使用了许多技巧,例如 TokenShift 和 SmallInitEmb (使用的完整技巧列表在 官方 GitHub 仓库的 README 中 说明)。对于 RWKV 的训练,现有的项目仓库可以将参数量扩展到 14B,并且迭代修了 RWKV-4 的一些训练问题,例如数值不稳定性等。

2 RWKV架构

2.1 线性注意力机制

RWKV 模型架构与经典的 transformer 模型架构非常相似 (例如也包含 embedding 层、Layer Normalization、用于预测下一 token 的因果语言模型头、以及多个完全相同的网络层等),唯一的区别在于注意力层,它与传统的 transformer 模型架构完全不同,因此 RWKV 的注意力计算公式也不一样。

线性注意力机制是Transformer模型中自注意力机制的一个变体,旨在减少计算复杂度,特别是针对序列长度的二次方增长问题。在标准的Transformer模型中,自注意力的计算复杂度是O(T^2d),其中T是序列长度,d是特征维度。这种计算复杂度在处理长序列时会迅速变得不可行。线性注意力机制通过将复杂度降低到O(Td),使得模型能够更高效地处理长序列。

基本原理

线性注意力机制的核心思想是将传统的点积注意力(dot-product attention)替换为一种更高效的计算方式,同时保持对序列中各个元素间关系的捕捉能力。在点积注意力中,每个元素对其他所有元素的注意力是通过计算它们的点积并应用softmax函数来实现的。而在线性注意力中,这种计算被替换为一种更直接的方法。

计算过程

  1. 键向量和查询向量的变换:首先,对于序列中的每个元素,我们将其表示为查询(Q)、键(K)和值(V)向量。这些向量可以通过输入序列的线性变换得到。

  2. 注意力分数的计算:在传统的自注意力中,注意力分数是通过计算查询和所有键的点积得到的。在线性注意力中,我们使用一种线性化的方法来近似这种点积。一种常见的方法是使用一个可学习的权重向量w来与键向量进行点积,然后将结果与查询向量进行点积,以此来模拟传统的点积注意力:

    Attention(Q,K,V) =softmax(\frac{QK^T}{\sqrt{d_k}})

    在线性注意力中,这个计算可以被近似为:

    Attention(Q,K,V) \approx softmax(QWK^T)V

    其中W是一个可学习的权重矩阵。

  3. 值向量的加权求和:计算完注意力分数后,我们使用这些分数对值向量进行加权求和,得到最终的输出。

优点

  • 计算效率:线性注意力机制显著降低了计算复杂度,从O(T^2d)降低到O(Td),使得模型能够更高效地处理长序列。
  • 内存效率:由于计算复杂度的降低,线性注意力也减少了内存的使用,这对于大规模的模型和长序列尤为重要。

缺点

  • 精度损失:由于线性化近似,线性注意力可能会损失一些精度,尤其是在捕捉序列中复杂依赖关系时。
  • 灵活性限制:相比于传统的自注意力机制,线性注意力在模拟不同元素间复杂交互的能力上可能有所限制。

2.2 模型公式化

RWKV可以被公式化为Transformer或RNN,这使得它在训练时能够并行化计算,并在推理时保持线性复杂度。

模型公式化的组成

  1. Receptance (R): 接收向量,用于捕捉和存储过去的信息。
  2. Weight (W): 位置权重衰减向量,一个可训练的参数,用于模拟时间衰减。
  3. Key (K): 键向量,在传统的注意力机制中用于与查询向量计算关系分数。
  4. Value (V): 值向量,在注意力机制中用于与计算得到的权重相乘,生成输出。

公式化过程

RWKV模型的核心是其独特的注意力机制,即WKV操作符,它将传统的点积注意力替换为一种线性注意力形式。以下是RWKV模型中一些关键的公式化步骤:

  1. 时间混合 (Time Mixing):

    • 时间混合层通过RWKV操作符结合了时间维度上的混合,允许模型在序列的不同时间步之间传递信息。
    • 公式化可以表示为:WKV_t = \sum_{i=1}^{t-1}e^{-w(t-i)}K_i \cdot V_i +e^u K_t \cdot V_t
    • 其中,w 是时间衰减因子,u 是当前时间步的加权因子。
  2. 通道混合 (Channel Mixing):

    • 通道混合层则处理特征维度上的混合,允许模型在不同特征通道之间共享信息。
    • 公式化可以表示为:CWKV_t = max(R_t' \cdot K_t\ ,0)^2
    • 其中,Rt′和 Kt′分别是通道混合层的接收向量和键向量。
  3. 输出门控 (Output Gating):

    • 输出门控通过sigmoid函数控制信息流,增强模型对信息的选择性传递。
    • 公式化可以表示为:O_t = W_o \cdot \sigma (R_t) \cdot WKV_t
    • 其中,σσ 表示sigmoid函数,WoWo​ 是输出权重。
  4. 序列计算 (Sequential Computation):

    • RWKV模型在序列的每个时间步上递归地计算上述操作,从而实现序列的动态处理。
    • 公式化可以表示为:H_t = LayerNorm(H_{t-1}+O_t)
    • 其中,HtHt​ 是第t步的隐藏状态,LayerNorm是层归一化操作。
  1. 传统的 RNN 模型无法并行训练,而 RWKV 更像一个 “线性 GPT”,因此比 GPT 训练得更快。
  2. 传统的 RNN 模型无法利用很长距离的上下文信息 (LSTM 用作语言模型时也只能有效处理约 100 个 token),而 RWKV 可以处理数千个甚至更多的 token

2.3 参数规模

研究者们将RWKV模型的参数规模扩展到140亿,这是迄今为止训练的最大密集RNN,并且发现其性能与同样规模的Transformer相当。

性能与效率

3.1 基准测试

论文通过在多个NLP任务上的测试,展示了RWKV在大规模模型上的性能和效率。

3.2 预训练模型

作者发布了从1.69亿到140亿参数的预训练模型,并在Pile数据集上进行了训练

技术细节

4.1 时间混合和通道混合

RWKV模型由堆叠的残差块组成,每个块包含时间混合和通道混合子块,这些子块利用过去的信息。

一个RWKV块(左)和完整的RWKV剩余块内的元素,配备了一个用于语言建模的最终头部

4.2 RWKV操作符

模型中的WKV操作符与传统的注意力机制相似,但通过相对位置和时间衰减向量来修改,以实现循环行为。

用于语言建模的RWKV架构

4.3 输出门控

在时间混合和通道混合块中使用sigmoid激活函数的输出门控来控制信息流。

训练与推理

5.1 训练阶段

  1. 并行化训练

    • RWKV模型利用Transformer架构的优势,实现训练过程中的并行化。这与传统的RNN不同,后者由于其递归性质,在训练时通常需要逐步迭代,难以实现并行处理。
    • 并行化处理可以显著加快训练速度,尤其是在处理大规模数据集时。
  2. 时间并行模式

    • 时间并行模式允许模型在训练时同时处理序列中的所有元素,这得益于RWKV的线性注意力机制,它不需要在时间步之间交换信息。
    • 这种模式下,模型的计算复杂度主要来自于矩阵乘法操作,这些操作可以很容易地在现代硬件上并行执行。
  3. 优化策略

    • 为了提高训练效率,RWKV模型采用了多种优化策略,包括自定义CUDA内核和小型初始化嵌入等。
    • 自定义CUDA内核可以针对特定的计算任务优化性能,而小型初始化嵌入有助于模型更快地收敛。

5.2 推理阶段

推理阶段是模型将学到的知识应用到新数据上,进行预测或决策的过程。RWKV模型在推理时采用以下策略:

  1. 序列化推理

    • 与训练阶段的并行化不同,RWKV在推理时采用序列化处理,这与RNN的处理方式类似。
    • 在序列化推理中,模型逐个处理序列中的元素,每个元素的输出依赖于之前的计算结果,这使得模型在处理长序列时具有线性的时间复杂度。
  2. 时间序列模式

    • 时间序列模式允许模型在推理时利用其RNN结构的优势,通过递归地更新内部状态来处理序列数据。
    • 这种模式下,模型可以有效地处理长序列,同时保持较低的内存和计算需求。
  3. 输出门控和状态更新

    • RWKV模型在每个时间步使用输出门控机制来控制信息的流动,这有助于模型在推理时更加专注于重要的信息。
    • 状态更新是RWKV模型推理的核心,模型通过更新其内部状态来捕捉序列中的长期依赖关系。
  4. 效率与性能

    • RWKV模型在推理时展现出了高效的性能,这得益于其线性复杂度的计算特性和优化的算法设计。
    • 这种设计使得RWKV模型在处理实际应用中的长序列数据时,既能保持较高的准确率,又能实现快速响应。

优化策略

6.1 自定义内核

  • 为了提高计算效率,特别是在执行WKV操作时,作者开发了自定义的CUDA内核。
  • 这些内核针对特定的计算任务进行了优化,利用GPU的并行处理能力,以加速模型的训练和推理过程。

6.2 小初始化嵌入

  • 在训练的初期阶段,作者采用了小型初始化嵌入的方法,即用较小的值初始化嵌入矩阵。
  • 这种方法有助于模型从初始状态快速收敛,因为它减少了初始阶段的噪声,并允许模型更平稳地开始学习过程。

6.3 时间并行模式

  • RWKV模型在训练时采用时间并行模式,这意味着模型可以同时处理序列中的所有元素。
  • 这种并行化处理减少了训练时间,因为它允许模型在多个时间步上并行执行计算,而不是逐个时间步顺序执行。

总结

  • RWKV为序列数据处理提供了一种新的高效且可扩展的架构,通过线性复杂度的注意力机制和有效的训练动态,展示了与传统Transformer相当的性能。

以上是读论文的内容,简单的记录下这一最新的网络,目前不少研究是基于该框架的,希望有多一个新的浪头 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2110473.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jmeter同步定时器、固定定时器、统一随机定时器、常数吞吐量定时器详解

一、同步定时器 可以让多个线程同时向服务器发送请求,实现瞬间并发(相当于现实中同步秒杀商品)类似于集合点 例如:10个人约定去旅游,出发前提前会在某一个地方等到10个人同时都到了约定地点之后再一同排队上车 在任意接口下添加同步定时器模…

C#基础(6)值类型和引用类型

前言 我们先前已经完成了数组相关的学习,今天我们就要来详细介绍一下数据类型了。 引用类型是指变量存储的是对象的引用或地址,而不是实际的数据。在引用类型中,变量存储的是指向对象的指针,通过这个指针可以访问对象的实际数据…

电阻负载柜的故障排除方法有哪些?如何解决常见问题?

电阻负载柜是电力系统中的重要设备,主要用于模拟实际负载,对电力设备进行测试和调试。然而,在使用过程中,可能会出现各种故障。以下是一些常见的电阻负载柜故障及其排除方法: 1. 电源无法启动:首先检查电源…

[数据集][目标检测]西红柿成熟度检测数据集VOC+YOLO格式3241张5类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):3241 标注数量(xml文件个数):3241 标注数量(txt文件个数):3241 标注…

Day 8 ~ 9: 队列

队列的原理 定义 定义:限制在两端进行插入操作和删除操作的线性表。 队尾:允许进行存入操作的一端。 对头:允许进行删除操作的一端。 特点 先进先出。 比如:食堂点餐,先进先出,银行叫号,先进先出…

ConfigBus

Config&Bus 构建server端 构建client端 config执行流程 配置git本地库 Server安全连接 Config集成eureka提升可用性 Config client快速检测 动态刷新配置 Bus 配置配置刷新的两种方式 消息通知解析 XXApplicationEvent类型共有属性 AckRemoteApplicationEvent 事件驱动模型…

浅谈SOC片上系统LoRa-STM32WLE5数据安全防御机制

随着物联网设备的普及,数以亿计的设备正在通过无线网络进行通信,传输大量的敏感数据。这种大规模的设备联网带来了便捷性,但也伴随着巨大的安全风险。SoC片上系统通过将无线通信、处理器、存储和安全机制集成在同一个芯片中,为物联…

启动spring boot项目时,第三方jar包扫描不到的问题

讲述一下遇到的问题: 在启动类Application上使用ComponentScan 这个注解来扫描第三方的包,然后就会出现报错。异常就是无法加载本地的bean,但是可以加载到第三方的bean; 了解过spring boot启动流程的都知道,Springboo…

kuka6轴机器人配置外部启动信号(学习记录,可能不对)

文档认为最重要的信号配置 我自己的信号配置(只配红框,输出部分有需要再添加) 外部启动的时序 有个点注意:外部启动后,为了“骗”BCO,需要在main程序的开头写上一段运动指令(走当前位置即可&…

python中的分支语句

注意: 在python中,每一个对象都有一个布尔值, >>>>>> True 或者 False >>>>>> 且只能判断 0 或者 1 举个例子: n % 2 :就是如果结果等于1 才会执行下一句, 所以要判断是偶数…

Kafka【十四】生产者发送消息时的消息分区策略

【1】分区策略 Kafka中Topic是对数据逻辑上的分类,而Partition才是数据真正存储的物理位置。所以在生产数据时,如果只是指定Topic的名称,其实Kafka是不知道将数据发送到哪一个Broker节点的。我们可以在构建数据传递Topic参数的同时&#xff…

GS-SLAM论文阅读笔记--LoopSplat

介绍 这篇文章看标题是解决GS-SLAM回环检测的,GS-SLAM回环检测之前文章很少,但他对于SLAM又很重要,确实值得阅读一番。而且这些作者的学校又是很厉害的。 文章目录 介绍1.背景介绍2.关键内容2.1 Gaussian Splatting SLAM2.2 Gaussian Splat…

C语言之联合体和枚举

目录 前言 一、联合体类型的声明 二、联合体的特点 三、联合体的大小计算 四、联合体的适用场景举例: 五、枚举类型的声明 六、枚举类型的优点 总结 前言 本文主要讲述C语言的两种结构类型:联合体和枚举。 ❤️感谢支持,点赞关注不迷路❤️ 一…

计算polydata相交

使用vtk.vtkBooleanOperationPolyDataFilter() 可以进行求交,差,并操作 并且可以填充交面,不会形成一个缺口 vtkBooleanOperationPolyDataFilter 计算由两个输入表面定义的体积计算出的并集、交集或差集的边界。 这两个表面不需要是流形的…

六,Spring Boot 容器中 Lombok 插件的详细使用,简化配置,提高开发效率

六,Spring Boot 容器中 Lombok 插件的详细使用,简化配置,提高开发效率 文章目录 六,Spring Boot 容器中 Lombok 插件的详细使用,简化配置,提高开发效率1. Lombok 介绍2. Lombok 常用注解2.1 ToString2.2 Se…

数字经济时代,零售企业如何实现以消费者为中心的数字化转型?

在数字经济时代,零售企业正面临着前所未有的挑战与机遇。随着消费者行为的数字化和多样化,传统的零售模式已难以满足市场需求。为了在激烈的市场竞争中立于不败之地,零售企业必须实现以消费者为中心的数字化转型。这一转型不仅仅是技术的升级…

[ios]准备好app后使用xcode发布ios操作

在app代码完成后,点击xcode进行发布

嵌入式开发学习路线(25届校招学习) 嵌入式学习路线七年规划:从大一小白到校招大佬 (学习路线汇总)

嵌入式开发学习路线(25届校招可以参考) 嵌入式系统作为当前最热门且最有发展前途的IT应用领域之一,吸引了大量有志于从事该行业的学习者。为了系统地掌握嵌入式开发技能,以下是一条详细的学习路线,旨在帮助初学者逐步…

算法:图片压缩算法【Z字行扫描】(Java实现)

要在Java中实现Z字形扫描,我们需要遍历一个给定的nn矩阵,并按照Z字形的顺序输出其元素。Z字形扫描的路径通常是从矩阵的左上角开始,沿着对角线方向交替向下和向上移动,直到遍历完整个矩阵。 下面是一个简单的Java实现示例&#xf…

不同vlan之间的通信方法

1.通过路由器的物理接口 1.给PC1,PC2配置IP地址,网关2.进入交换机配置vlan,交换机所有口都配置access口并绑定vlan3.配置路由器,进入路由器的两个接口配置网关IP和掩码缺点:成本高,每增加一个vlan就需要一个物理端口和…