Adam 和 AdamW 优化器详解及其训练显存需求分析:以LLaMA-2 7B为例(中英双语)

news2025/10/24 23:20:04

中文版

Adam 和 AdamW 优化器详解及其显存需求分析

在训练大规模神经网络时,优化器的选择和其在显存中的消耗是至关重要的,特别是像 LLaMA-2 7B 这样的大模型。今天我们将详细分析 Adam 优化器AdamW 优化器,并结合 float32bfloat16 精度的情况,探讨它们在显存消耗方面的表现。

1. Adam 优化器简介

Adam(Adaptive Moment Estimation)是一种常用的优化算法,它结合了动量(Momentum)和RMSProp优化器的优点。Adam通过维护每个参数的一阶矩估计(动量)和二阶矩估计(梯度平方的指数加权平均)来对参数进行更新。Adam优化器通过以下公式更新模型的权重:

1.1 Adam 优化器数学公式

假设我们有一个损失函数 ( L ( θ ) L(\theta) L(θ) ) 和参数向量 ( θ \theta θ ),则 Adam 优化器的更新规则如下:

  1. 计算梯度
    g t = ∇ θ L ( θ t ) g_t = \nabla_\theta L(\theta_t) gt=θL(θt)
    其中 ( g t g_t gt ) 是当前时间步 ( t t t ) 的梯度。

  2. 一阶矩估计 (动量)
    m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t mt=β1mt1+(1β1)gt
    其中 ( β 1 \beta_1 β1 ) 是一阶矩的衰减率,通常取值接近 1(如 0.9)。

  3. 二阶矩估计
    v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 vt=β2vt1+(1β2)gt2
    其中 ( β 2 \beta_2 β2 ) 是二阶矩的衰减率,通常取值接近 1(如 0.999)。

  4. 偏差校正(为了修正初始时刻 ( m_t ) 和 ( v_t ) 的偏差):
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1β1tmt,v^t=1β2tvt

  5. 更新参数
    θ t + 1 = θ t − α v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θtv^t +ϵαm^t
    其中 ( α \alpha α ) 是学习率,( ϵ \epsilon ϵ ) 是防止除零的常数。

1.2 Adam 优化器的内存消耗

Adam 优化器的内存消耗比传统的 SGD 优化器更高,因为它需要为每个参数维护 一阶矩二阶矩,即两个额外的变量。假设 LLaMA-2 7B 模型有 70 亿个参数,每个参数需要存储 两个额外的矩,因此优化器的内存需求是模型参数内存的两倍。

  • 参数内存:假设使用 float32 精度,每个参数占用 4 字节,那么 7B 模型的参数内存为:
    参数内存 = 7 × 1 0 9 × 4   字节 = 28   GB \text{参数内存} = 7 \times 10^9 \times 4 \, \text{字节} = 28 \, \text{GB} 参数内存=7×109×4字节=28GB
  • 优化器内存:由于需要为每个参数维护两个矩,因此优化器的内存需求是模型参数的两倍:
    优化器内存 = 2 × 28   GB = 56   GB \text{优化器内存} = 2 \times 28 \, \text{GB} = 56 \, \text{GB} 优化器内存=2×28GB=56GB

总的来说,Adam 优化器在 LLaMA-2 7B 模型的训练中,总显存消耗大约为 模型参数显存 + 梯度显存 + 优化器显存,即 28 GB + 28 GB + 56 GB = 112 GB,对于 24GB 显卡显然是无法容纳的。

下面的python代码测试7B大模型本身的参数量:以float32计算。进位采用1024,计算得出:7B大模型的参数量为26.08 GB;当进位采用1000时,计算得出28.00 GB。为什么尝试1000,是因为在其他博文中看到28GB这个数字,自己测试一下,发现他们是在以1000为进位的时候测试得出的。参考文章:https://cuiyuhao.com/posts/c87c0f5d/

# 定义参数
num_parameters = 7 * 10**9  # 70 亿个参数
bytes_per_param = 4  # 每个参数占用 4 字节(32 位浮动数)

# 计算显存需求(单位:字节)
memory_in_bytes = num_parameters * bytes_per_param

# 将字节转换为 GB
memory_in_GB = memory_in_bytes / (1024 ** 3)  # 转换为 GB, 可调为1000

