LLM 分布式训练框架 | DeepSpeed与Accelerate

news2025/1/16 1:02:51

🚀 简单记录下根据网上资料(如Reference中所列)所学到的一些知识,这里主要介绍的是deepspeed分布式训练框架相关概念。

😄小日记:今天太舒服了,早上跑了6km,晚上吃了养生菌菇火锅~

在这里插入图片描述

文章目录

  • 1、Accelerate和deepspeed的联系
  • 2、基本概念
  • 3、通信策略
  • 4、Zero(ZeRO-Stage3、ZeRO-Offload)
    • 4.1、ZeRO中不同stage的区别
    • 4.2、ZeRO-Offload
  • 5、deepspeed中的混合精度
  • 6、gradient checkpoint
  • 7、DeepSpeed的推理优化
  • 8、deepspeed搭配transformers库使用
  • Reference

1、Accelerate和deepspeed的联系

  • Accelerate是PyTorch官方提供的分布式训练工具,而deepspeed是由Microsoft提供的分布式训练工具。
  • 最主要的区别在于支持的模型规模不同,deepspeed支持更大规模的模型。 deepspeed还提供了更多的优化策略和工具,例如ZeRO和Offload等。
  • Accelerate更加稳定和易于使用,适合中小规模的训练任务。
  • ⭐注意:在这里插入代码片Accelerate只支持nvlink,而T4,3090这类显卡是PIX ,检测方式:nvidia-smi topo -m。【我用3090ti跑这个项目:LLaMA-Factory,就跑不了会报错,我怀疑就是这个问题。用deepspeed是ok的】

😄 我更推荐用deepspeed,deepspeed方便了我们在机器有限的情况下来训练、微调大模型,同时它也有很多优秀的性能优化。
目前主流的训练LLM的方式: PyTorch + GPU + DeepSpeed + LLM训练框架
优势:

  • 存储效率:DeepSpeed提供了一种Zero的新型解决方案来减少训练显存的占用,它与传统的数据并行不同,它将模型状态和梯度进行分区来节省大量的显存;
  • 可扩展性:DeepSpeed支持高效的数据并行、模型并行、pipeline并行以及它们的组合,这里也称3D并行;
  • 易用性: 在训练阶段,只需要修改几行代码就可以使pytorch模型使用DeepSpeed。

2、基本概念

在分布式计算环境中,有几个非常基础的概念需要理解:

  • 节点编号(node_rank:):分配给系统中每个节点的唯一标识符,用于区分不同计算机之间的通信。
  • 全局进程编号(rank):分配给整个系统中的每个进程的唯一标识符,用于区分不同进程之间的通信。
  • 局部进程编号(local_rank):分配给单个节点内的每个进程的唯一标识符,用于区分同一节点内的不同进程之间的通信。
  • 全局总进程数(word_size):在整个系统中运行的所有进程的总数,用于确定可以并行完成多少工作以及需要完成任务所需的资源数量。
  • 主节点(master_ip+master_port):在分布式计算环境中,主节点负责协调所有其他节点和进程的工作,为了确定主节点,我们需要知道它的IP地址和端口号。主节点还负责监控系统状态、处理任务分配和结果汇总等任务,因此是整个系统的关键部分。

3、通信策略

deepspeed 还提供了 mpi、gloo 和 nccl 等通信策略,可以根据具体情况进行选择和配置。

  • mpi 是一种跨节点通信库,常用于 CPU 集群上的分布式训练;
  • gloo 是一种高性能的分布式训练框架,支持 CPU 和 GPU 上的分布式训练;
  • nccl 是 NVIDIA 提供的 GPU 专用通信库,被广泛应用于 GPU 上的分布式训练。

在使用 DeepSpeed 进行分布式训练时,可以根据具体情况选择合适的通信库。通常情况下,如果是在 CPU 集群上进行分布式训练,可以选择 mpi 和 gloo;如果是在 GPU 上进行分布式训练,可以选择 nccl。

4、Zero(ZeRO-Stage3、ZeRO-Offload)

在DeepSpeed下,ZeRO训练支持了完整的ZeRO Stages1, 2和3,以及支持将优化器状态、梯度和模型参数从GPU显存下沉到CPU内存或者硬盘上,实现不同程度的显存节省,以便训练更大的模型。

