【RWKV】如何新增一个自定义的Tokenizer和模型到HuggingFace

news2025/1/19 20:21:23

在这里插入图片描述

0x0. 前言

RWKV社区在Huggingface上放了rwkv-4-world和rwkv-5-world相关的一系列模型,见:https://huggingface.co/BlinkDL/rwkv-4-world & https://huggingface.co/BlinkDL/rwkv-5-world ,然而这些模型的格式是以PyTorch的格式进行保存的即*.pt文件,并没有将其转换为标准的Huggingface模型。后来了解到这里还有一个问题是RWKV的世界模型系列的tokenizer是自定义的,在Huggingface里面并没有与之对应的Tokenizer。没有标准的Huggingface模型就没办法使用TGI进行部署,也不利于模型的传播以及和其它模型一起做评测等等。

让RWKV world系列模型登陆Huggingface社区是必要的,这篇文章介绍了笔者为了达成这个目标所做的一些努力,最后成功让rwkv-4-world和rwkv-5-world系列模型登陆Huggingface。大家可以在 https://huggingface.co/RWKV 这个空间找到目前所有登陆的RWKV模型:

在这里插入图片描述

本系列的工作都整理开源在 https://github.com/BBuf/RWKV-World-HF-Tokenizer ,包含将 RWKV world tokenizer 实现为 Huggingface 版本,实现 RWKV 5.0 的模型,提供模型转换脚本,Lambda数据集ppl正确性检查工具 等等。

0x1. 效果

以 RWKV/rwkv-4-world-3b 为例,下面分别展示一下CPU后端和CUDA后端的执行代码和效果。

CPU

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-world-3b")
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-world-3b", trust_remote_code=True)

text = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
prompt = f'Question: {text.strip()}\n\nAnswer:'

inputs = tokenizer(prompt, return_tensors="pt")
output = model.generate(inputs["input_ids"], max_new_tokens=256)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

输出:

Question: In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.

Answer: The dragons in the valley spoke perfect Chinese, according to the scientist who discovered them.

GPU

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-4-world-3b", torch_dtype=torch.float16).to(0)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-world-3b", trust_remote_code=True)

text = "你叫什么名字?"
prompt = f'Question: {text.strip()}\n\nAnswer:'

inputs = tokenizer(prompt, return_tensors="pt").to(0)
output = model.generate(inputs["input_ids"], max_new_tokens=40)
print(tokenizer.decode(output[0].tolist(), skip_special_tokens=True))

输出:

Question: 你叫什么名字?

Answer: 我是一个人工智能语言模型,没有名字。

我们可以在本地通过上述代码分别运行CPU/GPU上的wkv-4-world-3b模型,当然这需要安装transformers和torch库。

0x2. 教程

下面展示一下在 https://github.com/BBuf/RWKV-World-HF-Tokenizer 做的自定义实现的RWKV world tokenizer的测试,RWKV world模型转换,检查lambda数据集正确性等的教程。