print(f"模型需要的显存为: {memory_in_GB:.2f} GB")

以bf16为例,由于它是float32的一半,所以它的参数量为 26.08GB / 2 = 13.04GB (以1024为进位),当以1000进位的时候,28GB / 2 = 14GB

2. AdamW 优化器简介

AdamW 是 Adam 优化器的一种变体,它对权重衰减(weight decay)做了改进。Adam 在更新参数时直接将权重衰减项添加到梯度中,而 AdamW 通过将衰减项从一阶矩和二阶矩的更新中分离出来,使得优化过程更加稳定。AdamW 的更新公式与 Adam 类似,但衰减项被单独处理:

2.1 AdamW 优化器数学公式
  1. 权重衰减(weight decay)项
    θ t + 1 = θ t − α v ^ t + ϵ m ^ t − λ θ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t - \lambda \theta_t θt+1=θtv^t +ϵαm^tλθt
    其中 ( λ \lambda λ ) 是权重衰减系数。
2.2 AdamW 优化器的内存消耗

与 Adam 优化器一样,AdamW 也需要为每个参数维护 一阶矩二阶矩。因此,它的内存消耗与 Adam 优化器相同,差异主要体现在梯度更新时的计算过程,而不是内存需求。

所以,AdamW 优化器对显存的占用与 Adam 优化器是相同的,依然是 模型参数的两倍

  • 参数内存:28 GB
  • 优化器内存:56 GB

因此,AdamW 优化器在 LLaMA-2 7B 模型上的显存消耗与 Adam 优化器一致。

3. 在 float32 和 bfloat16 下的显存需求

3.1 使用 float32 精度
  • 模型参数内存:28 GB
  • 梯度内存:28 GB
  • 优化器内存:56 GB
  • 总显存需求:28 GB + 28 GB + 56 GB = 112 GB
3.2 使用 bfloat16 精度

使用 bfloat16 精度时,每个参数、梯度和优化器状态的内存需求将减半。假设 LLaMA-2 7B 使用 bfloat16 精度,则:

  • 模型参数内存:14 GB
  • 梯度内存:14 GB
  • 优化器内存:28 GB
  • 总显存需求:14 GB + 14 GB + 28 GB = 56 GB

因此,使用 bfloat16 精度时,显存需求比使用 float32 精度时减少了约一半。

4. 总结

  • Adam 优化器AdamW 优化器 都需要为每个参数维护一阶矩和二阶矩,因此它们的内存消耗是 模型参数内存的两倍
  • float32 精度:在使用 float32 精度时,LLaMA-2 7B 模型的总显存需求大约为 112 GB
  • bfloat16 精度:在使用 bfloat16 精度时,LLaMA-2 7B 模型的总显存需求为 56 GB

通过选择合适的优化器和精度,尤其是在资源有限的情况下,可以大大减少显存消耗,确保大模型的训练可以在较小的 GPU 上完成。

英文版

Detailed Analysis of Adam and AdamW Optimizers and Their Memory Consumption with float32 and bfloat16 Precision

When training large-scale neural networks, especially models like LLaMA-2 7B, the choice of optimizer and the associated memory consumption are crucial factors. In this post, we’ll delve into the Adam optimizer and AdamW optimizer, explain their memory consumption in both float32 and bfloat16 precision, and provide a detailed example using the LLaMA-2 7B model.

1. Adam Optimizer Overview

Adam (Adaptive Moment Estimation) is a widely used optimizer that combines the benefits of both momentum and RMSProp optimizers. It maintains first-order momentum (moving averages of gradients) and second-order momentum (moving averages of squared gradients) to adaptively adjust the learning rate for each parameter. The update rule for Adam is as follows:

1.1 Adam Optimizer Mathematical Formulas

