基于transformers框架实践Bert系列6-完形填空

news2024/9/21 22:33:39

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解完形填空应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)(注意:这是关键输出,本次任务就需要获取该值,可以取出那个被mask掉的token,获取其前几个,取score最高的(当然也可以使用top_k或者top_p方式获取一定随机性)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:ChnSentiCorp(该数据集本身是做情感分类,但是我们只需要取其text部分即可)
2)模型权重使用:bert-base-chinese

2.3 任务说明

完形填空其实就是在一段文字中mask掉几个字,让模型能够自动填充字。这里本身就是bert模型做预训练是所做的事情之一,因此就是让数据给模型做训练的过程。

2.4 实现关键

1)数据集结构是一个带有text和label两列的数据,我们只需要获取到text部分即可。
在这里插入图片描述
2)随机mask掉部分数据,这个本身也是bert的训练过程,因此在transforms框架中DataCollatorForLanguageModeling已经实现了,你也可以自己实现随机mask掉你的数据进行训练

3 关键代码

3.1 数据集处理

数据集不需要做过多处理,只需要将text部分进行tokenizer,并制定max_length和truncation即可

def process_function(datas):
    tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)
    return tokenized_datas
new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2 模型加载

model = BertForMaskedLM.from_pretrained(model_path)

注意:这里使用的是transformers中的BertForMaskedLM,该类对bert模型进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层。另外使用AutoModelForMaskedLM也可以,其实AutoModel最终返回的也是BertForMaskedLM,它是根据你config中的model_type去匹配的。
这里列一下BertForMaskedLM的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加增加了BertOnlyMLMHead,BertOnlyMLMHead其实就是一个二层神经网络,一层是BertPredictionHeadTransform(包括linear+geluAct+ln),一层是decoder(hidden_size*vocab_size大小的linear)。
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# 将输出结果outputs取第一个返回值,也就是last_hidden_state
sequence_output = outputs[0]
# 将last_hidden_state输入到cls层中,获得最终结果(预测的score和词)
prediction_scores = self.cls(sequence_output)

3.3 自动并随机mask数据

关键代码在于DataCollatorForLanguageModeling,该类会实现自动mask。参考torch_mask_tokens方法。

trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=new_datasets["train"],
                  data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),
                  )

4 整体代码

"""
基于BERT做完形填空
1)数据集来自:ChnSentiCorp
2)模型权重使用:bert-base-chinese
"""
# step 1 引入数据库
from datasets import DatasetDict
from transformers import TrainingArguments, Trainer, BertTokenizerFast, BertForMaskedLM, DataCollatorForLanguageModeling, pipeline

model_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/ChnSentiCorp"

# step 2 数据集处理
datasets = DatasetDict.load_from_disk(data_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)


def process_function(datas):
    tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)
    return tokenized_datas


new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)


# step 3 加载模型
model = BertForMaskedLM.from_pretrained(model_path)


# step 4 创建TrainingArguments
# 原先train是9600条数据,batch_size=32,因此每个epoch的step=300
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=32,  # 训练时的batch_size
                               num_train_epochs=1,              # 训练轮数
                               logging_steps=30,                # log 打印的频率
                               )


# step 5 创建Trainer
trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=new_datasets["train"],
                  # 自动MASK关键所在,通过DataCollatorForLanguageModeling实现自动MASK数据
                  data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),
                  )


# Step 6 模型训练
trainer.train()

# step 7 模型评估
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
str = datasets["test"][3]["text"]
str = str.replace("方便","[MASK][MASK]")
results = pipe(str)
# results[0][0]["token_str"]
print(results[0][0]["token_str"]+results[1][0]["token_str"])

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1714976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络故障与排除

一、Router-ID冲突导致OSPF路由环路 路由器收到相同Router-ID的两台设备发送的LSA,所以查看路由表看到的OSPF缺省路由信息就会不断变动。而当C1的缺省路由从C2中学到,C2的缺省路由又从C1中学到时,就形成了路由环路,因此出现路由不…

如何找出真正的交易信号?Anzo Capital昂首资本总结7个

匕首是一种新兴的价格走势形态,虽然不常见,但具有较高的统计可靠性。它通常预示着趋势的持续发展。该模式涉及到同时参考两个不同的时间周期进行交易,一个是短期,另一个是长期,比如一周时间框架与一天时间框架、一天时…

【微机原理及接口技术】可编程计数器/定时器8253

【微机原理及接口技术】可编程计数器/定时器8253 文章目录 【微机原理及接口技术】可编程计数器/定时器8253前言一、8253的内部结构和引脚二、8253的工作方式三、8253的编程总结 前言 本篇文章就8253芯片展开,详细介绍8253的内部结构和引脚,8253的工作方…

JRT性能演示

演示视频 君生我未生,我生君已老,这里是java信创频道JRT,真信创-不糊弄。 基础架构决定上层建筑,和给有些品种的植物种植一样,品种不对,施肥浇水再多,也是不可能长成参天大树的。JRT吸收了各方…

