使用 PyTorch 完全分片数据并行技术加速大模型训练

news2025/1/12 2:50:02

本文,我们将了解如何基于 PyTorch 最新的 完全分片数据并行 (Fully Sharded Data Parallel,FSDP) 功能用 Accelerate 库来训练大模型。

动机 🤗

随着机器学习 (ML) 模型的规模、大小和参数量的不断增加,ML 从业者发现在自己的硬件上训练甚至加载如此大的模型变得越来越难。 一方面,人们发现大模型与较小的模型相比,学习速度更快 (数据和计算效率更高) 且会有显著的提升 [1]; 另一方面,在大多数硬件上训练此类模型变得令人望而却步。

分布式训练是训练这些机器学习大模型的关键。大规模分布式训练 领域最近取得了不少重大进展,我们将其中一些最突出的进展总结如下:

  1. 使用 ZeRO 数据并行 - 零冗余优化器 [2]

  2. 阶段 1: 跨数据并行进程 / GPU 对优化器状态 进行分片

  3. 阶段 2: 跨数据并行进程/ GPU 对优化器状态 + 梯度 进行分片

  4. 阶段 3: 跨数据并行进程 / GPU 对优化器状态 + 梯度 + 模型参数 进行分片

  5. CPU 卸载: 进一步将 ZeRO 阶段 2 的优化器状态 + 梯度 卸载到 CPU 上 [3]

  6. 张量并行 [4]: 模型并行的一种形式,通过对各层参数进行精巧的跨加速器 / GPU 分片,在实现并行计算的同时避免了昂贵的通信同步开销。

  7. 流水线并行 [5]: 模型并行的另一种形式,其将模型的不同层放在不同的加速器 / GPU 上,并利用流水线来保持所有加速器同时运行。举个例子,在第 2 个加速器 / GPU 对第 1 个 micro batch 进行计算的同时,第 1 个加速器 / GPU 对第 2 个 micro batch 进行计算。

  8. 3D 并行 [3]: 采用 ZeRO 数据并行 + 张量并行 + 流水线并行 的方式来训练数百亿参数的大模型。例如,BigScience 176B 语言模型就采用了该并行方式 [6]。

本文我们主要关注 ZeRO 数据并行,更具体地讲是 PyTorch 最新的 完全分片数据并行 (Fully Sharded Data Parallel,FSDP) 功能。DeepSpeedFairScale 实现了 ZeRO 论文的核心思想。我们已经将其集成到了 transformersTrainer 中,详见博文 通过 DeepSpeed 和 FairScale 使用 ZeRO 进行更大更快的训练[10]。最近,PyTorch 已正式将 Fairscale FSDP 整合进其 Distributed 模块中,并增加了更多的优化。

Accelerate 🚀: 无需更改任何代码即可使用 PyTorch FSDP

我们以基于 GPT-2 的 Large (762M) 和 XL (1.5B) 模型的因果语言建模任务为例。

以下是预训练 GPT-2 模型的代码。其与 此处 的官方因果语言建模示例相似,仅增加了 2 个参数 n_train (2000) 和 n_val (500) 以防止对整个数据集进行预处理/训练,从而支持更快地进行概念验证。

run_clm_no_trainer.py

运行 accelerate config 命令后得到的 FSDP 配置示例如下:

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
fsdp_config:
  min_num_params: 2000
  offload_params: false
  sharding_strategy: 1
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
use_cpu: false

多 GPU FSDP

本文我们使用单节点多 GPU 上作为实验平台。我们比较了分布式数据并行 (DDP) 和 FSDP 在各种不同配置下的性能。我们可以看到,对 GPT-2 Large(762M) 模型而言,DDP 尚能够支持其中某些 batch size 而不会引起内存不足 (OOM) 错误。但当使用 GPT-2 XL (1.5B) 时,即使 batch size 为 1,DDP 也会失败并出现 OOM 错误。同时,我们看到,FSDP 可以支持以更大的 batch size 训练 GPT-2 Large 模型,同时它还可以使用较大的 batch size 训练 DDP 训练不了的 GPT-2 XL 模型。

硬件配置: 2 张 24GB 英伟达 Titan RTX GPU。

GPT-2 Large 模型 (762M 参数) 的训练命令如下:

export BS=#`try with different batch sizes till you don't get OOM error,
#i.e., start with larger batch size and go on decreasing till it fits on GPU`

time accelerate launch run_clm_no_trainer.py \
--model_name_or_path gpt2-large \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size $BS
--per_device_eval_batch_size $BS
--num_train_epochs 1
--block_size 12

