# NLP-transformer学习:(5)Bert 实战

news2024/9/17 7:27:15

NLP-transformer学习:(5)模型训练和预测

在这里插入图片描述

基于 NLP-transformer学习:(2,3,4),这里对transformer 更近一步,学习尝试使用其中的bert


文章目录

  • NLP-transformer学习:(5)模型训练和预测
    • @[TOC](文章目录)
  • 1 上节补充:什么是 BERT
  • 2 RBT模型调用与实践
    • 2.1 数据准备
    • 2.2 创建 dataset,读取csv
    • 2.3 创建训练
    • 2.4 模型预测

提示:以下是本篇文章正文内容,下面案例可供参考

1 上节补充:什么是 BERT

BERT: Bidirectional Encoder Representation from Transformers
是一个预训练的语言表征模型。它强调了不再像以往一样采用传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的masked language model(MLM),以致能生成深度的双向语言表征。BERT论文发表时提及在11个NLP(Natural Language Processing,自然语言处理)任务中获得了新的state-of-the-art的结果
推荐链接:
https://blog.csdn.net/SMith7412/article/details/88755019
其实说白了就是 transformer可以进行堆叠

2 RBT模型调用与实践

2.1 数据准备

url:https://github.com/SophonPlus/ChineseNlpCorpus
下载这个数据集
在这里插入图片描述

2.2 创建 dataset,读取csv

from transformers import AutoTokenizer, AutoModelForSequenceClassification  # tokenizer and model will help text classiffication
import pandas as pd # used for dealing with csv data
from torch.utils.data import Dataset
#from torch.utils.data import random_split # my torch version is 1.12.0 this functions is added at lest 1.13, so I add it in utils.py
import torch
import utils as ut

from torch.utils.data import DataLoader

class TxtDataset(Dataset):
    # return None type value import code readability
    def __init__(self) -> None:
        super().__init__()
        # read the data
        self.data = pd.read_csv("./01-Getting_Started/04-model/ChnSentiCorp_htl_all.csv")
        # remove null data
        self.data = self.data.dropna()

    def __getitem__(self, index): # return one label with data in one time
        return self.data.iloc[index]["review"], self.data.iloc[index]["label"]

    def __len__(self): #return the size of current dataset 
        return len(self.data)
   
if __name__ == "__main__":

    txtdataset = TxtDataset()
    for i in range(5):
        # print the sentense with label, 1 shows positive comment
        print(txtdataset[i])
    
    print("len txtdataset:" + str(len(txtdataset)))

    # length means the propotion, 
    # for this case, train dataset is 0.9 and valid dataset is 0.1, 
    # validset + tran must equals 1.0
    trainset, validset = ut.random_split(txtdataset, [.9, .1], generator=torch.Generator().manual_seed(42))
    print("len trainset:" + str(len(trainset)))
    print("len validset:" + str(len(validset)))

    for i in range(10):
        print(trainset[i])

    # 
    tokenizer = AutoTokenizer.from_pretrained("./rbt3")

运行结果:
在这里插入图片描述
可以看到数据被正常读取,length可以读到,并且数据集中有结尾为1的postive评价和为0的negtive评价

2.3 创建训练

训练代码,其中要说明的是这里做的是迁移学习,ii因此学习率可以小一些。
在训练伊始需要对optimizer 归零。
加上之前小结以及给你建立好的数据和,我们可以这样
代码:

from transformers import AutoTokenizer, AutoModelForSequenceClassification  # tokenizer and model will help text classiffication
import pandas as pd # used for dealing with csv data
from torch.utils.data import Dataset
#from torch.utils.data import random_split # my torch version is 1.12.0 this functions is added at lest 1.13, so I add it in utils.py
import torch
import utils as ut

from torch.utils.data import DataLoader


class TxtDataset(Dataset):
    # return None type value import code readability
    def __init__(self) -> None:
        super().__init__()
        # read the data
        self.data = pd.read_csv("/home/mex/Desktop/learn_transformer/mexwayne_transformers_NLP/01-Getting_Started/04-model/ChnSentiCorp_htl_all.csv")
        # remove null data
        self.data = self.data.dropna()

    def __getitem__(self, index): # return one label with data in one time
        return self.data.iloc[index]["review"], self.data.iloc[index]["label"]

    def __len__(self): #return the size of current dataset 
        return len(self.data)
    

