基于transformers框架实践Bert系列1--分类器(情感分类)

news2024/10/6 16:25:25

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解情感分类应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

目录

  • 1 环境说明
  • 2 前期准备
    • 2.1 了解Bert的输入输出
    • 2.2 数据集与模型
    • 2.3 任务说明
    • 2.4 实现关键
  • 3 关键代码
    • 3.1 数据集处理
    • 3.2 模型加载
    • 3.3 评估函数
  • 4 整体代码
  • 5 运行效果

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。(注意:这是关键输出,本次分类任务就需要获取该值,并进行一次线性层处理
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集:weibo_senti_100k(微博公开数据集),这里只是演示,使用其中2400条数据
2)模型:bert-base-chinese
注意:本次练习都是采用本地数据集和本地权重模型,不直接从hf下载,因为速度过慢

2.3 任务说明

情感分类任务其实就是将输出的结果映射到一个0或者1的标签上

2.4 实现关键

由于我们是判断整个句子的分类,因此需要的是输出序列的第一个token(参考Bert的输入输出),也就是pooler_output。将pooler_output输出转化为我们所需要的label值,因此需要 一个线性层实现(这一步下面我们利用transformers框架帮我们实现好的类)

3 关键代码

3.1 数据集处理

weibo_senti_100k是一个有2列数据分别label,review,其中label是结果,review则是input。我们需要做3个事情:读取数据、划分数据集、tokenizer

# 读取数据集
df = pd.read_csv(data_path)
dataset = load_dataset("csv", data_files=data_path, split='train')
dataset = dataset.filter(lambda x: x["review"] is not None)
# 划分数据集
datasets = dataset.train_test_split(test_size=0.1)  
# 数据集进行tokenizer
def process_function(datas):
    tokenized_datas = tokenizer(datas["review"], max_length=256, truncation=True)
    tokenized_datas["labels"] = datas["label"]
    return tokenized_datas

tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2 模型加载

model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)

注意:这里使用的是transformers中的BertForSequenceClassification,该类对bert的分类进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层(就是2.4中提到的实现关键)。另外使用AutoModelForSequenceClassification也可以,其实AutoModel最终返回的也是BertForSequenceClassification,它是根据你config中的model_type去匹配的。
这里列一下BertForSequenceClassification的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加dropout和分类线性层
self.dropout = nn.Dropout(classifier_dropout)
# 注意此处的num_labels可以在config中配置(使用标签:id2label),或者在from_pretrained传入,默认是2,当多分类的时候可以自己定义
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 在forward方法中,将bert输出outputs中的第二个值(也就是前面讲到的pooler_output),进行dropout处理并关联线性层
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

3.3 评估函数

这里采用evaluate库加载accuracy准确度计算方式来做评估,本次实验将accuracy的计算py文件下载下来,因此也是本地加载

# 评估函数:此处的评估函数可以从https://github.com/huggingface/evaluate下载到本地
acc_metric = evaluate.load("./evaluate/metric_accuracy.py")
def evaluate_function(prepredictions):
    predictions, labels = prepredictions
    predictions = predictions.argmax(axis=1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    return acc```
## 3.4 设置TrainingArguments

```python
# step 5 创建TrainingArguments
# 2400条数据,其中train和test比例9:1,因此train数据为2160条,batch_size=32,因此step=68,
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=32,  # 训练时的batch_size
                               # gradient_accumulation_steps=2,   # *** 因为显存够用,本次实验没有使用梯度累积 ***
                               gradient_checkpointing=True,     # *** 梯度检查点 ***
                               per_device_eval_batch_size=32,    # 验证时的batch_size
                               num_train_epochs=3,              # 训练轮数
                               logging_steps=20,                # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="accuracy",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型

4 整体代码

"""
基于BERT做情感分析
1)数据集来自:weibo_senti_100k(微博公开数据集),这里只是演示,使用2400条数据
2)模型权重使用:bert-base-chinese
"""

# step 1 引入数据库
import evaluate
import pandas as pd
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, \
    DataCollatorWithPadding, pipeline

model_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/weibo_senti_100k.csv"

# step 2 数据集处理
df = pd.read_csv(data_path)
dataset = load_dataset("csv", data_files=data_path, split='train')
dataset = dataset.filter(lambda x: x["review"] is not None)
datasets = dataset.train_test_split(test_size=0.1)  # 划分数据集
print("train data size:", len(datasets["train"]["label"]))
tokenizer = BertTokenizer.from_pretrained(model_path)


def process_function(datas):
    tokenized_datas = tokenizer(datas["review"], max_length=256, truncation=True)
    tokenized_datas["labels"] = datas["label"]
    return tokenized_datas


new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

# step 3 加载模型
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2)

