训练自己的GPT2模型(中文),踩坑与经验

news2024/9/21 1:53:40

GPT2与Bert、T5之类的模型很不一样!!!

如果你对Bert、T5、BART的训练已经很熟悉,想要训练中文GPT模型,务必了解以下区别!!!
官方文档里虽然已经有教程,但是都是英文,自己实践过才知道有很多坑!!!
中文也有一些教程,但是使用了TextDataset这种已经过时的方法,不易于理解GPT2的真正工作原理。
在这里插入图片描述

开门见山说结论,与bert的最主要区别:

  1. GPT2Tokenizer,是以字节为单位的字节对编码,不是以中文的字或词为单位的!
    对于英文,GPT2Tokenizer大部分时候是以单词为单位进行切分的,但是对中文则完全不同,有时候2个id代表一个中文字,有时候又是1个?这一奇怪的现象正是因为采用字节对编码的结果。
    这也是为什么很多中文GPT使用BertTokenizer作为分词器,因为比较符合直观。
  2. GPT2Tokenizer没有默认的【pad_token】,需要自己设置,而且需要padding在左边!
  3. 训练时GPT2的【labels】和【input_ids】是一样的!所以使用的DataCollator不同

与T5的主要区别:

  1. generate时的设置不同,因为input本身也是output的一部分,所以最好设置max_new_tokens

下面对这几点分别介绍:

1.tokenizer问题
官方介绍:如下
Construct a GPT-2 tokenizer. Based on byte-level Byte-Pair-Encoding.
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will be encoded differently whether it is at the beginning of the sentence (without space) or not:

 from transformers import GPT2Tokenizer
 tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
 tokenizer("Hello world")['input_ids']
[15496, 995]
 tokenizer(" Hello world")['input_ids']
[18435, 995]

You can get around that behavior by passing add_prefix_space=True when instantiating this tokenizer or when you call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance.

总结起来就是:

  1. GPT-2 tokenizer 基于字节对进行编码。更多介绍可以看Byte-Pair-Encoding
  2. GPT-2 tokenizer 会把空格视为token的一部分(T5也是如此),例如“hello”与“ hello”的encode结果截然不同
  3. 你可以设置add_prefix_space,来避免上述情况,但是模型效果会下降

tokenize过程:
由于英文字母转换为字节再转换为单字节字符后和原来是一样的,所以英文tokenize看起来和bert差不多。(单字节字符共有256个,是ascii码的扩充,0-128和ascii码一样,所以不影响英文编码)
然而中文则面目全非,GPT-2 tokenizer的vocab里面看不见一个中文,因为vocab全都是单字节字符的组合。如下图:
在这里插入图片描述

