RWKV – transformer 与 RNN 的强强联合

news2024/12/26 14:50:24

在 NLP (Natural Language Processing, 自然语言处理) 领域,ChatGPT 和其他的聊天机器人应用引起了极大的关注。每个社区为构建自己的应用,也都在持续地寻求强大、可靠的开源模型。自 Vaswani 等人于 2017 年首次提出 Attention Is All You Need 之后,基于 transformer 的强大的模型一直在不断地涌现,它们在 NLP 相关任务上的表现远远超过基于 RNN (Recurrent Neural Networks, 递归神经网络) 的 SoTA 模型,甚至多数认为 RNN 已死。而本文将介绍一个集 RNN 和 transformer 两者的优势于一身的全新网络架构——RWKV!现已在 HuggingFace transformers 库中支持。

RWKV 项目概览

RWKV 项目已经启动,由 Bo Peng 主导、贡献和维护。同时项目成员在官方 Discord 也开设了不同主题的讨论频道: 如性能 (RWKV.cpp、量化等),扩展性 (数据集收集和处理),相关研究 (chat 微调、多模态微调等)。该项目中训练 RWKV 模型所需的 GPU 资源由 Stability AI 提供。

读者可以加入 官方 discord 频道 了解详情或者参与讨论。如想了解 RWKV 背后的思想,可以参考这两篇博文:

  • https://johanwind.github.io/2023/03/23/rwkv_overview.html

  • https://johanwind.github.io/2023/03/23/rwkv_details.html

Transformer 与 RNN 架构对比

RNN 架构是最早广泛用于处理序列数据的神经网络架构之一。与接收固定输入尺寸的经典架构不同,RNN 接收当前时刻的 “token”(即数据流中的当前数据点) 和先前时刻的 “状态” 作为输入,通过网络预测输出下一时刻的 “token” 和  “状态”,同时输出的 “状态” 还能继续用到后续的预测中去,一直到序列末尾。RNN 还可以用于不同的 “模式”,适用于多种不同的场景。参考 Andrej Karpathy 的博客,RNN 可以用于: 一对一 (图像分类),一对多 (图像描述),多对一 (序列分类),多对多 (序列生成),等等。

01956ab1833a25c094215c04bb691f08.png

由于 RNN 在计算每一时刻的预测值时使用的都是同一组网络权重,因此 RNN 很难解决长距离序列信息的记忆问题,这一定程度上也是训练过程中梯度消失导致的。为解决这个问题,相继有新的网络架构被提出,如 LSTM 或者 GRU,其中 transformer 是已被证实最有效的架构。

在 transformer 架构中,不同时刻的输入 token 可以在 self-attention 模块中并行处理。首先 token 经过 Q、K、V 权重矩阵做线性变换投影到不同的空间,得到的 Q、K 矩阵用于计算注意力分数 (通过 softmax,如下图所示),然后乘以 V 的隐状态得到最终的隐状态,这种架构设计可以有效缓解长距离序列问题,同时具有比 RNN 更快的训练和推理速度。

ca3d4c95db415d05bb1173181210ba4d.png 56d3e820fa1a0b4ad9204297f960beef.png

在训练过程中,Transformer 架构相比于传统的 RNN 和 CNN 有多个优势,最突出的优势是它能够学到上下文特征表达。不同于每次仅处理输入序列中一个 token 的 RNN 和 CNN,transformer 可以单次处理整个输入序列,这种特性也使得 transformer 可以很好地应对长距离序列 token 依赖问题,因此 transformer 在语言翻译和问答等多种任务中表现非常亮眼。

在推理过程中,RNN 架构在推理速度和内存效率方面会具有一些优势。例如计算简单 (只需矩阵 - 向量运算) 、内存友好 (内存不会随着推理阶段的进行而增加),速度稳定 (与上下文窗口长度一致,因为 RNN 只关注当前时刻的 token 和状态)。

RWKV 架构

