NLP实验-基于预训练模型的文本分类

news2024/12/24 22:45:55

使用BERT及其变体实现AclImdb情感分类

    • 前言
      • 数据集介绍
      • 【Hugging Face】使用方法和如何挑选一个自己需要的模型
    • 基于BERT预训练模型的本文分类
      • 数据预处理
      • 载入文本标记器
      • 将数据转化为模型可以接受的格式
      • 训练模型
      • 加载模型
    • 基于RoBerta预训练模型的文本分类
    • 基于DeBerta预训练模型的文本分类
    • 全部代码链接

2024.05.05 17:35
实现基于预训练模型的文本分类任务,使用三种不同的预训练模型,并对比分类准确率。

前言

数据集介绍

数据集来源:链接
数据集简要概述:该数据集包含电影评论及情感极性标签,其可用于作为一个基准情绪的分类。
共50000条评论,分为25k训练集,25k测试集,25k正面评论,25k负面评论,还有50000条未标记数据用于无监督学习。
数据集结构如下:

|- test
|-- neg
|-- pos
|- train
|-- neg
|-- pos

除此之外,每个电影的评论不超过30条,因为多了会存在相关性。
负面评价≤4分,证明评价≥7分,满分10分。

【Hugging Face】使用方法和如何挑选一个自己需要的模型

参考文章
模型名称解读:在Hugging Face上,模型名称通常是对模型架构、训练数据和任务的一种描述。这些模型名称通常包含了一些关键信息,帮助用户理解模型的基本特征。

【例子】

  1. “bert-base-uncased”
    这个模型名称中的"bert"代表了模型架构为BERT(Bidirectional Encoder Representations from Transformers)。
    "base"表示这是基本版的模型,通常是指相对较小的模型规模。
    "uncased"表示这个模型是在训练数据中将所有文本转换为小写处理的,没有区分大小写。
    "bert-base-uncased"表示了一个基于BERT架构的小型模型,适用于不区分大小写的任务。

  2. “gpt2-medium”
    这个模型名称中的"gpt2"代表了模型架构为GPT-2(Generative Pre-trained Transformer 2)。
    "medium"表示这是GPT-2模型系列中的中型规模模型。
    "gpt2-medium"表示了一个中等规模的GPT-2模型。

  3. “roberta-large”
    这个模型名称中的"roberta"代表了模型架构为RoBERTa(Robustly Optimized BERT approach)。
    "large"表示这是RoBERTa模型系列中的大型规模模型。
    "roberta-large"表示了一个大型的RoBERTa模型。

  4. “distilbert-base-uncased-finetuned-sst-2-english”
    这个模型名称解释了一些特定的信息。"distilbert"指的是经过蒸馏(distillation)处理的BERT模型,特点是具有较小的模型规模和更快的推理速度。
    "base"和"uncased"与之前提到的意义相同。
    "finetuned-sst-2"表示这个模型是在SST-2(斯坦福情感树库)数据集上进行了微调(Fine-tuning)以用于情感分类任务。
    "english"表示这个模型是为英语任务预训练和微调而创建的。

  5. “t5-base”
    这个模型名称中的"t5"是指T5(Text-to-Text Transfer Transformer)模型,这是一种基于Transformer架构的文本生成模型。
    "base"与之前提到的意义相同,表示模型的基本版本。

  6. “facebook/wmt19-mu-en-1024”
    这个模型名称指的是Facebook团队针对WMT19 Multilingual Translation任务训练的英语-多语言(mu)翻译模型。
    "-en"表示英语作为源语言
    "1024"表示模型的隐藏状态大小为1024。

  7. TheBloke/Llama-2-13B-chat-GGML
    “TheBloke”:这部分可能是指该模型的创建者、团队或者用户名。
    “Llama-2-13B”:这部分可能是指模型的架构、版本或系列。它可能是从较早版本的Llama模型发展而来,或者是在Llama模型系列中的第二个版本。 “2-13B"可能指的是模型参数和规模,表明该模型具有130亿个参数。
    “chat”:这部分可能指出该模型是专门用于对话或聊天任务的。这种指明任务类型的信息有助于用户了解模型的适用性。
    GGML”:这部分可能是指模型的训练或微调框架、方法或技术。

  8. stabilityai/sd-vae-ft-mse-original
    “stabilityai”:这部分可能是指模型的创建者、提供者或组织名称。它可能代表一个名为 “stabilityai” 的实体或团队。
    “sd-vae-ft-mse-original”:这部分可能提供了关于模型的其他关键信息。例如,“sd-vae” 可能表示变分自动编码器(VAE)的一种改进或特定类型。“ft” 可能是指模型进行了微调(fine-tuning)。“mse-original” 可能是指在模型训练过程中使用了均方误差(Mean Squared Error)作为损失函数或评价指标。