FSDP 运行截屏:

e08a91b1c71639b487adf63a4d110487.png
FSDP 运行截屏

cf9eae5538e60d6e6ee6d6584bdf3497.png

表 1: GPT-2 Large (762M) 模型 FSDP 训练性能基准测试

从表 1 中我们可以看到,相对于 DDP 而言,FSDP 支持更大的 batch size,在不使用和使用 CPU 卸载设置的情况下 FSDP 支持的最大 batch size 分别可达 DDP 的 2 倍及 3 倍。从训练时间来看,混合精度的 DDP 最快,其后是分别使用 ZeRO 阶段 2 和阶段 3 的 FSDP。由于因果语言建模的任务的上下文序列长度 ( --block_size ) 是固定的,因此 FSDP 在训练时间上加速还不是太高。对于动态 batch size 的应用而言,支持更大 batch size 的 FSDP 可能会在训练时间方面有更大的加速。目前,FSDP 的混合精度支持在 transformers 上还存在一些 问题。一旦问题解决,训练时间将会进一步显著缩短。

使用 CPU 卸载来支持放不进 GPU 显存的大模型训练

训练 GPT-2 XL (1.5B) 模型的命令如下:

export BS=#`try with different batch sizes till you don't get OOM error,
#i.e., start with larger batch size and go on decreasing till it fits on GPU`

time accelerate launch run_clm_no_trainer.py \
--model_name_or_path gpt2-xl \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size $BS
--per_device_eval_batch_size $BS
--num_train_epochs 1
--block_size 12

886f460a66a49a0d6e9a2a5caf5bcf5c.png

表 2: GPT-2 XL (1.5B) 模型上的 FSDP 基准测试

从表 2 中,我们可以观察到 DDP (带和不带 fp16) 甚至在 batch size 为 1 的情况下就会出现 CUDA OOM 错误,从而无法运行。而开启了 ZeRO- 阶段 3 的 FSDP 能够以 batch size 为 5 (总 batch size = 10 (5 2) ) 在 2 个 GPU 上运行。当使用 2 个 GPU 时,开启了 CPU 卸载的 FSDP 还能将最大 batch size 进一步增加到每 GPU 14。开启了 CPU 卸载的 FSDP 可以在单个 GPU 上训练 GPT-2 1.5B 模型,batch size 为 10。这使得机器学习从业者能够用最少的计算资源来训练大模型,从而助力大模型训练民主化。

Accelerate 的 FSDP 集成的功能和限制

下面,我们深入了解以下 Accelerate 对 FSDP 的集成中,支持了那些功能,有什么已知的限制。

支持 FSDP 所需的 PyTorch 版本: PyTorch Nightly 或 1.12.0 之后的版本。

命令行支持的配置:

  1. 分片策略: [1] FULL_SHARD, [2] SHARD_GRAD_OP

  2. Min Num Params: FSDP 默认自动包装的最小参数量。

  3. Offload Params: 是否将参数和梯度卸载到 CPU。

如果想要对更多的控制参数进行配置,用户可以利用 FullyShardedDataParallelPlugin ,其可以指定 auto_wrap_policybackward_prefetch 以及 ignored_modules

创建该类的实例后,用户可以在创建 Accelerator 对象时把该实例传进去。

有关这些选项的更多信息,请参阅 PyTorch FullyShardedDataParallel 代码。

接下来,我们体会下 min_num_params 配置的重要性。以下内容摘自 [8],它详细说明了 FSDP 自动包装策略的重要性。

e6b0247df6d55bfccbbe851dbda9632f.png
FSDP 自动包装策略的重要性

(图源: 链接)

当使用 default_auto_wrap_policy 时,如果该层的参数量超过 min_num_params ,则该层将被包装在一个 FSDP 模块中。官方有一个在 GLUE MRPC 任务上微调 BERT-Large (330M) 模型的示例代码,其完整地展示了如何正确使用 FSDP 功能,其中还包含了用于跟踪峰值内存使用情况的代码。

fsdp_with_peak_mem_tracking.py

我们利用 Accelerate 的跟踪功能来记录训练和评估期间的峰值内存使用情况以及模型准确率指标。下图展示了 wandb 实验台 页面的截图。

9a3680fd0bda035679b75eb865196444.png
wandb 实验台

