使用DeepSpeed加速大型模型训练(二)

news2025/1/18 20:18:15

使用DeepSpeed加速大型模型训练

在这篇文章中,我们将了解如何利用Accelerate库来训练大型模型,从而使用户能够利用DeeSpeed的 ZeRO 功能。

简介

尝试训练大型模型时是否厌倦了内存不足 (OOM) 错误?我们已经为您提供了保障。大型模型性能非常好[1],但很难使用可用的硬件进行训练。为了充分利用可用硬件来训练大型模型,可以使用 ZeRO(零冗余优化器)[2] 来利用数据并行性。

下面是使用 ZeRO 的数据并行性的简短描述以及此博客文章中的图表

a. 第 1 阶段:分片优化器状态跨数据并行工作器/GPUs的状态
b. 第 2 阶段:分片优化器状态+梯度跨数据并行工作器/GPUs
c. 第 3 阶段:分片优化器状态+梯度+模型参数跨数据并行工作器/GPUs
d. 优化器卸载(Optimizer Offload):将梯度+优化器状态卸载到ZERO Stage 2之上的CPU/磁盘构建
e. 参数卸载(Param Offload):将模型参数卸载到ZERO Stage 3之上的CPU/磁盘构建

在这篇博文中,我们将看看如何使用ZeRO和Accelerate来利用数据并行性。DeepSpeed, FairScale和PyTorch fullyshardeddataparlparallel (FSDP)实现了ZERO论文的核心思想。伴随博客,通过DeepSpeed和FairScale 使用ZeRO适应更多和训练更快[4]和使用PyTorch完全分片并行数据加速大型模型训练[5], 这些已经集成在transformer Trainer和accelerate中。背后的解释可以看这些博客,主要集中在使用Accelerate来利用DeepSpeed ZeRO。

Accelerate :利用DeepSpeed ZeRO没有任何代码的变化

我们将研究只微调编码器模型用于文本分类的任务。我们将使用预训练microsoft/deberta-v2-xlarge-mnli (900M params) 用于对MRPC GLUE数据集进行微调。
代码可以在这里找到 run_cls_no_trainer.py它类似于这里的官方文本分类示例,只是增加了计算训练和验证时间的逻辑。让我们比较分布式数据并行(DDP)和DeepSpeed ZeRO Stage-2在多gpu设置中的性能。

要启用DeepSpeed ZeRO Stage-2而无需任何代码更改,请运行加速配置并利用 Accelerate DeepSpeed Plugin

ZeRO Stage-2 DeepSpeed Plugin 示例

配置

compute_environment: LOCAL_MACHINE
deepspeed_config:
 gradient_accumulation_steps: 1
 gradient_clipping: 1.0
 offload_optimizer_device: none
 offload_param_device: none
 zero3_init_flag: false
 zero_stage: 2
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false

现在,运行下面的命令进行训练

accelerate launch run_cls_no_trainer.py \
  --model_name_or_path "microsoft/deberta-v2-xlarge-mnli" \
  --task_name "mrpc" \
  --ignore_mismatched_sizes \
  --max_length 128 \
  --per_device_train_batch_size 40 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --output_dir "/tmp/mrpc/deepspeed_stage2/" \
  --with_tracking \
  --report_to "wandb" \

在我们的单节点多gpu设置中,DDP支持的无OOM错误的最大批处理大小是8。相比之下,DeepSpeed Zero-Stage 2允许批量大小为40,而不会出现OOM错误。因此,与DDP相比,DeepSpeed使每个GPU能够容纳5倍以上的数据。下面是wandb运行的图表快照,以及比较DDP和DeepSpeed的基准测试表。



表1:DeepSpeed ZeRO Stage-2在DeBERTa-XL (900M)模型上的基准测试
使用更大的批处理大小,我们观察到总训练时间加快了3.5倍,而性能指标没有下降,所有这些都没有改变任何代码。

为了能够调整更多选项,您将需要使用DeepSpeed配置文件和最小的代码更改。我们来看看怎么做。

Accelerate :利用 DeepSpeed 配置文件调整更多选项

首先,我们将研究微调序列到序列模型以训练我们自己的聊天机器人的任务。具体来说,我们将facebook/blenderbot-400M-distill在smangrul/MuDoConv(多域对话)数据集上进行微调。该数据集包含来自 10 个不同数据源的对话,涵盖角色、基于特定情感背景、目标导向(例如餐厅预订)和一般维基百科主题(例如板球)。

