【自然语言处理】【大模型】RWKV:基于RNN的LLM

news2024/10/7 12:24:52

相关博客
【自然语言处理】【大模型】RWKV:基于RNN的LLM
【自然语言处理】【大模型】CodeGen:一个用于多轮程序合成的代码大语言模型
【自然语言处理】【大模型】CodeGeeX:用于代码生成的多语言预训练模型
【自然语言处理】【大模型】LaMDA:用于对话应用程序的语言模型
【自然语言处理】【大模型】DeepMind的大模型Gopher
【自然语言处理】【大模型】Chinchilla:训练计算利用率最优的大语言模型
【自然语言处理】【大模型】大语言模型BLOOM推理工具测试
【自然语言处理】【大模型】GLM-130B:一个开源双语预训练语言模型
【自然语言处理】【大模型】用于大型Transformer的8-bit矩阵乘法介绍
【自然语言处理】【大模型】BLOOM:一个176B参数且可开放获取的多语言模型
【自然语言处理】【大模型】PaLM:基于Pathways的大语言模型
【自然语言处理】【chatGPT系列】大语言模型可以自我改进
【自然语言处理】【ChatGPT系列】FLAN:微调语言模型是Zero-Shot学习器
【自然语言处理】【ChatGPT系列】ChatGPT的智能来自哪里?

RWKV:基于RNN的LLM

​ 基于Transformer的LLM已经取得了巨大的成功,但是其在显存消耗和计算复杂度上都很高。RWKV是一个基于RNN的LLM,其能够像Transformer那样高效的并行训练,也能够像RNN那样高效的推理。

一、背景知识

1. RNN

​ RNN是指一类神经网络模型结构,其中最具有代表性的是LSTM:
f t = σ g ( W f x t + U f h t − 1 + b f ) i t = σ g ( W i x t + U i h t − 1 + b i ) o t = σ g ( W o x t + U o h t − 1 + b o ) c ~ t = σ c ( W c x t + U c h t − 1 + b c ) c t = f t ⊙ c t − 1 + i t ⊙ c ~ t h t = o t ⊙ σ h ( c t ) \begin{align} f_t&=\sigma_g(W_fx_t+U_f h_{t-1}+b_f) \tag*{(1)} \\ i_t&=\sigma_g(W_ix_t+U_i h_{t-1}+b_i) \tag*{(2)} \\ o_t&=\sigma_g(W_ox_t+U_o h_{t-1}+b_o) \tag*{(3)} \\ \tilde{c}_t&=\sigma_c(W_cx_t+U_c h_{t-1}+b_c) \tag*{(4)} \\ c_t&=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t \tag*{(5)} \\ h_t&=o_t\odot\sigma_h(c_t) \tag*{(6)} \end{align} \\ ftitotc~tctht=σg(Wfxt+Ufht1+bf)=σg(Wixt+Uiht1+bi)=σg(Woxt+Uoht1+bo)=σc(Wcxt+Ucht1+bc)=ftct1+itc~t=otσh(ct)(1)(2)(3)(4)(5)(6)
其中, x t x_t xt是当前时间步的输入, h t − 1 h_{t-1} ht1是上一个时间步的隐藏状态,所有的 W W W U U U b b b都是可学习参数, σ \sigma σ表示 sigmoid \text{sigmoid} sigmoid函数。 f t f_t ft是“遗忘门”,用来控制前一个时间步上传递信息的比例; i t i_t it是“输入门”,用于控制当前时间步保留的信息比例; o t o_t ot是"输出门",用于产生最终的输出。

2. Transformers和AFT

​ Transformer是NLP中主流的一种模型架构,其依赖于注意力机制来捕获所有输入和输出tokens的关系:
Attn ( Q , K , V ) = softmax ( Q K ⊤ ) V (7) \text{Attn}(Q,K,V)=\text{softmax}(QK^\top)V \tag{7} \\ Attn(Q,K,V)=softmax(QK)V(7)
为了简洁,这里忽略了多头和缩放因子 1 d k \frac{1}{\sqrt{d_k}} dk 1 Q K ⊤ QK^\top QK是序列中每个token之间的成对注意力分数,其能够被分解为向量表示:
Attn ( Q , K , V ) t = ∑ i = 1 T e q t ⊤ k i ∑ i = 1 T e q t ⊤ k i v i = ∑ i = 1 T e q t ⊤ k i v i ∑ i = 1 T e q t ⊤ k i (8) \text{Attn}(Q,K,V)_t=\sum_{i=1}^T\frac{e^{q_t^\top k_i}}{\sum_{i=1}^T e^{q_t^\top k_i}}v_i=\frac{\sum_{i=1}^T e^{q_t^\top k_i}v_i}{\sum_{i=1}^T e^{q_t^\top k_i}}\tag{8} \\ Attn(Q,K,V)t=i=1Ti=1Teqtkieqtkivi=i=1Teqtkii=1Teqtkivi(8)
在AFT中,设计了一种注意力变体:
Attn + ( W , K , V ) t = ∑ i = 1 t e w t , i + k i v i ∑ i = 1 t e w t , i + k i (9) \text{Attn}^+(W,K,V)_t=\frac{\sum_{i=1}^t e^{w_{t,i}+k_i}v_i}{\sum_{i=1}^t e^{w_{t,i}+k_i}} \tag{9} \\ Attn+(W,K,V)t=i=1tewt,i+kii=1tewt,i+kivi(9)
其中, { w t , i } ∈ R T × T \{w_{t,i}\}\in R^{T\times T} {wt,i}RT×T是可学习的位置偏差,每个 w t , i w_{t,i} wt,i是一个标量。