ZeRO(Zero Redundancy Optimizer)是一种用于大规模训练优化的技术,主要是用来减少内存占用。在大规模训练中,内存占用可以分为 Model StatesActivation两部分 (Activation估计是前向传播时保存下来的各神经元的激活值吧,因为反向传播时计算梯度需要用到) ,而 ZeRO 主要是为了解决 Model States 的内存占用问题。

ZeRO 将模型参数分成了三个部分:Optimizer StatesGradientModel Parameter

  • Optimizer States 是 Optimizer 在进行梯度更新时所需要用到的数据,例如 SGD 中的 Momentum。
  • Gradient 是在反向传播后所产生的梯度信息,其决定了参数的更新方向。
  • Model Parameter 则是模型参数,也就是我们在整个过程中通过数据“学习”的信息。

4.1、ZeRO中不同stage的区别

  • ZeRO-0:禁用所有类型的分片,仅使用 DeepSpeed 作为 DDP (Distributed Data Parallel)

  • ZeRO-1:把优化器状态(optimizer states)分片到每个数据并行的工作进程(每个GPU)下;

  • ZeRO-2:把 优化器状态(optimizer states) + 梯度(gradients) 分片到每个数据并行的工作进程(每个GPU)下;

  • ZeRO-3:把优化器状态(optimizer states) + 梯度(gradients) + 模型参数(parameters) 分片到每个数据并行的工作进程(每个GPU)下。内存减少与数据并行度呈线性关系。例如,在64个GPU(Nd=64)之间进行拆分将产生64倍的内存缩减。通信量有50%的适度增长。

  • ZeRO-Infinity是ZeRO-3的拓展。允许通过使用 NVMe 固态硬盘扩展 GPU 和 CPU 内存来训练大型模型。ZeRO-Infinity 需要启用 ZeRO-3。

⭐ ZeRO-Stage3在deepspeed中通过zero_optimization.stage=0/1/2/3 设置,ZeRO-Offload通过zero_optimization.offload_optimizer.device设置。
⭐ 备注:优化器状态 一般包含FP32 Gradient、FP32 Variance、FP32 Momentum、FP32 Parameters。梯度和模型参数 一般会用FP16就够了,所以占用大头一般是优化器相关的。
所以根据实际硬件资源,选择适合Stage策略即可。如果遇到要跑更大的模型,比如想在3090 24GB下跑13B模型,可能Stage3也OOM跑不起来,此时可以开启Optimizer Offload和Param Offload即可跑起来,但相应的性能会受影响。

4.2、ZeRO-Offload

ZeRO-Offload: offload指将数据、梯度、优化器状态等下沉到CPU内存或硬盘上。

  • Optimizer Offload: 在Stage2的基础上,把梯度和优化器状态下沉到CPU内存或硬盘。
  • Param Offload: 在Stage3的基础上,把模型参数下沉到CPU内存或硬盘上。

5、deepspeed中的混合精度

混合精度训练是指在训练过程中同时使用FP16(半精度浮点数)和FP32(单精度浮点数)两种精度的技术。使用FP16可以大大减少内存占用,从而可以训练更大规模的模型。但是,由于FP16的精度较低,训练过程中可能会出现梯度消失和模型不稳定的问题。因此,需要使用一些技术来解决这些问题,例如动态精度缩放(Dynamic Loss Scaling)和混合精度优化器(Mixed Precision Optimizer)等。

  • deepspeed提供了混合精度训练的支持,可以通过在配置文件中设置"fp16.enabled": true来启用混合精度训练。在训练过程中,deepspeed会自动将一部分操作转换为FP16格式,并根据需要动态调整精度缩放因子,从而保证训练的稳定性和精度。

  • ⭐ 在使用混合精度训练时,需要注意一些问题,例如梯度裁剪(Gradient Clipping)和学习率调整(Learning Rate Schedule)等。梯度裁剪可以防止梯度爆炸,学习率调整可以帮助模型更好地收敛。因此,在设置混合精度训练时,需要根据具体情况进行选择和配置。

BF16和FP16都是混合精度训练中使用的浮点数表示格式。

  • BF16是一种Brain Floating Point格式,由英特尔提出,可以提供更好的数值稳定性和更高的精度,但需要更多的存储空间。在混合精度训练中,BF16可以作为一种精度更高的替代品,用于一些关键的计算操作,例如梯度累加和权重更新等。使用BF16和FP16一样可以提高模型的训练速度和精度,并减少内存占用。

  • 在 DeepSpeed 中,可以通过在配置文件中设置 "bf16.enabled": true 来启用 BF16 混合精度训练。这将会将一部分操作转换为 BF16 格式,并根据需要动态调整精度缩放因子,从而提高模型的训练速度和精度,并减少内存占用。