使用此仓库(https://github.com/BBuf/RWKV-World-HF-Tokenizer)的Huggingface项目

上传转换后的模型到Huggingface上时,如果bin文件太大需要使用这个指令 transformers-cli lfs-enable-largefiles 解除大小限制.

  • RWKV/rwkv-5-world-169m
  • RWKV/rwkv-4-world-169m
  • RWKV/rwkv-4-world-430m
  • RWKV/rwkv-4-world-1b5
  • RWKV/rwkv-4-world-3b
  • RWKV/rwkv-4-world-7b

RWKV World模型的HuggingFace版本的Tokenizer

下面的参考程序比较了原始tokenizer和HuggingFace版本的tokenizer对不同句子的编码和解码结果。

from transformers import AutoModelForCausalLM, AutoTokenizer
from rwkv_tokenizer import TRIE_TOKENIZER
token_path = "/Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer"

origin_tokenizer = TRIE_TOKENIZER('/Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_vocab_v20230424.txt')

from transformers import AutoTokenizer
hf_tokenizer = AutoTokenizer.from_pretrained(token_path, trust_remote_code=True)

# 测试编码器
assert hf_tokenizer("Hello")['input_ids'] == origin_tokenizer.encode('Hello')
assert hf_tokenizer("S:2")['input_ids'] == origin_tokenizer.encode('S:2')
assert hf_tokenizer("Made in China")['input_ids'] == origin_tokenizer.encode('Made in China')
assert hf_tokenizer("今天天气不错")['input_ids'] == origin_tokenizer.encode('今天天气不错')
assert hf_tokenizer("男:听说你们公司要派你去南方工作?")['input_ids'] == origin_tokenizer.encode('男:听说你们公司要派你去南方工作?')

# 测试解码器
assert hf_tokenizer.decode(hf_tokenizer("Hello")['input_ids']) == 'Hello'
assert hf_tokenizer.decode(hf_tokenizer("S:2")['input_ids']) == 'S:2'
assert hf_tokenizer.decode(hf_tokenizer("Made in China")['input_ids']) == 'Made in China'
assert hf_tokenizer.decode(hf_tokenizer("今天天气不错")['input_ids']) == '今天天气不错'
assert hf_tokenizer.decode(hf_tokenizer("男:听说你们公司要派你去南方工作?")['input_ids']) == '男:听说你们公司要派你去南方工作?'

Huggingface RWKV World模型转换

使用脚本scripts/convert_rwkv_world_model_to_hf.sh,将huggingface的BlinkDL/rwkv-4-world项目中的PyTorch格式模型转换为Huggingface格式。这里,我们以0.1B为例。

#!/bin/bash
set -x

cd scripts
python convert_rwkv_checkpoint_to_hf.py --repo_id BlinkDL/rwkv-4-world \
 --checkpoint_file RWKV-4-World-0.1B-v1-20230520-ctx4096.pth \
 --output_dir ../rwkv4-world4-0.1b-model/ \
 --tokenizer_file /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer \
 --size 169M \
 --is_world_tokenizer True
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer/rwkv_vocab_v20230424.json ../rwkv4-world4-0.1b-model/
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer/tokenization_rwkv_world.py ../rwkv4-world4-0.1b-model/
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer/tokenizer_config.json ../rwkv4-world4-0.1b-model/

使用脚本 scripts/convert_rwkv5_world_model_to_hf.sh,将来自 huggingface BlinkDL/rwkv-5-world 项目的 PyTorch 格式模型转换为 Huggingface 格式。在这里,我们以 0.1B 为例。

#!/bin/bash
set -x

cd scripts
python convert_rwkv5_checkpoint_to_hf.py --repo_id BlinkDL/rwkv-5-world \
 --checkpoint_file RWKV-5-World-0.1B-v1-20230803-ctx4096.pth \
 --output_dir ../rwkv5-world-169m-model/ \
 --tokenizer_file /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer \
 --size 169M \
 --is_world_tokenizer True

cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_v5.0_model/configuration_rwkv5.py ../rwkv5-world-169m-model/
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_v5.0_model/modeling_rwkv5.py ../rwkv5-world-169m-model/
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer/rwkv_vocab_v20230424.json ../rwkv5-world-169m-model/
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer/tokenization_rwkv_world.py ../rwkv5-world-169m-model/
cp /Users/bbuf/工作目录/RWKV/RWKV-World-HF-Tokenizer/rwkv_world_tokenizer/tokenizer_config.json ../rwkv5-world-169m-model/

另外,您需要在生成文件夹中的 config.json 文件开头添加以下几行:

"architectures": [
    "RwkvForCausalLM"
],
"auto_map": {
    "AutoConfig": "configuration_rwkv5.Rwkv5Config",
    "AutoModelForCausalLM": "modeling_rwkv5.RwkvForCausalLM"
},

运行Huggingface的RWKV World模型

run_hf_world_model_xxx.py演示了如何使用Huggingface的AutoModelForCausalLM加载转换后的模型,以及如何使用通过AutoTokenizer加载的自定义RWKV World模型的HuggingFace版本的Tokenizer进行模型推断。请看0x1节,这里就不赘述了。

检查Lambda

如果你想运行这两个脚本,首先需要下载一下 https://github.com/BlinkDL/ChatRWKV ,然后cd到
rwkv_pip_package 目录做pip install -e . ,然后再回到ChatRWKV的v2目录下面,运行这里提供的lambda_pt.pylambda_hf.py

check_lambda文件夹下的lambda_pt.pylambda_hf.py文件分别使用RWKV4 World 169M的原始PyTorch模型和HuggingFace模型对lambda数据集进行评估。从日志中可以看出,他们得到的评估结果基本上是一样的。

lambda_pt.py lambda评估日志
# Check LAMBADA...
# 100 ppl 42.41 acc 34.0
# 200 ppl 29.33 acc 37.0
# 300 ppl 25.95 acc 39.0
# 400 ppl 27.29 acc 36.75
# 500 ppl 28.3 acc 35.4
# 600 ppl 27.04 acc 35.83
...
# 5000 ppl 26.19 acc 35.84
# 5100 ppl 26.17 acc 35.88
# 5153 ppl 26.16 acc 35.88
lambda_hf.py lambda评估日志
# Check LAMBADA...
# 100 ppl 42.4 acc 34.0
# 200 ppl 29.3 acc 37.0
# 300 ppl 25.94 acc 39.0
# 400 ppl 27.27 acc 36.75
# 500 ppl 28.28 acc 35.4
# 600 ppl 27.02 acc 35.83
...
# 5000 ppl 26.17 acc 35.82
# 5100 ppl 26.15 acc 35.86
# 5153 ppl 26.14 acc 35.86

从lambda的输出结果可以验证原始基于ChatRWKV系统运行的模型和转换后的Huggingface模型是否精度是等价的。

0x3. 实现

Tokenizer的实现

Tokenizer的实现分为两步。

因为目前社区的RWKV world模型tokenizer文件是一个txt文件:https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/rwkv_vocab_v20230424.txt 。我们需要将其转换为Huggingface的AutoTokenizer可以读取的json文件。这一步是通过 https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/scripts/convert_vocab_json.py 这个脚本实现的,我们对比一下执行前后的效果。

在这里插入图片描述

转换为json文件后:

在这里插入图片描述

这里存在一个转义的关系,让gpt4解释一下\u0000和\x00的关系:

在这里插入图片描述

有了这个json文件之后,我们就可以写一个继承PreTrainedTokenizer类的RWKVWorldTokenizer了,由于RWKV world tokenzier的原始实现是基于Trie树(https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/rwkv_tokenizer.py#L5),所以我们实现 RWKVWorldTokenizer 的时候也要使用 Trie 树这个数据结构。具体的代码实现在 https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/rwkv_world_tokenizer/tokenization_rwkv_world.py 这个文件。

需要注意 https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/rwkv_world_tokenizer/tokenization_rwkv_world.py#L211 这一行,当解码的token id是0时表示句子的结束(eos),这个时候就停止解码。

RWKV World 5.0模型实现

实现了rwkv world tokenizer之后就可以完成所有rwkv4 world模型转换到HuggingFace模型格式了,因为rwkv4 world模型的模型结构和目前transformers里面支持的rwkv模型的代码完全一样,只是tokenzier有区别而已。

但是如果想把 https://huggingface.co/BlinkDL/rwkv-5-world 这里的模型也转换成 HuggingFace模型格式,那么我们就需要重新实现一下模型了。下面红色部分的这个模型就是rwkv5.0版本的模型,剩下的都是5.2版本的模型。

在这里插入图片描述

我对照着ChatRWKV里面rwkv world 5.0的模型实现完成了HuggingFace版本的rwkv world 5.0版本模型代码实现,具体见:https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/rwkv_world_v5.0_model/modeling_rwkv5.py ,是在transformers官方提供的实现上进一步修改得来。

实现了这个模型之后就可以完成rwkv world 5.0的模型转换为HuggingFace结果了,教程请看上一节或者github仓库。成品:https://huggingface.co/RWKV/rwkv-5-world-169m

0x4. 踩坑

在实现tokenizer的时候,由于RWKV world tokenzier的eos是0,我们没办法常规的去插入eos token,所以直接在Tokenzier的解码位置特判0。

RWKVWorldTokenizer的初始化函数在 super().__init__ 的时机应该是在构造self.encoder之后,即:

在这里插入图片描述

否则会在当前最新的transformers==4.34版本中造成下面的错误:

在这里插入图片描述

在模型实现部分踩了一个坑,在embedding的时候第一次需要需要对embedding权重做一个pre layernorm的操作:

在这里插入图片描述

这个操作只能在第一次输入的时候做一般就是prefill,我没注意到这一点导致后面decoding的时候也错误做了这个操作导致解码的时候乱码,后面排查了不少时间才定位到这个问题。

另外,在做check lambda的时候,如果不开启torch.no_grad上下文管理器禁用梯度计算,显存会一直涨直到oom:https://github.com/BBuf/RWKV-World-HF-Tokenizer/blob/main/check_lambda/lambda_hf.py#L48 。

0x5. 总结

这件事情大概花了国庆的一半时间加一个完整的周六才搞定,所以这篇文章记录一下也可以帮助有相关需求的小伙伴们少踩点坑。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1095656.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

$ vue -Vbash: vue: command not found

$ vue -V bash: vue: command not found报这个错,我们需要找到vue安装路径,添加在环境变量的用户变量中: 1、vue安装路径 2、编辑环境变量 然后重新打开命令框,就可以了

嵌入式数据库sqlite3【基础篇】基本命令操作,小白一看就懂(C/C++)

目录 前言 一、sqlite概念和特性 二、sqlite安装 三、sqlite3数据类型 四、sqlite数据库约束 五、sqlite常用命令 六、SQL语句(增删改查) 七、sqlite使用实例(教学管理数据库) 总结 前言 数据在实际工作中应用非常广泛…

Linux网络编程系列之服务器编程——阻塞IO模型

Linux网络编程系列 (够吃,管饱) 1、Linux网络编程系列之网络编程基础 2、Linux网络编程系列之TCP协议编程 3、Linux网络编程系列之UDP协议编程 4、Linux网络编程系列之UDP广播 5、Linux网络编程系列之UDP组播 6、Linux网络编程系列之服务器编…

Stream流中的 max()和 sorted()方法

需求:某个公司的开发部门,分为开发 一部 和 二部 ,现在需要进行年中数据结算。分析: 员工信息至少包含了(名称、性别、工资、奖金、处罚记录)开发一部有 4 个员工、开发二部有 5 名员工分别筛选出 2 个部门…

1.2 向量的长度与点积

一、向量的点积 两个向量 v ( v 1 , v 2 ) \boldsymbol v(v_1,v_2) v(v1​,v2​) 与 w ( w 1 , w 2 ) \boldsymbol w(w_1,w_2) w(w1​,w2​)的点积或内积是数字 v ⋅ w \boldsymbol v\cdot\boldsymbol w v⋅w: v ⋅ w v 1 w 1 v 2 w 2 ( 1.2.1 ) \boldsymbo…

【Agora UID 踩坑记录 Java 数据类型】

目录 负数二进制表示Java中32位无符号数的取法项目踩坑记录Java 0xffffffff隐式类型转换的坑 负数二进制表示 由于计算机中数据都以二进制表示,而负数的二级制是根据正数二进制取补码(补码就是先取反码,然后加1)得到,…

关于Gym变成Gymnasium

根据网页搜索的gym官网,发现如下网站https://www.gymlibrary.dev/ 刚进页面时 翻译一下,意思就是 Gym 的所有开发都已迁移到 Gymnasium,这是 Farama 基金会中的一个新软件包,由过去 18 个月来维护 Gym 的同一团队开发人员维护。如果您已经在使用最新版…

Java基础(运算符篇)

一、算术运算符 正号-负号加法-减法*乘法/除法%模运算(取余)自增--自减 算术运算符的使用比较简单,只需要注意一些细节。 tips: 加号( )除了可以作为正号,还可以用于字符串拼接。 public c…

解读下SWD协议以及其应用

SWD协议原理 SWD(Serial Wire Debug)协议是一种用于ARM Cortex微控制器的调试接口协议。它定义了主机计算机与目标设备之间通过SWD线进行通信的格式和规范。 SWD协议使用两根线进行通信:SWDIO(Serial Wire Debug I/O&#xff09…

跳石板(牛客)

目录 一、题目 二、代码 &#xff08;一&#xff09;解法一&#xff1a;超时了 &#xff08;二&#xff09;优化 一、题目 跳石板_牛客题霸_牛客网 二、代码 &#xff08;一&#xff09;解法一&#xff1a;超时了 #include <climits> #include <iostream> …

python的pyecharts第三方模块绘制高端统计图表

pyecharts库 python的pyecharts库是一个用于生成 Echarts 图表的python第三方类库&#xff0c;可以绘制很高端的统计图表以便展示数据。 安装方法 pip安装 pip install pyecharts或者github拉取下载安装 git clone https://github.com/pyecharts/pyecharts.git cd pyechar…

Android 内容提供者和内容观察者:数据共享和实时更新的完美组合

任务要求 一个作为ContentProvider提供联系人数据另一个作为Observer监听联系人数据的变化&#xff1a; 1、创建ContactProvider项目&#xff1b; 2、在ContactProvider项目中用Sqlite数据库实现联系人的读写功能&#xff1b; 3、在ContactProvider项目中通过ContentProvid…

1.安装 docker 容器并配置镜像加速器

1.2.1 实验环境准备 实验环境&#xff1a; rockylinux8.8 可以去官网下载 下载 Rocky | Rocky Linux 主机名&#xff1a; xuegod63 主机 ip: 192.168.1.63&#xff08;这个 ip 大家可以根据自己所在环境去配置&#xff0c;配置成静态 IP&#xff09; 2g 内存、2vCPU、50G 硬…

诊断DLL——CAPL_DLL集成安全访问算法

文章目录 前言一、CAPL DLL简介DLL生成C2338报错解决方案:二、添加27服务解锁算法三、CAPL调用dll前言 在实际诊断工程应用中,如UDS刷写——27服务,经常会遇到一些Seed2Key的算法问题,为了安全保密,这个算法的源码不便公开,我们可以将其打包成DLL,然后在CANoe诊断控制面…

树莓派4b配置通过smbus2使用LCD灯

出现报错&#xff1a; FileNotFoundError: [Errno 2] No such file or directory: ‘/dev/i2c-1’ 则说明没有打开I2C&#xff0c;可通过如下步骤进行设置 1、打开树莓派配置 sudo raspi-config2、进入Interface Options&#xff0c;配置I2C允许 目前很多python3版本已经不…

【LeetCode热题100】--121.买卖股票的最佳时机

121.买卖股票的最佳时机 class Solution {public int maxProfit(int[] prices) {int minprice Integer.MAX_VALUE;int maxprofit 0;for(int i 0;i<prices.length;i){if(prices[i] < minprice){minprice prices[i]; //找到最小值}else if(prices[i] - minprice > ma…

C++指针解读(7)-- 指针和函数

前面我们讲过用指针变量作为函数参数。这里讲指向函数的指针变量和返回指针的函数。 1、指向函数的指针变量 跟变量一样&#xff0c;函数也会存储在内存空间中&#xff0c;函数的内存空间有一个起始地址&#xff0c;每次调用函数时就从该入口地址开始执行函数代码。 既然函数…

Wordpress自定义小工具logo调用设置(可视化)

在主题开发中&#xff0c;需要调用网站的logo&#xff0c;最简单的办法就是用wp自带的函数&#xff0c;那就是the_custom_logo()&#xff0c;使用它还可以通过后台-自定义-logo&#xff0c;边修改边预览&#xff0c;还是很香的。 自定义徽标支持应首先使用add_theme_support()添…

【后端】韩顺平Java学习笔记(基础篇01)

因为之前有c基础&#xff0c;所以差不多一样的就简写了owo 来源&#xff1a;韩顺平 零基础30天学会Java 目录 I. 控制结构&#xff08;简&#xff09; 一、介绍 1. 顺序 → 从上到下执行&#xff0c;无跳转 2. 分支 → 单、双、多、嵌套 1&#xff09;单&#xff0c;即…

【GO入门】环境配置及Vscode配置

1 GO环境配置 欢迎来到Go的世界&#xff0c;让我们开始探索吧&#xff01; Go是一种新的语言&#xff0c;一种并发的、带垃圾回收的、快速编译的语言。它具有以下特点&#xff1a; 它可以在一台计算机上用几秒钟的时间编译一个大型的Go程序。Go为软件构造提供了一种模型&…