【Transformers】IMDB 分类

news2025/1/19 20:26:40

安装 transformers 库。

!pip install transformers
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

from transformers import BertTokenizer
from sklearn.model_selection import train_test_split
from transformers import TFBertForSequenceClassification

import warnings
warnings.filterwarnings('ignore')

加载 IMDB 数据集。

# 加载IMDB数据集
train_dataset = tfds.load(name='imdb_reviews', # 数据集名称
                                split='train', # 切分为训练集
                                as_supervised=True # 返回 (input, label)
                               )
test_dataset = tfds.load(name='imdb_reviews',
                         split='test',
                         as_supervised=True)

设置要调用的预训练模型名称。

BERT 是一种自然语言处理模型,能够理解文本的含义。bert-base-uncasedbert-base-cased 是 BERT 的两个预训练模型。

bert-base-uncased 的 “uncased” 表示它是一个不区分大小写的模型。这意味着在预处理数据中,所有单词都被转换为小写字母。这种模型的优点是在处理文本时能够忽略大小写,从而减少模型需要处理的单词数量,减少计算量。

bert-base-cased 的 “cased” 表示它是一个区分大小写的模型。在预处理数据中,保留了单词的原始大小写形式。这种模型的优点是在处理文本时可以更好地保留单词的信息,因为不同的大小写形式可能具有不同的含义。

在实际应用中,使用哪种模型取决于任务本身需要处理的文本是否区分大小写。例如,如果一个任务需要处理大小写敏感的文本,那么 bert-base-cased 可能是更好的选择。如果一个任务不需要处理大小写敏感的文本,那么 bert-base-uncased 可能是更好的选择。

bert_name = 'bert-base-uncased' # 不区分大小写,Love = love
tokenizer = BertTokenizer.from_pretrained(bert_name,
                                          add_special_tokens=True,
                                          do_lower_case=True,
                                          max_length=150,
                                          pad_to_max_length=True)

这段代码使用 Hugging Face Transformers 库中的 BertTokenizer 类来初始化一个分词器对象。

bert_name 是一个字符串,指定要使用哪个预训练的 BERT 模型。例如,它可以是 “bert-base-uncased”,表示使用基础的小写版本的 BERT。

add_special_tokens 是一个布尔值,表示是否在输入序列中添加特殊标记。这些标记包括在序列开头的 [CLS] 标记和在序列中的句子对之间的 [SEP] 标记。

do_lower_case 是一个布尔值,表示是否将所有输入文本转换为小写。

max_length 是一个整数,指定经过分词后的输入序列的最大长度。如果序列长度超过这个值,它将被截断。如果长度不够,分词器将使用特殊标记来填充序列。

pad_to_max_length 是一个布尔值,表示是否将输入序列填充到 max_length 指定的最大长度。如果为 True,则分词器将使用特殊标记来填充序列,使其长度与 max_length 相同。如果为 False,则分词器将不填充序列,并返回长度小于 max_length 的序列。

总的来说,这段代码设置了一个分词器对象,用于将输入文本标记化为可以喂入 BERT 模型进行文本分类或其他自然语言处理任务的格式。

# tokenizer的返回结果
tokenizer.encode_plus("The Chinese New Year is coming.",
                      add_special_tokens=True,
                      max_length=15,
                      pad_to_max_length=True,
                      return_attention_mask=True,
                      return_token_type_ids=True)

这段代码使用了 encode_plus 函数来对给定的文本进行编码, 生成可用于输入到预训练模型的 tokens 和其他信息。

具体来说,它的参数如下:

"The Chinese New Year is coming.": 给定的文本,需要进行编码的文本,当然这只是一个示例啦。

add_special_tokens=True: 是否在 tokens 中添加特殊 tokens(例如,[CLS], [SEP])以供模型使用。

max_length=15: 编码后 tokens 的最大长度限制。如果 tokens 的长度不合法,将进行截断或填充。

pad_to_max_length=True: 是否在 tokens 末尾填充0以使它们具有相同的长度。

return_attention_mask=True: 是否返回 attention mask。这个 mask 指示哪些 tokens 是实际输入和哪些tokens是填充的。

return_token_type_ids=True: 是否返回 token type IDs。这个 ID 指示 tokens 属于哪个句子,对于单句输入,它将为0。

{
'input_ids': [101, 1996, 2822, 2047, 2095, 2003, 2746, 1012, 102, 0, 0, 0, 0, 0, 0],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
}

此函数返回一个字典,其中包含以下键值对:

"input_ids": 编码后的 token 的列表,每个 token 都由一个整数表示。

"attention_mask": 一个由0和1组成的列表,用于指示哪些 tokens 是实际输入和哪些 tokens 是填充的。

"token_type_ids": 一个由0和1组成的列表,用于指示 tokens 属于哪个句子,对于单句输入,它将为0。

