使用BERT进行文本分类

news2024/12/23 17:55:27

a1bdee5d0531017b62484b534620e1bc.png

本范例我们微调transformers中的BERT来处理文本情感分类任务。

我们的数据集是美团外卖的用户评论数据集。

模型目标是把评论分成好评(标签为1)和差评(标签为0)。

#安装库
#!pip install datasets 
#!pip install transformers[torch]
#!pip install torchkeras

公众号算法美食屋后台回复关键词 torchkeras, 获取本文源代码和waimai评论数据集。

一,准备数据

准备数据阶段主要需要用到的是datasets.Dataset 和transformers.AutoTokenizer。

1,数据加载

HuggingFace的datasets库提供了类似TensorFlow中的tf.data.Dataset的功能。

import numpy as np 
import pandas as pd 

import torch 
from torch.utils.data import DataLoader 

import datasets
df = pd.read_csv("data/waimai_10k.csv")
ds = datasets.Dataset.from_pandas(df) 
ds = ds.shuffle(42) #打乱顺序
ds = ds.rename_columns({"review":"text","label":"labels"})
ds[0]
{'labels': 0, 'text': '晚了半小时,七元套餐饮料就给的罐装的可乐,真是可以'}
ds[0:4]["text"]
['晚了半小时,七元套餐饮料就给的罐装的可乐,真是可以',
 '很好喝!天天都喝~~',
 '东西很少,像半分每次都是这样失望',
 '配送比较慢(不是高峰时间点的结果1个多小时才送到);菜品备注了“老人吃请少油少盐”,结果还是很咸很油,哎…失望']

2,文本分词

transformers库使用tokenizer进行文本分词。

from transformers import AutoTokenizer #BertTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese') #需要和模型一致
print(tokenizer)
BertTokenizerFast(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)
#tokenizer可以使用 __call__,encode,encode_plus,batch_encode_plus等方法编码
#可以使用decode,batch_decode等方法进行解码
text_codes = tokenizer(text = '晚了半小时,七元套餐饮料就给的罐装的可乐,真是可以',
                       text_pair = None,
                       max_length = 100, #为空则默认为模型最大长度,如BERT是512,GPT是1024
                       truncation = True,
                       padding= 'do_not_pad') #可选'longest','max_length','do_not_pad'

#input_ids是编码后的数字,token_type_ids表示来自第1个句子还是第2个句子
#attention_mask在padding的位置是0其它位置是1
print(text_codes)
{'input_ids': [101, 3241, 749, 1288, 2207, 3198, 8024, 673, 1039, 1947, 7623, 7650, 3160, 2218, 5314, 4638, 5380, 6163, 4638, 1377, 727, 8024, 4696, 3221, 1377, 809, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
tokenizer.decode(text_codes["input_ids"][0])
'[CLS]'
tokenizer.batch_decode(text_codes["input_ids"])
['[CLS]',
 '晚',
 '了',
 '半',
 '小',
 '时',
 ',',
 '七',
 '元',
 '套',
 '餐',
 '饮',
 '料',
 '就',
 '给',
 '的',
 '罐',
 '装',
 '的',
 '可',
 '乐',
 ',',
 '真',
 '是',
 '可',
 '以',
 '[SEP]']
tokens = tokenizer.tokenize(ds['text'][0])
print("tokens=",tokens)
ids = tokenizer.convert_tokens_to_ids(tokens)
print("ids = ",ids)
tokens= ['晚', '了', '半', '小', '时', ',', '七', '元', '套', '餐', '饮', '料', '就', '给', '的', '罐', '装', '的', '可', '乐', ',', '真', '是', '可', '以']
ids =  [3241, 749, 1288, 2207, 3198, 8024, 673, 1039, 1947, 7623, 7650, 3160, 2218, 5314, 4638, 5380, 6163, 4638, 1377, 727, 8024, 4696, 3221, 1377, 809]

3,传入DataLoader

ds_encoded = ds.map(lambda example:tokenizer(example["text"],
                    max_length=50,truncation=True,padding='max_length'),
                    batched=True,
                    batch_size=20,
                    num_proc=2) #支持批处理和多进程map
#转换成pytorch中的tensor 
ds_encoded.set_format(type="torch",columns = ["input_ids",'attention_mask','token_type_ids','labels'])
#ds_encoded.reset_format() 
ds_encoded[0]
#分割成训练集和测试集
ds_train_val,ds_test = ds_encoded.train_test_split(test_size=0.2).values()
ds_train,ds_val = ds_train_val.train_test_split(test_size=0.2).values()
#在collate_fn中可以做动态批处理(dynamic batching)

def collate_fn(examples):
    return tokenizer.pad(examples) #return_tensors='pt'

#以下方式等价
#from transformers import DataCollatorWithPadding
#collate_fn = DataCollatorWithPadding(tokenizer=tokenizer)

dl_train = torch.utils.data.DataLoader(ds_train, batch_size=16, collate_fn = collate_fn)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=16,  collate_fn = collate_fn)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=16,  collate_fn = collate_fn)
for batch in dl_train:
    break

