Transformers微调BERT模型实现文本分类任务(colab)

news2024/10/7 14:30:47

1. 数据准备

使用colab进行实验
在这里插入图片描述

左上角上传数据,到当前实验室
右上角设置GPU选择

查看GPU

! nvidia-sm

在这里插入图片描述
安装需要的库

!pip install datasets
!pip install transformers[torch]
!pip install torchkeras

1.1 读取数据

import pandas as pd
data = pd.read_csv("/content/news.csv")
data

在这里插入图片描述

1.2 数据处理

我们看到label是中文字符串,训练时需要转换成数值型,如下

{'教育': 0,
 '体育': 1,
 '科技': 2,
 '时尚': 3,
 '房产': 4,
 '家居': 5,
 '财经': 6,
 '时政': 7,
 '娱乐': 8,
 '游戏': 9}

遍历一下就可以,并将全数据转换为data frame

#处理标签
def label_dic(data,label):
    d = {}
    labels = data[label].unique()
    for i,v in enumerate(labels):
        d[v] = i     
    return d
#数据整理
def get_train_data(data,col_x,col_y,label_dic):
    content = data[col_x]
    label = []
    for i in data[col_y]:
        label.append(label_dic.get(i))
    return content,label
label_dic = label_dic(data,"label")
content,label  = get_train_data(data,"text","label",label_dic)

1.3 数据转换

将数据转换成可以进行训练的数据

data = pd.DataFrame({"content":content,"label":label})
data = shuffle(data)

1.4 创建分词

from transformers import AutoTokenizer #BertTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
tokenizer