# step 4 评估函数:此处的评估函数可以从https://github.com/huggingface/evaluate下载到本地
acc_metric = evaluate.load("./evaluate/metric_accuracy.py")


def evaluate_function(prepredictions):
    predictions, labels = prepredictions
    predictions = predictions.argmax(axis=1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    return acc


# step 5 创建TrainingArguments
# 2400条数据,其中train和test比例9:1,因此train数据为2160条,batch_size=32,每个epoch的step=68,epoch=3,因此总共step=204,
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=32,  # 训练时的batch_size
                               # gradient_accumulation_steps=2,   # *** 梯度累加 ***
                               gradient_checkpointing=True,     # *** 梯度检查点 ***
                               per_device_eval_batch_size=32,    # 验证时的batch_size
                               num_train_epochs=3,              # 训练轮数
                               logging_steps=20,                # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="accuracy",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型

# step 6 创建Trainer
trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=new_datasets["train"],
                  eval_dataset=new_datasets["test"],
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=evaluate_function,
                  )

# step 7 训练
trainer.train()

# step 8 模型评估
evaluate_result = trainer.evaluate(new_datasets["test"])
print(evaluate_result)

# step 9:模型预测
id2_label = {0: "消极", 1: "积极"}
model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
sen = "不止啦,金沙 碧水 翠苇 飞鸟 游鱼 远山 彩荷七大要素哒"
print(pipe(sen))

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1690590.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

node版本管理nvm详细教程

安装 nvm 之前先清理node相关的所有配置,如环境变量、.npmrc文件、node_cache、node_global 等 一、下载nvm 任选一处下载即可 官网:Releases coreybutler/nvm-windows (github.com) 码云:nvm下载仓库: nvm下载仓库 百度网盘&#xff1…

基于GA遗传优化的CNN-GRU的时间序列回归预测matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1 CNN-GRU模型架构 4.2 GA优化CNN-GRU流程 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022a 3.部分核心程序 ...........................................…

雷电预警监控系统:守护安全的重要防线

TH-LD1在自然界中,雷电是一种常见而强大的自然现象。它既有震撼人心的壮观景象,又潜藏着巨大的安全风险。为了有效应对雷电带来的威胁,雷电预警监控系统应运而生,成为现代社会中不可或缺的安全防护工具。 雷电预警监控系统的基本…

代码随想录算法训练营第14天 |● 理论基础 ● 递归遍历 ● 迭代遍历 ● 统一迭代

文章目录 前言二叉树的递归遍历💖递归算法基本要素代码 迭代遍历-需要先理清思路再写前向迭代法后序迭代中序迭代 迭代法统一写法总结 前言 理论基础 需要了解 二叉树的种类,存储方式,遍历方式 以及二叉树的定义 记录我容易忘记的点 题目…

打造AI虚拟伴侣 - 优化方案

第一部分:框架优化概述 1、精确定位: 构建一个高度灵活且用户友好的平台,旨在通过无缝集成多种大型语言模型(LLMs)后端,为用户创造沉浸式的角色交互体验。不仅适配电脑端,还特别优化移动端体验,满足二次元AI虚拟伴侣市场的特定需求。 2、核心功能强化: 增强后端兼容…

大数据Hive中的UDF:自定义数据处理的利器(下)

