【AI大模型】Transformers大模型库(十一):Trainer训练类

news2024/9/20 0:52:00

 

目录

一、引言 

二、Trainer训练类

2.1 概述

2.2 使用示例

三、总结


一、引言 

 这里的Transformers指的是huggingface开发的大模型库,为huggingface上数以万计的预训练大模型提供预测、训练等服务。

🤗 Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨是让最先进的 NLP 技术人人易用。
🤗 Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 model hub 与社区共享。同时,每个定义的 Python 模块均完全独立,方便修改和快速研究实验。
🤗 Transformers 支持三个最热门的深度学习库: Jax, PyTorch 以及 TensorFlow — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。

本文重点介绍Trainer训练类

二、Trainer训练类

2.1 概述

2.2 使用示例

from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# 1. 加载数据集
# 假设我们使用的是Hugging Face的内置数据集,例如SST-2
dataset = load_dataset('sst2')  # 或者使用你自己的数据集

# 2. 数据预处理,可能需要根据模型进行Tokenization
# 以BERT为例,使用AutoTokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):
    return tokenizer(examples["sentence"], truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 3. 准备训练参数
training_args = TrainingArguments(
    output_dir='./results',          # 输出目录
    num_train_epochs=3,              # 总的训练轮数
    per_device_train_batch_size=16,  # 每个GPU的训练批次大小
    per_device_eval_batch_size=64,   # 每个GPU的评估批次大小
    warmup_steps=500,                # 预热步数
    weight_decay=0.01,               # 权重衰减
    logging_dir='./logs',            # 日志目录
    logging_steps=10,
)

# 4. 准备模型
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")

# 5. 创建Trainer并开始训练
trainer = Trainer(
    model=model,                         # 要训练的模型
    args=training_args,                  # 训练参数
    train_dataset=tokenized_datasets['train'],  # 训练数据集
    eval_dataset=tokenized_datasets['validation'], # 验证数据集
)

# 开始训练
trainer.train()

整个流程是机器学习项目中的标准流程:数据准备、模型选择、参数设置、训练与评估。每个步骤都是为了确保模型能够高效、正确地训练,以解决特定的机器学习任务:

  • 加载数据集 (load_dataset('sst2')):这行代码是使用Hugging Face的datasets库加载SST-2数据集,这是一个情感分析任务的数据集。如果你使用自定义数据集,需要相应地处理和加载数据。
  • 数据预处理 (tokenizer(examples["sentence"], truncation=True)):在训练模型之前,需要将文本数据转换为模型可以理解的格式。这里使用AutoTokenizer对文本进行分词(Tokenization),truncation=True意味着如果句子超过模型的最大输入长度,将截断超出部分。这一步是将文本转换为模型输入的张量格式。
  • 训练参数 (TrainingArguments):这部分定义了训练过程的配置,包括训练轮数(num_train_epochs)、每个设备的训练和评估批次大小、预热步数(warmup_steps)、权重衰减(weight_decay)等。这些参数对训练效率和模型性能有重要影响。
  • 准备模型 (AutoModelForSequenceClassification.from_pretrained()):这里选择或初始化模型,AutoModelForSequenceClassification是用于序列分类任务的模型,from_pretrained方法加载预训练的模型权重。选择的模型(如BERT的“bert-base-uncased”)是基于任务需求的。
  • 创建Trainer (Trainer):Trainer是Transformers库中的核心类,它负责模型的训练和评估流程。它接收模型、训练参数、训练数据集和评估数据集作为输入。Trainer自动处理了训练循环、损失计算、优化器更新、评估、日志记录等复杂操作,使得训练过程更加简洁和高效。
  • 开始训练 (trainer.train()):调用此方法开始模型的训练过程。Trainer会根据之前设定的参数和数据进行模型训练,并在每个指定的步骤打印日志,训练完成后,模型的权重会保存到指定的输出目录。

三、总结

本文对transformers训练类Trainer进行讲述并赋予应用代码,希望可以帮到大家!

如果您还有时间,可以看看我的其他文章:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

AI智能体研发之路-模型篇(六):【机器学习】基于tensorflow实现你的第一个DNN网络

AI智能体研发之路-模型篇(七):【机器学习】基于YOLOv10实现你的第一个视觉AI大模型

AI智能体研发之路-模型篇(八):【机器学习】Qwen1.5-14B-Chat大模型训练与推理实战

AI智能体研发之路-模型篇(九):【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战

《AI—Transformers应用》

【AI大模型】Transformers大模型库(一):Tokenizer

【AI大模型】Transformers大模型库(二):AutoModelForCausalLM

【AI大模型】Transformers大模型库(三):特殊标记(special tokens)

【AI大模型】Transformers大模型库(四):AutoTokenizer

【AI大模型】Transformers大模型库(五):AutoModel、Model Head及查看模型结构

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1860136.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于FreeRTOS+STM32CubeMX+LCD1602+MCP4152(SPI接口)的数字电位器Proteus仿真

一、仿真原理图: 二、仿真效果: 三、软件部分: 1)、时钟配置初始化: void SystemClock_Config(void) { RCC_OscInitTypeDef RCC_OscInitStruct = {0}; RCC_ClkInitTypeDef RCC_ClkInitStruct = {0}; /** Initializes the CPU, AHB and APB busses clocks */ RCC…

同城购物优惠联盟返现系统小程序源码

:省钱购物新体验 🎉一、同城优惠,一网打尽 在繁华的都市生活中,你是否总是为寻找各种优惠而费尽心思?现在,有了“同城优惠联盟返现小程序”,你可以轻松掌握同城各类优惠信息。无论是餐饮、购物…

解题思路:LeetCode 第 209 题 “Minimum Size Subarray Sum“

解题思路:LeetCode 第 209 题 “Minimum Size Subarray Sum” 在这篇博文中,我们将探讨如何使用 Swift 解决 LeetCode 第 209 题 “Minimum Size Subarray Sum”。我们会讨论两种方法:暴力法和滑动窗口法,并对这两种方法的时间复…

Arduino - 串行绘图仪

Arduino - Serial Plotter Arduino - 串行绘图仪 In this tutorial, we will learn how to use the Serial Plotter on Arduino IDE, how to plot the multiple graphs. 在本教程中,我们将学习如何在Arduino IDE上使用串行绘图仪,如何绘制多个图形。 A…

【软件工程】【22.04】p2

关键字: 软件开发分本质及涉及问题、需求规约与项目需求不同、用况图概念包含模型元素及其关系、创建系统的用况模型RUP进行活动、软件生存周期&软件生存周期模型&软件项目过程管理关系、CMMI基本思想 模块结构图:作用域、控制域;语…

vue2 antd 开关和首页门户样式,表格合计

1.首页门户样式 如图 1.关于圆圈颜色随机设置 <a-col :span"6" v-for"(item, index) in menuList" :key"index"><divclass"circle":style"{ borderColor: randomBorderColor() }"click"toMeRouter(item)&qu…

版本控制工具-git分支管理

目录 前言一、git分支管理基本命令1.1 基本命令2.1 实例 二、git分支合并冲突解决三、git merge命令与git rebase命令对比 前言 本篇文章介绍git分支管理的基本命令&#xff0c;并说明如何解决git分支合并冲突&#xff0c;最后说明git merge命令与git rebase命令的区别。 一、…

Python重拾

1.Python标识符规则 字母&#xff0c;下划线&#xff0c;数字&#xff1b;数字不开头&#xff1b;大小写区分&#xff1b;不能用保留字&#xff08;关键字&#xff09; 2.保留字有哪些 import keyword print(keyword.kwlist)[False, None, True, and,as, assert, async, await…

【AI兼职副业必看,行业分析+注意事项+具体应用,想要做点副业的小白必看!】

前言 随着AI技术的日新月异&#xff0c;它已悄然渗透到我们生活的每一个角落&#xff0c;成为了我们日常生活和工作中的得力助手。在当前经济下行的环境下&#xff0c;AI技术更是成为了提升工作效率、拓展业务领域的关键。对于我们普通人而言&#xff0c;有效利用AI工具&#…

应变计在工程中的角色:精准监测与安全保障的得力助手

在工程领域中&#xff0c;应变计作为一种重要的测量工具&#xff0c;扮演着精准监测与安全保障的得力助手的角色。它能够实时、准确地测量物体在受力作用下的变形情况&#xff0c;为工程师提供关键的数据支持&#xff0c;从而确保工程的稳定性与安全性。 应变计在工程中的应用范…

深度学习训练基于Pod和RDMA

目录 ​编辑 引言 RDMA技术概述 InfiniBand iWARP RoCE Pod和容器化环境 深度学习训练与RDMA结合 MPI和RDMA 深度学习框架与RDMA 实战&#xff1a;基于Pod和RDMA的深度学习训练 环境准备 步骤 YAML 性能和优势 结论 引言 随着深度学习在人工智能领域的快速发展…

2024数字孪生发展研究报告

来源&#xff1a;华为&ampamp中国信通院 近期历史回顾&#xff1a; 2023内蒙古畜牧业温室气体减排策略与路径研究报告-能源基金会.pdf 2023园区工商业配储项目储能系统技术方案.pdf 欧洲和美国储能市场盘点&#xff08;英文&#xff09;.pdf 2024年第1季度全球ESG监管政策…

Python爬取中国福彩网彩票数据并以图表形式显示

网页分析 首先打开中国福彩网&#xff0c;点击双色球&#xff0c;选择往期开奖栏目 进入栏目后&#xff0c;选定往期的奖金数目作为我们想要爬取的目标内容 明确目标后&#xff0c;开始寻找数据所在的位置 鼠标右击页面&#xff0c;打开网页源代码&#xff0c;在源代码中搜索…

B端系统:消息页面的设计要点

在B端系统中&#xff0c;消息页面的作用是为用户提供实时的通信和信息交流功能&#xff0c;以便用户能够及时获取和处理重要的业务消息和通知。设计一个好的消息页面可以提高用户的工作效率和沟通效果。 以下是一些建议来设计消息页面&#xff1a; 易于查看和管理&#xff1a;…

免费直播课程!6月30日

<面向人工智能领域的开发工程师&#xff0c;特别是机器学习/深度学习方向> 在这里报名听课&#xff1a; F学社-全球FPGA技术提升平台 (zzfpga.com) TIPS&#xff1a; 报名后将在页面内弹出「腾讯会议号和会议密码」&#xff0c;注意复制保存哦~

备考必备:NOC大赛2022图形化编程决赛真题与解析

为了大家备考2023-2024学年全国中小学信息技术创新与实践大赛&#xff08;NOC大赛&#xff09;&#xff0c;角逐恩欧希教育信息化发明创新奖&#xff0c;今天给大家分享2022年NOC大赛图形化编程决赛真题试卷。 下载&#xff1a;更多NOC大赛真题及其他资料在网盘-真题文件夹或者…

Java高级重点知识点-10-Object类

文章目录 Object类(java.lang) Object类(java.lang) Object类是Java语言中的根类&#xff0c;即所有类的父类 重点&#xff1a; public String toString()&#xff1a;返回该对象的字符串表示。 public class User {private String username;private String password;public…

JavaWeb系列十三: 服务器端渲染技术(JSP) 下

韩顺平 2. EL表达式2.1 EL表达式快速入门2.2 EL表达式输出形式2.3 el运算符2.4 empty运算2.5 EL获取四个域数据2.6 EL获取HTTP相关信息 3. JSTL标签库3.1 jstl 快速入门3.2 <c:set/>标签3.3 <c:if/>标签3.4 <c:choose/>标签3.5 <c:forEach/>标签3.6 作…

蓝牙技术|苹果iOS 18的第三方配件将支持AirPods / AirTag的配对体验

苹果公司在 iOS 18 系统中引入了名为 AccessorySetupKit 的新 API&#xff0c;用户不需要进入蓝牙设置和按下按钮&#xff0c;系统就能识别附近的配件&#xff0c;并提示用户进行配对。首次向配件制造商开放这种配对体验。 iPhone 用户升级 iOS 18、iPad 用户升级到 iPadOS 1…

SAP BC 修改 FINS_ACDOC_CUST116 ERROR 为 WARNING 信息

FI再改如下配置时报错了 消息号 FINS_ACDOC_CUST116 参考 SAP 消息控制_sap消息号更改w为e-CSDN博客 需要指出的是你必须注意做重要的三个表 T100:包含所有的message T100C:你定义的message通常将出现在此表 T100s:Configurable system messages顾名思义就是你能设置的消息…