基于pytorch本地部署微调bert模型(yelp文本分类数据集)

news2025/1/11 19:45:39

项目介绍

本项目使用hugging face上提供的Bert模型API,基于yelp数据集,在本地部署微调Bert模型,官方的文档链接为https://huggingface.co/docs/transformers/quicktour,但是在官方介绍中出现了太多的API调用接口,无法在真正意义上做到本地微调部署,本项目致力于只通过Bert模型的接口获得Bert模型,其他包括数据集预处理、损失函数定义、模型训练以后后续模型的调试部署都在本地进行,让微调的过程清晰化和透明化。


BERT模型

BERT(Bidirectional Encoder Representations from Transformers)是谷歌于2018年发布的自然语言处理模型。其核心创新在于双向上下文理解,允许模型同时考虑上下文中的前后词,从而提升对文本含义的理解。BERT的训练过程采用了无监督学习,使用大规模的文本数据进行预训练,然后通过微调适应具体任务,如问答或情感分析。这个模型的发布极大推动了NLP领域的发展,成为许多后续模型的基础。

BERT模型的预训练过程分为两个目标任务:

  • 将训练数据集文本中的内容按照一定的比例挖空一些单词(或文字),BERT模型通过挖空单词的上下文本内容与语义复现出该位置上应该出现的单词;
  • 将多个句子组合成一个句子组,让BERT模型判断句子之间是否存在上下语句的关系。

yelp数据集

Yelp文本分类数据集是一个用于自然语言处理(NLP)任务的公开数据集,主要用于训练和评估文本分类模型。该数据集包含来自Yelp网站的用户评论,通常包括以下几个关键特征:

  1. 评论文本:用户对商家的评论内容,通常包含对服务、食品质量、环境等方面的评价。

  2. 星级评分:用户根据他们的体验给出的评分,通常是1到5颗星。

  3. 商家信息:评论关联的商家信息,包括商家名称、类别和位置等。


代码实现

依赖环境

基于pytorch架构,显存最好大于或者等于4GB

import torch
import torch.nn as nn
from tqdm.auto import tqdm
from statistics import mean
from torch.optim import AdamW
import matplotlib.pyplot as plt
from transformers import get_scheduler
from torch.utils.data import DataLoader
from prepare_dataset_model import bert_dataset, model_bert,get_dataset_list
from transformers import AutoModel,AutoTokenizer,AutoModelForSequenceClassification

# from datasets import load_dataset 
# 可以直接从huggingface上下载并预处理训练数据集
# 感兴趣的朋友可以自行查阅函数说明文档,本项目使用自定义的数据集预处理类与函数

prepare_dataset_model是定义在另一个脚本里的数据预处理函数,接下来先展开这一部分

数据集预处理

从官网上可以下载yelp对应的数据集,本项目选择的是csv格式的数据,使用pandas可以非常轻松的对csv格式的数据进行操作与处理

下面是放在prepare_dataset_model.py脚本中的代码

数据列表的获取

def get_dataset_list(data_path,simple_num, rate=0.8):
    data = pd.read_csv(data_path)
    data_list = []
    print('loading dataset...')
    show_bar = tqdm(range(simple_num))

    for index,item in data.iterrows():
        data_list.append({"text":item['text'], "label":item['label']})
        show_bar.update(1)
        if index==simple_num:
            break

    lenght = len(data_list)
    train_data_list = data_list[:int(lenght*rate)]
    test_data_list = data_list[int(lenght*rate):]
    return train_data_list,test_data_list

ylep数据集有几万条数据,一次性全部读出来在大多数情况下显得不现实,于是定义simple_num参数进行传入我们想要处理的数据数,rate是训练集与测试集的比例,最后返回按照比例的训练集列表和测试集列表

数据集类的定义

class bert_dataset(Dataset):
    def __init__(self, data_list, tokenizer):
        self.dataset = data_list
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        item = self.dataset[idx]
        text = item["text"]
        label = item["label"]

        inputs = self.tokenizer(text,padding="max_length",truncation=True,return_tensors='pt')
        # inputs["label"] = label
        return inputs,label
    
    def __len__(self):
        return len(self.dataset)