​ 受AFT启发,在RWKV中的 w t , i w_{t,i} wt,i是一个乘以相对位置的时间衰减向量:
w t , i = − ( t − i ) w (10) w_{t,i}=-(t-i)w \tag{10} \\ wt,i=(ti)w(10)
其中, w ∈ ( R ≥ 0 ) d w\in (R_{\geq 0})^d w(R0)d d d d是通道数。这里需要 w w w是非负来保证 e w t , i ≤ 1 e^{w_{t,i}}\leq 1 ewt,i1并且每个信道随时间衰减。

二、RWKV(Receptance Weighted Key Value)

在这里插入图片描述

​ RWKV由一系列的基本Block组成,每个Block则由time-mixing block和channel-mixing block组成的(如上图所示)。
在这里插入图片描述

​ RWKV递归的形式可以看做是当前输入和前一个时间不输入的线性插值,如上图所示。

1. Time-mixing block

​ Time-mixing block的作用同Self-Attention相同,就是提供全局token的交互。细节如下:
r t = W r ⋅ ( μ r x t + ( 1 − u r ) x t − 1 ) k t = W k ⋅ ( μ k x t + ( 1 − u k ) x t − 1 ) v t = W v ⋅ ( μ v x t + ( 1 − μ v ) x t − 1 ) w k v t = ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i v i + e u + k t v t ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i + e u + k t o t = W o ⋅ ( σ ( r t ) ⊙ w k v t ) \begin{align} r_t&=W_r\cdot(\mu_rx_t+(1-u_r)x_{t-1}) \tag*{(11)} \\ k_t&=W_k\cdot(\mu_kx_t+(1-u_k)x_{t-1}) \tag*{(12)} \\ v_t&=W_v\cdot(\mu_vx_t+(1-\mu_v)x_{t-1}) \tag*{(13)} \\ wkv_t&=\frac{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}+e^{u+k_t}} \tag*{(14)} \\ o_t&=W_o\cdot(\sigma(r_t)\odot wkv_t) \tag*{(15)} \end{align} \\ rtktvtwkvtot=Wr(μrxt+(1ur)xt1)=Wk(μkxt+(1uk)xt1)=Wv(μvxt+(1μv)xt1)=i=1t1e(t1i)w+ki+eu+kti=1t1e(t1i)w+kivi+eu+ktvt=Wo(σ(rt)wkvt)(11)(12)(13)(14)(15)
所有的 μ \mu μ W W W都是可训练参数, r t r_t rt k t k_t kt v t v_t vt是当前输入 x t x_t xt和上一个时间步输入 x t − 1 x_{t-1} xt1的加权投影。

公式(14)中, w w w u u u是可训练参数,分子的第一项 ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i v i \sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i i=1t1e(t1i)w+kivi表示前 t − 1 t-1 t1步的加权结果, − ( t − 1 − i ) w + k i -(t-1-i)w+k_i (t1i)w+ki是随相对距离逐步衰减; e u + k t v t e^{u+k_t}v_t eu+ktvt则是当前时间步的结果。

公式(15)中,则通过 σ ( r t ) \sigma(r_t) σ(rt)控制最终输出的比例。

2. Channel-mixing block

​ Channel-mixing block类似于Transformer中的FFN部分,细节如下:
r t = W r ⋅ ( μ r x t − ( 1 − μ r ) x t − 1 ) k t = W k ⋅ ( μ k x t − ( 1 − μ k ) x t − 1 ) o t = σ ( r t ) ⊙ ( W v ⋅ max ⁡ ( k t , 0 ) 2 ) \begin{align} r_t&=W_r\cdot(\mu_rx_t-(1-\mu_r)x_{t-1}) \tag*{(16)} \\ k_t&=W_k\cdot(\mu_kx_t-(1-\mu_k)x_{t-1}) \tag*{(17)} \\ o_t&=\sigma(r_t)\odot(W_v\cdot\max(k_t,0)^2) \tag*{(18)} \\ \end{align} \\ rtktot=Wr(μrxt(1μr)xt1)=Wk(μkxt(1μk)xt1)=σ(rt)(Wvmax(kt,0)2)(16)(17)(18)

