AI大模型探索之路-训练篇8:大语言模型Transformer库-预训练流程编码体验

news2025/1/15 6:40:07

系列篇章💥

AI大模型探索之路-训练篇1:大语言模型微调基础认知
AI大模型探索之路-训练篇2:大语言模型预训练基础认知
AI大模型探索之路-训练篇3:大语言模型全景解读
AI大模型探索之路-训练篇4:大语言模型训练数据集概览
AI大模型探索之路-训练篇5:大语言模型预训练数据准备-词元化
AI大模型探索之路-训练篇6:大语言模型预训练数据准备-预处理
AI大模型探索之路-训练篇7:大语言模型Transformer库之HuggingFace介绍


目录

  • 系列篇章💥
  • 前言
  • 案例场景
  • 准备工作
    • 1)学术加速
    • 2)安装LFS
    • 3)下载数据集(原始语料库)
    • 4)下载模型到本地
  • 步骤1:导入相关依赖
  • 步骤2:获取数据集
  • 步骤3:构建数据集
  • 步骤4:划分数据集
  • 步骤5:创建DataLoader
  • 步骤6:创建模型及其优化器
  • 步骤7:训练与验证
  • 步骤8:模型预测
  • 总结


前言

在深入探索Transformer库及其高级组件之前,我们先手工编写一个预训练流程代码。这一过程不仅有助于理解预训练的步骤和复杂性,而且能让您体会到后续引入高级组件所带来的开发便利性。通过实践,我们将构建一个情感分类模型,该模型能够接收文本评价并预测其是正面还是负面的情感倾向。

案例场景

想象一下,我们有一个原始数据集,其中包含了酒店顾客的评价文本。我们的目标是训练一个模型,当输入类似“昨天我在酒店睡觉发现被子有一股霉味。”的评价时,模型能够预测出“差评”。
在这里插入图片描述

准备工作

本次仍是采用云服务器autodl调试运行

1)学术加速

source /etc/network_turbo

在这里插入图片描述

2)安装LFS

从 Hugging Face Hub 下载模型需要先安装Git LFS
安装git-lfs是为了确保从Hugging Face拉取模型时能够高效且完整地下载所有相关文件,尤其是那些大型的模型文件。
Ubuntu系统操作命令:
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
在这里插入图片描述

sudo apt-get install git-lfs
在这里插入图片描述

Centos命令参考:

curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash
sudo yum install git-lfs

执行:git lfs install
在这里插入图片描述

3)下载数据集(原始语料库)

创建一个pretrains目录,将数据集下载到这个目录,下载到本地后可以提高执行效率
git clone https://huggingface.co/datasets/dirtycomputer/ChnSentiCorp_htl_all
在这里插入图片描述

注意!重要!!:下载后请记得和Huggingface上的文件对比,尤其是大文件,确保下载完整

4)下载模型到本地

git clone https://huggingface.co/hfl/rbt3
下载到本地后,从本地加载执行效率更高
在这里插入图片描述
注意!重要!!:下载后请记得和Huggingface上的文件对比,尤其是大文件,确保下载完整

步骤1:导入相关依赖

首先,我们需要设置Python环境,并导入必要的库

from transformers import AutoTokenizer, AutoModelForSequenceClassification

步骤2:获取数据集

获取数据集是预训练中关键一步。我们使用前面从Huggingface下载的包含酒店评价的文本数据集。
1)加载本地的数据集,查看读取内容

import pandas as pd
data = pd.read_csv("/root/pretrains/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv")
data.dropna()
data

执行输出如下:
在这里插入图片描述

步骤3:构建数据集

创建一个自定义的数据集类,它将负责读取原始数据,可以执行必要的预处理步骤(例如清洗、分词、向量化),并将数据划分为训练集和验证集。

from torch.utils.data import Dataset

import pandas as pd

class MyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.data = pd.read_csv("/root/pretrains/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv")
        self.data = self.data.dropna()
    def __getitem__(self,index):
        return self.data.iloc[index]["review"], self.data.iloc[index]["label"]
    def __len__(self):
        return len(self.data) 

dataset = MyDataset()
for i in range(5):
    print(dataset[i])
('距离川沙公路较近,但是公交指示不对,如果是"蔡陆线"的话,会非常麻烦.建议用别的路线.房间较为简单.', 1)
('商务大床房,房间很大,床有2M宽,整体感觉经济实惠不错!', 1)
('早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。', 1)
('宾馆在小街道上,不大好找,但还好北京热心同胞很多~宾馆设施跟介绍的差不多,房间很小,确实挺小,但加上低价位因素,还是无超所值的;环境不错,就在小胡同内,安静整洁,暖气好足-_-||。。。呵还有一大优势就是从宾馆出发,步行不到十分钟就可以到梅兰芳故居等等,京味小胡同,北海距离好近呢。总之,不错。推荐给节约消费的自助游朋友~比较划算,附近特色小吃很多~', 1)
('CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风', 1)

