Transformer 模型实用介绍:BERT

news2024/11/25 4:04:05

动动发财的小手,点个赞吧!

在 NLP 中,Transformer 模型架构是一场革命,极大地增强了理解和生成文本信息的能力。

本教程[1]中,我们将深入研究 BERT(一种著名的基于 Transformer 的模型),并提供一个实践示例来微调基本 BERT 模型以进行情感分析。

BERT简介

BERT 由 Google 研究人员于 2018 年推出,是一种使用 Transformer 架构的强大语言模型。 BERT 突破了早期模型架构(例如 LSTM 和 GRU)单向或顺序双向的界限,同时考虑了过去和未来的上下文。这是由于创新的“注意力机制”,它允许模型在生成表示时权衡句子中单词的重要性。

BERT 模型针对以下两个 NLP 任务进行了预训练:

  • 掩码语言模型 (MLM)

  • 下一句话预测 (NSP)

通常用作各种下游 NLP 任务的基础模型,例如我们将在本教程中介绍的情感分析。

预训练和微调

BERT 的强大之处在于它的两步过程:

  • 预训练是 BERT 在大量数据上进行训练的阶段。因此,它学习预测句子中的屏蔽词(MLM 任务)并预测一个句子是否在另一个句子后面(NSP 任务)。此阶段的输出是一个预训练的 NLP 模型,具有对该语言的通用“理解”
  • 微调是针对特定任务进一步训练预训练的 BERT 模型。该模型使用预先训练的参数进行初始化,并且整个模型在下游任务上进行训练,从而使 BERT 能够根据当前任务的具体情况微调其对语言的理解。

实践:使用 BERT 进行情感分析

完整的代码可作为 GitHub 上的 Jupyter Notebook 获取

在本次实践练习中,我们将在 IMDB 电影评论数据集(许可证:Apache 2.0)上训练情感分析模型,该数据集

会标记评论是正面还是负面。我们还将使用 Hugging Face 的转换器库加载模型。

让我们加载所有库

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, roc_curve, auc
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

# Variables to set the number of epochs and samples
num_epochs = 10
num_samples = 100  # set this to -1 to use all data

首先,我们需要加载数据集和模型标记器。

# Step 1: Load dataset and model tokenizer
dataset = load_dataset('imdb')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

接下来,我们将创建一个绘图来查看正类和负类的分布。

# Data Exploration
train_df = pd.DataFrame(dataset["train"])
sns.countplot(x='label', data=train_df)
plt.title('Class distribution')
plt.show()
alt

接下来,我们通过标记文本来预处理数据集。我们使用 BERT 的标记器,它将文本转换为与 BERT 词汇相对应的标记。

# Step 2: Preprocess the dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
alt

之后,我们准备训练和评估数据集。请记住,如果您想使用所有数据,可以将 num_samples 变量设置为 -1。

if num_samples == -1:
    small_train_dataset = tokenized_datasets["train"].shuffle(seed=42)
    small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42)
else:
    small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(num_samples)) 
    small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(num_samples)) 

然后,我们加载预训练的 BERT 模型。我们将使用 AutoModelForSequenceClassification 类,这是一个专为分类任务设计的 BERT 模型。

# Step 3: Load pre-trained model
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

现在,我们准备定义训练参数并创建一个 Trainer 实例来训练我们的模型。

# Step 4: Define training arguments
training_args = TrainingArguments("test_trainer", evaluation_strategy="epoch", no_cuda=True, num_train_epochs=num_epochs)

# Step 5: Create Trainer instance and train
trainer = Trainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset
)

trainer.train()

结果解释

训练完我们的模型后,让我们对其进行评估。我们将计算混淆矩阵和 ROC 曲线,以了解我们的模型的表现如何。

# Step 6: Evaluation
predictions = trainer.predict(small_eval_dataset)

# Confusion matrix
cm = confusion_matrix(small_eval_dataset['label'], predictions.predictions.argmax(-1))
sns.heatmap(cm, annot=True, fmt='d')
plt.title('Confusion Matrix')
plt.show()