tokenizer
BertTokenizerFast(name_or_path=‘bert-base-chinese’, vocab_size=21128, model_max_length=512, is_fast=True, padding_side=‘right’, truncation_side=‘right’, special_tokens={‘unk_token’: ‘[UNK]’, ‘sep_token’: ‘[SEP]’, ‘pad_token’: ‘[PAD]’, ‘cls_token’: ‘[CLS]’, ‘mask_token’: ‘[MASK]’}, clean_up_tokenization_spaces=True), added_tokens_decoder={
0: AddedToken(“[PAD]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
100: AddedToken(“[UNK]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
101: AddedToken(“[CLS]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
102: AddedToken(“[SEP]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
103: AddedToken(“[MASK]”, rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

训练集取80%

train_len = round(len(data)*0.8)
train_data = tokenizer(data.content.to_list()[:train_len], padding = "max_length", max_length = 128, truncation=True ,return_tensors = "pt")
train_label = data.label.to_list()[:train_len]

2. 模型训练

2.1 导入模型

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=10)

在这里插入图片描述
在AutoModelForSequenceClassification.from_pretrained(“bert-base-chinese”, num_labels=10) 这个函数中,transformer 已经帮你定义了损失函数,既10个分类的交叉熵损失,所以下方我们只需要自己定义优化器和学习率即可。

2.2 定义优化器和学习率

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
batch_size = 16
train = TensorDataset(train_data["input_ids"], train_data["attention_mask"], torch.tensor(train_label))
train_sampler = RandomSampler(train)
train_dataloader = DataLoader(train, sampler=train_sampler, batch_size=batch_size)
  • train_data 是一个包含输入数据的字典,其中 “input_ids” 是模型输入的token ID,“attention_mask” 是用于标识输入序列中哪些位置是有效的前景tokens,“labels” 是序列分类任务的标签。我们可以自己打印下我们前面定义好的训练数据,如下
    在这里插入图片描述

  • TensorDataset 将数据转换为一个PyTorch张量数据集,其中每个样本是一个包含input_ids、attention_mask和label的元组。

  • RandomSampler 从数据集中随机抽取样本进行训练,这对于避免过拟合和获得更具代表性的训练集是有益的。

  • DataLoader 负责将数据集划分为批次,并为训练提供迭代器。

#定义优化器
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=1e-4)
#定义学习率和训练轮数
num_epochs = 1
from transformers import get_scheduler
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
  • model 是需要训练的模型,我们前面导入的预训练模型。
  • AdamW 是用于优化模型参数的优化器,它是一种改进的Adam优化器,常用于深度学习。
  • lr_scheduler 是学习率调度器,它用于在训练过程中调整学习率。在这个例子中,使用的是"linear"调度器,它会在训练开始时逐渐增加学习率,然后逐渐减少。
  • num_epochs 定义了训练的轮数。
  • num_training_steps 定义了训练的总步数,它是 epochs 乘以训练数据集的总步数。
  • get_scheduler 函数用于根据提供的参数创建学习率调度器。

2.3 开始训练

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

检查是否有可用的GPU,如果有,则将device设置为cuda;否则,设置为cpu。模型移动到选定的设备上。

循环

for epoch in range(num_epochs):
    total_loss = 0
    model.train()
    for step, batch in enumerate(train_dataloader):
        if step % 10 == 0 and not step == 0:
            print("step: ",step, "  loss:",total_loss/(step*batch_size))
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        model.zero_grad()        
        outputs = model(b_input_ids, 
                    token_type_ids=None, 
                    attention_mask=b_input_mask, 
                    labels=b_labels)

        loss = outputs.loss       
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        lr_scheduler.step()
    avg_train_loss = total_loss / len(train_dataloader)      
    print("avg_loss:",avg_train_loss)

-b_input_ids、b_input_mask、b_labels:这些是从批次中提取的输入ID、掩码和标签,并移动到device上。

  • model.zero_grad():清除模型的梯度。
  • outputs = model(…):使用模型进行前向传播。
  • loss = outputs.loss:从输出中提取损失。
  • total_loss += loss.item():累加损失。
  • loss.backward():进行反向传播,计算损失关于模型参数的梯度。
  • torch.nn.utils.clip_grad_norm_(…):对梯度进行裁剪,以防止梯度爆炸。
  • optimizer.step():更新模型的参数。
  • lr_scheduler.step():更新学习率。
    在这里插入图片描述

3. 模型预测

inp = "专家指导:参加SSAT考试读美国优质高中(图)SSAT考试的全称是Secondary SchoolAdmission Test),是美国(微博)中学入学测试,相当于中国的中考,近年来,越来越多的中国学生通过参加SSAT申请美国高中,然后一步步进入世界一流大学。"

import numpy as np
test = tokenizer(inp,return_tensors="pt",padding="max_length",max_length=128)

model.eval()
with torch.no_grad():  
    test["input_ids"] = test["input_ids"].to(device)
    test["attention_mask"] = test["attention_mask"].to(device)
    outputs = model(test["input_ids"], 
                    token_type_ids=None, 
                    attention_mask=test["attention_mask"])
pred_flat = np.argmax(outputs["logits"].cpu(),axis=1).numpy().squeeze()
pred_flat.tolist() 
#0

我们也可以把标签进行一个映射

id2label_dic={}
for k,v in label_dic.items():
    id2label_dic[v] = k
id2label_dic[pred_flat.tolist()]
#教育

4. 模型保存

model.config.id2label = id2label_dic

model.save_pretrained("./bert0207")
tokenizer.save_pretrained("./bert0207")

在这里插入图片描述
调取模型预测

from transformers import pipeline 
classifier = pipeline("text-classification",model="./bert0207")

classifier(inp)
#[{'label': '教育', 'score': 0.9345001578330994}]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1437946.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

thinkphp6入门(19)-- 中间件向控制器传参

可以通过给请求对象赋值的方式传参给控制器&#xff08;或者其它地方&#xff09;&#xff0c;例如 <?phpnamespace app\middleware;class Hello {public function handle($request, \Closure $next){$request->hello ThinkPHP;return $next($request);} } 然后在控制…

Android7.0-Fiddler证书问题

一、将Fiddler的证书导出到电脑&#xff0c;点击Tools -> Options -> HTTPS -> Actions -> Export Root Certificate to Desktop 二、下载Window版openssl&#xff0c; 点击这里打开页面&#xff0c;下拉到下面&#xff0c;选择最上面的64位EXE点击下载安装即可 安…

第十五篇【传奇开心果系列】Python的OpenCV库技术点案例示例:图像配准

传奇开心果短博文系列 系列短博文目录Python的OpenCV库技术点案例示例系列短博文目录前言一、常见的图像配准任务介绍二、图像配准任务:图像拼接介绍和示例代码三、图像配准任务:图像校正介绍和示例代码四、图像配准任务:图像配准介绍和示例代码五、基于特征点的配准方法介绍…

typecho 在文章中添加 bilibili 视频

一、获取视频来源&#xff1a; 可以有2种方式来定位一个 bilibili 视频&#xff1a; 第一种是使用 bvid 参数定位第二种是使用 aid 参数定位 如何获取这两个参数&#xff1f; 首先我们可以看看 bilibili 网站中的视频页面链接其实可以分为两种&#xff1a; 第一种是类似&a…

自动化测试 —— Web自动化三大报错

Web自动化三大报错有哪些呢&#xff1f;接下来给大家讲讲。 Web自动化三大报错&#xff08;Exception&#xff09; 1. Exception1&#xff1a;no such element&#xff08;没有在页面上找到这个元素&#xff09; reason1&#xff1a;元素延迟加载了 solution&#xff1a; …

手把手教你激活BetterZip for Mac免费下载(附注册码) v5.3.4

软件介绍 BetterZip for Mac是一款广受欢迎的文件解压缩工具&#xff0c;支持Mac以及Windows等多个平台&#xff0c;能够生成被Win和Mac支持的压缩包&#xff0c;让用户可以在Mac和Windows电脑之间使用一种通用压缩包&#xff0c;用户可以更快捷地向压缩文件中添加和删除文件&…

设计模式-行为型模式(下)

1.访问者模式 访问者模式在实际开发中使用的非常少,因为它比较难以实现并且应用该模式肯能会导致代码的可读性变差,可维护性变差,在没有特别必要的情况下,不建议使用访问者模式. 访问者模式(Visitor Pattern) 的原始定义是&#xff1a; 允许在运行时将一个或多个操作应用于一…

有哪些方法可以配置并发服务器?

通过合理配置并发服务器&#xff0c;可以提高服务器的处理能力和响应速度&#xff0c;从而更好地满足用户需求。本文将介绍一些常见的并发服务器配置方法&#xff0c;以帮助您更好地实现服务器的高效运行。 一、选择合适的操作系统 操作系统的选择是并发服务器配置的重要环节…

泛娱乐社交出海洞察,Flat Ads解锁海外增长新思路

摘要:解读泛娱乐社交应用出海现状与趋势,解锁“掘金”泛娱乐社交出海赛道新思路。 根据全球舆情监测机构 Meltwater 和社交媒体机构We are Social最新发布数据显示,全球社交媒体活跃用户数量已突破50亿,约占世界人口总数62.5%。庞大的用户数量意味着广阔的增量空间,目前,随着全…

【NLP 自然语言处理(一)---词向量】

文章目录 什么是NLP自然语言处理发展历程自然语言处理模型模型能识别单词的方法词向量分词 一个向量vector表示一个词词向量的表示-one-hot多维词嵌入word embeding词向量的训练方法 CBOW Skip-gram词嵌入的理论依据 一个vector&#xff08;向量&#xff09;表示短语或者文章ve…

力扣精选算法100道——和为 K 的子数组[前缀和专题]

和为K的子数组链接 目录 第一步&#xff1a;了解题意​编辑 第二步&#xff1a;算法原理 第三步&#xff1a;代码 第一步&#xff1a;了解题意 数组中和为k的连续子数组&#xff0c;我们主要关注的是连续的&#xff0c; 比如[1,1,1],和为2的子数组有俩个&#xff0c;比如第…

Springboot简单设计两级缓存

两级缓存相比单纯使用远程缓存&#xff0c;具有什么优势呢&#xff1f; 本地缓存基于本地环境的内存&#xff0c;访问速度非常快&#xff0c;对于一些变更频率低、实时性要求低的数据&#xff0c;可以放在本地缓存中&#xff0c;提升访问速度 使用本地缓存能够减少和Redis类的远…

JWT令牌 | 一个区别于cookie/session的更安全的校验技术

目录 1、简介 2、组成成分 3、应用场景 4、生成和校验 5、登录下发令牌 &#x1f343;作者介绍&#xff1a;双非本科大三网络工程专业在读&#xff0c;阿里云专家博主&#xff0c;专注于Java领域学习&#xff0c;擅长web应用开发、数据结构和算法&#xff0c;初步涉猎Pyth…

波纹扩散效果

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>波纹扩散</title><style>body {disp…

在 vite + ts 中,使用require

因为 require 属于 Webpack 的方法&#xff0c;所以 Vite 项目中是不能使用require的&#xff0c;所以控制台会给你报错&#xff0c;如下 解决办法如下&#xff1a;就是很不情愿的办法&#xff0c;没招啊 第一步、安装插件 npm i vite-plugin-require-transform --save-de…

【蓝桥杯冲冲冲】[NOIP2017 提高组] 宝藏

蓝桥杯备赛 | 洛谷做题打卡day29 文章目录 蓝桥杯备赛 | 洛谷做题打卡day29[NOIP2017 提高组] 宝藏题目背景题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1样例 #2样例输入 #2样例输出 #2提示题解代码我的一些话[NOIP2017 提高组] 宝藏 题目背景 NOIP2017 D2T2 题目描…

服务器安装Docker (centOS)

1. 卸载旧版本的Docker&#xff08;如果有&#xff09; 首先&#xff0c;如果您的系统上安装了旧版本的Docker&#xff0c;需要将其卸载。Docker的旧版本称为docker或docker-engine。使用以下命令来卸载旧版本&#xff1a; sudo yum remove docker \ docker-client \ docker-…

vCenterServer部署

一、硬件配置 vCenterServer本身最低的硬件要求是14GB&#xff0c;而vCenterServer则是以虚拟机的形式安装在ESXi中的虚拟机&#xff0c;所以ESXi的最低硬件要求是15.5GB&#xff0c;就是15872MB 二、安装vCenterServer 直接解压VMware-VCSA-all-8.0.0-20920323.iso&#xf…

使用No-SQL数据库支持连接查询用例的讨论

简介 在本文中&#xff0c;我们将简单介绍什么是No-SQL数据库。然后我们会讨论一种使用关系数据库比较容易实现的查询&#xff0c;即连接查询&#xff0c;怎么样使用No-SQL来实现。 什么是No-SQL数据库 与No-SQL数据库相对应的是传统的关系数据库&#xff08;RDBMS&#xff…

Leetcode刷题-(11~15)-Java+Python+JavaScript

算法是程序员的基本功&#xff0c;也是各个大厂必考察的重点&#xff0c;让我们一起坚持写算法题吧 遇事不决&#xff0c;可问春风&#xff0c;春风不语&#xff0c;即是本心。 我们在我们能力范围内&#xff0c;做好我们该做的事&#xff0c;然后相信一切都事最好的安排就可…