Huggingface微调BART的代码示例:WMT16数据集训练新的标记进行翻译

news2025/2/27 4:58:43

BART模型是用来预训练seq-to-seq模型的降噪自动编码器(autoencoder)。它是一个序列到序列的模型,具有对损坏文本的双向编码器和一个从左到右的自回归解码器,所以它可以完美的执行翻译任务。

如果你想在翻译任务上测试一个新的体系结构,比如在自定义数据集上训练一个新的标记,那么处理起来会很麻烦,所以在本文中,我将介绍添加新标记的预处理步骤,并介绍如何进行模型微调。

因为Huggingface Hub有很多预训练过的模型,可以很容易地找到预训练标记器。但是我们要添加一个标记可能就会有些棘手,下面我们来完整的介绍如何实现它,首先加载和预处理数据集。

加载数据集

我们使用WMT16数据集及其罗马尼亚语-英语子集。load_dataset()函数将从Huggingface下载并加载任何可用的数据集。

 importdatasets
 
 dataset=datasets.load_dataset("stas/wmt16-en-ro-pre-processed", cache_dir="./wmt16-en_ro")

在上图1中可以看到数据集内容。我们需要将其“压平”,这样可以更好的访问数据,让后将其保存到硬盘中。

 defflatten(batch):
     batch['en'] =batch['translation']['en']
     batch['ro'] =batch['translation']['ro']
     
     returnbatch
 
 # Map the 'flatten' function
 train=dataset['train'].map( flatten )
 test=dataset['test'].map( flatten )
 validation=dataset['validation'].map( flatten )
 
 # Save to disk
 train.save_to_disk("./dataset/train")
 test.save_to_disk("./dataset/test")
 validation.save_to_disk("./dataset/validation")

下图2可以看到,已经从数据集中删除了“translation”维度。

标记器

标记器提供了训练标记器所需的所有工作。它由四个基本组成部分:(但这四个部分不是所有的都是必要的)

Models:标记器将如何分解每个单词。例如,给定单词“playing”:i) BPE模型将其分解为“play”+“ing”两个标记,ii) WordLevel将其视为一个标记。

Normalizers:需要在文本上发生的一些转换。有一些过滤器可以更改Unicode、小写字母或删除内容。

Pre-Tokenizers:为操作文本提供更大灵活性处理的函数。例如,如何处理数字。数字100应该被认为是“100”还是“1”、“0”、“0”?

Post-Processors:后处理具体情况取决于预训练模型的选择。例如,将 [BOS](句首)或 [EOS](句尾)标记添加到 BERT 输入。

下面的代码使用BPE模型、小写Normalizers和空白Pre-Tokenizers。然后用默认值初始化训练器对象,主要包括

1、词汇量大小使用50265以与BART的英语标记器一致

2、特殊标记,如和,

3、初始词汇量,这是每个模型启动过程的预定义列表。

 fromtokenizersimportnormalizers, pre_tokenizers, Tokenizer, models, trainers
 
 # Build a tokenizer
 bpe_tokenizer=Tokenizer(models.BPE())
 bpe_tokenizer.normalizer=normalizers.Lowercase()
 bpe_tokenizer.pre_tokenizer=pre_tokenizers.Whitespace()
 
 trainer=trainers.BpeTrainer(
     vocab_size=50265,
     special_tokens=["<s>", "<pad>", "</s>", "<unk>", "<mask>"],
     initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
 )

使用Huggingface的最后一步是连接Trainer和BPE模型,并传递数据集。根据数据的来源,可以使用不同的训练函数。我们将使用train_from_iterator()。

 defbatch_iterator():
     batch_length=1000
     foriinrange(0, len(train), batch_length):
         yieldtrain[i : i+batch_length]["ro"]
         
 bpe_tokenizer.train_from_iterator( batch_iterator(), length=len(train), trainer=trainer )
 
 bpe_tokenizer.save("./ro_tokenizer.json")
 

BART微调

现在可以使用使用新的标记器了。

 fromtransformersimportAutoTokenizer, PreTrainedTokenizerFast
 
 en_tokenizer=AutoTokenizer.from_pretrained( "facebook/bart-base" );
 ro_tokenizer=PreTrainedTokenizerFast.from_pretrained( "./ro_tokenizer.json" );
 ro_tokenizer.pad_token=en_tokenizer.pad_token
 
 deftokenize_dataset(sample):
     input=en_tokenizer(sample['en'], padding='max_length', max_length=120, truncation=True)
     label=ro_tokenizer(sample['ro'], padding='max_length', max_length=120, truncation=True)
 
     input["decoder_input_ids"] =label["input_ids"]
     input["decoder_attention_mask"] =label["attention_mask"]
     input["labels"] =label["input_ids"]
 
     returninput
 
 train_tokenized=train.map(tokenize_dataset, batched=True)
 test_tokenized=test.map(tokenize_dataset, batched=True)
 validation_tokenized=validation.map(tokenize_dataset, batched=True)