# ROC Curve
fpr, tpr, _ = roc_curve(small_eval_dataset['label'], predictions.predictions[:, 1])
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(1.618 * 55))
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([01], [01], color='navy', lw=2, linestyle='--')
plt.xlim([0.01.0])
plt.ylim([0.01.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
alt
alt

混淆矩阵详细说明了我们的预测如何与实际标签相匹配,而 ROC 曲线则向我们展示了各种阈值设置下真阳性率(灵敏度)和假阳性率(1 - 特异性)之间的权衡。

最后,为了查看我们的模型的实际效果,让我们用它来推断示例文本的情绪。

# Step 7: Inference on a new sample
sample_text = "This is a fantastic movie. I really enjoyed it."
sample_inputs = tokenizer(sample_text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")

# Move inputs to device (if GPU available)
sample_inputs.to(training_args.device)

# Make prediction
predictions = model(**sample_inputs)
predicted_class = predictions.logits.argmax(-1).item()

if predicted_class == 1:
    print("Positive sentiment")
else:
    print("Negative sentiment")

总结

通过浏览 IMDb 电影评论的情感分析示例,我希望您能够清楚地了解如何将 BERT 应用于现实世界的 NLP 问题。我在此处包含的 Python 代码可以进行调整和扩展,以处理不同的任务和数据集,为更复杂和更准确的语言模型铺平道路。

Reference

[1]

Source: https://towardsdatascience.com/practical-introduction-to-transformer-models-bert-4715ed0deede

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781097.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp app运行到ios详细流程

uniapp运行到IOS真机调试(windows系统) 工具步骤1.首先数据线连接电脑和手机2.右键点击桌面上的HBuilder,打开文件所在位置3.打开HBuilder编辑器里要运行的项目,点击运行>运行到手机或模拟器>运行到IOS APP基座>勾选你的…

【Java虚拟机学习2】HotSpot虚拟机下对象的创建及在Java堆中对象的内存分配、布局和对象的访问

HotSpot虚拟机下对象的创建及在Java堆中对象的内存分配、布局和对象的访问 一、对象的创建 Step1:类加载检查 虚拟机遇到一条new指令时,首先将检查是否能在常量池中定位到这个类的符号引用,并且检查这个符号引用代表的类是否已被加载过、解…

【深度学习Week2】卷积神经网络

卷积神经网络 Convolutional Neural Networks,CNN 【第一部分:代码练习】1.MNIST 数据集分类2.CIFAR10 数据集分类3.使用 VGG16 对 CIFAR10 分类 【第二部分:问题总结】 【第一部分:代码练习】 1.MNIST 数据集分类 1.1 加载数据…

STM32入门学习之USART串口通信:

1.串口通信简介:通用异步收发传输器UART(Universal Asynchronous Receiver/Transmitter)是负责处理数据总线和串口之间的串/并通信的设备。UART通信规定了数据帧的格式:起始位、数据位、校验位、停止位等。UART异步通信只需要通信双方设置好数据帧的格式…

html2Canvas+JsPDF 导出pdf 无法显示网络图片

html2CanvasJsPDF 导出pdf 问题:类似于下面着这种网络图片使用img导出的时候是空白的 https://gimg3.baidu.com/search/srchttp%3A%2F%2Fpics4.baidu.com%2Ffeed%2F7e3e6709c93d70cf827fb2fda054500cb8a12bc9.jpeg%40f_auto%3Ftoken%3Dd97d3f0fd06e680e592584f8c7a2…

Devart UniDAC Crack

Devart UniDAC Crack 通用数据访问组件(UniDAC)是一个强大的非可视化跨数据库数据访问组件库,适用于Delphi、Delphi for.NET、CBuilder和Lazarus(Free Pascal)。我们将长期成功开发的经验结合到一个产品中,提供对流行数据库服务器的统一访问,…

Sublime Text 4 激活教程(Windows+Mac)

下载安装 官网 https://www.sublimetext.com 点击跳转 2023.7.21 版本为4143 Windows激活方式 一、激活License方式 入口在菜单栏中"Help” -> “Enter License” 注意格式,可能会过期失效,失效就用方式二 Mifeng User Single User License E…

SUSE宣布推出免费RHEL分叉以保留企业级Linux的选择权

导读在Red Hat宣布将限制AlmaLinuxOS或Rocky Linux等社区发行版对其公共仓库的访问后,最近Red Hat与IBM之间发生了一些争论,有鉴于此,SUSE今天宣布计划为RHEL和CentOS用户提供一个免费的替代方案。 SUSE已经开发了SUSE Linux Enterprise (SLE…

【数据挖掘】PCA/LDA/ICA:A成分分析算法比较

一、说明 在深入研究和比较算法之前,让我们独立回顾一下它们。请注意,本文的目的不是深入解释每种算法,而是比较它们的目标和结果。 如果您想了解更多关于PCA和ZCA之间的区别,请查看我之前基于numpy的帖子: PCA 美白与…

Fatdog64 Linux 814发布

导读Fatdog64 Linux是一个小型、桌面、64位的Linux发行版。 最初是作为Puppy Linux的衍生品,并增加了一些应用程序。该项目最新的版本,Fatdog64 814,是8xx系列的最后一个版本,未来的版本将转向9xx基础。 尽管它是该系列的最后一个…

红黑树概念

这里写目录标题 红黑树概念红黑树的性质红黑树节点的定义红黑树的插入 红黑树概念 红黑树,是一种二叉搜索树,但在每个结点上增加一个存储位表示结点的颜色,可以是Red或Black。 通过对任何一条从根到叶子的路径上各个结点着色方式的限制&…

Docker Compose 解析:定义和管理多容器应用,从多角度探索其优势和应用场景

🌷🍁 博主 libin9iOak带您 Go to New World.✨🍁 🦄 个人主页——libin9iOak的博客🎐 🐳 《面试题大全》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~&#x1f33…

【代码随想录 | Leetcode | 第十天】哈希表 | 三数之和 | 四数之和

前言 欢迎来到小K的Leetcode|代码随想录|专题化专栏,今天将为大家带来哈希法~三数之和 | 四数之和的分享✨ 目录 前言15. 三数之和18. 四数之和总结 15. 三数之和 ✨题目链接点这里 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], num…

flask 页面新增文件,存在重复文件时,返回错误消息

(40条消息) flask 读取文件夹文件,展示在页面,可以通过勾选删除_U盘失踪了的博客-CSDN博客 项目结构 这是一个基本的Flask应用程序,主要有两个路由,一个是index,用于显示所有存在的文件以及用于删除已选的文件&#…

C# SolidWorks 二次开发 -从零开始创建一个插件(2)

上一篇我详细讲解了如何创建一个插件,但是无界面无按钮,这种插件适合配合事件偷偷的在后台做点什么事情。今天这篇讲一下如何增加一些按钮到工具栏、菜单上去。 先告诉大家这个东西注册表在哪,因为solidworks在这方面做的不太好,…

七大排序算法和计数排序

文章目录 一、直接插入排序二、希尔排序三、直接选择排序四、堆排序五、冒泡排序六、快速排序6.1递归实现快速排序6.2非递归实现快速排序 七、归并排序7.1递归实现归并排序7.2非递归实现归并排序 八、计数排序 以下排序以从小到大排序为例 一、直接插入排序 时间复杂度&#x…

如何从gitee上下载项目并把它在本地运行起来

有时候我们会想到在gitee上下载下来项目,那么怎么把项目下载到本地并跑起来呢? 第一步:在git上找到你想要克隆下来的项目,按照如下操作复制项目地址连接,如下图: 以上可以选择HTTPS和SSH两种形式。 第二步…

在SPringBoot中整合Mybatis-plus以及mybatis-puls的基本使用

创建SPringBoot项目 1.选择创建项目 2.创建SPringBoot项目 3.选择SPringBoot的版本和依赖 4.导入mysql,druid,mybatis-plus和lombok的依赖,导入后记得更新依赖 <dependency><groupId>com.mysql</groupId><artifactId>mysql-connector-j</artifactId…

Mybatis单元测试,不使用spring

平时开发过程中需要对mybatis的Mapper类做单元测试&#xff0c;主要是验证语法是否正确&#xff0c;尤其是一些复杂的动态sql&#xff0c;一般项目都集成了spring或springboot&#xff0c;当项比较大时&#xff0c;每次单元测试启动相当慢&#xff0c;可能需要好几分钟&#xf…

Mac 四大常用清理软件推荐,软件特色下载教程横向评测

Mac 一般来说基本是不会中毒的&#xff0c;而且像 现在的 windows 也是很少中毒&#xff0c;但我们可能还是需要一款杀毒清理软件&#xff0c;主要是为了清理垃圾&#xff0c;统一查看并管理软件开机自启、权限信息等&#xff0c;统一卸载清理等功能&#xff0c;另外我们可能还…