# when we train the model, we will make traindata with batchsize and pack them with into one tensor
# so we need implement a function like this
# the text and label should be store.    
def collate_fun(batch):
    texts, labels = [], []
    for item in batch:
        texts.append(item[0])
        labels.append(item[1])
    # the datat which is too long should be truncated with max_length, and the short data should be 
    # padd all the data into same length
    inputs = tokenizer(texts, max_length=128, padding="max_length", truncation=True, return_tensors="pt")
    # if the label in the data, the loss will caculate by model like brt did so
    inputs["labels"] = torch.tensor(labels)
    return inputs

def evaluate():
    model.eval()
    acc_num = 0
    with torch.inference_mode():
        for batch in validloader:
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k, v in batch.items()}
            output = model(**batch) # input the valid data into model and get the reuslts
            pred = torch.argmax(output.logits, dim=-1)
            acc_num += (pred.long() == batch["labels"].long()).float().sum() # the value after == is type of bool, we need to change it into int
    return acc_num / len(validset) # finaly we count the negtive and positive value 


def train(epoch=3, log_step=100):
    global_step = 0
    for ep in range(epoch):
        model.train() # need to open the train mode for model
        for batch in trainloader:
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k, v in batch.items()}
            optimizer.zero_grad() # the optimizer should be set zero first before used
            output = model(**batch) # we want to put all the key into it, so we use double start
            print(output)
            output.loss.backward()
            optimizer.step() # update the model
            if global_step % log_step == 0:
                print(f"ep: {ep}, global_step: {global_step}, loss: {output.loss.item()}")
            global_step += 1 # gobale step need increase
        acc = evaluate() # check the accuracy
        print(f"ep: {ep}, acc: {acc}")




if __name__ == "__main__":

    txtdataset = TxtDataset()
    for i in range(5):
        # print the sentense with label, 1 shows positive comment
        print(txtdataset[i])
    
    print("len txtdataset:" + str(len(txtdataset)))

    # length means the propotion, 
    # for this case, train dataset is 0.9 and valid dataset is 0.1, 
    # validset + tran must equals 1.0
    trainset, validset = ut.random_split(txtdataset, [.9, .1], generator=torch.Generator().manual_seed(42))
    print("len trainset:" + str(len(trainset)))
    print("len validset:" + str(len(validset)))

    for i in range(10):
        print(trainset[i])

    tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")
    trainloader = DataLoader(trainset, batch_size=32, shuffle=True,  collate_fn=collate_fun)
    validloader = DataLoader(validset, batch_size=64, shuffle=False, collate_fn=collate_fun)

    print(type(trainloader))


    from torch.optim import Adam # define the trainer, use adam gradiant descent
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    if torch.cuda.is_available(): # model should be set on the gpu
        model = model.cuda()
    optimizer = Adam(model.parameters(), lr=2e-5)

    train()

因为做的是迁移训练,因此这个迭代很较少,步长也很大,因此很快就结束
在这里插入图片描述

2.4 模型预测

2.3 中我们可以将训练好的 代码进行预测
代码:

    # 前面的省略
    from torch.optim import Adam # define the trainer, use adam gradiant descent
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    if torch.cuda.is_available(): # model should be set on the gpu
        model = model.cuda()
    optimizer = Adam(model.parameters(), lr=2e-5)

    train()

    
    sen = "我觉得这家羊肉泡沫馆子不错,做的羊肉很好吃!"
    id2_label = {0: "差评!", 1: "好评!"}
    model.eval()
    with torch.inference_mode():
        inputs = tokenizer(sen, return_tensors="pt")
        inputs = {k: v.cuda() for k, v in inputs.items()}
        logits = model(**inputs).logits
        pred = torch.argmax(logits, dim=-1)
        print(f"输入:{sen}\n模型预测结果:{id2_label.get(pred.item())}")