上面代码的第5行,为罗马尼亚语的标记器设置填充标记是非常必要的。因为它将在第9行使用,标记器使用填充可以使所有输入都具有相同的大小。

下面就是训练的过程:

 fromtransformersimportBartForConditionalGeneration
 fromtransformersimportSeq2SeqTrainingArguments, Seq2SeqTrainer
 
 model=BartForConditionalGeneration.from_pretrained(  "facebook/bart-base" )
 
 training_args=Seq2SeqTrainingArguments(
     output_dir="./",
     evaluation_strategy="steps",
     per_device_train_batch_size=2,
     per_device_eval_batch_size=2,
     predict_with_generate=True,
     logging_steps=2,  # set to 1000 for full training
     save_steps=64,  # set to 500 for full training
     eval_steps=64,  # set to 8000 for full training
     warmup_steps=1,  # set to 2000 for full training
     max_steps=128, # delete for full training
     overwrite_output_dir=True,
     save_total_limit=3,
     fp16=False, # True if GPU
 )
 
 trainer=Seq2SeqTrainer(
     model=model,
     args=training_args,
     train_dataset=train_tokenized,
     eval_dataset=validation_tokenized,
 )
 
 trainer.train()

过程也非常简单,加载bart基础模型(第4行),设置训练参数(第6行),使用Trainer对象绑定所有内容(第22行),并启动流程(第29行)。上述超参数都是测试目的,所以如果要得到最好的结果还需要进行超参数的设置,我们使用这些参数是可以运行的。

推理

推理过程也很简单,加载经过微调的模型并使用generate()方法进行转换就可以了,但是需要注意的是对源 (En) 和目标 (RO) 序列使用适当的分词器。

总结

虽然在使用自然语言处理(NLP)时,标记化似乎是一个基本操作,但它是一个不应忽视的关键步骤。HuggingFace的出现可以方便的让我们使用,这使得我们很容易忘记标记化的基本原理,而仅仅依赖预先训练好的模型。但是当我们希望自己训练新模型时,了解标记化过程及其对下游任务的影响是必不可少的,所以熟悉和掌握这个基本的操作是非常有必要的。

本文代码:https://avoid.overfit.cn/post/6a533780b5d842a28245c81bf46fac63

作者:Ala Alam Falaki

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/412876.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java Stream API 操作完全攻略:让你的代码更加出色 (四)

前言 Java Stream 是一种强大的数据处理工具&#xff0c;可以帮助开发人员快速高效地处理和转换数据流。使用 Stream 操作可以大大简化代码&#xff0c;使其更具可读性和可维护性&#xff0c;从而提高开发效率。本文将为您介绍 Java Stream 操作的所有方面&#xff0c;包括 ran…

交友项目【通用设置】三个功能实现

目录 1&#xff1a;交友项目【通用设置】 1.1&#xff1a;查询通用设置 1.1.1&#xff1a;接口地址 1.1.2&#xff1a;流程分析 1.1.3&#xff1a;代码实现 1.2&#xff1a;设置陌生人问题 1.2.1&#xff1a;接口地址 1.2.2&#xff1a;流程分析 1.2.3&#xff1a;代码…

Python 小型项目大全 51~55

五十一、九十九瓶的变体 原文&#xff1a;http://inventwithpython.com/bigbookpython/project51.html 在歌曲“九十九瓶”的这个版本中&#xff0c;该程序通过删除一个字母、交换一个字母的大小写、调换两个字母或重叠一个字母&#xff0c;在每个小节中引入了一些小的不完美。…

4月,我从外包公司离职了

先说一下自己的情况&#xff0c;大专生&#xff0c;18年通过校招进入湖南某软件公司&#xff0c;干了接近4年的功能测试&#xff0c;今年年初&#xff0c;感觉自己不能够在这样下去了&#xff0c;长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试…

python学习

1.安装 Download Python | Python.org 安装时&#xff0c;点击添加路径。 1.1 python的解释器 我们把代码写进.py结尾的文件里&#xff0c;然后 python 路径文件名就可以运行它了。 2.字面量 例如print("我们"),"我们",就是字符串字面量&#xff0c;…

简化你的代码,提高生产力:这10个Lambda表达式必须掌握

前言 Lambda表达式是一种在现代编程语言中越来越常见的特性&#xff0c;可以简化代码、提高生产力。这篇文章将介绍10个必须掌握的Lambda表达式&#xff0c;这些表达式涵盖了在实际编程中经常用到的常见场景&#xff0c;例如列表操作、函数组合、条件筛选等。通过学习这些Lambd…

JUC源码系列-CountDownLatch源码研读

前言 CountDownLatch是一个很有用的工具&#xff0c;latch是门闩的意思&#xff0c;该工具是为了解决某些操作只能在一组操作全部执行完成后才能执行的情景。例如&#xff0c;小组早上开会&#xff0c;只有等所有人到齐了才能开&#xff1b;再如&#xff0c;游乐园里的过山车&…

运行时内存数据区之堆(二)