Assume we have a loss function ( L ( θ ) L(\theta) L(θ) ) and parameter vector ( θ \theta θ ). The Adam optimizer’s update rule can be written as:

  1. Compute the gradient:
    g t = ∇ θ L ( θ t ) g_t = \nabla_\theta L(\theta_t) gt=θL(θt)
    where ( g t g_t gt ) is the gradient at time step ( t t t ).

  2. First moment estimate (momentum):
    m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t mt=β1mt1+(1β1)gt
    where ( β 1 \beta_1 β1 ) is the decay rate for the first moment, typically close to 1 (e.g., 0.9).

  3. Second moment estimate:
    v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 vt=β2vt1+(1β2)gt2
    where ( β 2 \beta_2 β2 ) is the decay rate for the second moment, typically close to 1 (e.g., 0.999).

  4. Bias correction (to adjust for the initial bias):
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1β1tmt,v^t=1β2tvt

  5. Parameter update:
    θ t + 1 = θ t − α v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θtv^t +ϵαm^t
    where ( α \alpha α ) is the learning rate and ( ϵ \epsilon ϵ ) is a small constant to prevent division by zero.

1.2 Memory Consumption of Adam Optimizer

The Adam optimizer requires extra memory compared to traditional stochastic gradient descent (SGD) because it stores both first moment (m) and second moment ( v v v) for each parameter. So, for each parameter, Adam needs to store two additional variables.

For example, with the LLaMA-2 7B model having 7 billion parameters, the memory required for the optimizer would be twice the size of the model parameters.

  • Model Parameters Memory: Assuming float32 precision (4 bytes per parameter):
    Model Parameters Memory = 7 × 1 0 9 × 4   bytes = 28   GB \text{Model Parameters Memory} = 7 \times 10^9 \times 4 \, \text{bytes} = 28 \, \text{GB} Model Parameters Memory=7×109×4bytes=28GB
  • Optimizer Memory: Since Adam maintains two variables for each parameter, the memory needed for the optimizer is:
    Optimizer Memory = 2 × 28   GB = 56   GB \text{Optimizer Memory} = 2 \times 28 \, \text{GB} = 56 \, \text{GB} Optimizer Memory=2×28GB=56GB

In total, the memory required for training with Adam would be the sum of the model parameters, gradients, and optimizer states:

Total Memory = 28   GB (model) + 28   GB (gradients) + 56   GB (optimizer) = 112   GB \text{Total Memory} = 28 \, \text{GB (model)} + 28 \, \text{GB (gradients)} + 56 \, \text{GB (optimizer)} = 112 \, \text{GB} Total Memory=28GB (model)+28GB (gradients)+56GB (optimizer)=112GB

This total memory requirement makes it clear that for a 24GB GPU, training this model would not fit without further optimizations.

2. AdamW Optimizer Overview

AdamW is a variant of the Adam optimizer that decouples the weight decay from the gradient update. In the standard Adam optimizer, weight decay is incorporated into the gradient update, while in AdamW, the weight decay is applied separately to the parameters.

2.1 AdamW Optimizer Mathematical Formulas

The update rule for AdamW is similar to that of Adam but includes a decoupled weight decay term:

  1. Weight decay term:
    θ t + 1 = θ t − α v ^ t + ϵ m ^ t − λ θ t \theta_{t+1} = \theta_t - \frac{\alpha}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t - \lambda \theta_t θt+1=θtv^t +ϵαm^tλθt
    where ( λ \lambda λ ) is the weight decay coefficient.
2.2 Memory Consumption of AdamW Optimizer

Since AdamW maintains the same first and second moment estimates as Adam, its memory consumption is identical to Adam’s. Thus, the memory requirements for AdamW in terms of the model parameters and optimizer states are the same as those for Adam:

  • Model Parameters Memory: 28 GB
  • Optimizer Memory: 56 GB

Thus, AdamW’s memory consumption in LLaMA-2 7B training is also:

Total Memory = 28   GB (model) + 28   GB (gradients) + 56   GB (optimizer) = 112   GB \text{Total Memory} = 28 \, \text{GB (model)} + 28 \, \text{GB (gradients)} + 56 \, \text{GB (optimizer)} = 112 \, \text{GB} Total Memory=28GB (model)+28GB (gradients)+56GB (optimizer)=112GB

3. Memory Consumption with float32 and bfloat16 Precision

3.1 Using float32 Precision

In float32 precision, each parameter, gradient, and optimizer state requires 4 bytes of memory. Therefore, for LLaMA-2 7B, the memory consumption is:

  • Model Parameters: 28 GB
  • Gradients: 28 GB
  • Optimizer States: 56 GB
  • Total Memory: 28 GB + 28 GB + 56 GB = 112 GB
3.2 Using bfloat16 Precision