RWKV 的灵感来自于 Apple 公司的 Attention Free Transformer。RWKV 该架构经过精心简化和优化,可以转换为 RNN。除此此外,为使 RWKV 性能媲美 GPT,还额外使用了许多技巧,例如 TokenShift 和  SmallInitEmb (使用的完整技巧列表在 官方 GitHub 仓库的 README 中 说明)。对于 RWKV 的训练,现有的项目仓库可以将参数量扩展到 14B,并且迭代修了 RWKV-4 的一些训练问题,例如数值不稳定性等。

RWKV 是 RNN 和 Transformer 的强强联合

如何把 transformer 和 RNN 优势结合起来?基于 transformer 的模型的主要缺点是,在接收超出上下文长度预设值的输入时,推理结果可能会出现潜在的风险,因为注意力分数是针对训练时的预设值来同时计算整个序列的。

RNN 本身支持非常长的上下文长度。即使在训练时接收的上下文长度有限,RNN 也可以通过精心的编码,来得到数百万长度的推理结果。目前,RWKV 模型使用上下文长度上为 8192 ( ctx8192) 和  ctx1024 时的训练速度和内存需求均相同。

传统 RNN 模型的主要缺陷,以及 RWKV 是如何避免的:

  1. 传统的 RNN 模型无法利用很长距离的上下文信息 (LSTM 用作语言模型时也只能有效处理约 100 个 token),而 RWKV 可以处理数千个甚至更多的 token,如下图所示:

95a84561bddaea646dd85a48f209cf70.png
  1. 传统的 RNN 模型无法并行训练,而 RWKV 更像一个 “线性 GPT”,因此比 GPT 训练得更快。

通过将这两个优势强强联合,希望 RWKV 可以实现 “1 + 1 > 2” 的效果。

RWKV 注意力公式

RWKV 模型架构与经典的 transformer 模型架构非常相似 (例如也包含 embedding 层、Layer Normalization、用于预测下一 token 的因果语言模型头、以及多个完全相同的网络层等),唯一的区别在于注意力层,它与传统的 transformer 模型架构完全不同,因此 RWKV 的注意力计算公式也不一样。

本文不会对注意力层过多的介绍,这里推荐一篇 Johan Sokrates Wind 的博文,里面有对注意力层的分数计算公式等更全面的解释。

现有检查点

纯语言模型: RWKV-4 模型

大多数采用 RWKV 架构的语言模型参数量范围从 170M 到 14B 不等。据 RWKV 概述博文 介绍,这些模型已经在 Pile 数据集上完成训练,并进行了多项不同的基准测试,取得了与其他 SoTA 模型表现相当的性能结果。

8565c3be790c98580574562632b3998b.png

指令微调/Chat 版: RWKV-4 Raven

Bo 还训练了 RWKV 架构的 “chat” 版本: RWKV-4 Raven 模型。RWKV-4 Raven 是一个在 Pile 数据集上预训练的模型,并在 ALPACA、CodeAlpaca、Guanaco、GPT4All、ShareGPT 等上进行了微调。RWKV-4 Raven 模型有多个版本,如不同语言 (仅英文、英文 + 中文 + 日文、英文 + 日文等) 和不同大小 (1.5B 参数、7B 参数、14B 参数) 等。

所有 HF 版的模型都可以在 Hugging Face Hub 的 RWKV 社区主页 找到。

集成 🤗 Transformers 库

感谢这个 Pull Request 的贡献,RWKV 架构现已集成到 🤗 transformers 库中。在作者撰写本文之时,您已经可以通过从源代码安装 transformers 库,或者使用其 main 分支。RWKV 架构也会与 transformers 库一起更新,您可以像使用任何其他架构一样使用它。

下面让我们来看一些使用示例。

文本生成示例

要在给定 prompt 的情况下生成文本,您可以使用 pipeline:

from transformers import pipeline
model_id = "RWKV/rwkv-4-169m-pile"
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
pipe = pipeline("text-generation", model=model_id)
print(pipe(prompt, max_new_tokens=20))
>>> [{'generated_text': '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were'}]