BERT 在处理文本时会对标点符号进行编码。

在 BERT 中,标点符号被视为一个独立的单词,与其他单词一样被嵌入到向量空间中。这意味着BERT在处理文本时不仅考虑了单词本身的含义,还考虑了它们之间的关系和上下文。
例如,在句子中使用逗号或句号可能会改变整个句子的含义,因此 BERT 会将这些标点符号编码为一个单独的词,并考虑它们在上下文中的作用。这种方法有助于提高 BERT 在处理自然语言处理任务中的性能,因为它能够更好地捕捉文本中的复杂性和细微差别。

特殊意义的编码。

[CLS] = 101[SEP] = 102[PAD] = 0

def bert_encode(text):
    text = text.numpy().decode('utf-8') 
    encode_result = tokenizer.encode_plus(text,
                                          add_special_tokens=True,
                                          max_length=150,
                                          pad_to_max_length=True,
                                          return_attention_mask=True,
                                          return_token_type_ids=True)
    input_ids = encode_result['input_ids']
    token_type_ids = encode_result['token_type_ids']
    attention_mask = encode_result['attention_mask']
    
    return input_ids, token_type_ids, attention_mask

text = text.numpy().decode('utf-8')

输入的 Tensor 张量对象是一个包含了 UTF-8 编码的文本字符串的张量。
首先,它使用 numpy() 方法将张量对象转换为 NumPy 数组对象。
然后,使用 NumPy 数组对象的 decode() 方法将 UTF-8 编码的字符串转换为 Python 字符串。

# 对数据集进行转换
bert_encode_train = [bert_encode(text) for text, label in train_dataset]
bert_encode_label = [label for text, label in train_dataset]

bert_encode_train = np.array(bert_encode_train)  # 类型转换 tensor -> array
bert_encode_label = tf.keras.utils.to_categorical(bert_encode_label, num_classes=2)  # 标签类型转换

划分数据集。

X_train, X_val, y_train, y_val = train_test_split(bert_encode_train,
                                                  bert_encode_label,
                                                  test_size=0.2,
                                                  random_state=520)

X_trainX_val 分为三部分。

train_inputs_ids, train_token_type_ids, train_attention_masks = np.split(X_train, 3, axis=1)  # 拆分

val_inputs_ids, val_token_type_ids, val_attention_masks = np.split(X_val, 3, axis=1)

减掉多余的1维。

train_inputs_ids = train_inputs_ids.squeeze()
train_token_type_ids = train_token_type_ids.squeeze()
train_attention_masks = train_attention_masks.squeeze()

val_inputs_ids = val_inputs_ids.squeeze()
val_token_type_ids = val_token_type_ids.squeeze()
val_attention_masks = val_attention_masks.squeeze()

定义构建训练和验证批数据的函数combine_dataset

def combine_dataset(input_ids, token_type_ids, attention_mask, label):
    data_format = {'input_ids' : input_ids,
                   'token_type_ids' : token_type_ids,
                   'attention_mask' : attention_mask}
    return data_format, label
# 训练批数据
train_ds = tf.data.Dataset.from_tensor_slices((train_inputs_ids,
                                               train_token_type_ids,
                                               train_attention_masks,
                                               y_train)).map(combine_dataset).shuffle(100).batch(16)

这段代码创建了一个用于训练机器学习模型的数据集(train_ds)。该数据集包含三个输入张量 train_inputs_ids、train_token_type_ids、train_attention_masks 和一个目标变量 y_train。

首先,tf.data.Dataset.from_tensor_slices() 方法将三个输入张量与目标变量 y_train 沿着第一个维度进行切片,形成了多个元组,每个元组包含四个分别对应的张量切片。这样每个元组就对应着一个样本。

然后,map() 方法将 combine_dataset 函数映射到数据集的每个元素上。combine_dataset 函数将四个输入张量打包成一个字典,键为字符串 ‘input_ids’、‘token_type_ids’ 和 ‘attention_mask’,值为对应的张量。这样,每个元组就被转换为了一个字典,其中包含了对应的输入数据和目标变量。

接下来,shuffle() 方法将数据集中的元素进行随机洗牌,以增加模型的泛化能力。参数 100 指定了洗牌时所使用的缓冲区大小,即每次随机选取的元素数量。

最后,batch() 方法将数据集中的元素按照指定的批大小进行分组,每个批包含了指定数量的样本。这里的批大小为 16,即每个批次包含了 16 个样本。这样,整个数据集被分成了若干个批次,每次训练模型时只需要处理一个批次的数据,从而减少了内存占用和计算时间。

# 验证批数据
val_ds = tf.data.Dataset.from_tensor_slices((val_inputs_ids,
                                             val_token_type_ids,
                                             val_attention_masks,
                                             y_val)).map(combine_dataset).shuffle(100).batch(16)