Minor GC、Major GC、与Full GC JVM在进行GC时&#xff0c;并非每次都对上面三个内存&#xff08;新生代、老年代&#xff1a;方法区&#xff09;区域一起回收的&#xff0c;大部分时候回收的都是指新生代。 针对HotSpot VM的实现&#xff0c;它里面的GC按照回收区域又分为两…

浅谈 如果做微服务了 这个模块怎么去划分?

如果做微服务了 这个模块怎么去划分&#xff1f; 还是高内聚 低耦合的一个思想吧 &#xff0c;单一职责的设计原则&#xff0c;也是一个封装的思想吧&#xff0c; 业务维度&#xff1a; ​ 按照业务的关联程度来决定&#xff0c;关联比较密切的业务适合拆分为一个微服务&…

C++语法(14)---- 模板进阶

C语法&#xff08;13&#xff09;---- 模拟实现priority_queue_哈里沃克的博客-CSDN博客https://blog.csdn.net/m0_63488627/article/details/130069707?spm1001.2014.3001.5501 目录 1.非类型模板参数 2.模板的特化 1.函数模板(仿函数) 2.类模板 1.全特化 2.半特化、偏…

INFINONE XC164单片机逆向记录(6)C语言学习

本人所写的博客都为开发之中遇到问题记录的随笔,主要是给自己积累些问题。免日后无印象,如有不当之处敬请指正(欢迎进扣群 24849632 探讨问题); 写在专栏前面https://blog.csdn.net/Junping1982/article/details/129955766 INFINONE XC164单片机逆向记录(1)资料准备

FusionCharts Suite XT v3.20.0 Crack

FusionCharts Suite XT v3.20.0 改进了仪表的径向条形图和调整大小功能。2023 年 4 月 11 日 - 9:37新版本特征 添加了一个新方法“_changeXAxisCoordinates”&#xff0c;它允许用户将 x 轴更改为在图例或数据交互时自动居中对齐。更新了 Angular 集成以支持 Angular 版本 14 …

【无功优化】基于多目标差分进化算法的含DG配电网无功优化模型【IEEE33节点】(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

SAM - 分割一切图像【AI大模型】

如果你认为 AI 领域已经通过 ChatGPT、GPT4 和 Stable Diffusion 快速发展&#xff0c;那么请系好安全带&#xff0c;为 AI 的下一个突破性创新做好准备。 推荐&#xff1a;用 NSDT场景设计器 快速搭建3D场景。 Meta 的 FAIR 实验室刚刚发布了 Segment Anything Model (SAM)&am…

电脑软件:推荐一款Windows剪贴板增强软件——ClipX

目录 ClipX能做什么&#xff1f; 软件优点 软件不足之处 今天要介绍的剪切板神器——ClipX&#xff0c;拥有它可以作为弥补Windows 自带的剪贴板的短板的增强型工具软件。 ClipX能做什么&#xff1f; 1. 扩充剪贴板数量&#xff0c;数量可以自己设置 ClipX支持4到1024个剪…

Flutter(三)--可滚动布局

之前介绍了布局和容器&#xff0c;它们都用于摆放一个或多个子组件&#xff0c;而实际应用中&#xff0c;受限于手机、Pad、电脑的屏幕大小&#xff0c;一个布局不可能摆放无限个组件&#xff0c;我们往往采取滚动的方式&#xff0c;来使得一部分组件展示在屏幕上&#xff0c;一…

L2-041 插松枝PTA

人造松枝加工场的工人需要将各种尺寸的塑料松针插到松枝干上&#xff0c;做成大大小小的松枝。他们的工作流程&#xff08;并不&#xff09;是这样的&#xff1a; 每人手边有一只小盒子&#xff0c;初始状态为空。每人面前有用不完的松枝干和一个推送器&#xff0c;每次推送一…

piwigo安装及初步使用

一 摘要 本文主要介绍piwigo 安装及初步使用&#xff0c;nginx \php\mysql 等使用 docker 安装 二 环境信息 2.1 操作系统 CentOS Linux release 7.9.2009 (Core)2.2 piwigo piwigo-13.6.0.zip三 安装 3.1安装资源下载 piwigo 请到官网下载https://piwigo.org 安装步骤也…

【STL九】关联容器——map容器、multimap容器

【STL九】关联容器——map容器、multimap容器一、map简介二、头文件三、模板类四、map的内部结构五、成员函数1、迭代器2、元素访问3、容量4、修改操作~~5、操作~~5、查找6、查看操作六、demo1、查找find2、查找lower_bound、upper_bound3、insert、emplace() 和 emplace_hint(…

超详细!Apache+Tomcat+mod_jk搭建负载均衡集群

目录 0.流程图&#xff1a; 1.集群环境&#xff1a; 2.Apache服务器安装httpd&#xff1a; 3.tomcat1服务器和tomcat2服务器安装jdk和Tomcat 4.tomcat1服务器和tomcat2服务器创建页面&#xff1a; 5.Apache服务器的mod_jk模块的安装&#xff1a; 6.查看是否mod_jk.so模块…