或者可以运行下面的代码片段:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
prompt = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=20)
print(tokenizer.decode(output[0].tolist()))
>>> In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.\n\nThe researchers found that the dragons were able to communicate with each other, and that they were

使用 Raven 模型 (chat 模型) 示例

您可以以 alpaca 风格使用提示 chat 版模型,示例如下:

from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = "RWKV/rwkv-raven-1b5"
model = AutoModelForCausalLM.from_pretrained(model_id).to(0)
tokenizer = AutoTokenizer.from_pretrained(model_id)
question = "Tell me about ravens"
prompt = f"### Instruction: {question}\n### Response:"
inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=100)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))
>>> ### Instruction: Tell me about ravens
### Response: RAVENS are a type of bird that is native to the Middle East and North Africa. They are known for their intelligence, adaptability, and their ability to live in a variety of environments. RAVENS are known for their intelligence, adaptability, and their ability to live in a variety of environments. They are known for their intelligence, adaptability, and their ability to live in a variety of environments.

据 Bo 所述,这条 discord 消息 (访问超链接时请确保已加入 discord 频道)  中有更详细的书写指令技巧。

7401ac9db22e52885182fc221b24147d.png

权重转换

任何用户都可以使用 transformers 库中提供的转换脚本轻松地将原始 RWKV 模型权重转换为 HF 格式。具体步骤为: 首先,将 “原始” 权重 push 到 Hugging Face Hub (假定目标仓库为 RAW_HUB_REPO,目标权重文件为 RAW_FILE),然后运行以下转换脚本:

python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR

如果您想将转换后的模型 push 到 Hub 上 (假定推送目录为 dummy_user/converted-rwkv),首先请确保在 push 模型之前使用 huggingface-cli login 登录 HF 账号,然后运行:

python convert_rwkv_checkpoint_to_hf.py --repo_id RAW_HUB_REPO --checkpoint_file RAW_FILE --output_dir OUTPUT_DIR --push_to_hub --model_name dummy_user/converted-rwkv

未来工作

多语言 RWKV

Bo 目前正在研究在多语言语料库上训练 RWKV 模型,最近发布了一个新的 多语言分词器。

社区后续研究方向

RWKV 社区非常活跃,致力于几个后续研究方向。项目清单可以在 RWKV 的 discord 专用频道中找到 (访问超链接时请确保已加入 discord 频道)。欢迎加入这个 RWKV 研究频道,以及对 RWKV 的积极贡献!

模型压缩与加速

由于只需要矩阵 - 向量运算,对于非标准化和实验性的计算硬件,RWKV 是一个非常理想的架构选择,例如光子处理器/加速器。

因此自然地,RWKV 架构也可以使用经典的加速和压缩技术 (如 ONNX、4 位/8 位量化等)。我们希望集成了 transformer 的 RWKV 架构能够使更多开发者和从业者受益。

在不久的将来,RWKV 还可以使用 optimum 库提出的加速技术。rwkv.cpp 或 rwkv-cpp-cuda 仓库涉及的其中一些技术在库中已标明。

致谢

我们 Hugging Face 团队非常感谢 Bo 和 RWKV 社区抽出宝贵时间来回答关于架构的问题,以及非常感谢他们的帮助和支持。我们很期待在 HF 生态中看到更多 RWKV 模型的应用。我们还要感谢 Johan Wind 发布的关于 RWKV 的博文,这对我们理解架构本身和其潜力有很大帮助。最后,我们着重感谢 ArEnSc 开启 RWKV 集成到 transformers 库的 PR 所做的工作,以及感谢 Merve Noyan、Maria Khalusova 和 Pedro Cuenca 审阅和校对本篇文章!

引用

如果您希望在工作中使用 RWKV,请使用此 cff 引用。https://github.com/BlinkDL/RWKV-LM/blob/main/CITATION.cff


英文原文: https://hf.co/blog/rwkv