基于BERT预训练模型的本文分类

本文使用 https://huggingface.co/models 中的bert-base-uncased预训练模型进行实战。
首先是模型参数的下载,进入huggingface网址,搜索bert,选择下图中第一个选项,
bert模型参数
点击下载按钮,下载下图中框起来的文件到本地,文件夹命名为’bert-base-uncased‘。
在这里插入图片描述

数据预处理

要想使用Trainer进行训练,需要将数据调整到一定的规范,以下展示使用预训练模型对应的文本标记器(tokenizer)和datasets库处理原始数据。

"token"是什么?:在AI领域,token指文本或代码的最小单元,可以理解为单词或字符的更高级表示,token将文本分解成有意义的片段,例如单词、词根或标点符号,这些片段作为模型的输入,帮助模型理解和生成内容。例如:“我喜欢人工智能”这句话可以被分解成“我”、“喜欢”、“人工智能”三个token,每个token都代表一个独立的语义概念,token的长度和类型取决于具体的模型和应用,有些模型使用单个字符作为token,而有些模型则使用更长的词组或短语作为token。总而言之,token是AIGC模型处理和生成内容的基本单位,对于理解AIGC的工作原理至关重要。

from datasets import Dataset
from transformers import BertTokenizer
import os

# 载入原始数据
def load_data(base_path):
    paths = os.listdir(base_path)
    result = []
    for path in paths:
        with open(os.path.join(base_path, path), 'r', encoding='utf-8') as f:
            result.append(f.readline())
    return result

# 读入数据并转化为datasets.Dataset
def get_dataset(base_path):
		# 为了展示方便,这里只取前3个数据,真实使用需要删掉切片操作
    pos_data = load_data(os.path.join(base_path, 'pos'))[:3]
    neg_data = load_data(os.path.join(base_path, 'neg'))[:3]
    
		# 列表合并
    texts = pos_data + neg_data
		# 生成标签,其中使用 '1.' 和 '0.' 是因为需要转化为浮点数,要不然模型训练时会报错
    labels = [[1., 0.]]*len(pos_data) + [[0., 1.]] * len(neg_data)
    dataset = Dataset.from_dict({'texts':texts, 'labels':labels})
    return dataset

# 加载数据
train_dataset = get_dataset('../data/aclImdb/train/')
test_dataset = get_dataset('../data/aclImdb/test/')

# 可查看数据集结构、标签、特征等
print(train_dataset)
print(train_dataset['labels'])
print(train_dataset.features)

载入文本标记器

# cache_dir是预训练模型的地址
cache_dir="../transformersModels/bert-base-uncased/"
tokenizer = BertTokenizer.from_pretrained(cache_dir)

-注意: 这个路径的模型要自己下载,不能是transformer包下的,要不会报错。

将数据转化为模型可以接受的格式

# 设置最大长度
MAX_LENGTH = 512