步骤4:划分数据集

对数据集进行划分,语料库中90%作为预训练数据,10%作为验证数据;这确保了模型在未见过的数据上进行验证和测试。

from torch.utils.data import  random_split

trainset, validset = random_split(dataset,lengths=[0.9,0.1])
len(trainset),len(validset)

输出:(6989, 776)

步骤5:创建DataLoader

1)加载数据集
利用分词器进行数据加载(即将文本数据转化为机器能识别的数字序列矩阵)
为了高效地加载数据,采用批量的方式加载预训练数据和校验数据,加载时最大长度为128,多了会进行截取,少了会自动补0

import torch
from torch.utils.data import DataLoader

tokenizer = AutoTokenizer.from_pretrained("/root/pretrains/rbt3")

def collate_func(batch):
    texts,labels=[],[]
    for item in batch:
        texts.append(item[0])
        labels.append(item[1])
        ## return_tensors="pt" 返回的是pytorch tensor类型。
        ## 吃葡萄不吐葡萄皮
        ## 不吃葡萄到吐葡萄皮
    inputs = tokenizer(texts,max_length=128,padding="max_length",truncation=True, return_tensors="pt")
    inputs["labels"] = torch.tensor(labels)
    return inputs
## dataloader中设置shuffle值为True,表示每次加载的数据都是随机的,将输入数据的顺序打乱。shuffle值为False,
## 表示输入数据顺序固定。

trainloader = DataLoader(trainset,batch_size=32,shuffle=True,collate_fn=collate_func)
validloader = DataLoader(validset,batch_size=64,shuffle=False,collate_fn=collate_func)

next(enumerate(validloader))[1]

输出如下:(下面tensor就是转化后的序列矩阵)
在这里插入图片描述

步骤6:创建模型及其优化器

根据本地下载的模型地址,创建模型对象
基于Transformer架构,定义一个情感分类模型。选择合适的优化器(如AdamW或RMSprop)以调整模型权重,从而最小化损失函数。

from torch.optim import Adam

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("/root/pretrains/rbt3")

if torch.cuda.is_available():
    model = model.cuda()
"""
当我们训练一个机器学习模型时,我们需要选择一个优化算法来帮助我们找到模型参数的最佳值。这个优化算法就是优化器(optimizer)。

在这行代码中,我们选择了一种叫做Adam的优化算法作为我们的优化器。Adam算法是一种常用的优化算法,
它根据每个参数的梯度(即参数的变化率)和学习率(lr)来更新参数的值。

"model.parameters()"表示我们要优化的是模型的参数。模型的参数是模型中需要学习的权重和偏置等变量。

"lr=2e-5"表示学习率的值被设置为2e-5(即0.00002)。学习率是控制模型在每次迭代中更新参数的步长。较大的学习率可能导致模型无法收敛,
而较小的学习率可能需要更长的训练时间
"""
optimizer = Adam(model.parameters(), lr=2e-5)

步骤7:训练与验证

定义一个训练和评估的函数
设定训练循环,包括前向传播、计算损失、反向传播和权重更新。同时,定期在验证集上检查模型性能,以监控过拟合情况并及时停止训练。
def evaluate():
    ## 将模型设置为评估模式
    model.eval()
    acc_num=0
    #将训练模型转化为推理模型,模型将使用转换后的推理模式进行评估
    with torch.inference_mode():
        for batch in validloader:
            ## 检查是否有可用的GPU,如果有,则将数据批次转移到GPU上进行加速
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k,v in batch.items()}
            ##对数据批次进行前向传播,得到模型的输出
            output = model(**batch)
            ## 对模型输出进行预测,通过torch.argmax选择概率最高的类别。
            pred = torch.argmax(output.logits,dim=-1)
            ## 计算正确预测的数量,将预测值与标签进行比较,并使用.float()将比较结果转换为浮点数,使用.sum()进行求和操作
            acc_num += (pred.long() == batch["labels"].long()).float().sum()
    ## 返回正确预测数量与验证集样本数量的比值,这表示模型在验证集上的准确率
    return acc_num / len(validset)

def train(epoch=3,log_sep=100):
    global_step = 0
    for ep in range(epoch):
        ## 开启训练模式
        model.train()
        for batch in trainloader:
            if torch.cuda.is_available():
                batch = {k: v.cuda() for k, v in batch.items()}
            ## 梯度归0
            optimizer.zero_grad()
            ## 对数据批次进行前向传播,得到模型的输出
            output=model(**batch)
            ## 计算损失函数梯度并进行反向传播
            output.loss.backward()
            ## 优化器更新
            optimizer.step()
            if(global_step % log_sep == 0):
                print(f"ep:{ep},global_step:{global_step},loss:{output.loss.item()}")
            global_step += 1
        ## 准确率
        acc = evaluate()
        ## 第几轮
        print(f"ep:{ep},acc:{acc}")