三、并行训练和序列解码

​ RWKV可以类似Transformer那样高效的并行。设batch size为B、seq_length为T、channels为d,计算量主要来自于矩阵乘法 W □ , □ ∈ { r , k , v , o } W_\square,\square\in \{r,k,v,o\} W,{r,k,v,o},单层的时间复杂度为 O ( B T d 2 ) O(BTd^2) O(BTd2)。此外,更新注意力分数 w k v t wkv_t wkvt需要顺序扫描,其时间复杂度为 O ( B T d ) O(BTd) O(BTd)。矩阵乘法可以像Transformer那样并行,但是WKV的计算是依赖时间步的,所以只能在其他维度上并行。

​ RWKV具有类似RNN的结构,解码时将 t t t步的输出作为 t + 1 t+1 t+1步的输入。相比于自注意力机制随着序列长度,计算复杂度呈平方次增长,RWKV则是与序列长度呈线性关系。因此,RWKV能够更高效的处理更长的序列。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL数据库详解 三:索引、事务和存储引擎

文章目录 1. 索引1.1 索引的概念1.2 索引的作用1.3 如何实现索引1.4 索引的缺点1.5 建立索引的原则依据1.6 索引的分类和创建1.6.1 普通索引1.6.2 唯一索引1.6.3 主键索引1.6.4 组合索引1.6.5 全文索引 1.7 查看索引1.8 删除索引 2. 事务2.1 事务的概念2.2 事务的ACID特性2.2.1…

Java 高频疑难问题系列一

​​​​​​​ 目录 ​编辑​​​​​​​ 1.零长度 2.redis的有序集的排序 3.Unsafe类 4.带资源的try语句 5.Spring如何实现计划任务 6.Java中普通代码块,构造代码块,静态代码块执行顺序 7.MyBatis缓存机制 8.Redis Java 2种类型操作转换 9.CAS底层原理和问题 1…

【数据分享】2006-2021年我国城市级别的市容环境卫生相关指标(20多项指标)

《中国城市建设统计年鉴》中细致地统计了我国城市市政公用设施建设与发展情况,在之前的文章中,我们分享过基于2006-2021年《中国城市建设统计年鉴》整理的2006—2021年我国城市级别的市政设施水平相关指标、2006-2021年我国城市级别的各类建设用地面积数…

【pytorch】模型常用函数(conv2d、linear、loss、maxpooling等)

