【Python】科研代码学习:七 TrainingArguments,Trainer

news2024/11/18 8:32:33

【Python】科研代码学习:七 TrainingArguments,Trainer

  • TrainingArguments
    • 重要的方法
  • Trainer
    • 重要的方法
    • 使用 Trainer 的简单例子

TrainingArguments

  • HF官网API:Training
    众所周知,推理是一个大头,训练是另一个大头
    之前的很多内容,都是为训练这里做了一个小铺垫
    如何快速有效地调用代码,训练大模型,才是重中之重(不然学那么多HF库感觉怪吃苦的)
  • 首先看训练参数,再看训练器吧。
    首先,它的头文件是 transformers.TrainingArguments
    再看它源码的参数,我勒个去,太多了吧。
    ※ 我这里挑重要的讲解,全部请看API去。
  • output_dir (str)设置模型输出预测,或者中继点 (checkpoints) 的输出目录。模型训练到一半,肯定需要有中继点文件的嘛,就相当于游戏存档有很多一样,防止跑一半直接程序炸了,还要从头训练
  • overwrite_output_dir (bool, optional, defaults to False):把这个参数设置成 True,就会覆盖其中 output_dir 中的文档。一般在从中继点继续训练时需要这么用
  • do_train (bool, optional, defaults to False):指明我在做训练集的训练任务
  • do_eval (bool, optional):指明我在做验证集的评估任务
  • do_predict (bool, optional, defaults to False) :指明我在做测试集的预测任务
  • evaluation_strategy :评估策略:训练时不评估 / 每 eval_steps 步评估,或者每 epoch 评估
"no": No evaluation is done during training.
"steps": Evaluation is done (and logged) every eval_steps.
"epoch": Evaluation is done at the end of each epoch.
  • per_device_train_batch_size :训练时每张卡的batch大小,默认为8
    per_device_eval_batch_size :评估时每张卡的batch大小,默认为8
  • learning_rate (float, optional, defaults to 5e-5):学习率,里面使用的是 AdamW optimizer
    其他相应的 AdamW Optimizer 的参数还有:
    weight_decay adam_beta1adam_beta2adam_epsilon
  • num_train_epochs:训练的 epoch 个数,默认为3,可以设置小数。
  • lr_scheduler_type:具体作用要查看 transformers 里的 Scheduler 是干什么用的
  • warmup_ratiowarmup_steps :让一开始的学习率从0逐渐升到 learning_rate 用的
  • logging_dir :设置 logging 输出的文档
    除此之外还有一些和 logging相关的参数:
    logging_strategy ,logging_first_step ,logging_steps ,logging_nan_inf_filter 设置日志的策略
  • 与保存模型中继文件相关的参数:
    save_strategy :不保存中继文件 / 每 epoch 保存 / 每 save_steps 步保存
"no": No save is done during training.
"epoch": Save is done at the end of each epoch.
"steps": Save is done every save_steps.

save_steps :如果是整数,表示多少步保存一次;小数,则是按照总训练步,多少比例之后保存一次
save_total_limit :最多中继文件的保存上限,如果超过上限,会先把最旧的那个中继文件删了再保存新的
save_safetensors :使用 savetensor来存储和加载 tensors,默认为 True
push_to_hub :是否保存到 HF hub

  • use_cpu (bool, optional, defaults to False):是否用 cpu 训练
  • seed (int, optional, defaults to 42) :训练的种子,方便复现和可重复实验
  • data_seed :数据采样的种子
  • 数据精读相关的一些参数:
    FP32、TF32、FP16、BF16、FP8、FP4、NF4、INT8
    bf16 (bool, optional, defaults to False)fp16 (bool, optional, defaults to False)
    tf32 (bool, optional)
  • run_name :展示在 wandb and mlflow logging 中的描述
  • load_best_model_at_end (bool, optional, defaults to False):是否保存效果最好的中继点作为最终模型,与 save_total_limit 有些交互操作
    如果上述设置成 True 的话,考虑 metric_for_best_model ,即如何评估效果最好。默认为 loss 即损失最小
    如果你修改了 metric_for_best_model 的话,考虑 greater_is_better ,即指标越大越好还是越小越好
  • 一些加速相关的参数,貌似都比较麻烦
    fsdp
    fsdp_config
    deepspeed
    accelerator_config
  • optim :设置 optimizer,默认为 adamw_torch
    也可以设置成 adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
  • resume_from_checkpoint :传入中继点文件的目录,从中继点继续训练

重要的方法

  • ※ 那我怎么访问或者修改上述参数呢?
    由于这个需要实例化,所以我们需要使用OO的方法修改
    下面讲一下其中重要的方法
  • set_dataloader:设置 dataloader
    在这里插入图片描述