那么中文是怎么变成id的呢?中文转换过程如下(这部分比较烦,不看不影响模型的训练
外部看起来的情况:中文(utf-8)–>字节串(一个中文3个字节)–>每个字节对应一个单字节字符–>单字节字符串–>寻找vocab里对应的子串,进行分词–>转变为input_ids
实际情况:中文(utf-8)–>字节串(一个中文3个字节)–>寻找vocab里对应的子字节串,进行分词–>转变为input_ids
可以看下面例子理解以上过程:

>>> '中国'.encode('utf-8')
b'\xe4\xb8\xad\xe5\x9b\xbd'

>>> [tokenizer.byte_encoder[b] for b in b'\xe4\xb8\xad\xe5\x9b\xbd']
['ä', '¸', 'Ń', 'å', 'Ľ', '½']

>>> ''.join(['ä', '¸', 'Ń', 'å', 'Ľ', '½'])
 'ä¸ŃåĽ½'

>>> tokenizer.tokenize('中国')
['ä¸Ń', 'åĽ', '½']

>>> tokenizer.convert_tokens_to_ids(['ä¸Ń', 'åĽ', '½'])
[40792, 32368, 121]

>>> tokenizer.tokenize('ä¸ŃåĽ½')
['ä', 'Â', '¸', 'Å', 'ĥ', 'Ã¥', 'Ä', '½', '½']

#由于python的encode命令默认使用utf-8编码,而不是单字节字符集,
#所以这里将“中国”的分词结果拼回去在分词,结果会不一样

>>> tokenizer.byte_decoder['ä']  #此处使用单字节字符集,将'ä'映射为一个字节
228   #十进制228对应十六进制0xe4
>>> bytearray([228])
bytearray(b'\xe4')

>>> 'ä'.encode('utf-8')  #此处使用默认encode,将'ä'映射为2个字节
b'\xc3\xa4'

2.Padding问题
由于gpt是自回归语言模型,理论上来说,是不需要pad的,因为生成的id必须立即接在输入的id后面,中间不能有pad_token。
但是当一个batch训练时,难免出现输入句子不一样长的情况,所以需要在前面添加pad_token而不是像Bert一样默认添加在后面。所以需要在加载tokenizer时设置:

tokenizer = GPT2Tokenizer.from_pretrained(model_path,padding_side='left')

3.训练label问题

  1. 对于GPT,训练数据集里没有输入输出的区别,没有question与answer之分。训练时,一整句话,既是input,也是label。所以labels与input_ids 完全一致。举例如下:
    假设我希望训练模型,使其能进行如下问答:question:“中国是首都是什么?”answer:“北京”
    T5:input_ids :“中国是首都是什么?”,labels:“北京”
    GPT2:input_ids :“中国是首都是什么?北京”,labels:“中国是首都是什么?北京”

  2. 当你的数据集已经有question和answer列,那么需要将question和answer拼接在一起,再tokenizer处理为input_ids与attention_mask列

  3. 当你的数据集已经有input_ids与attention_mask列,那么就使用 transformers提供的DataCollatorForLanguageModeling即可让dataloader自动生成labels。如下是训练一个epoch的方式:

#dataset已经经过处理,有input_ids与attention_mask列
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
data_loader = DataLoader(dataset, batch_size=batch_size,
                             shuffle=True, collate_fn=data_collator, drop_last=False)
# acclelrator包装
model, data_loader = accelerator.prepare(model, data_loader)
#训练一个epoch
for step, batch in enumerate(data_loader):
    optimizer.zero_grad()
    outputs = model(**batch)
    loss = outputs[0]
    accelerator.backward(loss)
    optimizer.step()

4.Generate问题

  1. 由于模型的config中pad_token一般为None,但在生成一个batch的时候,因为设置了early_stopping=True,所以生成的序列不一样长,难免要用到padding,所以这一项需要设置 :pad_token_id=tokenizer.pad_token_id,使所有生成序列一样长。
  2. GPT2生成的结果,max_length表示prompt+generate的总长度,max_new_tokens表示generate的长度,通常我们想要限制的都是generate的长度,input_ids的长度一般不算在内,所以设置 max_length=None, max_new_tokens=256。 T5模型则一般设置max_length,因为decoder部分一般没有前缀。
  3. 前面提到过的,input需要padding,但需要pad在左边,pad_token一般与eos_token相同,不影响生成结果。
  4. 由于是生成(test)不是训练,所以input_ids和训练时不同。训练时输入 问题+答案;测试时只输入 问题,不需要提供labels
    举个例子,训练时,input_ids是“中国是首都是什么?北京”;测试时,input_ids则为“中国是首都是什么?”,然后模型生成“中国是首都是什么?北京”,需要自己再把后面部分截取出来作为 答案
input_ids=tokenizer("中国是首都是什么?")['input_ids']
attention_mask=tokenizer("中国是首都是什么?")['attention_mask']
generated_ids = model.generate(
   input_ids=input_ids,
   attention_mask=attention_mask,
   min_length=3,
   max_length=None,
   max_new_tokens=256,
   pad_token_id=tokenizer.pad_token_id,
   repetition_penalty=3.5,
   length_penalty=2.5,
   early_stopping=True,)
   
decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

>>> decoded_preds 
'中国是首都是什么?北京'

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/196605.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手撸低代码平台搭建(四)组件拖动自由布局的实现

前言 大家好,在前两篇文章中,我们走进了前端低代码的世界,并揭秘了低代码的核心——页面设计器的实现。在揭秘页面设计器时,我们重点分享了顺序排列布局的组件拖动方式,那篇文章的评论中,有小伙伴问到自由布局的实现,那么我们在这篇文章中来分享一下自由布局拖动的实现…

Echarts柱形头部圆弧处理

第008个点击查看专栏目录对于柱状图来说,我们想要的效果是圆柱的上面进行圆弧的处理,产生顺滑的感觉,怎么处理呢,只要设置好样式即可,参考源代码圆角半径,单位px,支持传入数组分别指定 4 个圆角…

VMware 多站点容灾之SRM部署实践

一、背景 在VMware 多云场景中,我们最初会通过vmware的副本机制手动克隆或主从模式完成一些节点的灾备,虽然在初期不会出现什么问题,但一旦出现灾备恢复的复杂度和数据丢失风险还是一大考验,基于此,我们可借助VMware v…

Qt 获取网络信息

在Qt Network模块中使用QHostAddress存放IP地址,QHostInfo类来获取主机名和IP。 进行TCP/UDP编程时,需要将连接的主机名解析为IP地址,这个操作用DNS协议执行。 在互联网中现在有两种IP类型:IVP4和IVP6。 IP地址是给每一个连接在互…

Linux操作系统之基础IO

目录 系统IO调用接口 open write read 理解文件描述符fd 理解Linux操作系统的设计哲学,多态的思想是如何应用在Linux文件系统的 输出,追加,输入重定向的本质 子进程共享父进程的文件 IO的两个缓冲区 Linux特有的EXT文件系统 磁盘系…

代码训练营第二十天|530.二叉搜索树的最小绝对差 ● 501.二叉搜索树中的众数 ● 236. 二叉树的最近公共祖先

530 .二叉搜索树的最小绝对差 看完题后的思路 因为是二叉搜索树,所以直接按照二叉搜索树中序遍历,得到递增序列。遍历过程中一个指针指向遍历过的前一个元素 prenull; void f(root)if rootnull return递归 f&#x…

git语义化定制版本规范

目录说明说明 语义化版本控制规范,语义化的版本控制规范要求版本号由三部分构成:x.y.z MAJOR(X):这个是主版本号,一般是涉及到不兼容的 API 更改时,这个会变化。MINOR(Y)&#xff…

剑指Offer pow() 函数实现(快速幂)!!!

剑指 Offer 16. 数值的整数次方 实现 pow(x, n) ,即计算 x 的 n 次幂函数(即,xn)。不得使用库函数,同时不需要考虑大数问题。 示例 1: 输入:x 2.00000, n 10 输出:1024.00000 示…

早已有所耳闻的堆排序,你知道如何用C语言实现吗? 【堆排序|C语言版】

目录 0.写在前面 1.什么是堆? 2. 堆排序 2.1 建堆 2.1.1 AdjustUp(向上调整算法) 2.1.2 AdjustDown(向下调整算法) 2.2 两种建堆算法的时间复杂度 2.2.1 AdjustUp建堆的时间复杂度 2.2.2 AdjustDown建堆的时间…

神经网络(模型)量化介绍 - PTQ 和 QAT

神经网络(模型)量化介绍 - PTQ 和 QAT1. 需求目的2. 量化简介3. 三种量化模式3.1 Dynamic Quantization - 动态量化3.2 Post-Training Static Quantization - 训练后静态量化3.3 Quantization Aware Training - 量化感知训练4. PTQ 和 QAT 简介5. 设备和…

Flutter 小技巧之 3.7 性能优化background isolate

Flutter 3.7 的 background isolate 绝对是一大惊喜,尽管它在 release note 里被一笔带过 ,但是某种程度上它可以说是 3.7 里最实用的存在:因为使用简单,提升又直观。 Background isolate YYDS 前言 我们知道 Dart 里可以通过新建…

CODESYS开发教程9-文件读写(CAA File库)

今天继续我们的小白教程,老鸟就不要在这浪费时间了😊。 前面一期我们介绍了CODESYS的定时器及触发相关的功能块。这一期主要介绍CODESYS的CAA.File库中的目录和文件读写功能块,主要包括文件路径、名称、大小的获取以及文件的创建、打开、读、…

软测(概念) · 软件测试的基本概念 · 什么是需求 · 测试用例的概念 · 软件错误(bug)的概念

一、什么是软件测试软件测试和开发的区别测试和调试的区别一个优秀的软件测试人员具备的素质二、什么是需求从测试人员角度看待需求三、测试用例的概念四、软件错误(bug)的概念一、什么是软件测试 最常见的解释是:软件测试就是找 BUG&#x…

个人博客美化

总体参考: Butterfly 文档:https://butterfly.js.organzhiyu :https://anzhiy.cn张洪 Heo :https://blog.zhheo.comLeonus :https://blog.leonus.cn 注:博客所有美化大部分(全部)都参…

React项目实战之租房app项目(九)登录模块基础布局和功能实现

前言 目录前言一、房屋详情模块二、登录模块2.1 登录模块效果图2.2 基础布局2.3 调用接口实现登录2.4 实现表单验证功能2.4.1 formik介绍2.4.2 formik基本使用2.4.3 添加表单验证2.5 代码优化总结一、房屋详情模块 房屋详情模块主要是展示之前获取到的房源信息,由于…

为防护加码,飞凌嵌入式i.MX93系列开发板让通信安全又稳定

来源:飞凌嵌入式官网www.forlinx.com随着新基建的加快推进,智能制造迎来了更好的发展时机,嵌入式板卡等智能设备也在更多的应用场景中大放异彩。但随着现场的设备数量的剧增,环境中的各种干扰信号也随之增加,这就对设备…

windows下GitHub的SSH key配置

SSH Key 是一种方法来确定受信任的计算机,从而实现免密码登录。 Git是分布式的代码管理工具,远程的代码管理是基于SSH的,所以要使用远程的Git则需要SSH的配置。 下面的步骤将完成 生成SSH密钥 并 添加公共密钥到GitHub上的帐户 先设置GitHub…

Apifox接口测试工具详细解析

最近发现一款接口测试工具--apifox,我我们很难将它描述为一款接口管理工具 或 接口自测试工具。 官方给了一个简单的公式,更能说明apifox可以做什么。 Apifox Postman Swagger Mock JMeter Apifox的特点: 接口文档定义: Apif…

接口测试学习第二天

1、全局变量 概念:在postman全局生效的变量,全局唯一。设置: 代码设置:pm.globals.set("glb_age",100)//示例: pm.globals.set("glb_age",100) 获取: 代码获取:var 接收值…

Java的内部类详解(成员内部类、静态内部类、局部内部类、匿名内部类)

Java知识点总结:想看的可以从这里进入 目录2.2.4、 内部类1、成员内部类2、静态内部类3、局部内部类4、匿名内部类2.2.4、 内部类 一个类定义在另一个类内,那么这个类就是一个内部类,比如:在类A中定义一个类B,B就是内…