Using bfloat16 precision (16-bit floating point), each parameter, gradient, and optimizer state requires only 2 bytes of memory. For LLaMA-2 7B, the memory consumption with bfloat16 would be:

  • Model Parameters: ( 7 × 1 0 9 × 2   bytes = 14   GB 7 \times 10^9 \times 2 \, \text{bytes} = 14 \, \text{GB} 7×109×2bytes=14GB )
  • Gradients: 14 GB
  • Optimizer States: ( 2 × 14   GB = 28   GB 2 \times 14 \, \text{GB} = 28 \, \text{GB} 2×14GB=28GB )
  • Total Memory: 14 GB + 14 GB + 28 GB = 56 GB

By using bfloat16 precision, the memory consumption is reduced by half compared to float32, which is a significant advantage for training large models on GPUs with limited memory.

4. Summary

  • Adam Optimizer and AdamW Optimizer both require additional memory for maintaining the first and second moment estimates, leading to twice the memory requirement of model parameters for the optimizer.
  • float32 Precision: With float32 precision, the memory requirement for training LLaMA-2 7B with Adam or AdamW is approximately 112 GB.
  • bfloat16 Precision: With bfloat16 precision, the memory requirement is reduced to 56 GB.

By choosing the appropriate optimizer and precision, you can significantly reduce memory usage and ensure the training of large models on GPUs with limited memory, which is essential for scaling up deep learning experiments.

后记

2024年11月29日18点38分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2252523.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt入门6——Qt窗口

目录 1. QMenuBar 菜单栏 2. QToolBar 工具栏 3. QStatusBar 状态栏 4. QDockWidget 浮动窗口 5. QDialog 对话框 5.1 Qt内置对话框 1. QMessageBox 消息对话框 2. QColorDialog 颜色对话框 3. QFileDialog 文件对话框 4. QFontDialog 字体对话框 5. QInputDialo…

A058-基于Spring Boot的餐饮管理系统的设计与实现

🙊作者简介:在校研究生,拥有计算机专业的研究生开发团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看项目链接获取⬇️,记得注明来意哦~🌹 赠送计算机毕业设计600个选题ex…

【VUE3】npm : 无法加载文件 D:\Program\nodejs\node_global\npm.ps1,因为在此系统上禁止运行脚本。

npm : 无法加载文件 D:\Program\nodejs\npm.ps1。未对文件 D:\Program\nodejs\npm.ps1 进行数字签名。无法在当前系统上运行该脚本。有关运行脚本和设置执行策略的详细信息,请参阅 https:/go.microsoft.com/fwlink/?LinkID135170 中的 about_ Execution_Policies。…

《JavaScript高级程序设计》读书笔记 17

感谢点赞、关注和收藏! 这一篇讲内存相关,主要是垃圾回收机制。 垃圾回收 JavaScript 是使用垃圾回收的语言,也就是说执行环境负责在代码执行时管理内存。在 C 和 C等语言中,内存如何管理是开发者来决定的。JavaScript通过自动内…

c语言——数组名该如何理解呢?

一般情况下,数组名表示首元素地址,以下2种除外: ①、sizeof(数组名) 表示整个数组 ※只有数组名的情况 sizeof(数组名i) 就不能表示整个数组 ②、&数组名 表示整个数组,取的是整个数…

IDL学习笔记(一)数据类型、基础运算、控制语句

近期,需要用到modis数据批量预处理,于是重新学习idl,感谢郭师兄推荐,以及张洋老师的详细教导。特以此为学习笔记,望学有所成。 IDL学习笔记(一) 数据类型数据类型创建数组类型转换函数代码输出print往文件…

数据结构——排序第三幕(深究快排(非递归实现)、快排的优化、内省排序,排序总结)超详细!!!!