# 训练
train()

输出3轮训练结果,准确率在88%-89%左右
在这里插入图片描述

步骤8:模型预测

完成训练后,利用训练好的模型对新输入的评价进行情感分类。展示模型如何接收新文本,并输出预测结果。

#sen = "我昨晚在酒店里睡得非常好"
sen ="昨天我在酒店睡觉发现被子有一股霉味"

id2label = {0:"差评",1:"好评"}
## 将模型设置为评估模式
model.eval

 #将训练模型转化为推理模型,模型将使用转换后的推理模式进行评估
with torch.inference_mode():
    ## 分词&&向量化
    inputs = tokenizer(sen,return_tensors = "pt")
    ## GPU加速
    inputs = {k:v.cuda() for k,v in inputs.items()}
    ## 进行预测
    logits=model(**inputs).logits
    ## 在logits的最后一个维度上找到最大值,并返回其所在的索引。这相当于选择模型认为最有可能的类别
    pred = torch.argmax(logits, dim = -1)
    
    print(f"输入:{sen} \n模型的预测结果:{id2label.get(pred.item())}")

1)第1次预测:(sen =“昨天我在酒店睡觉发现被子有一股霉味”)
输入:昨天我在酒店睡觉发现被子有一股霉味
模型的预测结果:差评
2)第2次预测:(sen =“我昨晚在酒店里睡得非常好”)
输入:我昨晚在酒店里睡得非常好
模型的预测结果:好评

总结

通过上述步骤,我们手工完成了基于Transformer库的情感分类模型预训练流程。虽然这个过程涉及了大量细节和代码编写,但它为我们提供了宝贵的洞见,让我们了解了从原始数据处理到模型训练和验证的整个流程。在后续篇章中,我们将引入更多的Transformer组件,这些高级工具将显著简化我们的开发流程,使我们能够更快捷、更高效地进行模型开发和实验

在这里插入图片描述

🎯🔖更多专栏系列文章:AIGC-AI大模型探索之路

如果文章内容对您有所触动,别忘了点赞、⭐关注,收藏!加入我,让我们携手同行AI的探索之旅,一起开启智能时代的大门!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1635263.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Codigger数据篇(下):数据安全的全方位保障

在数字化浪潮中,数据已成为现代企业的核心财富。Codigger作为领先的数据服务平台,深知数据安全对于用户的重要性,因此在深挖数据价值的同时,我们始终坚守数据安全防线。 一、双重加密技术保障 Codigger平台运用先进的加密通信和…

【C/C++】动态内存管理(C:malloc,realloc,calloc,free || C++:new,delete)

🔥个人主页: Forcible Bug Maker 🔥专栏: C | | C语言 目录 前言C/C内存分布C语言中的动态内存管理:malloc/realloc/realloc/freemallocrealloccallocfree C中的动态内存管理:new/deletenew和delete操作内…

android studio 编译一直显示Download maven-metadata.xml

今天打开之前的项目的时候遇到这个问题:android studio 编译一直显示Download maven-metadata.xml, AI 查询 报错问题:"android studio 编译一直显示Download maven-metadata.xml" 解释: 这个错误通常表示Android Studio在尝试从Maven仓库…

用Python Turtle画一个中国结

中国结,作为中华民族传统文化的象征之一,以其独特的编织技艺和深厚的文化内涵,深受人们喜爱。今天,我们就来用Python的turtle模块,尝试绘制一个充满韵味的中国结。 我们先来看看整个中国结生成的过程: 中国…

机器学习 | 准确率、召回率、精准率、特异度傻傻分不清?ROC曲线怎么看?一篇文章帮你搞定

