huggingface实现中文文本分类

news2024/9/28 23:30:30

目录

1  自定义数据集

2  分词 

 2.1  重写collate_fn方法¶

3  用BertModel加载预训练模型 

 4  模型试算

 5  定义下游任务¶

6  训练 

7  测试 


 

#导包
import torch
from datasets import load_from_disk  #用于加载本地磁盘的datasets文件

1  自定义数据集

#自定义数据集
#需要继承 torch.utils.data.Dataset,
#并且实现__init__(self)/__len__(self)/__getitem__(self,i)这些方法
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        #加载本地磁盘的datasets
        self.datasets = load_from_disk('../data/ChnSentiCorp')  #self.datasets是一个字典,包含训练、校验、测试的datatset
        self.dataset = self.datasets[split]  #使用split来区分获取的是训练、校验、测试的datatset中的哪一个
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, i):
        """让数据集像列表一样可以根据索引获取数据‘text’与“label”"""
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']
        return text, label
    
dataset = Dataset(split='train')
dataset

 

<__main__.Dataset at 0x2afb31f03a0>
dataset.dataset

 

Dataset({
    features: ['text', 'label'],
    num_rows: 9600
})
len(dataset)

 9600

dataset[0]

 

('选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般',
 1)

2  分词 

#分词工具导包
from transformers import BertTokenizer
#加载字典和分词工具    huggingface自带的中文词典bert-base-chinese加载进来
tokenizer = BertTokenizer.from_pretrained(r'../data/bert-base-chinese/')
tokenizer

 

BertTokenizer(name_or_path='../data/bert-base-chinese/', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

 2.1  重写collate_fn方法¶

自定义取数据的方法 

def collate_fn(data):
    #从传入的数据集data(dataset)中分离出文本句子sents和标签labels
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]
    
    #编码:完成句子的分词操作
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                       truncation=True,      #文本超过最大长度会被截断
                                       padding='max_length',  #不足最大长度,补充《pad》
                                       return_tensors='pt',   #返回的数据是pythorch类型
                                       return_length='True')
    
    #获取编码之后的数字 索引
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)   #labels在pythorch中一般设置成long类型
    
    return input_ids, attention_mask, token_type_ids, labels
#创建数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset,  #将数据集传进来
    batch_size=16,    #一批数据16个
    collate_fn=collate_fn,  #传入自定义的取数据的方法
    shuffle=True,    #打乱数据
    drop_last=True   #最后一批数据若不满足16个数据,就删除
)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break   #只完成赋值,不输出打印
print(len(loader))
600
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape)

 

torch.Size([16, 512]) torch.Size([16, 512]) torch.Size([16, 512]) torch.Size([16])

3  用BertModel加载预训练模型 

#Bert模型导包
from transformers import BertModel


#加载预训练模型
pretrained = BertModel.from_pretrained('../data/bert-base-chinese/')

#固定bert的参数:  遍历参数,修改每一个参数的requires_grad_,使其不能进行求导、梯度下降
for param in pretrained.parameters():
    param.requires_grad_(False)  #变量最右边添加下划线,表示直接修改变量的原始属性

 4  模型试算

out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)


out.last_hidden_state.shape  
#torch.Size([16, 512, 768])
#16:一批数据有16个
#512:一个句子的长度
#768:yincan


#torch.Size([16, 512, 768])


out.last_hidden_state[:, 0].shape  #获取cls特殊词的输出结果


#torch.Size([16, 768])

 5  定义下游任务¶

class Model(torch.nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        #预训练模型层
        self.pretrained_model = pretrained_model
        #输出层:全连接层
        #768:表示将上一层预训练模型层768个的输出结果作为全连接层的输入数据
        #2:表示二分类问题就会有两个输出结果
        self.fc = torch.nn.Linear(768, 2)
        
        
    #前向传播
    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
            
        #在Bert中,获取cls特殊词的输出结果来做分类任务
        out = self.fc(out.last_hidden_state[:, 0])
        out.softmax(dim=1)
        return out
    
#声明模型
model = Model(pretrained)
#创建模型:输入参数
model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape

#torch.Size([16, 2])  16个句子对,2个分类结果
torch.Size([16, 2]) 
torch.cuda.is_available()

 True

#设置设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

 device(type='cuda', index=0)

6  训练 

# AdamW  梯度下降的优化算法  在Adam的基础上稍微改进了一点
from transformers import AdamW  


#训练
optimizer = AdamW(model.parameters(), lr=5e-4)  #设置优化器
#声明损失函数
loss = torch.nn.CrossEntropyLoss()

#建模
model = Model(pretrained)
#模型训练
model.train()
#将模型传到设备上
model.to(device)
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    #把传入的数据都传到设备上
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)
    
    out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    
    #计算损失函数
    l = loss(out, labels)  #out:预测值   labels:真实值
    #用损失函数来反向传播
    l.backward()
    #梯度更新
    optimizer.step()
    #梯度清零
    optimizer.zero_grad()
    
    #每隔5次,计算一下准确率
    if i % 5 == 0:
        out = out.argmax(dim=1)  #计算预测值
        #计算准确率
        accuracy =(out == labels).sum().item() / len(labels)   #item()是拿到求和之后的数字
        print(i, l.item(), accuracy)
        
    if i == 100:
        break  

 