1、二维卷积函数——cnv2d(): in_channels (int): 输入通道数 out_channels (int): 输出通道数 kernel_size (int or tuple): 卷积核大小 stride (int or tuple, optional): 步长 Default: 1 padding (int, tuple or str, optional): 填充 Default: 0 padding_mode (str, optio…

计算机是如何工作的下篇

操作系统(Operating System ) 操作系统是一组做计算机资源管理的软件的统称。目前常见的操作系统有:Windows系列、Unix系列、Linux系列、OSX系列、Android系列、iOS系列、鸿蒙等. 操作系统由两个基本功能: 对下,要管理硬件设备. 对上,要给…

数据标注赋能机器学习进行内容审核

数据标注一直以来都是人工智能的基础,是机器学习得以训练的不可或缺的步骤。随着互联网的兴起,如何创建和维护一个健康的网络环境将成为互联网平台不断解决的问题,但对于与日俱增的用户增长和铺天盖地的网络信息,人工审核内容变得…

【牛客网】BC146 添加逗号

一.题目描述 牛客网题目链接:添加逗号_牛客题霸_牛客网 描述: 对于一个较大的整数 N(1<N<2,000,000,000) 比如 980364535&#xff0c;我们常常需要一位一位数这个数字是几位数&#xff0c;但是如果在这 个数字每三位加一个逗号&#xff0c;它会变得更加易于朗读。 因此&a…

指针扩展之——函数指针

前言&#xff1a;小伙伴们好久不见&#xff0c;本篇文章我们继续讲解一个指针的扩展——函数指针。 一.何为函数指针 我们通过对指针的学习已经知道&#xff0c;凡是叫什么什么指针的&#xff0c;都是指指向这个东西的指针。 所以所谓函数指针&#xff0c;也就是指向函数的指…

001 linux 导学

前言 本文建立在您已经安装好linux环境后&#xff0c;本文会向您介绍Shell的一些常用指令 什么是linux Linux是一种自由和开放源代码的类UNIX操作系统&#xff0c;该操作系统的内核由林纳斯托瓦兹在1991年首次发 布&#xff0c;之后&#xff0c;在加上用户空间的应用程序之后…

TypeScript 从入门到进阶之基础篇(一) ts类型篇

系列文章目录 文章目录 系列文章目录前言一、安装必要软件二、TypeScript 基础类型1.基础类型之 数字类型 number2.基础类型之 字符串类型 string3.基础类型之 布尔类型 boolean4.基础类型之 空值类型 void5.基础类型之 null 、undefined类型6.基础类型之 任意类型 any &#x…

Dell戴尔笔记本电脑灵越系列Inspiron 5598原厂Windows10系统2004

戴尔灵越原装出厂系统自带显卡、声卡、蓝牙、网卡等所有驱动、出厂主题壁纸、系统属性戴尔专属LOGO标志、Office办公软件、MyDell等预装程序 链接&#xff1a;https://pan.baidu.com/s/1VYUa7u0-Az4c9bOnWV9GZQ?pwd550m 提取码&#xff1a;550m

常见的查找算法以及分块搜索算法的简明教程

顺序查找 最基本的查找算法 举例 // 顺序查找public static int searchSequence(int[] arr, int target) {int i 0;for (int arr2 : arr) {if (arr2 target) {return i;}i;}return -1;}二分查找 [! warning] 值得注意的是这个二分查找算法只对无重复元素的递增或递减的数组有…

常用的辅助类(必会)

1.CountDownLatch package com.kuang.add;import java.util.concurrent.CountDownLatch;//计数器 减法 public class CountDownLatchDemo {public static void main(String[] args) throws InterruptedException {//总数是6&#xff0c;必须要执行任务的时候&#xff0c;再使用…

ARM接口编程—ADC(exynos 4412平台)

ADC简介 ADC ADC(Analog to Digital Converter)即模数转换器&#xff0c;指一个能将模拟信号转化为数字信号的电子元件 ADC主要参数 分辨率 ADC的分辨率一般以输出二进制数的位数来表示&#xff0c;当最大输入电压一定时&#xff0c;位数越高&#xff0c;分辨率越高&#xf…

全国职业技能大赛云计算--高职组赛题卷①(私有云)

全国职业技能大赛云计算--高职组赛题卷①&#xff08;私有云&#xff09; 第一场次题目&#xff1a;OpenStack平台部署与运维任务1 基础运维任务&#xff08;5分&#xff09;任务2 OpenStack搭建任务&#xff08;15分&#xff09;任务3 OpenStack云平台运维&#xff08;15分&am…

LC142. 环形链表 II

题目大意 给你一个链表&#xff0c;要求判断是否有环&#xff0c;若有环&#xff0c;找出环的入口结点。 142. 环形链表 II 判断是否有环 判环比较简单&#xff0c;用一个一次走一个结点的快指针&#xff0c;和一个一次走一个结点的慢指针同时遍历链表&#xff0c;若两指针相…

第一个Three的demo实例

Three的第一个Demo 前言效果图1、导入threejs2、创建场景3、创建相机4、创建渲染器5、创建几何图形6、创建材质7、创建网格8、将网格添加到场景中9、设置相机的位置10、渲染11、整体代码 前言 创建第一个demo实例—旋转的方格 效果图 1、导入threejs import * as THREE from…

漏洞赏金猎人开源工具集合,自动辅助渗透测试工具

漏洞赏金猎人开源工具集合&#xff0c;自动辅助渗透测试工具。 公开收集的一个国外白帽子用的比较多的开源工具列表 这是一款半自动渗透测试的工具&#xff0c;当前版本多用于渗透测试的信息搜集&#xff0c;每周保持更新&#xff0c;最终的目标是类似于linpeas的全自动渗透测…

直接插入排序、希尔排序详解。及性能比较

直接插入排序、希尔排序详解。及性能比较 一、 直接插入排序1.1 插入排序原理1.2 代码实现1.3 直接插入排序特点总结 二、希尔排序 ( 缩小增量排序 )2.1 希尔排序原理2.2 代码实现2.3 希尔排序特点总结 三、直接插入排序和希尔排序性能大比拼 !!!3.1 如何对比性能&#xff1f;准…

4款视频号数据分析平台!

很多人在做视频号的时候就会有创作参考的需求&#xff0c;那么你们知道视频号中有哪些数据平台&#xff1f;今天就和大家来分享一下 接下来就总结一下视频号数据平台有哪些&#xff1f;排名不分前后。 1&#xff1a;视频号助手&#xff08;channels.weixin.qq.com&#xff09…