6、gradient checkpoint

😄 大型模型在静态和动态方面都很耗资源。首先,它们很难适配 GPU,而且哪怕你把它们放到了设备上,也很难训练,因为批处理大小被迫限制的太小而无法收敛。所以用gradient checkpoint相当于时间换空间。

gradient checkpoint的意思是在反向传播时重新计算深度神经网络的中间值(而通常情况是在前向传播时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)。

  • 具体工作原理是从计算图中省略一些激活值(由前向传播产生,其中这里的”一些“是指可以只省略模型中的部分激活值,折中时间和空间,即前向传播的时候存一个节点释放一个节点,空的那个等需要用的时候再backword的时候重新计算)。这减少了计算图使用的内存,降低了总体内存压力(并允许在处理过程中使用更大的批次大小)。
  • pytorch中对应的函数与实现原理:PyTorch 通过 torch.utils.checkpoint.checkpoint 和 torch.utils.checkpoint.checkpoint_sequential 提供梯度检查点,根据官方文档的 notes,它实现了以下功能,在前向传播时,PyTorch 将保存模型中的每个函数的输入元组。在反向传播过程中,对于每个函数,输入元组和函数的组合以实时的方式重新计算,插入到每个需要它的函数的梯度公式中,然后丢弃(显存中只保存输入数据和函数)。网络计算开销大致相当于每个样本通过模型前向传播开销的两倍。

❓注:神经网络使用的总内存基本上是两个部分的总和,包括静态内存动态内存

  • 静态内存:尽管 PyTorch 模型中内置了一些固定开销,但总的来说几乎完全由模型权重决定。而如今,在生产中使用的现代深度学习模型的总参数在100万到10亿之间。作为参考,一个带 16GB GPU 内存的 NVIDIA T4 的实际限制大约在1-1.5亿个参数之间。

  • 动态内存:在训练模式下,每次通过神经网络的前向传播都为网络中的每个神经元计算一个激活值,这个值随后被存储在所谓的计算图中。必须为批次中的每个单个训练样本存储一个值,因此数量会迅速的累积起来。总成本取决于模型大小和批处理大小,并设置适用于您的GPU内存的最大批处理大小的限制。一开始存储激活的原因是,在反向传播期间计算梯度时需要用到激活。

7、DeepSpeed的推理优化

除了训练优化,deepspeed还可以推理优化。
如下图,红色虚线框是以该单位为优化Kernel,对应的数字是优化的效率倍数。
在这里插入图片描述

8、deepspeed搭配transformers库使用

下面只是简单举例讲讲大致用法,更多使用讲解请参考参考文献[4]

pip install deepspeed

例如我们可以在transformers的trainer中加入args时,在args里将定义好deepsped的config文件路径传入:
ds_config.json的配置根据实际情况定义,如以下一个ZeRO-2的例子:

{
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "zero_allow_untested_optimizer": true,
    "fp16": {
      "enabled": "auto",
      "loss_scale": 0,
      "initial_scale_power": 16,
      "loss_scale_window": 1000,
      "hysteresis": 2,
      "min_loss_scale": 1
    },  
    "zero_optimization": {
      "stage": 2,
      "allgather_partitions": true,
      "allgather_bucket_size": 5e8,
      "reduce_scatter": true,
      "reduce_bucket_size": 5e8,
      "overlap_comm": false,
      "contiguous_gradients": true
    }
  }

train_bash.py如下:

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    TrainingArguments,
    Trainer
)
from peft import (
    LoraConfig,
    TaskType,
    get_peft_model,
    get_peft_model_state_dict,
    set_peft_model_state_dict
)

# 设置训练参数
deepspeed_config = "./ds_config.json" # deepspeed配置文件
training_args = TrainingArguments(
	...
    deepspeed=deepspeed_config, # deepspeed配置文件的位置
)



model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map=device_map)
# LoRA训练配置,转换模型
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=LORA_R, # LoRA中低秩近似的秩
    lora_alpha=LORA_ALPHA, # 见上文中的低秩矩阵缩放超参数
    lora_dropout=LORA_DROPOUT, # LoRA层的dropout
)
# 
model = get_peft_model(model, lora_config)
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (
    lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
).__get__(model, type(model))
# 打印模型中的可训练参数
model.print_trainable_parameters()