二,定义模型

一个完整的模型(Model)包括模型架构(Architecture)和模型权重(Checkpoints/Weights)。

transformers提供了3种指定模型架构的方法。

  • 第1种是指定模型架构(如: from transformers import BertModel)

  • 第2种是自动推断模型架构(如: from transformers import AutoModel)

  • 第3种是自动推断模型架构并自动添加Head (如: from transformers import AutoModelForSequenceClassification )

第1种方案和第2种方案用户可以灵活地根据自己要做的任务设计Head,并且需要对基础模型有一定的了解。

此处我们使用第3种方案。

from transformers import AutoModelForSequenceClassification 

#加载模型 (会添加针对特定任务类型的Head)
model = AutoModelForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=2)
dict(model.named_children()).keys()
dict_keys(['bert', 'dropout', 'classifier'])

我们可以用一个batch的数据去试算一下

output = model(**batch)
output.loss
tensor(0.6762, grad_fn=<NllLossBackward0>)

三,训练模型

下面使用我们的梦中情炉 torchkeras 来实现最优雅的微调训练循环。🤗🤗

from torchkeras import KerasModel 

#我们需要修改StepRunner以适应transformers的数据集格式

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        out = self.net(**batch)
        
        #loss
        loss= out.loss
        
        #preds
        preds =(out.logits).argmax(axis=1) 
    
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
        
        all_loss = self.accelerator.gather(loss).sum()
        
        labels = batch['labels']
        acc = (preds==labels).sum()/((labels>-1).sum())
        
        all_acc = self.accelerator.gather(acc).mean()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item(), self.stage+'_acc':all_acc.item()}
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
KerasModel.StepRunner = StepRunner
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

keras_model = KerasModel(model,
                   loss_fn=None,
                   optimizer = optimizer
                   )
keras_model.fit(
    train_data = dl_train,
    val_data= dl_val,
    ckpt_path='bert_waimai.pt',
    epochs=100,
    patience=10,
    monitor="val_acc", 
    mode="max",
    plot = True,
    wandb = False,
    quiet = True
)

d121f802b2db56ebc991f5f7a7e70bd7.png

此处准确率不是很高,可能冻结BERT除head之外的部分的参数进行训练效果会更好一些,感兴趣的小伙伴可以尝试。

四,评估模型

可以使用huggingFace的evaluate库来进行模型评估。

通过evaluate的load方法可以加载一些常用的评估指标。

可以用add_batch逐批次地往这些评估指标上添加数据,最后用compute计算评估结果。

!pip install evaluate
import evaluate
metric = evaluate.load("accuracy")
model.eval()
dl_test = keras_model.accelerator.prepare(dl_test)
for batch in dl_test:
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()
{'accuracy': 0.9128440366972477}

五,使用模型

texts = ["味道还不错,下次再来","这他妈也太难吃了吧","感觉不是很新鲜","还行我家狗狗很爱吃"]
batch = tokenizer(texts,padding=True,return_tensors="pt")
batch = {k:v.to(keras_model.accelerator.device) for k,v in batch.items()}
from torch import nn 
logits = model(**batch).logits 
scores = nn.Softmax(dim=-1)(logits)[:,-1]
print(scores)
#可以看到得分与人的预期是高度一致的
tensor([0.9510, 0.0133, 0.1020, 0.6223], device='cuda:0',
       grad_fn=<SelectBackward0>)

也可以用pipeline把tokenizer和model组装在一起

from transformers import pipeline
classifier = pipeline(task="text-classification",tokenizer = tokenizer,model=model.cpu())
classifier("挺好吃的哦")
[{'label': 'LABEL_1', 'score': 0.9468138813972473}]

六,保存模型

保存model和tokenizer之后,我们可以用一个pipeline加载,并进行批量预测。

model.config.id2label = {0:"差评",1:"好评"}
model.save_pretrained("waimai_10k_bert")
tokenizer.save_pretrained("waimai_10k_bert")
('waimai_10k_bert/tokenizer_config.json',
 'waimai_10k_bert/special_tokens_map.json',
 'waimai_10k_bert/vocab.txt',
 'waimai_10k_bert/added_tokens.json',
 'waimai_10k_bert/tokenizer.json')