这一步是经典的自定义数据集操作,值得一提的是,假如直接把label字段传入inputs的字典,在微调的模型中同样也能够接受,并且在模型中可以直接返回按照交叉熵损失函数计算的loss值;也可以将label字段分开,另外定义损失函数进行计算,本项目选择的是第二种

完整代码

import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from tqdm.auto import tqdm
from torch.utils.data import Dataset

class bert_dataset(Dataset):
    def __init__(self, data_list, tokenizer):
        self.dataset = data_list
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        item = self.dataset[idx]
        text = item["text"]
        label = item["label"]

        inputs = self.tokenizer(text,padding="max_length",truncation=True,return_tensors='pt')
        # inputs["label"] = label
        return inputs,label
    
    def __len__(self):
        return len(self.dataset)

class model_bert(nn.Module):
    def __init__(self, bert):
        super(model_bert,self).__init__()
        self.bert = bert
        # self.out = nn.Linear(5,1)

    def forward(self, input_ids=None, token_type_ids=None, attention_mask=None):
        bert_out = self.bert(input_ids,token_type_ids,attention_mask)
        out = F.softmax(bert_out.logits, dim=-1)
        return out
    


def get_dataset_list(data_path,simple_num, rate=0.8):
    data = pd.read_csv(data_path)
    data_list = []
    print('loading dataset...')
    show_bar = tqdm(range(simple_num))

    for index,item in data.iterrows():
        data_list.append({"text":item['text'], "label":item['label']})
        show_bar.update(1)
        if index==simple_num:
            break

    lenght = len(data_list)
    train_data_list = data_list[:int(lenght*rate)]
    test_data_list = data_list[int(lenght*rate):]
    return train_data_list,test_data_list

if __name__ == '__main__':
    data_path = r'your data path'
    train_list,test_list = get_dataset_list(data_path,5000)
    print(f'len of trian:{len(train_list)}')
    print(f'len of test:{len(test_list)}')

        

这里对模型也进行了调整,在输出的最后加一个softmax激活函数,最后输出分类的值,假如采用这种方式训练,损失函数的选择也应该进行对应的改变,在后面的代码中没有使用到这个模型,读者们可以自行对比一下使用原模型以及使用调整后的模型最后的训练效果


模型搭建

具体Bert模型的搭建可以使用huggingface提供的API接口

tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained('google-bert/bert-base-uncased')

具体的函数调用仍可以参考huggingface官方的说明文档,运行上面的语句后终端会显示一些下载的进度条,模型和token会被下载到C盘的.cache缓存文件夹中,以我的电脑为例,保存路径为

C:\Users\29278\.cache\huggingface

下载完成之后,以后的每一次调用都可以从这个路径上调用(除非在huggingface上的模型有更新),假如我们想直接调用本地的模型,可以通过save_pretrained语句把模型保存到指定的路径中(详见官网介绍),再通过from_pretrained语句进行调用

如果想对Bert模型的结构进行进一步的调整,可以参考上一模块的model_bert类型的架构进行定义

model_use = model_bert(model)

 损失函数,优化器和学习率优化

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(),lr = 5e-5)
lr_scheduler = get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_trainStep)

这里的学习率选择使用全连接的方式进行优化,num_trainStep定义为对每一个批次的训练之后进行学习率的调整

模型训练代码

import torch
import torch.nn as nn
from tqdm.auto import tqdm
from statistics import mean
from torch.optim import AdamW
import matplotlib.pyplot as plt
# from datasets import load_dataset
from transformers import get_scheduler
from torch.utils.data import DataLoader
from prepare_dataset_model import bert_dataset, model_bert,get_dataset_list
from transformers import AutoModel,AutoTokenizer,AutoModelForSequenceClassification

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
print(f'using {device}...')
epoch = 10
batch_size = 8

model_path = r"your model path".replace('\\', '/')
token_path = r"your token path".replace('\\','/')
data_path = r"your data path"
tokenizer = AutoTokenizer.from_pretrained(token_path)