我们可以看到,DDP 占用的内存是使用了自动模型包装功能的 FSDP 的两倍。不带自动模型包装的 FSDP 比带自动模型包装的 FSDP 的内存占用更多,但比 DDP 少得多。与 min_num_params=1M 时相比, min_num_params=2k 时带自动模型包装的 FSDP 占用的内存略少。这凸显了 FSDP 自动模型包装策略的重要性,用户应该调整 min_num_params 以找到能显著节省内存又不会导致大量通信开销的设置。如 [8] 中所述,PyTorch 团队也在为此开发自动配置调优工具。

需要注意的一些事项

  • PyTorch FSDP 会自动对模型子模块进行包装、将参数摊平并对其进行原位分片。因此,在模型包装之前创建的任何优化器都会被破坏并导致更多的内存占用。因此,强烈建议在对模型调用 prepare 方法后再创建优化器,这样效率会更高。对单模型而言,如果没有按照顺序调用的话, Accelerate 会抛出以下告警信息,并自动帮你包装模型并创建优化器。

    FSDP Warning: When using FSDP, it is efficient and recommended to call prepare for the model before creating the optimizer

即使如此,我们还是推荐用户在使用 FSDP 时用以下方式显式准备模型和优化器:

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)
+ model = accelerator.prepare(model)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)

- model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(model,
- optimizer, train_dataloader, eval_dataloader, lr_scheduler
- )

+ optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
+ optimizer, train_dataloader, eval_dataloader, lr_scheduler
+ )
  • 对单模型而言,如果你的模型有多组参数,而你想为它们设置不同优化器超参。此时,如果你对整个模型统一调用 prepare 方法,这些参数的组别信息会丢失,你会看到如下告警信息:

    FSDP Warning: When using FSDP, several parameter groups will be conflated into a single one due to nested module wrapping and parameter flattening.

告警信息表明,在使用 FSDP 对模型进行包装后,之前创建的参数组信息丢失了。因为 FSDP 会将嵌套式的模块参数摊平为一维数组 (一个数组可能包含多个子模块的参数)。举个例子,下面是 GPU 0 上 FSDP 模型的有名称的参数 (当使用 2 个 GPU 时,FSDP 会把第一个分片的参数给 GPU 0, 因此其一维数组中大约会有 55M (110M / 2) 个参数)。此时,如果我们在 FSDP 包装前将 BERT-Base 模型的 [bias, LayerNorm.weight] 参数的权重衰减设为 0,则在模型包装后,该设置将无效。原因是,你可以看到下面这些字符串中均已不含这俩参数的名字,这俩参数已经被并入了其他层。想要了解更多细节,可参阅本 问题 (其中写道: 原模型参数没有 .grads 属性意味着它们无法单独被优化器优化 (这就是我们为什么不能支持对多组参数设置不同的优化器超参) )。

{
'_fsdp_wrapped_module.flat_param': torch.Size([494209]),

'_fsdp_wrapped_module._fpw_module.bert.embeddings.word_embeddings._fsdp_wrapped_module.flat_param': torch.Size([11720448]),

'_fsdp_wrapped_module._fpw_module.bert.encoder._fsdp_wrapped_module.flat_param': torch.Size([42527232])
}
  • 如果是多模型情况,须在创建优化器之前调用模型 prepare 方法,否则会抛出错误。

  • FSDP 目前不支持混合精度,我们正在等待 PyTorch 修复对其的支持。

工作原理 📝

405eaa20c335cff37eaef89babfb2a9f.png
FSDP 工作流

(图源: 链接)

上述工作流概述了 FSDP 的幕后流程。我们先来了解一下 DDP 是如何工作的,然后再看 FSDP 是如何改进它的。在 DDP 中,每个工作进程 (加速器 / GPU) 都会保留一份模型的所有参数、梯度和优化器状态的副本。每个工作进程会获取不同的数据,这些数据会经过前向传播,计算损失,然后再反向传播以生成梯度。接着,执行 all-reduce 操作,此时每个工作进程从其余工作进程获取梯度并取平均。这样一轮下来,每个工作进程上的梯度都是相同的,且都是全局梯度,接着优化器再用这些梯度来更新模型参数。我们可以看到,每个 GPU 上都保留完整副本会消耗大量的显存,这限制了该方法所能支持的 batch size 以及模型尺寸。