from transformers import pipeline 
classifier = pipeline("text-classification",model="waimai_10k_bert")
classifier(["味道还不错,下次再来","我去,吃了我吐了三天"])
[{'label': '好评', 'score': 0.950958251953125},
 {'label': '差评', 'score': 0.9617311954498291}]

更多有趣范例,公众号算法美食屋后台回复关键词:torchkeras,可在torchkeras仓库获取范例源码。

如果你喜欢torchkeras,记得给吃货一颗星星⭐️哦~~😋😋

06935dd7ca03441a2f15df5cdd975299.png

171890a96a068b13ec2828ec8988eaf4.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/614413.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

你知道ping命令是如何工作的吗?

你知道ping命令是如何工作的吗&#xff1f; 我们用来测试一台机器与另一台机器的网络连通性一般会使用ping命令&#xff0c;那么你知道ping命令是如何工作的吗&#xff1f;ping命令是基于ICMP协议工作的。 一、介绍ICMP协议 因特网控制报文协议ICMP&#xff08;Internet Con…

安卓大作业 咖啡展示App

系列文章 安卓大作业 咖啡展示App 文章目录 系列文章1&#xff0e;背景2&#xff0e;所有截图3&#xff0e;总结4. 源代码获取 1&#xff0e;背景 为了便于用户对于咖啡信息的浏览&#xff0c;我设计了一个咖啡展示的app。可以在这个App中查看到一些咖啡的信息&#xff0c;点…

爬虫语言最好用的是那种?

目前最好用的爬虫语言有多种选择&#xff0c;具体的选择取决于你的需求和个人偏好。Python是较为流行的爬虫语言之一&#xff0c;其生态系统丰富&#xff0c;拥有大量优秀的爬虫框架和工具。另外&#xff0c;JavaScript、Go、Ruby等编程语言也可以用于爬虫开发。总之&#xff0…

设置ubuntu下SVN服务开机自启

目录 0.背景环境 1.开机自启步骤 0.背景环境 1&#xff09;ubuntu下&#xff0c;已搭建好svn版本库&#xff0c;具体搭建方法参考文末的其他博客链接 2&#xff09;在搭svn服务器的过程中&#xff0c;发现ubuntu重启后&#xff0c;svn服务就关闭了 svn正常开启时见下图 所以…

为什么大多数企业数字化转型失败率高达80%?

数字化转型失败率为什么这么高&#xff1f; 多年的转型研究表明&#xff0c;企业数字化转型的成功率还不到 30%。 麦肯锡2023年报告显示&#xff0c;只有 16% 的受访者表示他们组织的数字化转型成功地提高了绩效&#xff0c;并使他们能够长期维持变革。 即使是精通数字技术的行…

十五周算法训练营——回溯算法

今天是十五周算法训练营的第十周&#xff0c;主要讲回溯算法专题。&#xff08;欢迎加入十五周算法训练营&#xff0c;与小伙伴一起卷算法&#xff09; 解决一个回溯问题&#xff0c;实际上就是一个决策树的遍历过程&#xff0c;只需要思考三个问题&#xff1a; 路径&#xff1…

市值暴涨8000亿,马斯克告诉了美国同行,为啥需要中国市场?

马斯克访华仅仅40多个小时&#xff0c;却带动了股价连涨5天&#xff0c;涨幅最高达到20%&#xff0c;市值飙涨8000亿元人民币&#xff0c;马斯克也因此再度问鼎全球首富之位&#xff0c;凸显出中国之行给他带来的巨大好处。 一、中国市场带动了特斯拉的辉煌 2018年马斯克为产能…

依据换行符分割字符串numpy.char.splitlines()含换行符与回车符的区别

【小白从小学Python、C、Java】 【等级考试500强双证书考研】 【Python-数据分析】 依据换行符分割字符串 numpy.char.splitlines() 含换行符与回车符的区别 [太阳]选择题 以下说法错误的是&#xff1a; import numpy as np a "I\nLove\rChina\r\nforever" print(&q…

APACHE-ATLAS-2.1.0 - 安装HIVE HOOK(六)

写在前面 本博文以获取HIVE元数据为例&#xff0c;进行流程和源码的分析。 请提前安装好HADOOP和HIVE的环境&#xff0c;用于测试。 ATLAS官网&#xff1a;https://atlas.apache.org/#/HookHive ATLAS支持的元数据源 什么是Hive Hook&#xff08;钩子&#xff09; HOOK是一种…