在上一篇文章中,我们对第一种用户定义函数(UDF)进行了基础介绍。接下来,本文将带您深入了解剩余的两种UDF函数类型。 文章目录 1. UDAF1.1 简单UDAF1.2 通用UDAF 2. UDTF3. 总结 1. UDAF 1.1 简单UDAF 第一种方式是 Simple(简单…

叶面积指数(LAI)数据、NPP数据、GPP数据、植被覆盖度数据获取

引言 多种卫星遥感数据反演叶面积指数(LAI)产品是地理遥感生态网推出的生态环境类数据产品之一。产品包括2000-2009年逐8天数据,值域是-100-689之间,数据类型为32bit整型。该产品经过遥感数据获取、计算归一化植被指数、解译植被类…

测量模拟量的优选模块:新型设备M-SENS3 8

| 具有8路自由选择通道的新型设备M-SENS3 8 IPETRONIK推出的模拟量测量设备——M-SENS3 8是新一代设备的新成员。该模块具有8个通道,能够自由选择测量模式,不仅支持高精度电压和电流的测量,还新增了频率测量模式。各通道分辨率高达18位&…

Selenium常用命令(python版)

日升时奋斗,日落时自省 目录 1、Selenium 2、常见问题 1、Selenium 安装Python和配置环境没有涉及 注:如有侵权,立即删除 首先安装selenium包,安装方式很简单 pip install selenium 注:我这里已经安装好了,所以…

spring boot集成Knife4j

文章目录 一、Knife4j是什么?二、使用步骤1.引入依赖2.新增相关的配置类3.添加配置信息4.新建测试类5. 启动项目 三、其他版本集成时常见异常1. Failed to start bean ‘documentationPluginsBootstrapper2.访问地址后报404 一、Knife4j是什么? 前言&…

弘君资本股市行情:股指预计保持震荡上扬格局 关注汽车、银行等板块

弘君资本指出,近期商场体现全体分化,指数层面上看,沪指一路震动上行,创出年内新高,创业板指和科创50指数体现相对较弱,依然是底部震动走势。从盘面体现上看,轮动依然是当时商场的主基调&#xf…

逻辑分析仪 - 采样率/采样深度

采样深度(Sampling Depth) 采样深度指的是逻辑分析仪在一次捕获过程中可以记录的最大样本数量。简单来说,采样深度越大,逻辑分析仪可以记录的数据量就越多。这对于分析长时间的信号变化或复杂的信号序列非常重要。 采样率&#…

WEB攻防【2】——ASPX/.NET项目/DLL反编译/未授权访问/配置调试报错

ASP:windowsiisaspaccess .net:windowsiisaspxsqlserver IIS上的安全问题也会影响到 WEB漏洞:本身源码上的问题 服务漏洞:1、中间件 2、数据库 3、第三方软件 #知识点: 1、.NET:配置调试-信息泄绵 2、.NET:源码反编译-DLL…

使用Flask ORM进行数据库操作的技术指南

文章目录 安装Flask SQLAlchemy配置数据库连接创建模型类数据库操作插入数据查询数据更新数据删除数据 总结 Flask是一个轻量级的Python Web框架,其灵活性和易用性使其成为开发人员喜爱的选择。而ORM(对象关系映射)则是一种将数据库中的表与面…

HCIP-Datacom-ARST自选题库__OSPF单选【80道题】

1.OSPFV2是运行在IPV4网络的IGP,OSPFV3是运行在IPV6网络的ICP,OSPFV3与OSPFv2的报文类型相同,包括Hello报文、DD报文、LSR报文、LSU报文和LSAck报文。关于OSPFv3报文,以下哪个说法是正确的 OSPFv3使用报文头部的认证字段完成报文…

揭秘齿轮加工工艺的选用原则:精准打造高效传动的秘密武器

在机械制造领域,齿轮作为传动系统中的重要组成部分,其加工工艺的选择至关重要。不同的齿轮加工工艺会影响齿轮的精度、耐用性和效率。本文将通过递进式结构,深入探讨齿轮加工工艺的选用原则,带您了解如何精准打造高效传动的秘密武…

最简单的 UDP-RTP 协议解析程序

最简单的 UDP-RTP 协议解析程序 最简单的 UDP-RTP 协议解析程序原理源程序结果下载链接参考 最简单的 UDP-RTP 协议解析程序 本文介绍网络协议数据的处理程序。网络协议数据在视频播放器中的位置如下所示。 本文中的程序是一个 UDP/RTP 协议流媒体数据解析器。该程序可以分析 …

Java | Leetcode Java题解之第109题有序链表转换二叉搜索树

题目: 题解: class Solution {ListNode globalHead;public TreeNode sortedListToBST(ListNode head) {globalHead head;int length getLength(head);return buildTree(0, length - 1);}public int getLength(ListNode head) {int ret 0;while (head…

彩虹聚合二级域名DNS管理系统源码v1.3

聚合DNS管理系统可以实现在一个网站内管理多个平台的域名解析, 目前已支持的域名平台有:阿里云、腾讯云、华为云、西部数码、CloudFlare。 本系统支持多用户,每个用户可分配不同的域名解析权限;支持API接口, 支持获…

结合反序列化注入tomcat内存马

0x01 前提概述 通过前几个内存马的学习我们可以知道,将内存马写在jsp文件上传并不是传统意义上的内存马注入,jsp文件本质上就是一个servlet,servlet会编译成class文件,也会实现文件落地。借用木头师傅的一张图 结合反序列化注入内…