FSDP 通过让各数据并行工作进程分片存储优化器状态、梯度和模型参数来解决这个问题。进一步地,还可以通过将这些张量卸载到 CPU 内存来支持那些 GPU 显存容纳不下的大模型。在具体运行时,与 DDP 类似,FSDP 的每个工作进程获取不同的数据。在前向传播过程中,如果启用了 CPU 卸载,则首先将本地分片的参数搬到 GPU/加速器。然后,每个工作进程对给定的 FSDP 包装模块/层执行 all-gather 操作以获取所需的参数,执行计算,然后释放/清空其他工作进程的参数分片。在对所有 FSDP 模块全部执行该操作后就是计算损失,然后是后向传播。在后向传播期间,再次执行 all-gather 操作以获取给定 FSDP 模块所需的所有参数,执行计算以获得局部梯度,然后再次释放其他工作进程的分片。最后,使用 reduce-scatter 操作对局部梯度进行平均并将相应分片给对应的工作进程,该操作使得每个工作进程都可以更新其本地分片的参数。如果启用了 CPU 卸载的话,梯度会传给 CPU,以便直接在 CPU 上更新参数。

如欲深入了解 PyTorch FSDP 工作原理以及相关实验及其结果,请参阅 [7,8,9]。

问题

如果在 accelerate 中使用 PyTorch FSDP 时遇到任何问题,请提交至 accelerate。

但如果你的问题是跟 PyTorch FSDP 配置和部署有关的 - 你需要提交相应的问题至 PyTorch。

参考文献

[1] Train Large, Then Compress: Rethinking Model Size for Efficient Training and Inference of Transformers

[2] ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

[3] DeepSpeed: Extreme-scale model training for everyone - Microsoft Research

[4] Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

[5] Introducing GPipe, an Open Source Library for Efficiently Training Large-scale Neural Network Models

[6] Which hardware do you need to train a 176B parameters model?

[7] Introducing PyTorch Fully Sharded Data Parallel (FSDP) API | PyTorch

[8] Getting Started with Fully Sharded Data Parallel(FSDP) — PyTorch Tutorials 1.11.0+cu102 documentation

[9] Training a 1 Trillion Parameter Model With PyTorch Fully Sharded Data Parallel on AWS | by PyTorch | PyTorch | Mar, 2022 | Medium

[10] Fit More and Train Faster With ZeRO via DeepSpeed and FairScale

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/pytorch-fsdp

原文作者: Sourab Mangrulkar,Sylvain Gugger

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1288334.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

harmony开发之image组件的使用

HarmonyOS是一款面向万物互联时代的、全新的分布式操作系统。 在传统的单设备系统能力基础上,HarmonyOS提出了基于同一套系统能力、适配多种终端形态的分布式理念,能够支持手机、平板、智能穿戴、智慧屏、车机、PC、智能音箱、耳机、AR/VR眼镜等多种终端…

在KeyarchOS上体验WildFly

一、浪潮信息KeyarchOS简单介绍 KeyarchOS具备稳定可靠、高效软硬协同、全天候运维、安全可信等特性,增强了对云计算、人工智能等场景的支持,性能稳定性领先,生态完善,符合金融、政务、能源、交通、通信、教育、医疗等关键行业的应用要求。具备非常广泛的应用。 官方地址&…

BACnet I/O模块在水利环境监测全自动控制系统中的应用:稳定、高效、实时

前言 “绿水青山就是金山银山”,水利环境一直是国际生态部门关注的重点。随着经济的发展、针对水利环境的监管也日趋严格,尤其是重点河、湖水系、水源地、城市内河等成为重点监管对象,监管力度也愈来愈严格,监测布点密度不断加大。…

linux交换分区管理SWAP

6.2.5 交换分区管理SWAP 6.2.5.1 概念 作用: ”提升“内存容量,防止OOM(out of memory,内存溢出)。 ​ 对应windows中的虚拟内存。 ​ 从功能上讲,交换分区主要是在内存不够用的时候,将部分内…

【Azure 架构师学习笔记】- Azure Databricks (2) -集群

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Databricks】系列。 接上文 【Azure 架构师学习笔记】- Azure Databricks (1) - 环境搭建 前言 在上文中提到了ADB 的其中一个核心就是集群,所以这里专门研究一下ADB 的集群。 ADB 集群 首先了解一下ADB…

git 使用记录

远程仓库为空初始化 初始化本地仓库 git init 在本地仓库书写代码(这里可以编辑一个文本文件做测试,如hello.txt) 5)执行:git add 要让git管理的文件(git add hello.txt)>执行完此操作将我…

【Linux】Linux基础

文章目录 学习目标操作系统不同应用领域的主流操作系统虚拟机 Linux系统的发展史Linux内核版和发行版 Linux系统下的文件和目录结构单用户操作系统vs多用户操作系统Windows和Linux文件系统区别 Linux终端命令格式终端命令格式查阅命令帮助信息 常用命令显示文件和目录切换工作目…