0 0.6882429718971252 0.5
5 0.7822732329368591 0.3125
10 0.7996063828468323 0.25
15 0.7967076301574707 0.3125
20 0.839418888092041 0.3125
25 0.6795901656150818 0.5625
30 0.7707732319831848 0.3125
35 0.6784831285476685 0.4375
40 0.728607177734375 0.375
45 0.7425007224082947 0.375
50 0.6188052892684937 0.5625
55 0.7185056805610657 0.375
60 0.8377469778060913 0.1875
65 0.7717736959457397 0.3125
70 0.7421607375144958 0.4375
75 0.7337921857833862 0.375
80 0.8023619651794434 0.3125
85 0.7294195890426636 0.5625
90 0.7909258008003235 0.3125
95 0.7105788588523865 0.4375
100 0.7786014676094055 0.5

7  测试 

model.eval()用于将模型设置为评估模式。‌ 在评估模式下,模型将关闭一些在训练过程中使用的特性,如‌Dropout和BatchNorm层的训练模式,以确保模型在推理时能够给出准确的结果。使用model.eval()可以帮助我们更好地评估模型的性能,并发现潜在的问题。 

def test():
    model.eval()
    correct = 0
    total = 0
    
    loader_test = torch.utils.data.DataLoader(dataset = Dataset('validation'),
                                              batch_size = 32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        #输入的数据传入到设备上
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        labels = labels.to(device)
        
        if i == 5:
            break
        print(i)
        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = out.argmax(dim=1)
        correct += (out==labels).sum().item()
        total += len(labels)
    print(correct / total)
test()

 

0
1
2
3
4
0.40625

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2175112.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

番外篇 | 复现AC-YOLOv5,进行自动化织物缺陷检测

前言:Hello大家好,我是小哥谈。我们提出了一种基于AC-YOLOv5的新型纺织缺陷检测方法。将空洞空间金字塔池化(ASPP)模块引入YOLOv5主干网络中,提出了squeeze-and-excitation(CSE)通道注意力模块,并将其引入到YOLOv5主干网络中。🌈 目录 🚀1.基础概念 🚀2.添…

【d54_2】【Java】【力扣】142.环形链表

思路 关于判断是否重复的就hashSet&#xff0c;这种有主动去重性质的类 新建一个hashSet 遍历链表并放进hashSet&#xff0c; 如果不能放&#xff0c;说明这个遍历过&#xff0c;这个就是环的地方 如果最后到遍历到null&#xff0c;说明没环 代码 /*** Definition for s…

5.3 克拉默法则、逆矩阵和体积

本节是使用代数而不是消元法来求解 A x b A\boldsymbol x\boldsymbol b Axb 和 A − 1 A^{-1} A−1。所有的公式都会除以 det ⁡ A \det A detA&#xff0c; A − 1 A^{-1} A−1 和 A − 1 b A^{-1}\boldsymbol b A−1b 中的每个元素都是一个行列式除以 A A A 的行列式。…

基于微信小程序的网上商城+ssm(lw+演示+源码+运行)

摘 要 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;微信小程序被用户普遍使用&#xff0c;为方便用户能够可以…

STM32F407单片机编程入门(二十七)以太网接口详解及实战含源码

文章目录 一.概要二.单片机以太网系统基本结构1.OSI 七层模型2.单片机实现以太网功能组成 三.STM32F407VET6单片机以太网内部结构1.MII接口介绍2.RMII接口介绍 四.LWIP TCP/IP协议栈介绍五.PHY收发器LAN8720介绍1.LAN8720内部框图2.LAN8720应用电路3.LAN8720以太网模块 六.Cube…

在公司网络环境下,无法访问公共网络时,可在插件端配置网络代理后使用通义灵码

在公司网络环境下&#xff0c;无法访问公共网络时&#xff0c;可在插件端配置网络代理后使用通义灵码。 通义灵码插件下载&#xff1a;通义灵码_智能编码助手_AI编程-阿里云 配置网络代理 公司网络通常使用 HTTP 代理服务器在网络流量发送到目标位置之前进行拦截&#xff0c;以…

6--苍穹外卖-SpringBoot项目中菜品管理 详解(二)

目录 菜品分页查询 需求分析和设计 代码开发 设计DTO类 设计VO类 Controller层 Service层接口 Service层实现类 Mapper层 功能测试 删除菜品 需求设计和分析 代码开发 Controller层 Service层接口 Service层实现类 Mapper层 功能测试 修改菜品 需求分析和设…

Spring--boot自动配置原理案例--阿里云--starter

Spring–boot自动配置原理案例–阿里云–starter 定义这个starter的作用是它可以将阿里云的工具类自动放入IOC容器中&#xff0c;供人使用。 我们看一看构建starter的过程&#xff0c;其实就是在atuoconfigure模块中加入工具类&#xff0c;然后写一个配置类在其中将工具类放入…

【ChromeDriver安装】爬虫必备

以下是安装和配置 chromedriver 的步骤&#xff1a; 1. 确认 Chrome 浏览器版本 打开 Chrome 浏览器&#xff0c;点击右上角的菜单按钮&#xff08;三个点&#xff09;&#xff0c;选择“帮助” > “关于 Google Chrome”。 2. 下载 Chromedriver 根据你的 Chrome 版本&…

【研赛A题成品论文】24华为杯数学建模研赛A题成品论文+可运行代码丨免费分享

2024华为杯研究生数学建模竞赛A题精品成品论文已出&#xff01; A题 风电场有功功率优化分配 一、问题分析 A题是一道工程建模与优化类问题&#xff0c;其目的是根据题目所给的附件数据资料分析风机主轴及塔架疲劳损伤程度&#xff0c;以及建立优化模型求解最优有功功率分配…

哪些AI软件能轻松搞定你的文案、总结、论文、计划书?

大家好&#xff01;在我们每天紧张忙碌的生活中&#xff0c;有时候一天结束时&#xff0c;我们还有一堆事情等着处理。 图片 但别担心&#xff0c;今天我要为大家介绍几款AI软件&#xff0c;它们可以在你忙碌的一天结束后&#xff0c;成为你的得力助手&#xff0c;帮你轻松管…

初识Tomcat

Tomcat是一款可以运行javaWebAPP的服务器软件。 一个服务器想要执行java代码&#xff0c;则需要JRE&#xff08;jvm、java运行环境等&#xff09;&#xff0c;但是需要执行javaWEB项目则还需要服务器软件&#xff0c;Tomacat就是其中很流行的一款。因为一个javaWEB项目会有很多…

Accelerate单卡,多卡config文件配置

依赖库 from accelerate import Accelerator from accelerate import DistributedDataParallelKwargs ddp_kwargs DistributedDataParallelKwargs(find_unused_parametersTrue) accelerator Accelerator(kwargs_handlers[ddp_kwargs]) 代码中删除所有的.cuda() 或者to(devic…

Xshell连接服务器

一、Xshell-7.0.0164p、Xftp 7下载 1.1、文件下载 通过网盘分享的文件&#xff1a;xshell 链接: https://pan.baidu.com/s/1qc0CPv4Hkl19hI9tyvYZkQ 提取码: 5snq –来自百度网盘超级会员v2的分享 1.2、ip连接 下shell和xftp操作一样&#xff1a;找到文件—》新建—》名称随…

【英特尔IA-32架构软件开发者开发手册第3卷:系统编程指南】2001年版翻译,1-1

文件下载与邀请翻译者 学习英特尔开发手册&#xff0c;最好手里这个手册文件。原版是PDF文件。点击下方链接了解下载方法。 讲解下载英特尔开发手册的文章 翻译英特尔开发手册&#xff0c;会是一件耗时费力的工作。如果有愿意和我一起来做这件事的&#xff0c;那么&#xff…

论文不同写作风格下的ChatGPT提示词分享

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 在学术论文写作中&#xff0c;不同的写作风格能显著影响文章的表达效果与读者的理解。无论是描述性、分析性、论证性&#xff0c;还是批判性写作风格&#xff0c;合理选择和运用恰当的写…

生成模型小结

突然发现之前整理的makedown有必要放在博客里面,这样不同的设备之间可以直接观看达到复习的效果. GAN G和D不断的博弈提高自己。GAN的优点是保真度比较高&#xff0c;缺点是多样性比较低。 (auto-encoder)AE&#xff0c;DAE、VAE、VQVAE 输入x&#xff0c;经过编码器生成&…

Elasticsearch学习笔记(2)

索引库操作 在Elasticsearch中&#xff0c;Mapping是定义文档字段及其属性的重要机制。 Mapping映射属性 type&#xff1a;字段数据类型 1、字符串&#xff1a; text&#xff1a;可分词的文本&#xff0c;适用于需要全文检索的情况。keyword&#xff1a;用于存储精确值&am…

二阶低通滤波器(Simulink仿真)

1、如何将S域传递函数转为Z域传递函数 传递函数如何转化为差分方程_非差分方程转成差分方程-CSDN博客文章浏览阅读4.1k次,点赞4次,收藏50次。本文介绍了如何将传递函数转化为差分方程,主要适用于PLC和嵌入式系统。通过MATLAB的系统辨识工具箱获取传递函数,并探讨了离散化方…

OpenCV第十二章——人脸识别

1.人脸跟踪 1.1 级联分类器 OpenCV中的级联分类器是一种基于AdaBoost算法的多级分类器&#xff0c;主要用于在图像中检测目标对象。以下是对其简单而全面的解释&#xff1a; 一、基本概念 级联分类器&#xff1a;是一种由多个简单分类器&#xff08;弱分类器&#xff09;级联组…