加载模型。

from transformers import TFBertForSequenceClassification

bert_model = TFBertForSequenceClassification.from_pretrained(bert_name)

注意是 TFBertForSequenceClassification ,可不敢写成 TFBartForSequenceClassification

我就犯了这个错,因为确实存在 BART 这个模型,细心一点,避免不必要的麻烦。

# 定义优化器

optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)

# 定义损失函数

loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 模型编译

bert_model.compile(optimizer=optimizer,
                   loss=loss,
                   metrics=['accuracy'])
bert_model.summary()
history = bert_model.fit(train_ds,
                         epochs=1,
                         validation_data=val_ds)

BERT 模型很大哈~,我垃圾电脑跑不起来,就用 Colab 训练了一轮,结果如下:

在测试集上进行测试,和训练集相同的预处理操作。

bert_test = [bert_encode(text) for text ,label in test_dataset]
bert_test_label = [label for text, label in test_dataset]

bert_test = np.array(bert_test)
bert_test_label = tf.keras.utils.to_categorical(bert_test_label, num_classes=2)

test_inputs_ids, test_token_type_ids, test_attention_masks = np.split(bert_test, 3, axis=1)  # 拆分

test_inputs_ids = test_inputs_ids.squeeze()
test_token_type_ids = test_token_type_ids.squeeze()
test_attention_masks = test_attention_masks.squeeze()

test_ds = tf.data.Dataset.from_tensor_slices((test_inputs_ids,
                                              test_token_type_ids,
                                              test_attention_masks,
                                              bert_test_label)).map(combine_dataset).shuffle(100).batch(16)
# 模型测试

bert_model.evaluate(test_ds)

只训练了一轮哦,最后效果还挺不错的哈。




在你想要放弃的那一刻,想想为什么当初坚持了那么久。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/398001.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

IO学习、拓展贴