from transformers import TrainingArguments

args = TrainingArguments("working_dir")
args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
args.per_device_train_batch_size
  • 设置 logging 相关的参数
    在这里插入图片描述
  • 设置 optimizer
    在这里插入图片描述
  • 设置保存策略
    在这里插入图片描述
  • 设置训练策略
    在这里插入图片描述
  • 设置评估策略
    在这里插入图片描述
  • 设置测试策略
    在这里插入图片描述

Trainer

  • 终于到大头了。Trainer 是主要用 pt 训练的,主要支持 GPUs (NVIDIA GPUs, AMD GPUs)/ TPUs
  • 看下源码,它要的东西不少,讲下重要参数:
  • model:要么是 transformers.PretrainedModel 类型的,要么是简单的 torch.nn.Module 类型的
  • argsTrainingArguments 类型的训练参数。如果不提供的话,默认使用 output_dir/tmp_trainer 里面的那个训练参数
  • data_collator DataCollator 类型参数,给训练集或验证集做数据分批和预处理用的,如果没有tokenizer默认使用 default_data_collator,否则默认使用 DataCollatorWithPadding (Will default to default_data_collator() if no tokenizer is provided, an instance of DataCollatorWithPadding otherwise.)
  • train_dataset (torch.utils.data.Dataset or torch.utils.data.IterableDataset, optional) :提供训练的数据集,当然也可以是 Datasets 类型的数据
  • eval_dataset :类似的验证集的数据集
  • tokenizer :提供 tokenizer 分词器
  • compute_metrics :验证集使用时候的计算指标,具体得参考 EvalPrediction 类型
  • optimizers :可以提供 Tuple(optimizer, scheduler)。默认使用 AdamW 以及 get_linear_schedule_with_warmup() controlled by args
    在这里插入图片描述

重要的方法

  • compute_loss:设置如何计算损失
    在这里插入图片描述
  • train:设置训练集训练任务,第一个参数可以设置是否从中继点开始训练
    在这里插入图片描述
  • evaluate:设置验证集评估任务,需要提供验证集
    在这里插入图片描述
  • predict:设置测试集任务
    在这里插入图片描述
  • save_model:保存模型参数到 output_dir在这里插入图片描述
  • training_step:设置每一个训练的 step,把一个batch的输入经过了何种操作,得到一个 torch.Tensor
    在这里插入图片描述

使用 Trainer 的简单例子

  • 主要就是加载一些参数,传进去即可
    模型、训练参数、训练集、验证集、计算指标
    调用训练方法 .train()
    最后保存模型即可 .save_model()
from transformers import (
    Trainer,
    )
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.save_model(outputdir="./xxx")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1506096.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 理解进程

目录 一、基本概念 二、描述进程-PCB 1、task_struct-PCB的一种 2、task_ struct内容分类 三、组织进程 四、查看进程 1、ps指令 2、top命令 3、/proc文件系统 4、在/proc文件中查看指定进程 5、进程的工作目录 五、通过系统调用获取进程标示符 1、getpid()/get…

空间复杂度(数据结构)

概念: 空间复杂度也是一个数学表达式,是对一个算法在运行过程中临时占用存储空间大小的量度 。 空间复杂度不是程序占用了多少bytes的空间,因为这个也没太大意义,所以空间复杂度算的是变量的个数。空间复杂度计算规则基本跟实践复…

nicegui学习使用

https://www.douyin.com/shipin/7283814177230178363 python轻量级高自由度web框架 - NiceGUI (6) - 知乎 python做界面,为什么我会强烈推荐nicegui 秒杀官方实现,python界面库,去掉90%事件代码的nicegui python web GUI框架-NiceGUI 教程…

EI级 | Matlab实现PCA-GCN主成分降维结合图卷积神经网络的数据多特征分类预测

EI级 | Matlab实现PCA-GCN主成分降维结合图卷积神经网络的数据多特征分类预测 目录 EI级 | Matlab实现PCA-GCN主成分降维结合图卷积神经网络的数据多特征分类预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现PCA-GCN主成分降维结合图卷积神经网络的数据多…

用conda创建虚拟环境

下载好conda之后,在跑代码之前,可以用conda来创建虚拟环境,然后在虚拟环境中下载包pip之类的。 创建步骤如下: 1.conda create --name hhh 其中hhh为我的虚拟环境的名字,之后选择y即yes即可继续创建 可以看到&#…

LVS集群 ----------------(直接路由 )DR模式部署 (二)

一、LVS集群的三种工作模式 lvs-nat:修改请求报文的目标IP,多目标IP的DNAT lvs-dr:操纵封装新的MAC地址(直接路由) lvs-tun:隧道模式 lvs-dr 是 LVS集群的 默认工作模式 NAT通过网络地址转换实现的虚拟服务器&…