可视化的三种图结构方案 (canvas、fabric、G6)

原生 canvas、fabric 以及 G6&#xff0c;三种方案各有优劣势 原生 canvasfabricG6优点灵活、自由、可定制化非常强封装了 canvas 的 api&#xff0c;使用简单灵活提供了复杂树、图等 api&#xff0c;只需要按照文档配置即可缺点开发复杂、耗时对于构建大型树、图等复杂、耗时…

chatgpt赋能python:选择函数:Python实现之道

选择函数&#xff1a;Python 实现之道 什么是选择函数&#xff1f; 在 SEO 中&#xff0c;选择函数是指搜索引擎在对网站内容进行排名时所采用的一种规则。选择函数由搜索引擎定义&#xff0c;其目的在于建立一个算法来确定哪些网站会出现在搜索结果的前几页中。对于网站管理…

百度视频质量评测的实践之路

视频编解码技术日新月异&#xff0c;新的编解码技术赋予视频业务新的应用场景和新的用户视听体验。同时&#xff0c;视频作为带宽消耗大户&#xff0c;如何在视听体验和视频带宽之间取得最优的平衡是一个永恒的话题。视频质量评测主要用来回答&#xff1a;体验是否改善、带宽是…

chatgpt赋能python:如何用Python制作动画?

如何用Python制作动画&#xff1f; Python作为一种优秀的编程语言&#xff0c;可以用于不同领域的编程。其中&#xff0c;Python也可以被用于创建动画。使用Python的主要好处之一是其强大的Matplotlib库&#xff0c;它可以帮助我们更轻松地创建可视化效果。 什么是Matplotlib…

chatgpt赋能python:Python如何取出List中的数据

Python如何取出List中的数据 在Python中&#xff0c;列表&#xff08;List&#xff09;是一种非常常见的数据类型&#xff0c;它可以存储任意类型的数据&#xff0c;并且可以按照下标索引来访问其中的元素。本篇文章将介绍如何使用Python来取出List中的数据。 常规方法 在Py…

Android BlueToothBLE入门(一)——低功耗蓝牙介绍

学更好的别人&#xff0c; 做更好的自己。 ——《微卡智享》 本文长度为3150字&#xff0c;预计阅读8分钟 前言 距上篇文章发布都一个多月了&#xff0c;先声明&#xff0c;我可不会停更。这么长时间没更新文章&#xff0c;其实原因就三点&#xff1a; 原因一是工作上事确实多&…

JavaScript之事件(十)

JavaScript之事件 1、事件绑定2、事件类型3、事件的传播4、事件对象1、事件对象常用的属性2、事件对象常用的方法 事件可用于处理、验证用户输入、用户动作和浏览器动作。 1、事件绑定 事件绑定就是给HTML标签绑定特定的事件&#xff0c;当该事件触发&#xff0c;则会执行相应的…

一款射频芯片的layout设计指导案例-篇章2

《一款射频芯片的layout设计指导案例-篇章1》中&#xff0c;我们阐述了RTL8762元件布局顺序、DC/DC电路元件布局走线、电源Bypass布局规范、外部flash布局走线、RF布局走线&#xff0c; 本篇阐述晶振、ESD、板层等相关指导建议—— 40MHz晶振布局走线规范 在没有结构限制情况下…

chatgpt赋能python:Python如何在同一行输入三个数?

Python如何在同一行输入三个数&#xff1f; Python语言是一门广泛使用的编程语言&#xff0c;被广泛应用于数据分析、机器学习、Web开发、科学计算、人工智能等领域。但是&#xff0c;有时候我们需要在同一行输入多个变量或数字&#xff0c;这可能给一些初学者带来一些困惑。本…

暑期实习开始啦「GitHub 热点速览」

作者&#xff1a;HelloGitHub-小鱼干 无巧不成书&#xff0c;刚好最近有小伙伴在找实习&#xff0c;而 GitHub 热榜又有收录实习信息的项目在榜。所以&#xff0c;无意外本周特推就收录了这个实习项目&#xff0c;当然还有国内版本。除了应景的实习 repo 之外&#xff0c;还有帮…

快手 | 后端Java实习生 | 一面

目录 1.Redis缓存和MySQL数据一致性如何保证&#xff1f;2.你使用缓存&#xff0c;在高并发的情况下&#xff0c;如果多个缓存同时过期了怎么办&#xff1f;3.消息队列积压了怎么办&#xff1f;4.jdk1.8之后Java内存模型分别哪几个部分&#xff1f;每个部分用一句话概括一下&am…