1. 字节流 1.1 FileInputStream import org.junit.jupiter.api.Test;import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException;/*** 演示FileInputStream的使用(字节输入流 文件--> 程序)*/ public class FileInputStream_ {pu…

10款最佳项目管理工具推荐,总有一款适合你

为什么需要项目管理工具? 如今企业规模不断扩大,业务逐渐复杂化,项目管理已经成为现代企业管理中不可或缺的一环; 如果没有合适的项目管理工具,我们的项目管理和跟踪就会变得非常困难。这可能导致项目延期或者出现一…

免费Api接口汇总(亲测可用,可写项目)

免费Api接口汇总(亲测可用)1. 聚合数据2. 用友API3. 天行数据4. Free Api5. 购物商城6. 网易云音乐API7. 疫情API8. 免费Api合集1. 聚合数据 https://www.juhe.cn/ 2. 用友API http://iwenwiki.com/wapicovid19/ 3. 天行数据 https://www.tianapi.com…

RK356x U-Boot研究所(命令篇)3.9 scsi命令的用法

平台U-Boot 版本Linux SDK 版本RK356x2017.09v1.2.3文章目录 一、设备树与config配置二、scsi命令的定义三、scsi命令的用法3.1 scsi总线扫描3.2 scsi设备读写一、设备树与config配置 RK3568支持SATA接口,例如ROC-RK3568-PC: 原理图如下: 可以新建一个rk3568-sata.config配…

Oracle listagg,wm_concat函数行转列结果去重Oracle 11g/19c版本

1、准备数据表 2、根据学生名(stu_name)分组,学生名相同的,学生年龄(stu_age)用逗号拼接,使用 listagg()函数法拼接 3、上图中出现了两个12,12,实现去重 3.1 listagg() 函数 去重 【…

网络协议(十一):单向散列函数、对称加密、非对称加密、混合密码系统、数字签名、证书

网络协议系列文章 网络协议(一):基本概念、计算机之间的连接方式 网络协议(二):MAC地址、IP地址、子网掩码、子网和超网 网络协议(三):路由器原理及数据包传输过程 网络协议(四):网络分类、ISP、上网方式、公网私网、NAT 网络…

怎么把tif格式转成jpg?快速无损转换

怎么把tif格式转成jpg?在编辑使用图片的时候,弄清各种图片格式的特点是很重要的,因为图片总因自身格式具备的特点不同常常出现打不开的情况,或者占的体积大,这都会直接影响我们的使用。所以目前很多的图片格式都需要提…

spring boot整合RabbitMQ

文章目录 目录 文章目录 前言 一、环境准备 二、使用步骤 2.1 RabbitMQ高级特性 2.1.1 消息的可靠性传递 2.1.2 Consumer Ack 2.2.3 TTL 2.2.4 死信队列 总结 前言 一、环境准备 引入依赖生产者和消费都引入这个依赖 <dependency><groupId>org.springframework…

自动化测试总结--断言

采购对账测试业务流程中&#xff0c;其中一个测试步骤总是失败&#xff0c;原因是用例中参数写错及断言不明确 一、问题现象&#xff1a; 采购对账主流程中&#xff0c;其中一个步骤失败了&#xff0c;会导致这个套件一直失败 图&#xff08;1&#xff09;测试报告视图中&…

Navicate远程连接Linux上docker安装的MySQL容器

Navicate远程连接Linux上docker安装的MySQL容器失败 来自&#xff1a;https://bluebeastmight.github.io/ 问题描述&#xff1a;windows端的navicat远程连接不上Linux上docker安装的mysql&#xff08;5.7版本&#xff09;容器&#xff0c;错误代码10060 标注&#xff1a; 1、…

XSS攻击防御

XSS攻击防御XSS Filter过滤方法输入验证数据净化输出编码过滤方法Web安全编码规范XSS Filter XSS Filter的作用是通过正则的方式对用户&#xff08;客户端&#xff09;请求的参数做脚本的过滤&#xff0c;从而达到防范XSS攻击的效果。 XSS Filter作为防御跨站攻击的主要手段之…

C++ Primer阅读笔记--书包程序

1--该章节新知识点 ① 在 UNIX 和 Windows 系统中&#xff0c;执行完一个程序后&#xff0c;可以通过 echo 命令获得其返回值&#xff1b; # UNIX系统中&#xff0c;通过如下命令获得状态 echo $? ② 在标准库中&#xff0c;定义了两个输出流 ostream 对象&#xff1a;cerr…

运维效率狂飙,都在告警管理上

随着数字化进程的加速&#xff0c;企业IT设备和系统越来越多&#xff0c;告警和流程中断风险也随之增加。每套系统和工具发出的警报&#xff0c;听起来像是一场喧嚣的聚会&#xff0c;各自谈论不同的话题。更糟糕的是&#xff0c;安全和运维团队正在逐渐丧失对告警的敏感度&…

2.Fully Convolutional Networks for Semantic Segmentation论文记录

欢迎访问个人网络日志&#x1f339;&#x1f339;知行空间&#x1f339;&#x1f339; 文章目录1.基础介绍2.分类网络转换成全卷积分割网络3.转置卷积进行上采样4.特征融合5.一个pytorch源码实现参考资料1.基础介绍 论文:Fully Convolutional Networks for Semantic Segmentati…

如何用postman实现接口自动化测试

postman使用 开发中经常用postman来测试接口&#xff0c;一个简单的注册接口用postman测试&#xff1a; 接口正常工作只是最基本的要求&#xff0c;经常要评估接口性能&#xff0c;进行压力测试。 postman进行简单压力测试 下面是压测数据源&#xff0c;支持json和csv两个格…

Java反序列化漏洞——jdbc反序列化漏洞利用

漏洞原理如果攻击者能够控制JDBC连接设置项&#xff0c;那么就可以通过设置其指向恶意MySQL服务器进行ObjectInputStream.readObject()的反序列化攻击从而RCE。具体点说&#xff0c;就是通过JDBC连接MySQL服务端时&#xff0c;会有几个内置的SQL查询语句要执行&#xff0c;其中…

汽车用CAN通讯接口简介

随着新能源的普及,汽车用的芯片数量也越来越多,汽车在进行新四化(电动化、网联化、智能化、共享化),Gateway整车控制中心、TBox网联设备、IVI智能座舱、智驾域控制器等等ECU变得更智能,车控指令和车内通信变得更加丰富。车内ECU通讯比如CAN、LIN、蓝牙还有人提出高速以太…

pyflink学习笔记(四):datastream_api

现pyflink环境为1.16 &#xff0c;下面介绍下常用的datastream算子。现我整理的都是简单的、常用的&#xff0c;后期会继续补充。官网&#xff1a;https://nightlies.apache.org/flink/flink-docs-release-1.16/docs/dev/python/datastream/intro_to_datastream_api/from pyfli…

面向新时代,海泰方圆战略升级!“1465”隆重发布!

过去四年&#xff0c;海泰方圆“1344”战略一直在引领公司前行&#xff0c;搭建了非常坚实的战略框架基座&#xff0c;并推动全员在实践和行动中达成深度共识。 “1344”战略 1个定位&#xff0c;代表着当前机构用户的一组共性需求&#xff0c;密码安全数据治理信创工程。 3…

【项目精选】基于JAVA的私人牙科诊所管理系统(视频+论文+源码)

点击下载源码 摘要 随着科技的飞速发展&#xff0c;计算机已经广泛的应用于各个领域之中。在医学领域中&#xff0c;计算机主要应用于两个方面&#xff1a;一是医疗设备智能化&#xff0c;以硬件为主。另一种是病例信息管理系统&#xff08;HIS&#xff09;以软件建设为主&…