springcloud第3季 consul服务发现注册,配置中心2

一 consul的作用 1.1 为何使用注册中心 为何要用注册中心? 1.A服务调用B服务,使用ip和端口,如果B服务的ip或者端口发生变化,服务A需要进行改动; 2.如果在分布式集群中,部署多个服务B,多个服…

【开发工具】认识Git | 认识工作区、暂存区、版本库

文章目录 一、Git初识git本质上是一个版本控制器 二、Git的安装 - CentOS三、Git基本操作1. 创建Git本地仓库2. 配置Git3. 认识工作区、暂存区、版本库4. 版本回退5. 撤销修改情况1:对于工作区的代码,还没有add情况二:已经add ,但…

有哪些平台可以赚些零花钱?分享7个副业兼职平台

正规可靠的兼职副业平台有很多,以下是一些常见的平台: 1,微头条 微头条是一种短文本分享平台,通过精简和优化文字,以吸引读者的注意力。需要在有限的字数内表达清晰明了的观点,关键词的准确使用是关键。例…

不允许你不知道Python作用域

在Python中,变量的作用域限制非常重要。根据作用域分类,有局部、全局、函数和内建作用域。无作用域限制的变量可以在分支语句和循环中定义,并在外部直接访问。不同的作用域决定了变量的可访问范围,访问权限取决于变量的位置。 1.…

面试经典150题 -- 图的广度优先遍历 (总结)

总的链接 面试经典 150 题 - 学习计划 - 力扣(LeetCode)全球极客挚爱的技术成长平台 909 . 蛇梯棋 链接 : . - 力扣(LeetCode) 题意 : 直接bfs就好了 , 题意难以理解 : class Solution:def snakesA…

虚拟机中安装Win98

文章目录 一、下载Win98二、制作可启动光盘三、VMware中安装Win98四、Qemu中安装Win981. Qemu的安装2. 安装Win98 Win98是微软于1998年发布的16位与32位混合的操作系统,也是一代经典的操作系统,期间出现了不少经典的软件与游戏,还是值得怀念的…

office办公软件太贵了 Microsoft的Word为什么要买 Microsoft365家庭版多少钱 Microsoft365密钥

Microsoft office是一个被广泛使用的办公软件,它包括了 Word、Excel、PowerPoint 等多种常用的应用程序,已成为许多企业、机构和个人必备的工具。 首先,要理解 Microsoft Office 的价格,我们需要考虑到它的功能和市场需求。Micro…

Pycharm使用教程

1.设置字体型号与大小 file->setting->editor->font(字型),size(大小) 2.设置背景颜色 file->setting->editor->color scheme->scheme 3.注释/取消注释 ctrl/ 选中需要注释的部分,双击ctrl/ 取消注释则选…

揭秘数据中心幕后:从电力消耗到温度调控的策略

建设并运营数据中心并非简单的连接硬盘、通电和联网就可以,而是涉及复杂的硬件集成、能源管理、散热设计以及适应不断增长的数据处理和存储需求等诸多挑战。随着全球互联网的普及和AI技术的快速发展,数据中心的规模和能耗需求都在急剧增加。尤其是在电力…

Vue.js计算属性:实现数据驱动的利器

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

新雀优化算法NOA求解机器人栅格地图最短路径规划,可以自定义地图(提供MATLAB代码)

一、星雀优化算法 星雀优化算法(Nutcracker optimizer algorithm,NOA)由Mohamed Abdel-Basset等人于2023年提出,该算法模拟星雀的两种行为,即:在夏秋季节收集并储存食物,在春冬季节搜索食物的存储位置。CEC2005:星雀优化算法(Nut…

判断链表回文

题目&#xff1a; //方法一&#xff0c;空间复杂度O(n) class Solution { public:bool isPalindrome(ListNode* head) {vector<int> nums; //放进数组后用双指针判断ListNode* cur head;while(cur){nums.emplace_back(cur->val);cur cur->next;}for(int i0…

Spring MVC 全局异常处理器

如果不加以异常处理&#xff0c;错误信息肯定会抛在浏览器页面上&#xff0c;这样很不友好&#xff0c;所以必须进行异常处理。 1.异常处理思路 系统的dao、service、controller出现都通过throws Exception向上抛出&#xff0c;最后由springmvc前端控制器交由异常处理器进行异…

【玩转Linux】有关Linux权限

目录 一.Linux权限的概念 1. 权限的本质 2.Linux中的用户 3.Linux中的权限管理 (1)文件访问者的分类 (2)文件类型和访问权限&#xff08;事物属性&#xff09; ①文件基本权限 ②文件权限值的表示方法 (3)文件访问权限的相关设置方法 ① 用 户 表 示 符 / - 权 …