这里单独说下logits:
(1)未激活的原始分数:logits 是模型前向传播的输出之一,它代表模型对每个类别的信心值。这些信心值并没有经过激活函数(如 softmax 或 sigmoid)的处理,因此它们可以是任意实数(正、负、或零)。
(2)分类任务中的核心输出:在分类任务中,logits 是关键的输出,因为它们在激活函数处理后会转换为概率分布。概率分布中的最大值通常对应于模型最终的预测类别。
于是我们得到预测结果:
在这里插入图片描述
可以看到我用了十分小众的羊肉泡馍,以及关键词牛逼作为评价,得到的结果和预期保持一致

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2077266.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

在 Debian 上安装 IntelliJ IDEA 笔记

在 Debian💩 上安装 IntelliJ IDEA 💡 笔记 下载安装 JDK17安装 IntelliJ IDEA Community添加创建桌面启动项(快捷方式) 参考资料 下载 两个包已经下好了,一个JDK17,一个IntelliJ IDEA Community 使用 wge…

【Liunx入门】Liunx软件包管理器

文章目录 前言一、什么是软件包二、网络相关指令三、Ubuntu包管理软件apt1.查看软件包2.sudo权限3.软件安装4.卸载软件5.软件更新6.升级软件包 总结 前言 Linux软件包管理器是Linux系统中用于安装、升级和卸载软件包的工具。它们提供了一个方便的方式来管理软件包,…

c++习题25-大整数加法

目录 一,题目 二,思路 三,代码 一,题目 描述 求两个不超过200位的非负整数的和。 输入 有两行,每行是一个不超过200位的非负整数,可能有多余的前导0。 输出 一行,即相加后的结果。结果里不…

Physics of Language Models学习小结

1.概述 Physics of Language Models 参考:https://zhuanlan.zhihu.com/p/711391378 这是一系列论文和一个新的LLM研究方向,官网的概述如下。 苹果掉落,盒子移动,但重力和惯性等普遍规律对技术进步至关重要。虽然GPT-5或LLaMA-…

Threejs学习-三维坐标系、相机控件

坐标系: Three.js 使用的是右手坐标系,x 轴朝右,y 轴朝上,z 轴朝向自己。 相机控件轨道控制器 相机控件OrbitControls 通过相机控件OrbitControls实现旋转缩放预览效果。 // 设置相机控件轨道控制器OrbitControls const contr…

fastjson漏洞分析与复现

一、基础知识 Fastjson介绍: fastjson是阿里巴巴开源的JSON解析库,它可以解析JSON格式的字符串,支持将Java Bean序列化为JSON字符串,也可以从JSON字符串反序列化到JavaBean。即fastjson的主要功能就是将Java Bean序列化成JSON字…

IDEA插件支持API调试、接口用例支持一键同步API变更,MeterSphere开源持续测试工具v3.2.0版本发布

2024年8月26日,MeterSphere开源持续测试工具正式发布v3.2.0版本。 在这一版本中,接口测试方面,MeterSphere API Debugger插件支持API调试,接口用例支持一键同步API变更;测试管理方面,在“测试用例”模块中…

牛客笔试训练