# 使用文本标记器对texts进行编码
train_dataset = train_dataset.map(lambda e: tokenizer(e['texts'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
test_dataset = test_dataset.map(lambda e: tokenizer(e['texts'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)

# 保存处理好的数据到本地 
# 在数据量大的时候,处理数据需要很长的时间,为了不每次都重新处理数据,可以将数据先存到本地
train_dataset.save_to_disk('./data/train_dataset')
test_dataset.save_to_disk('./data/test_dataset')

‘texts’, ‘labels’, ‘input_ids’, ‘token_type_ids’, ‘attention_mask’

训练模型

from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments, BertConfig
import torch
from datasets import Dataset
import json
import os
# 设定使用的GPU编号,也可以不设置,但trainer会默认使用多GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# 将num_labels设置为2,因为我们训练的任务为2分类
model = BertForSequenceClassification.from_pretrained('../transformersModels/bert-base-uncased/', num_labels=2)

# 加载处理好的数据
train_dataset = Dataset.load_from_disk('./data/train_dataset/')
test_dataset = Dataset.load_from_disk('./data/test_dataset/')

# 冻结BERT参数
'''
因为BERT是预训练模型,因此可以不再进行权重更新,只对尾部的分类器进行优化。
与此同时,这个设置也会减少训练时使用的时间和显存。
'''
for param in model.base_model.parameters():
    param.requires_grad = False

# 创建trainer
# 训练超参配置
training_args = TrainingArguments(
    output_dir='./my_results',          # output directory 结果输出地址
    num_train_epochs=10,              # total # of training epochs 训练总批次
    per_device_train_batch_size=128,  # batch size per device during training 训练批大小
    per_device_eval_batch_size=128,   # batch size for evaluation 评估批大小
    logging_dir='./my_logs',            # directory for storing logs 日志存储位置
)

# 创建Trainer
trainer = Trainer(
    model=model.to('cuda'),              # the instantiated 🤗 Transformers model to be trained 需要训练的模型
    args=training_args,                  # training arguments, defined above 训练参数
    train_dataset=train_dataset,         # training dataset 训练集
    eval_dataset=test_dataset,           # evaluation dataset 测试集
)

# 训练、评估和保存模型
# 开始训练
trainer.train()

# 开始评估模型
trainer.evaluate()

# 保存模型 会保存到配置的output_dir处
trainer.save_model()
torch.save(model.state_dict(), 'model_save.bin')

这里保存模型参数代码不同可以看:链接

保存模型会生成三个文件:

# 模型配置文件
config.json

# 模型数据文件
model_save.bin

# 训练配置文件
training_args.bin

加载模型

output_config_file = './my_results/config.json'
output_model_file = './my_results/model_save.bin'

config = BertConfig.from_json_file(output_config_file)
model = BertForSequenceClassification(config)
state_dict = torch.load(output_model_file)
model.load_state_dict(state_dict)

cache_dir="../transformersModels/bert-base-uncased/"
tokenizer = BertTokenizer.from_pretrained(cache_dir)
data = tokenizer(['This is a good movie', 'This is a bad movie'], max_length=512, truncation=True, padding='max_length', return_tensors="pt")
print(model(**data))

输出结果:
SequenceClassifierOutput(
loss=None, logits=tensor([
[-0.2951,  0.5463],
[-0.4638,  0.6353]], 
grad_fn=<AddmmBackward0>), 
hidden_states=None, 
attentions=None)

由于只用3条数据训练了10轮,因此结果很差,正常训练结果可以变好了。

参考文章:Transformers实战——使用本地数据进行AclImdb情感分类

基于RoBerta预训练模型的文本分类

将Bert模型中的部分代码进行修改

注:
from transformers import RobertaModelfrom transformers import RobertaForSequenceClassification 之间的区别在于它们所代表的模型的不同。

  • RobertaModel 是 Hugging Face Transformers 库中的一个类,它表示了 RoBERTa 模型的基本架构。RobertaModel 只提供了预训练的 RoBERTa 模型的基本功能,例如输入编码、注意力机制等,但不包含用于特定任务(如分类)的额外层。

  • RobertaForSequenceClassificationRobertaModel 的一个派生类,它专门用于进行序列分类任务。RobertaForSequenceClassificationRobertaModel 的基础上添加了一个用于分类的线性层(linear layer),该层接收 RoBERTa 的输出并生成分类预测。这使得 RobertaForSequenceClassification 在进行分类任务时更加方便和高效。

因此,当你想要使用 RoBERTa 模型进行序列分类任务时,推荐使用 RobertaForSequenceClassification。如果你只需要 RoBERTa 模型的基本功能,而不涉及特定的任务,那么使用 RobertaModel 就足够了。

from transformers import RobertaTokenizer, RobertaModel, RobertaConfig

tokenizer = RobertaTokenizer.from_pretrained("pretrained_model/roberta_base/")
config = RobertaConfig.from_pretrained("pretrained_model/roberta_base/")
model = RobertaModel.from_pretrained("pretrained_model/roberta_base/")

基于DeBerta预训练模型的文本分类

ValueError: Couldn't instantiate the backend tokenizer from one of: 
(1) a `tokenizers` library serialization file, 
(2) a slow tokenizer instance to convert or 
(3) an equivalent slow tokenizer class to instantiate and convert. 
You need to have sentencepiece installed to convert a slow tokenizer to a fast one.

pip install sentencepiece
之后报错

TypeError: stat: path should be string, bytes, os.PathLike or integer, not NoneType
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")

看到某个博客下写的解决方案:
就是把你在huggingface上参考的那个模型,models下面其他相关文件别管有用没用全传上去,我之前就传了两,报这个错
试了下,成功!

全部代码链接

链接如下,欢迎star

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2043886.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用STM32定时器的PWM功能控制电机

目录 概述 1 系统框架结构 1.1 框架结构介绍 1.2 STM32 Cube配置PWM参数 2 软件实现 2.1 STM32Cube生成项目 2.2 PWM功能的User函数接口 3 测试 3.1 编写测试函数 3.2 功能测试 概述 本文主要介绍使用STM32定时器TIMER-8功能生成4路PWM&#xff0c;用于控制两路电机…

五种Slowing Changing Dimensions(SCD)方法及案例

SCD Type Description Key Features Type 1 Overwriting the existing data with new data, without keeping any history of the previous values. 直接覆盖&#xff0c;不留痕迹 - Overwrites Existing Data - No Historical Data - Simple Implementation Type…

Composerize神器:自动化转换Docker运行命令至Compose配置,简化容器部署流程

Composerize神器&#xff1a;自动化转换Docker运行命令至Compose配置&#xff0c;简化容器部署流程 在现代的微服务架构中&#xff0c;Docker Compose 是管理多容器应用的重要工具&#xff0c;它允许我们通过一个简单的 docker-compose.yml 文件来定义和运行多个关联的容器。然…

重发布实验

一、实验拓扑图 二、实验需求 1.如图搭建网络拓扑&#xff0c;所有路由器各自创建一个环回接口&#xff0c;合理规划IP地址 2.R1-R2-R3-R4-R6之间使用OSPF协议&#xff0c;R4-R5-R6之间使用RIP协议 3.R1环回重发布方式引入OSPF网络 4.R4/R6上进行双点双向重发布 5.分析网络…

CSP-CCF 202012-1 期末预测之安全指数

一、问题描述 二、解答 #include<iostream> using namespace std; int main() {int n;cin >> n;int w[100001] { 0 };int score[100001] { 0 };for (int i 1; i < n; i){cin >> w[i] >> score[i];}int y 0;for (int i 1; i < n; i){y y …

Java——反射(2/4):获取构造器对象并使用(获取类的构造器、并对其进行操作,获取类构造器的作用,代码实例)

目录 获取类的构造器 获取类的构造器、并对其进行操作 代码实例一 代码实例二 获取类构造器的作用 代码示例三 获取类的构造器 获取类的构造器、并对其进行操作 Class提供了从类中获取构造器的方法。 方法说明Constructor<?>[] getConstructors()获取全部构造器…

激光测距传感器

系列文章目录 1.元件基础 2.电路设计 3.PCB设计 4.元件焊接 5.板子调试 6.程序设计 7.算法学习 8.编写exe 9.检测标准 10.项目举例 11.职业规划 文章目录 前言一、产品原理&#xff1a;二、产品介绍&#xff1a;三、应用特点四、应用案例&#xff1a;1.冶金钢铁板卷材开卷工…

深入理解JavaScript性能优化:从基础到高级

引言 在当今快速发展的Web世界中,性能已经成为衡量应用质量的关键指标。随着Web应用复杂度的不断提升,JavaScript作为前端开发的核心语言,其性能优化变得尤为重要。本文旨在全面深入地探讨JavaScript性能优化的各个方面,从基础概念到高级技巧,帮助开发者构建高效、流畅的Web应用…

Android Studio本地加速安装gradle

Android Studio本地加速安装gradle 镜像下载依赖本地JAVA-JDK配置阿里云镜像配置环境变量验证gradle项目文件的介绍项目配置gradle项目Gradle-Wrapper加速配置&#xff0c;防止下载失败Gradle的常用命令 镜像下载 腾讯软件镜像源&#xff1a;https://mirrors.cloud.tencent.co…

50ETF期权移仓是什么?50ETF期权移仓要注意什么?

今天带你了解50ETF期权移仓是什么&#xff1f;50ETF期权移仓要注意什么&#xff1f;当前火热的期权交易市场&#xff0c;“移仓”同样是一门非常重要的技术。上证50ETF期权投资的过程中&#xff0c;我们可以进行一定的移仓操作的&#xff0c;如果移仓操作得好&#xff0c;可以很…

CSP-CCF 202104-1 灰度直方图

一、问题描述 二、解答 思路&#xff1a;用一个二维数组和一个一维数组、以及三个嵌套的for循环即可 代码&#xff1a; #include<iostream> using namespace std; int A[500][500] { 0 }; int main() {int n, m, L;cin >> n >> m >> L;int h[256] …

CocoaPods 官宣进入维护模式,不在积极开发新功能,未来将是 Swift Package Manager 的时代

昨天 CocoaPods 官宣现在项目**处于维护模式 **&#xff0c;简单来说&#xff0c;就是 CocoaPods 不会再像以前一样积极投入资源进行开发&#xff0c;这里的维护模式&#xff0c;就是让项目处于「可用」的状态&#xff0c;而此时距离 CocoaPods 的出现&#xff0c;也过去了有 1…

一套完整的NVR网络硬盘录像机解决方案和NVR程序源码介绍

随着网络技术的发展&#xff0c;视频数据存储的需求激增&#xff0c;促使硬盘录像机&#xff08;DVR&#xff09;逐渐演变为具备网络功能的网络视频录像机&#xff08;NVR&#xff09;。NVR&#xff0c;即网络视频录像机&#xff0c;负责网络视音频信号的接入、存储、转发、解码…

鸿蒙开发入门day05-ArkTs语言(接口与关键字)

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;还请三连支持一波哇ヾ(&#xff20;^∇^&#xff20;)ノ&#xff09; 目录 ArkTS语言介绍 接口 接口属性 接口继承 泛型类型和函数 泛型…

Unity(2022.3.38LTS) - 变换组件和约束

目录 一. 变换组件 二. 约束 一. 变换组件 在 Unity 中&#xff0c;变换组件&#xff08;Transform Component&#xff09;是每个游戏对象都必备的组件&#xff0c;用于控制对象在场景中的位置、旋转和缩放。 位置&#xff08;Position&#xff09;&#xff1a; 表示对象在…

opencv-python实战项目十:二维码识别

文章目录 一&#xff1a;简介二&#xff1a;opencv二维码识别流程三&#xff1a;整体代码四&#xff1a;效果 一&#xff1a;简介 二维码识别是一种利用图像处理技术&#xff0c;从数字图像中提取并解析二维码信息的过程。该技术广泛应用于信息快速交换、移动支付、产品追踪等…

SpringCloud的能源管理系统-能源管理平台源码

介绍 基于SpringCloud的能源管理系统-能源管理平台源码-能源在线监测平台-双碳平台源码-SpringCloud全家桶-能管管理系统源码 软件架构

MySQL的InnoDB存储引擎中的Buffer Pool机制

目录 Buffer Pool 简介 定义 为什么需要Buffer Pool 图解重点知识 Buffer Pool 的组成 数据页&#xff08;Data Pages&#xff09; 索引页&#xff08;Index Pages&#xff09; 插入缓冲页&#xff08;Insert Buffer Pages&#xff09; undo页&#xff08;Undo Pages&a…

idea鼠标悬浮显示注释

鼠标悬停在代码上的时候会出现快速文档&#xff0c;如下图&#xff0c;这里介绍下如何去除快速文档的显示 2020版本之前 依次找到 File—>Settings—>Editor—>General 去掉勾选 Show quick documentation on mouse move 2020版本之后 依次找到 File—>Settings…

Python数据可视化案例——地图

目录 简单案例&#xff1a; 进阶案例&#xff1a; 继上文数据可视化案例&#xff0c;今天学习用pyecharts练习数据可视化案例2-构建地图。 简单案例&#xff1a; 首先构建一个简单的地图。 代码&#xff1a; import json from pyecharts.charts import MapmapMap() data[…