作者: BlinkDL, Harrison Vanderbyl, Sylvain Gugger, Younes Belkada

译者: SuSung-boy

审校/排版: zhongdongy (阿东)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/592168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SAP-MM-发票行项目格式

目的:SAP提供标准事物代码屏幕变式,但因各个公司运用的方式不同,可采用屏幕变式进行自定义,方便最终用户使用; 方案:采用屏幕变式T-CODE:OLMRLIST对相应事物代码进行调整: 案例&am…

ASCII Unicode UTF-8等等编码介绍

目录 背景 Unicode UTF-8 ISO-8859-1 GB2312和GBK ANSI UTF-16LE 和UTF-16BE UTF-16 LE 和BE是什么 如何处理字节序问题 "带有BOM的UTF-8"又是什么? 背景 由于计算机是美国人发明的,因此最早只有127个字母被编码到计算机中&#x…

1.5k star,搭建一个属于自己或团队的WIKI知识管理系统

项目简介 zyplayer-doc是一款适合团队和个人使用的WIKI文档管理工具,同时还包含数据库文档、Api接口文档。 体验地址:http://zyplayer.com 在线文档:http://doc.zyplayer.com 各模块介绍 zyplayer-doc-manage 文档管理后台 具有项目模块…

ue4技术方向学习路线如何安排?

作为初学者,您可能会感到UE4(Unreal Engine 4)的学习曲线很陡峭。但是,只要您有一个良好的学习路线和资源,就可以轻松地掌握UE4技术方向。下面是一个可以帮助您快速入门UE4技术方向的学习路线安排。 第一阶段&#xff…

【资料分享】TVS二极管

瞬态电压抑制二极管也称为TVS二极管,是一种保护用的电子零件,可以保护电器设备不受导线引入的电压尖峰破坏。 TVS二极管会和要保护的电路并联。 当其电压超过突崩溃准位时,直接分流过多的电流。TVS二极管是箝位器,会抑制超过其崩…

java企业级信息系统开发学习笔记08 初探MyBatis实现简单查询

文章目录 一、学习目标1.了解什么是MyBatis框架2.掌握基于配置文件方式使用MyBatis3.掌握采用接口方式使用MyBatis 二、创建数据库与表1.在数据库中创建一个用户表 三、基于配置文件方式使用MyBatis(一)、创建一个Maven项目(二)、添加相关依赖&#xff0…

Android H5移动调试技巧——Chrome /Edge /QQ Inspect

Android APP项目原生的Html5做,cordova框架技术。苦于调试不便。稍微了解一下浏览器调试插件。 Chrome Inspect 使用该调试技巧,调试移动端中—自身开发的Hybrid APP、微信、UC、QQ、抖音等打开的页面,就像在电脑chrome调试一样方便。 手机…

WSL2 Ubuntu 22.04.02LTS + MobaXterm配置问题

参考了几篇文章: https://zhuanlan.zhihu.com/p/150555651https://zhuanlan.zhihu.com/p/150555651 https://zhuanlan.zhihu.com/p/151853503https://zhuanlan.zhihu.com/p/151853503 https://github.com/QMonkey/wsl-tutorial/issues/11https://github.com/QMon…

# jmeter 实战使用

jmeter 实战使用 官网 Apache JMeter - Apache JMeter™ Apache JMeter应用程序是开源软件, 一个纯Java应用程序设计 加载测试功能行为并测量性能。它是 最初设计用于测试 Web 应用程序,但具有 自扩展到其他测试功能后。 性能测试目的 提高页面响应速度…

1.5 初探Spring AOP

一、提出游吟诗人唱赞歌任务 骑士执行任务前和执行任务后,游吟诗人唱赞歌 1、采用传统方式实现 修改day04子包的勇敢骑士类 修改day04子包里的救美骑士类 执行测试类 - TestKnight 二、采用配置方式使用AOP 1、创建本讲所需子包 在net.hf.spring包里创建da…