牛客.过桥 在函数 public static int n;public static int[]arrnew int[2001];public static int bfs(){int left1;int right1;int ret0;while(left<right){ret;int rright;for(int ileft;i<right;i){rMath.max(r,arr[i]i);if(r>n){return ret;}}leftright1;rightr;}…

网络原理 TCP与UDP协议

博主主页: 码农派大星. 数据结构专栏:Java数据结构 数据库专栏:MySQL数据库 JavaEE专栏:JavaEE 关注博主带你了解更多数据结构知识 1.应用层 之前编写完了基本的 java socket &#xff0c;要知道&#xff0c;我们之前所写的所有代码都在应⽤层&#xff0c;都是为了 完成某项…

关键点检测——HRNet源码解析篇

&#x1f34a;作者简介&#xff1a;秃头小苏&#xff0c;致力于用最通俗的语言描述问题 &#x1f34a;专栏推荐&#xff1a;深度学习网络原理与实战 &#x1f34a;近期目标&#xff1a;写好专栏的每一篇文章 &#x1f34a;支持小苏&#xff1a;点赞&#x1f44d;&#x1f3fc;、…

linux下部署数据库总结

数据库 数据库主要分为两大类&#xff1a;关系型数据库与 NoSQL 数据库 关系型数据库&#xff0c;是建立在关系模型基础上的数据库&#xff0c;其借助于集合代数等数学概念和方法来处理数据库 中的数据主流的 MySQL、Oracle、MS SQL Server 和 DB2 都属于这类传统数据库。 NoSQ…

JVM理论篇(一)

一、类加载子系统 1.1 类加载子系统作用 类加载子系统负责从文件系统或者网络中加载Class文件&#xff0c;Class文件在文件开头有特定的文件标识。(CAFEBABE)ClassLoader只负责class文件的加载&#xff0c;至于它是否可以运行&#xff0c;则由Execution Engine 执行引擎决定。…

Spire.PDF for .NET【文档操作】演示:创建标记的 PDF 文档

带标签的 PDF&#xff08;也称为 PDF/UA&#xff09;是一种包含底层标签树&#xff08;类似于 HTML&#xff09;的 PDF&#xff0c;用于定义文档的结构。这些标签可以帮助屏幕阅读器浏览整个文档而不会丢失任何信息。本文介绍如何使用Spire.PDF for .NET在 C# 和 VB.NET 中从头…

Python中csv文件的操作3

在《Python中csv文件的操作2》中提到&#xff0c;with as语句可以自动关闭文件&#xff0c;而该语句可以和csv模块中的函数配合使用&#xff0c;达到读取和写入csv文件的目的。 1 csv文件的读取 使用csv模块中的函数读取csv文件的代码如图1所示。 图1 使用csv模块中的函数读取…

AI终于杀死了Leetcode!网友:面试神器已到位

家人们&#xff0c;今早起来 x 上一个帖子引起了奶茶的注意&#xff1a; 什么&#xff1f;奶茶以为自己没睡醒&#xff0c;揉了揉眼睛一看&#xff0c;没看错的话&#xff0c;这不就是AI结束了比赛吗。。。。 原文链接&#xff1a; https://www.reddit.com/r/leetcode/comments…

【ES6新特性】ES6新特性中Promise对象的概念,Async函数的使用以及Module语法

目录 1.Promise 对象 1.1 概念 1.2 使用 2.Async函数 2.1 同步和异步的区别 3.Mdule语法 1.Promise 对象 1.1 概念 Promise 是异步编程的一种解决方案&#xff0c;简单说就是一个容器&#xff0c;里面保存着某个未来才会结束 的事件&#xff08;通常是一个异步操作&#…

初识QT:从创建到认识

QT怎么安装这里就不说了&#xff0c;直接从使用开始 文章目录 1.QT项目的创建及介绍2.Hello QT&#xff01;2.1 图形化形式创建2.2 代码形式创建 3.对象树3.1 内存泄漏与对象树3.2 通过C类理解释放过程 4.乱码问题4.1 如何查看编码方式4.2 如何处理乱码 提示&#xff1a;QT项目…

arm 指令移位操作(11)

逻辑左移&#xff1a; 可以使寄存器也可以是 立即数 LSL &#xff1a; 字母缩写 举例&#xff1a; MOV R0&#xff0c;R1 &#xff0c;LSL #2 向左移位后&#xff0c;右面填0补充 逻辑右移&#xff1a; 可以使寄存器也可以是 立即数 LSR &#xff1a; 字母缩写 举例&…

10天速通Tkinter库——Day7:主菜单及图鉴

本篇博客我将介绍Tkinter实践项目《植物杂交实验室》中的杂交实验室主菜单、基础植物图鉴、杂交植物图鉴、杂交植物更多信息四个页面的制作。 它们作为主窗口的子页面实例&#xff0c;除了继承主窗口的基础设置&#xff08;如图标、标题、尺寸等等&#xff09;、还可以使用主窗…

《黑神话:悟空》游戏中的福建元素

《黑神话&#xff1a;悟空》作为一款深受玩家喜爱的动作角色扮演游戏&#xff0c;不仅在游戏剧情和角色设计上独具匠心&#xff0c;还巧妙地融入了丰富的中国传统文化元素&#xff0c;其中福建元素尤为突出。以下是对游戏中福建元素的详细解析&#xff1a; 一、地域文化与背景…