数据湖和中央数据仓库的设计

设计数据湖或中央数据仓库是许多大型组织的主要职能,这些组织每天处理数百万笔交易,并对这些交易进行进一步的报告、预测或机器学习项目分析。 为了将所有来自源系统(我们称之为“上游”)到其他业务应用(所谓“下游”&…

解决(error) ERR Errors trying to SHUTDOWN. Check logs.问题~

该问题出现在我在使用shutdown关闭redis服务器时,出现该问题的原因是由于配置文件的日志文件位置未配置或者缺少日志文件 我自己出现该问题是因为缺少日志文件,解决步骤如下所示: 第一步:在该目录下使用touch命令创建日志文件 第…

测绘资质测绘设备检定、校准管理制度

测绘设备检定、校准管理制度 建立健全测绘仪器设备检定、校准管理制度,明确测绘仪器设备的检定、校准、日常管理等要求

Kafka安装与配置-shell脚本一键安装配置(单机版)

文章目录 前言使用shell脚本一键安装1. 复制脚本2. 增加执行权限3. 执行脚本4. 加载用户环境变量5. 启动/停止Kafka内置zookeeper6. 启动/停止Kafka单机版 总结 前言 本文介绍了使用Shell脚本一键安装Kafka的方法。通过复制脚本并执行,可以自动下载、安装和配置Kaf…

【链表Linked List】力扣-24 两两交换链表中的节点

目录 题目描述 解题过程 题目描述 给你一个链表,两两交换其中相邻的节点,并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题(即,只能进行节点交换)。 示例 1: 输入:he…

webAPL用户注册倒计时

上一篇文章我已经留下这个题 以下是这道题代码&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><t…

机器学习---线性回归算法

1、什么是回归&#xff1f; 从大量的函数结果和自变量反推回函数表达式的过程就是回归。线性回归是利用数理统计中回归分析来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法。 2、一元线性回归 3、多元线性回归 如果回归分析中包括两个或两个以上的自变量&a…

文章解读与仿真程序复现思路——中国电机工程学报EI\CSCD\北大核心《基于自适应扩散核密度估计的时序相关概率最优潮流计算方法》

这个标题涉及到电力系统中潮流计算的一种方法&#xff0c;具体解读如下&#xff1a; 概述&#xff1a; 时序相关概率最优潮流计算方法&#xff1a; 这是一种电力系统中潮流计算的方法&#xff0c;其中考虑了时序&#xff08;时间序列&#xff09;和概率的相关性。概率最优潮流通…

MySQL 数据库如何实现 XA 规范?

本文我们来讨论 MySQL 的 XA 规范有哪些应用相关的内容。 MySQL 为我们提供了分布式事务解决方案&#xff0c;在前面的内容中提到过 binlog 的同步&#xff0c;其实是 MySQL XA 规范的一个应用&#xff0c;那么 XA 规范是如何定义的&#xff0c;具体又是如何应用的呢&#xff…

飞行汽车开发原理(上)

前言 小节的安排是由浅入深&#xff0c;要按顺序读&#xff1b;有电路知识基础的同学可跳到“计算机电路”一节开始。因为知识点之间有网状依赖&#xff0c;没办法按分类来讲。 为了避免过于深入、越讲越懵&#xff0c;很多描述仅为方便理解、不求严谨。 半导体特性 导体&a…

java学习part37定制排序和自然排序

150-常用类与基础API-使用Comparator接口实现定制排序及对比_哔哩哔哩_bilibili 1.自然排序 2.定制排序 对于一些排序方法&#xff0c;允许传入的话按定制的排序规则来&#xff0c;不传入默认按自然排序来。 匿名方式 3区别

5-Tornado入门、程序的原理图、tornado不能使用同步代码的演示

安装 pip install tornado第一个程序 from tornado import web from tornado import ioloop class IndexHandler(web.RequestHandler):def get(self):self.write(Hello Tornado!!123)if __name__ __main__:# 1.创建了app对象&#xff0c;设置路由,并开启debug模式app web.A…

17、迭代器模式(Iterator Pattern)

迭代器模式提供了顺序访问集合对象中的各种元素&#xff0c;而不暴露该对象内部结构的方法。如Java中遍历HashMap。 迭代器模式将遍历集合中所有元素的操作封装成迭代器类&#xff0c;其目的是在不暴露集合对象内部结构的情况下&#xff0c;对外提供统一访问集合的内部数据的方…