文章目录 前言一、非递归实现快排二、快排的优化版本三、内省排序四、排序算法复杂度以及稳定性的分析总结 前言 继上一篇博客基于递归的方式学习了快速排序和归并排序 今天我们来深究快速排序,使用栈的数据结构非递归实现快排,优化快排(三路…

【语音识别】Zipformer

Zipformer 是kaldi 团队于2024研发的序列建模模型。相比较于 Conformer、Squeezeformer、E-Branchformer等主流 ASR 模型,Zipformer 具有效果更好、计算更快、更省内存等优点。并在 LibriSpeech、Aishell-1 和 WenetSpeech 等常用数据集上取得了当时最好的 ASR 结果…

Python酷库之旅-第三方库Pandas(251)

目录 一、用法精讲 1186、pandas.tseries.offsets.BusinessMonthEnd.is_year_start方法 1186-1、语法 1186-2、参数 1186-3、功能 1186-4、返回值 1186-5、说明 1186-6、用法 1186-6-1、数据准备 1186-6-2、代码示例 1186-6-3、结果输出 1187、pandas.tseries.offs…

【06】Selenium+Python 定位动态ID

有时候页面元素的ID是动态变化的,这种变化的ID,无法通过By.ID来定位,也无法通过BY.XPATH的绝对路径来定位 比如此li标签的id,中间的数字部分就是变化的,刷新页面后,id中间部分的数字就会变化 刷新页面前ID:…

leetcode 之 二分查找(java)(2)

文章目录 74、搜索二维矩阵33、搜素旋转排序数组 74、搜索二维矩阵 题目描述: 给你一个满足下述两条属性的 m x n 整数矩阵: 每行中的整数从左到右按非严格递增顺序排列。每行的第一个整数大于前一行的最后一个整数。 给你一个整数 target &#xff…

16asm - 汇编介绍 和 debug使用

文章目录 前言硬件运行机制微机系统硬件组成计算机系统组成8086cpu组织架构dosbox安装配置debug debug使用R命令D命令E命令U命令T命令A命令标志寄存器 总结 前言 各位师傅大家好,我是qmx_07,今天给大家讲解 十六位汇编 和 debug调试器的使用 硬件运行…

UE4_材质节点_有关距离的_流体模拟

一、材质节点介绍: 特别注意:距离场需要独立显卡支持。 1、什么是距离场? 想象一下空间中只有两个实体, 一个球,一个圆柱. 空间由无数个点组成, 取其中任何一个点, 比如,它跟球面的最近距离是3, 跟圆柱面的最近距离是2, 那么这个点的值就…

win10系统安装docker-desktop

1、开启Hyper-v ———————————————— Hyper-V 是微软提供的一种虚拟化技术,它允许你在同一台物理计算机上运行多个独立的操作系统实例。这种技术主要用于开发、测试、以及服务器虚拟化等领域。 —————————————————————— &#…

【小白学机器学习39】如何用numpy生成总体,生成样本samples

目录 1 目的:研究 样本和总体之间的关系 2 先生成1个理论总体 2.0 下面是关于这一步的完整代码 2.1 一般情况下,我们先生成一个符合正态分布的总体 2.1.1 设置总体 ,或者说生成一个总体 2.2 为什么一定要是一个符合正态分布的总体&…

“指标管理系统”是什么?企业如何搭建指标管理系统?

在当今数字化时代,数据已成为企业决策的重要依据。然而,海量数据中如何筛选出关键指标,并对其进行有效管理,成为了众多企业面临的难题。为此,指标管理系统应运而生,它旨在帮助企业规范化定义、统一管理和高…

网际协议(IP)与其三大配套协议(ARP、ICMP、IGMP)

网际协议(Internet Protocol,IP),又称互联网协议。是OSI中的网络层通信协议,用于跨网络边界分组交换。它的路由功能实现了互联互通,并从本质上建立了互联网。网际协议IP是 TCP/IP 体系中两个最主要的协议之…

运维工作常用Shell脚本(Commonly Used Shell Scripts for Operation and Maintenance Work)

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 本人主要分享计算机核心技…

机器学习8-决策树CART原理与GBDT原理

Gini 系数 和Gini 系数增益 CART决策树算法流程举例 该篇文章对于CART的算法举例讲解,一看就懂。 决策树(Decision Tree)—CART算法 同时也可以观看视频 分类树 GBDT原理举例 可以看如下示例可以理解GBDT的计算原理 用通俗易懂的方式讲解: GBDT算法及…

oracle中删除指定前缀的表

近期接手做的项目,发觉数据库中有许多多余的表。究其原因,应该是同事贪图方便,将过去做过的项目复制粘贴,然后修修改改。包括数据库也是克隆过来的,然后又没有删除本项目多余的表,结果经过几个轮回&#xf…