# 模型训练
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)
trainer.train()

单机多卡跑:
【多机多卡跑应该还有加一些指令,具体参考参考文献[2]】

deepspeed --num_gpus 2 --master_port=9901 ./train_bash.py \
    --deepspeed ds_config.json \
    ... # the arguments of train_bash.py 

Reference

[1] deepspeed github: https://github.com/microsoft/DeepSpeed
[2] deepspeed官网:https://www.deepspeed.ai/
[3] deepspeed官方文档:https://deepspeed.readthedocs.io/en/latest/index.html
[4] transformers与deepspeed的集成使用教程:https://huggingface.co/docs/transformers/main_classes/deepspeed
[5] 大模型训练之框架篇
[6] 用ZeRO训练大模型原理解析及参数含义解释

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1272168.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分布编译和注释

文章目录 分布编译预处理编译汇编链接 注释单行注释多行注释预处理注释 总结 分布编译 上一节使用 gcc main.c就生成了a.exe的可执行文件,提到了将main.c文件生成a.exe实际上执行了以下四步: 预处理编译汇编链接   每一步都有单独的指令,而…

C++学习之继承中修改成员权限细节

看看下面的代码 这是错误的 class A { public:int x 10; }; class B :public A {using A::x;int x 100; };看看函数 class A { public:void fun(){cout << "uuuu" << endl;} }; class B :public A { public:using A::fun;void fun(){cout << …

每天学习一点点之 MySQL TINYINT

我已经不是第一次遇到关于 TINYINT 的问题了。在 MySQL 中&#xff0c;当我们将某个字段设置为 TINYINT&#xff0c;随着业务的扩展&#xff0c;我们可能会发现 TINYINT 的范围无法满足需求。这时需要修改字段属性。但如果表的数据量很大&#xff0c;或者由于分表导致涉及的表数…

AI虚拟数字人——营销宣传领域的新亮点

AI生活节即将到来&#xff0c;邀请消费者共同探索生活小妙趣&#xff0c;为美好生活注入新的想象。AI一词我们过去可能听的比较多&#xff0c;听到最多的可能就是AI虚拟数字人了。这年头&#xff0c;打造一个AI主播、虚拟数字人已经屡见不鲜了&#xff0c;因为AI数字人拥有强大…

数字孪生3D场景开发工具:弥补不足,开拓全新可能

随着数字化时代的来临&#xff0c;越来越多的企业和行业开始探索数字孪生技术的应用。数字孪生是指通过数字技术将现实世界中的物体、场景等复制到虚拟世界中&#xff0c;以实现实时监测、预测和优化。然而&#xff0c;在数字孪生的发展过程中&#xff0c;一些不足也逐渐浮现。…

AndroidStudio - 新版本 Logcat 使用详解

最近这俩天正好有时间给自己做一下减法&#xff0c;忘记是去年还是今年&#xff0c;在升级 AndroidStudio 后使用 Logcat查看日志的方式也发生了一些变化&#xff0c;虽然一直在使用&#xff0c;但每当看到之前还未关闭 Logcat 命令行工具额昂也&#xff0c;就感觉可能还存在知…

基于springboot的社区团购系统设计

摘 要 本课题是根据用户的需要以及网络的优势建立的一个社区团购系统&#xff0c;来满足用户团购的需求。 本社区团购系统应用Java技术&#xff0c;MYSQL数据库存储数据&#xff0c;基于Spring Boot框架开发。在网站的整个开发过程中&#xff0c;首先对系统进行了需求分析&…

手持机|三防智能手机_4寸/5寸/6寸安卓系统三防手机PDA手持终端方案

随着科技的不断发展&#xff0c;三防手持机作为一种多功能设备&#xff0c;正逐渐在各行业得到广泛应用。这款手持机采用高性能处理器&#xff0c;支持高精度北斗定位和工业本安防爆功能&#xff0c;并具备IP67级防水防尘性能和1.5米防跌落能力。因此&#xff0c;它在仓储管理、…

C语言进阶之笔试题详解(2)

前言 这里的内容包括二维数组笔试题和指针笔试题&#xff0c;供给读者对这部分知识进行加深和巩固。 ✨ 猪巴戒&#xff1a;个人主页✨ 所属专栏&#xff1a;《C语言进阶》 &#x1f388;跟着猪巴戒&#xff0c;一起学习C语言&#x1f388; 目录 前言 笔试题 二维数组 题目…

nvm 下载node时候下载不到npm包的解决方法

个人博客链接 公众号-nvm 下载node时候下载不到npm包的解决方法 求关注 可以跳过的背景 最近项目比较有空&#xff0c;所以就可以有时间写一些demo&#xff0c;主要测试下react的一些语法&#xff0c;毕竟自己上次写react已经是22年的7月份了,期间对于react-router等的hook…

差分阻抗90Ω±10%或者其他分别走什么阻抗

差分阻抗90Ω10%或者其他分别走什么阻抗 普通走线阻抗HDMI接口布线要求USB接口布线要求网口接口布线要求LCD 接口布线要求DDR3关键信号处理要点 普通走线阻抗 必须选择 PCB 走线阻抗来匹配使用中的所有逻辑系别的特性阻抗(对于 CMOS 和 TTL&#xff0c;特性阻抗的范围是 80~11…

Java 多线程循环打印

文章目录 一、标志变量 互斥锁二、标志变量 synchronized三、标志变量 互斥锁 条件变量四、原子变量五、信号量 一、标志变量 互斥锁 标志变量用于标识当前应该是哪个线程进行输出&#xff0c;互斥锁用于保证对标志变量的互斥访问。 public class Main {private static …

分享常见msvcp140.dll丢失的解决方法,msvcp140.dll修复的问题

在使用电脑的过程中可能会出现关于msvcp140.dll丢失的问题&#xff0c;通常出现这样的问题都会导致电脑中的程序出现不能正常运行的情况。并且如果不及时将msvcp140.dll修复的话可能还会导致电脑出现其他的问题。这篇文章就将给大家介绍关于msvcp140.dll丢失的解决方法。 一.常…

美国第三季度经济GDP数据亮眼,其增长率上修至近2年最快

KlipC报道&#xff1a;美国商务部公布美国第三季度GDP按年率增长5.2%&#xff0c;较首次预估数据上调了0.3%。也是近2年来最快增速。 KlipC的分析师表示&#xff1a;“相较于第二季度相比&#xff0c;第三季度的时机GDP主要反映了消费者指出和私人库存投资的加速和出口的上升。…

Prosys OPC Client连接OPC DA

Prosys OPC Client连接OPC DA Prosys OPC 客户端将帮助排除 OPC 连接故障并测试 OPC 服务器。 您可以读写数据、浏览服务器以及导出和导入地址空间。 OPC 客户端轻巧、快速且易于使用。 支持 OPC DA 1.0a 和 OPC DA 2.05a 官方地址: https://www.prosysopc.com/products/opc-…

XUbuntu22.04之安装OBS30.0强大录屏工具(一百九十五)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

颜色十六进制代码对照表

白色 #FFFFFF 2 红色 #FF0000 3 绿色 #00FF00 蓝色 #0000FF 5 牡丹红 #FF00FF 6 青色 #00FFFF 黄色 #FFFF00 8 黑色 #000000 9 海蓝 #70DB93 巧克力色 #5C3317 11 蓝紫色 #9F5F9F 12 黄铜色 #B5A642 亮金色 #D9D919 14 棕色 #A67D3D 15 青铜色 #8C7853 2号青铜色 #A67D3D 17 士…

基于SSM搭建系统

原理 SSM集成 SpringSpringMvcMybatis集成 框架集成核心&#xff0c;如果你的项目中&#xff0c;用到了Spring框架&#xff0c;那么其他框架主要就是和Spring集成&#xff1b; 和Spring集成的核心思路&#xff1a; 把当前框架的核心类&#xff0c;交给Spring管理&#xff08…

C++: string的模拟实现

C: string的模拟实现 一.前置说明1.模拟实现string容器的目的2.我们要实现的大致框架 二.默认成员函数1.构造函数2.拷贝构造函数1.传统写法2.现代写法 3.析构函数4.赋值运算符重载1.传统写法2.现代写法 三.遍历和访问1.operator[]运算符重载2.iterator迭代器 四.容量相关函数1.…

ant design vue3 处理 ant-card-head ant-tabs靠左边对齐之has选择器不生效

火狐浏览器是不支持has的。 解决方法&#xff1a;通过position来解决。