代码可以在这里找到run_seq2seq_no_trainer.py。目前有效衡量聊天机器人的参与度和人性化的做法是通过人工评估,这是昂贵的[6]。对于本例,跟踪的指标是BLEU分数(这不是理想的,但却是此类任务的常规指标)。如果您可以访问支持bfloat16精度的gpu,则可以调整代码以训练更大的T5模型,否则您将遇到NaN损失值。我们将在10000个训练样本和1000个评估样本上运行一个快速基准测试,因为我们对DeepSpeed和DDP感兴趣。我们将利用DeepSpeed Zero Stage-2 配置 zero2_config_accelerate.json(如下所示)进行训练。有关各种配置特性的详细信息,请参阅DeeSpeed文档。

{
    "fp16": {
        "enabled": "true",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 15,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "weight_decay": "auto",
            "torch_adam": true,
            "adam_w_mode": true
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

要使用上述配置启用DeepSpeed ZeRO Stage-2,请运行accelerate config并提供配置文件路径。更多详细信息,请参考accelerate官方文档中的DeepSpeed配置文件。

ZeRO Stage-2 DeepSpeed配置文件示例

compute_environment: LOCAL_MACHINE
deepspeed_config:
 deepspeed_config_file: /path/to/zero2_config_accelerate.json
 zero3_init_flag: false
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false

现在,运行下面的命令进行训练

accelerate launch run_seq2seq_no_trainer.py \
    --dataset_name "smangrul/MuDoConv" \
    --max_source_length 128 \
    --source_prefix "chatbot: " \
    --max_target_length 64 \
    --val_max_target_length 64 \
    --val_min_target_length 20 \
    --n_val_batch_generations 5 \
    --n_train 10000 \
    --n_val 1000 \
    --pad_to_max_length \
    --num_beams 10 \
    --model_name_or_path "facebook/blenderbot-400M-distill" \
    --per_device_train_batch_size 200 \
    --per_device_eval_batch_size 100 \
    --learning_rate 1e-6 \
    --weight_decay 0.0 \
    --num_train_epochs 1 \
    --gradient_accumulation_steps 1 \
    --num_warmup_steps 100 \
    --output_dir "/tmp/deepspeed_zero_stage2_accelerate_test" \
    --seed 25 \
    --logging_steps 100 \
    --with_tracking \
    --report_to "wandb" \
    --report_name "blenderbot_400M_finetuning"

当使用DeepSpeed配置时,如果用户在配置中指定了优化器和调度器,用户将不得不使用accelerate.utils.DummyOptimaccelerate.utils.DummyScheduler。这些都是用户需要做的小改动。下面我们展示了使用DeepSpeed配置时所需的最小更改的示例

- optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=args.learning_rate)
+ optimizer = accelerate.utils.DummyOptim(optimizer_grouped_parameters, lr=args.learning_rate)

- lr_scheduler = get_scheduler(
-     name=args.lr_scheduler_type,
-     optimizer=optimizer,
-     num_warmup_steps=args.num_warmup_steps,
-     num_training_steps=args.max_train_steps,
- )

+ lr_scheduler = accelerate.utils.DummyScheduler(
+     optimizer, total_num_steps=args.max_train_steps, warmup_num_steps=args.num_warmup_steps
+ )


表2:DeepSpeed ZeRO Stage-2在BlenderBot (400M)模型上的基准测试
在我们的单节点多gpu设置中,DDP支持的无OOM错误的最大批处理大小是100。相比之下,DeepSpeed Zero-Stage 2允许批量大小为200,而不会出现OOM错误。因此,与DDP相比,DeepSpeed使每个GPU能够容纳2倍以上的数据。我们观察到训练加速了1.44倍,评估加速了1.23倍,因为我们能够在相同的可用硬件上容纳更多的数据。由于这个模型是中等大小,加速不是那么令人兴奋,但这将改善与更大的模型。你可以和使用Space smangrul/Chat-E上的全部数据训练的聊天机器人聊天。你可以给机器人一个角色,一个特定情感的对话,用于目标导向的任务或自由流动的方式。下面是与聊天机器人的有趣对话。您可以在这里找到使用不同上下文的更多对话的快照。

CPU/磁盘卸载,使训练庞大的模型将不适合GPU内存

在单个24GB NVIDIA Titan RTX GPU上,即使批量大小为1,也无法训练GPT-XL模型(1.5B参数)。我们将看看如何使用DeepSpeed ZeRO Stage-3与CPU卸载优化器状态,梯度和参数来训练GPT-XL模型。

我们将利用DeepSpeed Zero Stage-3 CPU卸载配置zero3_offload_config_accelerate.json (如下所示)进行训练。

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "sub_group_size": 1e9,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

ZeRO Stage-3 CPU Offload DeepSpeed配置文件示例

compute_environment: LOCAL_MACHINE
deepspeed_config:
 deepspeed_config_file: /path/to/zero3_offload_config_accelerate.json
 zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false

现在,运行下面的命令进行训练

表3:在GPT-XL (1.5B)模型上对DeepSpeed ZeRO Stage-3 CPU Offload进行基准测试
即使批大小为1,DDP也会导致OOM错误。另一方面,使用DeepSpeed ZeRO Stage-3 CPU卸载,我们可以以16个batch_size大小进行训练。

最后,请记住,Accelerate只集成了DeepSpeed,因此,如果您对DeepSpeed的使用有任何问题或疑问,请向DeepSpeed GitHub提交问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/997562.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

倒⽴摆闭环控制的设计与开发

倒立摆是一种典型的多变量、高阶次、非线性、强耦合、自然不稳定系统,常被用来检验新的控制理论和算法的正确性及其在实际应用中的有效性。 load func_ip_comp% Locations of Poles and Zeroes of Open-Loop Compensated Transfer Function in Complex Plane figur…

查看当前所有的数据库

MySQL从小白到总裁完整教程目录:https://blog.csdn.net/weixin_67859959/article/details/129334507?spm1001.2014.3001.5502 先看下服务有没有启动,看下我这个是什么意思 exit 退出MySQL环境回车后是这样的 重新进入MySQL环境 查看当前所有的数据库 show datebases; mysql&…

nrf523832 串口点LED

/* P0.06:串口发送TXD P0.08:串口接收RXD P0.05:串口RTS:发送请求,硬件流控开启时有效 P0.07:串口CTS:发送允许,硬件流控开启时有效 */ #define RX_PIN_NUMBER 8 #define TX_PIN_N…

学习Bootstrap 5的第十天

目录 卡片 基础的卡片 实例 页眉和页脚 实例 多种颜色卡片 实例 标题、文本和链接 实例 图片卡片 实例 卡片图像叠加 实例 下拉菜单 基础的下拉列表 实例 下拉列表分隔线 实例 下拉列表标题 实例 禁用的和活动的项目 实例 下拉列表位置 实例 下拉菜单…

SpringMVC之文件的上传下载(教你如何使用有关SpringMVC知识实现文件上传下载的超详细博客)

目录 前言 一、文件上传 1. 配置多功能视图解析器(spring-mvc.xml) 2. 添加文件上传页面(upload.jsp) upload.jsp 3.做硬盘网络路径映射 4. 编写一个处理页面跳转的类 PageController.java ClazzController.java 5. 初步模…

【数据分享】STRM 90米分辨率DEM地形数据(无需转发/全国/分省/分市)

地形数据是我们在各种设计、研究中都经常使用的基础数据!之前我们分享过12.5米精度的DEM地形数据、30米精度的DEM地形数据(均可查看之前的文章获悉详情)。 本次给大家带来的是90米分辨率的DEM地形数据——STRM 90m高程数据!该数据是由美国太…

学成在线-网站搭建

文章目录 代码素材来自b站pink老师 <!DOCTYPE html> <html lang="en"> <head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>学成在线首…

《PWA实战:如何为你的网站增加离线功能和推送通知》

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

Hbase解决ERROR: KeeperErrorCode = ConnectionLoss for /hbase/master报错

1、在单机模式中&#xff0c;要先修改一个文件&#xff1a;/usr/local/hbase/conf/hbase-site.xml hbase-site.xml内容&#xff1a; <configuration><property><name>hbase.rootdir</name><value>file:///usr/local/hbase/hbase-tmp</value…

dji uav建图导航系列()move_base

文章目录 1、导航框架2、move_base功能包3、amcl功能包4、代价地图的配置4.1、通用配置文件4.2、全局规划配置文件4.3、局部规划配置文件5、局部规划器配置6、launch文件1、导航框架 导航的关键是机器人定位和路径规划两大部分 move_base:实现机器人导航中的最优路径规划 am…

ES8生产实践——pod日志采集(Elastic Agent方案)

pod日志采集方案 方案选型 DaemonSetElastic Agent方案&#xff1a;使用DaemonSet控制器在每个kubernetes集群节点上运行elastic agent服务&#xff0c;业务容器日志目录统一挂载到节点指定目录下。在fleet中配置集成Custom Logs集成策略&#xff0c;指定日志采集目录和inges…

【C语言 C++ 源码】课程设计 学生成绩管理系统

文章目录 演示 对学生的信息、学科成绩进行管理&#xff0c;并进行统计。 对信息进行读写文件操作。自动保存数据到文件。 演示 此课设分为C语言和C两个版本。对于学生数据&#xff0c;会自动保存数据到本地&#xff0c;下次运行自动读取数据。 部分源码&#xff1a; printf…

OceanBase 来参加外滩大会了(内附干货PPT)

9 月 7 日至 9 日&#xff0c;2023 inclusion外滩大会在上海黄浦世博园区举办。8 日&#xff0c;由赛迪顾问与 OceanBase 联合主办的外滩大会“分布式数据库助力数实融合”见解论坛圆满落幕。 数字经济加速发展&#xff0c;数字化转型进入深水区&#xff0c;企业对海量数据的存…

大数据课程L5——网站流量项目的实时业务系统搭建

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 掌握网站流量项目的 Flume—>Kafka 连通; ⚪ 掌握网站流量项目的实时业务系统搭建; 一、Flume—>Kafka 连通 1. 实现步骤 1. 启动三台服务器。 2. 启动 Zookeeper 集群。 执行指…

I2C总线驱动:裸机版、应用层的使用、二级外设驱动三种方法

一、I2C总线背景知识 SOC芯片平台的外设分为&#xff1a; 一级外设&#xff1a;外设控制器集成在SOC芯片内部二级外设&#xff1a;外设控制器由另一块芯片负责&#xff0c;通过一些通讯总线与SOC芯片相连 Inter-Integrated Circuit&#xff1a; 字面意思是用于“集成电路之间…

解决absolute绝对定位带来的div穿透问题

首先来看症状&#xff1a; 按理说蓝色和红色div应该并排同行显示&#xff0c;但是很明显&#xff1a;两个元素重叠了 代码如下&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv&q…

RNN项目实战——文本输入与预测

在本篇博客文章中&#xff0c;我们将使用pytorch搭建RNN模型来生成文本。 文本输入 神经网络不像人类那样善于处理文本数据。因为绝大多数的NLP任务&#xff0c;文本数据都会先通过嵌入码&#xff08;Embedding code)&#xff0c;独热编码(One-hot encoding)等方式转为数字编码…

python DVWA命令注入POC练习

这里同样是抓包&#xff0c;访问DVWA低难度的命令注入 <?phpif( isset( $_POST[ Submit ] ) ) {// Get input$target $_REQUEST[ ip ];// Determine OS and execute the ping command.if( stristr( php_uname( s ), Windows NT ) ) {// Windows$cmd shell_exec( ping …

RHCSA-VM-Linux基础配置命令

1.代码命令 1.查看本机IP地址&#xff1a; ip addr 或者 ip a [foxbogon ~]$ ip addre [foxbogon ~]$ ip a 1&#xff1a;<Loopback,U,LOWER-UP> 为环回2网卡 2: ens160: <BROADCAST,MULTICAST,UP,LOWER_UP>为虚拟机自身网卡 2.测试网络联通性&#xff1a; [f…

浅谈OPenGL中的纹理过滤

纹理图像中的纹理单元和屏幕上的像素几乎从来不会形成一对一的对应关系。如果程序员足够细心&#xff0c;确实可以实现这个效果&#xff0c;但这需要在对几何图形进行纹理贴图时进行精心的计划&#xff0c;使出现在屏幕上的纹理单元和像素能够对齐&#xff08;实际上在用OPenGL…