model = AutoModelForSequenceClassification.from_pretrained(model_path,num_labels=5).to(device)
optimizer = AdamW(model.parameters(),lr = 5e-5)
criterion = nn.CrossEntropyLoss()

# dataset_raw = load_dataset("csv",data_files=r"your data path")
# dataset_list = dataset_raw['train']
# print(type(dataset_list))

dataset_list,dataset_list_test = get_dataset_list(data_path,10)

dataset_class = bert_dataset(dataset_list,tokenizer)
dataset_class_test = bert_dataset(dataset_list_test,tokenizer)

dataset_input = DataLoader(dataset_class,batch_size=batch_size)
dataset_input_test = DataLoader(dataset_class,batch_size=batch_size,shuffle=True)

num_trainStep = epoch*len(dataset_input)
show_num_step = epoch*(len(dataset_input)+len(dataset_input_test))
lr_scheduler = get_scheduler(name="linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_trainStep)

loss_list = []
correct_list = []

loss_list_test = []
correct_list_test = []

process_bar = tqdm(range(show_num_step))
for step in range(epoch):
    loss_list_everyEpoch = []
    correct_list_everyEpoch = []
    model.train()
    for token, label in dataset_input:
        token = {k_in:v_in.squeeze(1).to(device) for k_in,v_in in token.items()}
        label = label.to(device)

        output = model(**token)
        out = output.logits
        # print(out)
        # print("loss:",out.loss)

        optimizer.zero_grad()
        loss = criterion(out,label)  
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        out_class = torch.argmax(out,dim=-1)

        correct_num = (out_class==label).sum()
        correct_rate = correct_num/label.shape[0]
        loss_list_everyEpoch.append(loss.item())
        correct_list_everyEpoch.append(correct_rate.item())
        process_bar.update(1)

    loss_list.append(mean(loss_list_everyEpoch))
    correct_list.append(mean(correct_list_everyEpoch))
    # print("train correct rate=",mean(correct_list_everyEpoch))
    show_correct_rate = mean(correct_list_everyEpoch)
    tqdm.write(f"train correct rate={show_correct_rate:2f}")

    loss_list_everyEpoch_test = []
    correct_list_everyEpoch_test = []
    model.eval()
    for token_test, label_test in dataset_input_test:
        token_test = {k_in:v_in.squeeze(1).to(device) for k_in,v_in in token_test.items()}
        label_test = label_test.to(device)

        output_test = model(**token_test)
        out_test = output_test.logits
        out_class_test = torch.argmax(out_test, dim=-1)
        # print(out_test)

        loss_test = criterion(out_test,label_test)
        loss_list_everyEpoch_test.append(loss_test.item())

        correct_num_test = (out_class_test==label_test).sum()
        correct_rate_test = correct_num_test/label_test.shape[0]
        correct_list_everyEpoch_test.append(correct_rate_test.item())
        process_bar.update(1)

    loss_list_test.append(mean(loss_list_everyEpoch_test))
    correct_list_test.append(mean(correct_list_everyEpoch_test))
    # print(f"test correct={mean(correct_list_everyEpoch_test)}")
    tqdm.write(f"test correct={mean(correct_list_everyEpoch_test):2f}")
    



print("train loss list:", loss_list)
print("test loss list:", loss_list_test)
print("train correct rate:", correct_list)
print("test correct rate:", correct_list_test)

结果

上述的代码仅提供一个demo,只去抽取了数据集中的10条数据进行训练,迭代10次,目的是让读者能够较快速地进行代码调试,再以此为基础对微调做更加针对性的操作

 [06:15<00:00, 18.93s/it]train loss list: [1.7557673454284668, 1.3130085468292236, 1.1637543439865112, 1.1429853439331055, 1.0526454448699951, 1.0120376348495483, 0.9493937492370605, 0.8610967397689819, 0.8940214514732361, 0.8735412955284119]
test loss list: [1.3395963907241821, 1.1231462955474854, 1.0310029983520508, 0.9932389855384827, 0.9472938776016235, 0.9150838851928711, 0.8861185908317566, 0.8644040822982788, 0.845656156539917, 0.8328518271446228]
train correct rate: [0.25, 0.625, 0.75, 0.625, 0.75, 0.625, 1.0, 1.0, 0.75, 0.875]
test correct rate: [0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 0.875, 0.875, 0.875, 0.875]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [06:15<00:00, 18.77s/it]

可以看到在训练集和验证集上,可以明显提现出模型的微调训练效果,微调大模型的发挥空间有很多,欢迎大家一起讨论交流~


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2155487.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React 中的延迟加载

延迟加载是 Web 开发中的一种有效的性能优化技术&#xff0c;尤其是对于 React 等库和框架。它涉及仅在需要时加载组件或资源&#xff0c;无论是响应用户操作还是当元素即将在屏幕上显示时。这可以减少应用程序的初始加载时间&#xff0c;减少资源消耗&#xff0c;并改善用户体…

ETLCloud:新一代ETL数据抽取工具的定义与革新

数据集成、数据治理已经成为推动企业数字化转型的核心动力&#xff0c;现在的企业比任何时候都需要一个更为强大的新一代数据集成工具来处理、整合并转化多种数据源。 而ETL&#xff08;数据提取、转换、加载&#xff09;作为数据管理的关键步骤&#xff0c;已在企业数据架构中…

串口助手的qt实现思路

要求实现如下功能&#xff1a; 获取串口号&#xff1a; foreach (const QSerialPortInfo &serialPortInfo, QSerialPortInfo::availablePorts()) {qDebug() << "Port: " << serialPortInfo.portName(); // e.g. "COM1"qDebug() <<…

【JavaEE】——线程的安全问题和解决方式

阿华代码&#xff0c;不是逆风&#xff0c;就是我疯&#xff0c;你们的点赞收藏是我前进最大的动力&#xff01;&#xff01;希望本文内容能够帮助到你&#xff01; 目录 一&#xff1a;问题引入 二&#xff1a;问题深入 1&#xff1a;举例说明 2&#xff1a;图解双线程计算…

SwiftUI 实现关键帧动画

实现一个扫描二维码的动画效果&#xff0c;然而SwiftUI中没有提供CABasicAnimation 动画方法&#xff0c;该如何实现这种效果&#xff1f;先弄清楚什么关键帧动画&#xff0c;简单的说就是指视图从起点至终点的状态变化&#xff0c;可以是形状、位置、透明度等等 本文提供了一…

(done) 声音信号处理基础知识(3) (一个TODO: modulation 和 timbre 的关联)(强度、响度、音色)

来源&#xff1a;https://www.youtube.com/watch?vJkoysm1fHUw sound power 通常可以被认为是能量传输的速率 声源往所有方向传输的每时间单位能量 用 瓦特(W) 作为单位测量 Sound intensity 声音强度&#xff0c;每单位面积的 sound power W/m^2 人类实际上能听到非常小强…

八. 实战:CUDA-BEVFusion部署分析-coordTrans Precomputation

目录 前言0. 简述1. 案例运行2. coordTrans3. Precomputation总结下载链接参考 前言 自动驾驶之心推出的 《CUDA与TensorRT部署实战课程》&#xff0c;链接。记录下个人学习笔记&#xff0c;仅供自己参考 本次课程我们来学习下课程第八章—实战&#xff1a;CUDA-BEVFusion部署分…

Python Selenium 自动化爬虫 + Charles Proxy 抓包

一、场景介绍 我们平常会遇到一些需要根据省、市、区查询信息的网站。 1、省市查询 比如这种&#xff0c;因为全国的省市比较多&#xff0c;手动查询工作量还是不小。 2、接口签名 有时候我们用python直接查询后台接口的话&#xff0c;会发现接口是加签名的。 而签名算法我…

keil5 MDK 最新版本官网下载(v5.40为例) ARM单片机环境搭建安装教程(STM32系列为例)

正所谓授之以鱼不如授之以渔。本文将细讲从官网下载keil5MDK来保证keil5为最新版本的实时性 &#xff08;注意新老版本可能出现版本兼容问题&#xff0c;若不放心&#xff0c;跟着老弟我一起下载5.40版本即可&#xff09; 目录 一、下载keil5 MDK 方法①:CSDN下载&#xff0…

计算机毕业设计 基于 Hadoop平台的岗位推荐系统 SpringBoot+Vue 前后端分离 附源码 讲解 文档

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

【深入学习Redis丨第六篇】Redis哨兵模式与操作详解

〇、前言 哨兵是一个分布式系统&#xff0c;你可以在一个架构中运行多个哨兵进程&#xff0c;这些进程使用流言协议来接收关于Master主服务器是否下线的信息&#xff0c;并使用投票协议来决定是否执行自动故障迁移&#xff0c;以及选择哪个Slave作为新的Master。 文章目录 〇、…

Django 5 学习笔记 2024版

1. 官方中文文档 Django 文档 | Django 文档 | Django (djangoproject.com) 2. 第一个应用 博客 总目录 <1>依赖安装: pip install django <2> 创建 工程 myapp django-admin startproject myapp cd myapp <3>创建 应用 app > python manage.py s…

算法-排序算法(冒泡选择插入希尔快速归并堆计算)

1.算法概述 1.1什么是算法 算法是特定问题的求解步骤的描述&#xff0c;是独立存在的一种解决问题的思想和方法。对于算法而言计算机编程语言并不重要&#xff0c;可以用任何计算机编程语言来编写算法。 程序数据结构算法 1.2数据结构和算法的区别和联系 数据结构只是静态…

CentOS 7 YUM源不可用

CentOS 7 操作系统在2024年6月30日后将停止官方维护&#xff0c;并且官方提供的YUM源将不再可用。 修改&#xff1a;nano /etc/yum.repos.d/CentOS-Base.repo # CentOS-Base.repo [base] nameCentOS-$releasever - Base baseurlhttp://mirrors.aliyun.com/centos/$rel…

数据库管理-第243期 云栖有感:AI?AI!(20240922)

数据库管理243期 2024-09-22 数据库管理-第243期 云栖有感&#xff1a;AI&#xff1f;AI&#xff01;&#xff08;20240922&#xff09;1 AI2 干货3 数据库总结 数据库管理-第243期 云栖有感&#xff1a;AI&#xff1f;AI&#xff01;&#xff08;20240922&#xff09; 作者&am…

Apache 中间件漏洞

CVE-2021-41773 环境搭建 docker pull blueteamsteve/cve-2021-41773:no-cgid 访问172.16.1.4:8080 使⽤curl http://172.16.1.4:8080/cgi-bin/.%2e/.%2e/.%2e/.%2e/etc/passwd

Linux中的调度算法

nice值的范围有限&#xff0c;即为[-20, 19]&#xff0c;也就是40个数字&#xff0c;优先级为[60, 99]即一共40个优先级 目前谈论的Linux操作系统叫做分时操作系统&#xff0c;调度的时候主要强调公平&#xff0c;还有一种是实时操作系统&#xff0c;比如智能汽车里面必须装有这…

网站设计中安全方面都需要有哪些考虑

网站设计中的安全性是一个多方面的问题&#xff0c;需要从多个角度进行考虑和实施。以下是一些关键的安全考虑因素&#xff1a; 数据加密&#xff1a; 使用SSL&#xff08;安全套接字层&#xff09;证书来建立加密连接&#xff0c;确保数据在传输过程中不被截获。定期更新SSL证…

学习IEC 62055付费系统标准

1.IEC 62055 国际标准 IEC 62055 是目前关于付费系统的唯一国际标准&#xff0c;涵盖了付费系统、CIS 用户信息系统、售电系统、传输介质、数据传输标准、预付费电能表以及接口标准等内容。 IEC 62055-21 标准化架构IEC 62055-31 1 级和 2 级有功预付费电能表IEC 62055-41 STS…

【重学 MySQL】三十七、聚合函数

【重学 MySQL】三十七、聚合函数 基本概念5大常用的聚合函数COUNT()SUM()AVG()MAX()MIN() 使用场景注意事项示例查询 聚合函数&#xff08;Aggregate Functions&#xff09;在数据库查询中扮演着至关重要的角色&#xff0c;特别是在处理大量数据时。它们能够对一组值执行计算&a…