一、真正类、假负类、假正类与真负类 二、准确率、召回率、精准率、特异度与假正率 1. 准确率 (Accuracy) 准确率表明成功预测(预测为负或为正)的结果占总样本的百分比。 准确率 , 2. 召回率/查全率/灵敏度/真正率(Recall&a…

C语言进阶|双链表

✈链表的分类 链表的结构非常多样,以下情况组合起来就有8种(2x22)链表结构: 虽然有这么多的链表的结构,但是我们实际中最常用还是两种结构:单链表和双向带头循环链表 1.无头单向非循环链表:结构简单&…

springboot基于SpringBoot的网上订餐系统开题报告+1w字文档+ppt

项目演示视频: 【源码免费送】基于SpringBoot的网上订餐系统录像 摘 要 随着我国经济的飞速发展,人们的生活速度明显加快,在餐厅吃饭排队的情况到处可见,近年来由于新兴IT行业的空前发展,它与传统餐饮行业也进行了新旧的结合&…

实验八智能手机互联网程序设计(微信程序方向)实验报告

请在上一次实验的基础之上完成“手机快速注册”页面、“企业用户注册”页面,并实现点击手机快速注册和企业用户注册后转跳至该页面在“手机快速注册”页面,输入框内输入内容并失去焦点后,下方的按钮会变化 在企业用户注册页面,用户…

Anomalib:用于异常检测的深度学习库!

大家好,今天给大家介绍了一个用于无监督异常检测和定位的新型库:anomalib,Github链接:https://github.com/openvinotoolkit/anomalib 简介 考虑到可重复性和模块化,这个开源库提供了文献中的算法和一组工具,以通过即插即用的方法设计自定义异常检测算法。 Anomalib 包…

# 从浅入深 学习 SpringCloud 微服务架构(七)Hystrix(1)

从浅入深 学习 SpringCloud 微服务架构(七)Hystrix(1) 一、Hystrix:基于 RestTemplate 的熔断配置 1、Hystrix 介绍: 1)Hystrix 是由 Netflix 开源的一个延迟和容错库, 用于隔离访…

Web3的可持续性:构建环境友好的去中心化系统

引言 随着全球对可持续发展和环境问题的日益关注,Web3技术作为一种新型的互联网模式,也开始受到社区和开发者的关注。但很少有人关注到Web3对环境可持续性的潜在影响。本文将探讨Web3如何构建一个环境友好的去中心化系统,以及这如何促进一个…

Shopee怎么选品成功率高达80%?请学

电商圈内流传着一句话:三分靠运营,七分靠选品。 选品在电商项目中至关重要,也是一个非常考验技巧和经验的环节。选品选择得好,后续的每一步都会变得相对轻松。 那么要怎么在众多商品中脱颖而出,提高在Shopee平台上选…

第三节课,功能2:开发后端用户的管理接口5min(用户的查询/状态更改)【4】【9开始--本人】

一、代码任务 【录个屏】 二、写代码 2.1 代码文件位置 2.2 代码如下: 2.3 官方文档: 网址: 逻辑删除 | MyBatis-Plus (baomidou.com) 三、代码有bug,没有鉴权,表里添加一个字段。role 管理员 3.1 判断操作的人&am…

了解 Postman:这个 API 工具的功能和用途是什么?

在软件开发中,经常听到 Postman 这个软件名。但其实很多新手开发者只知道这是软件开发常用的软件,并不知道实际是一个什么样工具,不知道具体的作用是什么。那今天就跟大家好好唠唠 Postman 这个软件。想要学习更多关于 Postman 的知识&#x…

call、apply、bind能用来干点啥(接上文)

上文我们了解了call、apply、bind的使用规则,学以致用,我们要在平时的搬砖中怎么使用呢? 其实好些人平时也用不到这三货,但是在框架底层,这三货可是经常被用到的啊,现在我们来了解了解吧 1、处理伪数组 假使,在html页面中有多个名为“c-container”的容器,现在我们来获取他…

[机缘参悟-166] :周期论:万物的周期现象是这个世界有序性和稳定性保障;超越周期:在轮回中,把握周期节奏。

目录 前言:超越周期 一、周期是大自然和宇宙的规律,是天道 1.1 概述 1.2 万物的周期规律的现象 1.3 电磁波的周期 二、计算机世界中的周期性 三、佛家的生命轮回规律 四、人类社会发展的周期规律 五、经济活动的周期规律 5.1 概述 5.2 股市的…

分享一个网站实现永久免费HTTPS访问的方法

免费SSL证书作为一种基础的网络安全工具,以其零成本的优势吸引了不少网站管理员的青睐。要实现免费HTTPS访问,您可以按照以下步骤操作: 一、 选择免费SSL证书提供商 选择一个提供免费SSL证书的服务商。如JoySSL,他们是国内为数不…

ArgoCD集成部署到Kubernetes

1:环境 kubernetes1.23.3ArgoCD2.3.3 2:ArgoCD介绍 Argo CD is a declarative, GitOps continuous delivery tool for Kubernetes. Argo CD是一个基于Kubernetes的声明式的GitOps工具。 那么,什么是GitOps呢? GitOps是以Git为基…

ROS 2边学边练(36)-- 添加一个坐标系(C++)

前言 此篇将会在之前已存在的几个坐标系(/world、/turtle1、/turtle2)的基础上再增加一个坐标系,相对来说,难度不大,主要是理解一些概念(脑子里面有3D场景的想象),比如一个小车机器人处在世界坐标系&#x…

春秋云镜 CVE-2023-50563

靶标介绍: SEMCMS是一套支持多种语言的外贸网站内容管理系统(CMS)。SEMCMS v4.8版本存在SQLI,该漏洞源于SEMCMS_Function.php 中的 AID 参数包含 SQL 注入 开启靶场: 开始实验: 1、使用后台扫描工具&…