【数据结构:排序算法】堆排序(图文详解)

🎁个人主页:我们的五年 🔍系列专栏:数据结构课程学习 🎉欢迎大家点赞👍评论📝收藏⭐文章 目录 🍩1.大堆和小堆 🍩2.向上调整算法建堆和向下调整算法建堆:…

redis数据类型之Hash,Bitmaps

华子目录 Hash结构图相关命令hexists key fieldhmset key field1 value1 [field2 value2...]hscan key cursor [MATCH pattern] [COUNT count] Bitmaps位图相关命令setbit1. **命令描述**2. **语法**3. **参数限制**4. **内存分配与性能**5. **应用实例**6. **其他相关命令**7.…

6.S081的Lab学习——Lab5: xv6 lazy page allocation

文章目录 前言一、Eliminate allocation from sbrk() (easy)解析: 二、Lazy allocation (moderate)解析: 三、Lazytests and Usertests (moderate)解析: 总结 前言 一个本硕双非的小菜鸡,备战24年秋招。打算尝试6.S081&#xff0…

C++STL容器系列(三)list的详细用法和底层实现

目录 一:介绍二:list的创建和方法创建list方法 三:list的具体用法3.1 push_back、pop_back、push_front、pop_front3.2 insert() 和 erase()3.3 splice 函数 四:list容器底层实现4.1 list 容器节点结构5.2 list容器迭代器的底层实…

磁带存储:“不老的传说”依然在继续

现在是一个数据指数增长的时代,根据IDC数据预测,2025年全世界将产生175ZB的数据。 这里面大部分数据是不需要存储的,在2025预计每年需要存储11ZB的数据。换算个容易理解的说法,1ZB是10^18Bytes, 相当于要写5556万块容量18TB的硬盘…

markdown语法保存

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

PaddleOCR2.7+Qt5

章节一:Windows 下的 PIP 安装 官网安装教程地址 按照里面的教程去安装 如果使用cuda版本的还要安装tensorrt,不然后面运行demo程序的程序会报如下错。 下载TensorRT 8版本,tensorrt下载地址 章节二:编译源码 进入官网源码地址 下…

深入解析Web前端三大主流框架:Angular、React和Vue

Web前端三大主流框架分别是Angular、React和Vue。下面我将为您详细介绍这三大框架的特点和使用指南。 Angular 核心概念: 组件(Components): 组件是Angular应用的构建块,每个组件由一个带有装饰器的类、一个HTML模板、一个CSS样式表组成。组件通过输入(@Input)和输出(…

旧手机翻身成为办公利器——PalmDock的介绍也使用

旧手机有吧!!! 破电脑有吧!!! 那恭喜你,这篇文章可能对你有点用了。 介绍 这是一个旧手机废物利用变成工作利器的软件。可以在 Android 手机上快捷打开 windows 上的文件夹、文件、程序、命…

解决文件传输难题:如何绕过Gitee的100MB上传限制

引言 在版本控制和代码托管领域,Gitee作为一个流行的平台,为用户提供了便捷的服务。然而,其对单个文件大小设定的100MB限制有时会造成一些不便。 使用云存储服务 推荐理由: 便捷性:多数云存储服务如: Dro…

构造+模拟,CF1148C. Crazy Diamond

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 Problem - 1148C - Codeforces 二、解题报告 1、思路分析 题目提示O(5n)的解法了,事实上我们O(3n)就能解决,关键在于1,n的处理 我们读入数据a[],代表初始数组…

(函数)求一元二次方程的根(C语言)

一、运行结果&#xff1b; 二、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h> # include <math.h>//声明函数&#xff1b; //判断条件等于0时&#xff1b; void zeor(double a, double b);//判断条件大于0时&#xff1b; void bigzeo…

【Linux】Socket中的心跳机制(心跳包)

Socket中的心跳机制(心跳包) 1. 什么是心跳机制&#xff1f;(心跳包) 在客户端和服务端长时间没有相互发送数据的情况下&#xff0c;我们需要一种机制来判断连接是否依然存在。直接发送任何数据包可以实现这一点&#xff0c;但为了效率和简洁&#xff0c;通常发送一个空包&am…

java高级——Collection集合之List探索(包含ArrayList、LinkedList、Vector底层实现及区别,非常详细哦)

java高级——Collection集合之List探索 前情提要文章介绍提前了解的知识点1. 数组2. 单向链表3. 双向链表4. 为什么单向链表使用的较多5. 线程安全和线程不安全的概念 ArrayList介绍1. 继承结构解析1.1 三个标志性接口1.2 AbstractList和AbstractCollection 2. ArrayList底层代…

⌈ 传知代码 ⌋ YOLOv9最新最全代码复现

&#x1f49b;前情提要&#x1f49b; 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间&#xff0c;对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…