Hibernate框架【三】——基本映射——一对一映射

系列文章目录 一对一映射 系列文章目录前言一、一对一映射是什么&#xff1f;二、一对一单向关联①一对一主键关联映射②一对一唯一外键映射 三、一对一双向关联①一对一主键关联映射&#xff08;双向关联Person<----->IdCard&#xff09;②h一对一唯一外键关联映射&…

3ds MAX 绘制简易沙发

目的是用3dMAX绘制简易的沙发 用到了长方体、圆角长方体、平移、旋转、阵列、组等操作 首先按照前面的教程创建一个圆角长方体&#xff0c;参数如图&#xff1a; 接着用圆柱体绘制靠背的支柱&#xff0c;为了方便使用阵列一次绘制出十个 我们设置的沙发长度为1&#xff4d;这…

Kibana:为 Kibana 定制自己的品牌 - 8.8

我还记得我刚加入 Elastic 的时候&#xff0c;那个时候在论坛里看到一个贴上是关于如何替换 Kibana 中的 Logo 从而重新定制自己的品牌。在今天&#xff0c;我非常高兴地告诉你&#xff0c;在 Elastic Stack 8.8 中&#xff0c;这个是完全可以通过配置的方式来实现的。 安装 如…

软考A计划-试题模拟含答案解析-卷十三

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

JS中数组22种常用API总结,slice、splice、map、reduce、shift、filter、indexOf......

大厂面试题分享 面试题库 前后端面试题库 &#xff08;面试必备&#xff09; 推荐&#xff1a;★★★★★ 地址&#xff1a;前端面试题库 web前端面试题库 VS java后端面试题库大全 一、引言 在前端开发中&#xff0c;数组是一种常见且重要的数据结构。数组提供了许多便捷的…

【录用案例】1区SCI仅1个月14天录用,新增19篇录用、5篇见刊、6篇检索

2023年5月20日-2023年5月26日&#xff0c;经核实&#xff0c;由我处Unionpub学术推荐的19篇论文已被期刊部录用、5篇见刊、6篇检索&#xff1a; 2区肿瘤类SCI 【期刊简介】IF:4.5-5.0&#xff0c;JCR2区&#xff0c;中科院2区 【检索情况】SCI 在检&#xff0c;正刊 【征稿…

一文包你学会网络数据抓包

本篇将图文并茂教你如何使用抓包工具&#xff0c;并在文章最后教大家如何偷取FTP的用户名密码。 一、安装 本文为大家介绍一个非常好用的抓包工具&#xff0c;科来。 下载地址&#xff1a; http://www.colasoft.com.cn/ 下载科莱 下载完毕&#xff0c;双击直接下一步即可安…

msvcr71.dll丢失的解决方法,多种修复方法全方位分享

当我们在使用某些软件时&#xff0c;可能会出现提示“msvcr71.dll丢失”的错误信息。这个错误信息意味着我们的电脑缺少msvcr71.dll文件&#xff0c;这个文件是由Microsoft Visual C 2003运行库提供的。如果我们遇到这个问题&#xff0c;我们需要采取措施来解决它。本文将介绍m…

体验管理|如何快速低成本开始体验相关的数字化工作‼️

Guofu 第 95⭐️ 篇原创文章分享 &#xff08;点击&#x1f446;&#x1f3fb;上方卡片关注我&#xff0c;加⭐️星标⭐️~&#xff09; &#x1f68f; 写在前面 在体验经济时代&#xff0c;传统企业在应对新需求、新挑战的时候&#xff0c;也需要用新的方式进行企业升级和转型…

DailyMart03:如何基于DDD设计商城的领域模型?

大家好&#xff0c;我是飘渺。既然有人催更那今天咱们就继续更新DDD&微服务系列&#xff01; 在面向对象开发中&#xff0c;所有事物都可以看作是对象。然而&#xff0c;在日常开发中&#xff0c;我们通常从数据出发来设计对象的表现形式&#xff0c;这种做法侧重于数据属性…