大模型LLM训练显存消耗详解

news2025/1/20 3:52:13

参考论文:ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

大模型的显存消耗一直都是面试常见的问题,这次我就彻彻底底的根据论文ZeRO中的调研和分析做一次分析

显存消耗的两个部分:Model States(跟模型的参数量和优化器相关)Residual Memory Consumption(跟训练时的batchsize,序列长度有关)

接下来,我就从这两个部分详细分析:


Model States

在这里插入图片描述

一个模型在显存消耗上,分为三个部分

  1. Optimizer States
  2. Gradients States
  3. Parameters States

更加具体的说,对于一个模型参数(Parameters)我们需要维护维护三个不同方面的参数
我们假设:模型的参数量大小为ModelSize

Parameters States

故名思义就是模型本身的权重参数,对于一个使用Float32存储的参数,我们需要32/8=4byte进行存储。

Gradients States

记录参数的梯度,对于一个使用Float32存储的参数,我们同样需要一个相同大小的梯度(4byte)保存它的梯度。

Optimizer States

对于最常用的Adam优化器以及其变体,对于一个使用Float32存储的参数需要维护两个额外的参数momentumvariance,也就是需要2*4=8byte进行保存


总的来说,对于Float32保存的模型来说,我们显存消耗是16(4+4+8)* ModelSize byte

但是对于半精度保存的模型(Float16),每个参数Parameters StatesGradients States的显存消耗都是2byte。在训练时,我们仍然需要保存其Float32的Parameters States用以加速运算,同时Adam优化器的两个参数momentumvariance同样也是Float32形式保存的,每个参数消耗的即为4+4+4=12 byte。所以半精度保存的模型,计算时的显存消耗仍然为16(2+2+12)* ModelSize byte


Residual Memory Consumption

剩下的显存消耗跟我们训练时的配置有关
主要有三个部分

  1. Activations
  2. Temporary buffers
  3. Memory Fragmentation

Activations

对于一个transformer based的模型来说,Activations的显存消耗和如下公式是成比例的:

number of transformer layers × hidden dimensions × sequence length × batch size

对于GPT2来说,这个比例大约为12

Temporary buffers 和 Memory Fragmentation

这两个参数不容易具体量化,Temporary buffers是多卡训练过程中为了提升梯度计算的效率,通常会执行一些类似于gradient all-reducegradient norm computation等操作,把数据集合到一个临时的缓存区中,这个临时区也会占用相当数量的显存

Memory Fragmentation,内存碎片的产生会导致内存空间的利用效率低下,即使有空余空间但是不足以分配给一个新的内存请求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1455273.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【机构vip教程】Android SDK手机测试环境搭建

Android SDK 的安装和环境变量的配置 前置条件:需已安装 jdk1.8及 以上版本 1、下载Android SDK,解压后即可(全英文路径);下载地址:http://tools.android-studio.org/index.php/sdk 2、新建一个环境变量&…

linux内核模块module_put()函数详解--03

对应module_put()函数详细用法分享。 第一:函数简介 //函数原型 void module_put(struct module * module) //函数功能 该函数功能是将一个特定模块module的引用计数减一 这样当一个模块的引用计数不为0而不能被内核卸载的 时候,可以调用该函数一次或多…

这样用TVS管

对于工程师来说,浪涌保护不仅仅是选择合适的电源板或者拔下几根电缆,主要涉及在 PCB 布局中放置瞬态保护组件并应用明确的接地策略。 TVS 二极管是用于保护PCB布局中组件的常用组件,这些组件放置在数据线上,一旦电路中接收到ESD脉…

出生医学证明档案管理系统

出生医学证明档案管理系统是一种用于管理和维护出生医学证明档案的软件系统。该系统可以帮助医院、出生登记机构和其他相关部门有效地管理和存储出生医学证明档案,提高工作效率和数据安全性。 专久智能出生医学证明档案管理系统的核心功能包括: 1. 档案管…

Linux超详细笔记

文章目录 Linux学习笔记操作系统Linux初识Linux的诞生Linux内核Linux发行版 虚拟机VMware安装远程连接Linux系统FinalShellFinalShell连接Linux WSL配置UbuntuLinux常用命令1.入门2.ls命令cd命令3.pwd命令4.相对路径和绝对路径5.mkdir命令6.文件操作命令(1&#xff…

【机构vip教程】Charles(1):Charles的介绍及安装

Charles Charles 是在 Mac (Charles是跨平台的 )下常用的网络封包截取工具,在做移动开发、测试时,我们为了调试与服务器端的网络通讯协议,常常需要截取网络封包来分析。Charles是一个HTTP代理服务器,HTTP监视器,反转代…

从零开始的 dbt 入门教程 (dbt core 开发进阶篇)

引 在上一篇文章中,我们花了专门的篇幅介绍了 dbt 更多实用的命令,那么我们继续按照之前的约定来聊 dbt 中你可能会遇到的疑惑以及有用的概念,如果你是 dbt 初学者,我相信如下知识点一定会对你有极大的帮助: 了解 db…

【Linux篇】Linux项目自动化构建工具-make/Makefile

💛不要有太大压力🧡 💛生活不是选择而是热爱🧡 💚文章目录💚 什么是make/Makefilemakefile认识makefilemakefile的编写伪目标 Linux下多程序编译 什么是make/Makefile 在实际工作中,一个项目可能…

在职阿里6年,一个28岁女软件测试工程师的心声

简单的先说一下,坐标杭州,16届本科毕业,算上年前在阿里巴巴的面试,一共有面试了有6家公司(因为不想请假,因此只是每个晚上去其他公司面试,所以面试的公司比较少) 其中成功的有4家&am…

芯片的分类

目录 通用处理器数字信号处理器专用处理器 通用处理器 我们常听说的中央处理器CPU就是一种典型的通用处理器(GPP)。这种处理器多使用片上系统(SoC)的设计理念,在处理器上集成各种功能模块,每一种功能都是用…

Python爬虫详解(一看就懂)

爬虫 爬虫是什么 爬虫简单的来说就是用程序获取网络上数据这个过程的一种名称。 爬虫的原理 如果要获取网络上数据,我们要给爬虫一个网址(程序中通常叫URL),爬虫发送一个HTTP请求给目标网页的服务器,服务器返回数据…

thinkphp+vue企业产品展示网站f7enu

本文首先介绍了企业产品展示网站管理技术的发展背景与发展现状,然后遵循软件常规开发流程,首先针对系统选取适用的语言和开发平台,根据需求分析制定模块并设计数据库结构,再根据系统总体功能模块的设计绘制系统的功能模块图&#…

三防平板电脑丨亿道工业三防平板丨三防平板定制丨机场维修应用

随着全球航空交通的增长和机场运营的扩展,机场维护的重要性日益凸显。为确保机场设施的安全和顺畅运行,采取适当的措施来加强机场维护至关重要。其中,三防平板是一种有效的工具,它可以提供持久耐用的表面保护,使机场维…

基于Java+Jsp的超市积分管理系统

🍅文末获取源码联系🍅 👇🏻 精彩项目推荐订阅👇🏻 不然下次找不到哟 感兴趣的可以先收藏起来,还有大家在毕设选题,项目以及论文编写等相关问题都可以给我留言咨询,希望帮…

300分钟吃透分布式缓存-01讲:业务数据访问性能太低怎么办?

这节课主要讲缓存的基本思想、缓存的优点、缓存的代价三个部分。 缓存的定义 先来看下缓存的定义。 & 缓存最初的含义,是指用于加速 CPU 数据交换的 RAM,即随机存取存储器,通常这种存储器使用更昂贵但快速的静态 RAM(SRAM&…

lv15 input子系统框架、外设驱动开发 5

一、input子系统基本框架 在我们日常的Linux系统中,存在大量的输入设备,例如按键、鼠标、键盘、触摸屏、摇杆等,他们本身就是字符设备,linux内核将这些字符设备的共同性抽象出来,简化驱动开发建立了一个input子系统。 …

Springboot+vue的物流管理系统(有报告)。Javaee项目,springboot vue前后端分离项目

演示视频: Springbootvue的物流管理系统(有报告)。Javaee项目,springboot vue前后端分离项目 项目介绍: 本文设计了一个基于Springbootvue的前后端分离的物流管理系统,采用M(model)…

【COMP337 LEC 5-6】

LEC 5 Perceptron &#xff1a; Binary Classification Algorithm 8 感应器是 单个神经元的模型 突触连接的强度取决于接受外部刺激的反应 X input W weights a x1*w1x2*w2....... > / < threshold Bias MaxIter is a hyperparameter 超参数 which has to be chosen…

Android系统app开发

Android系统app开发 系统app阔以使用很多系统源码中隐藏的api 首先先编译出jar包 整编源码后&#xff0c;在这个目录下&#xff0c;这个就是包含系统源码隐藏api的jar包 系统app的标志 拷贝jar文件到app模块下 在编译之前把这个jar添加到classpath路径里面去&#xff0c;这样…

【机构vip教程】Selenium(2):selenium IDE工具

Selenium IDE工具&#xff1a; 该工具是一个用于构建脚本的初级工具&#xff0c;其实是FireFox的一个插件&#xff0c;拥有一个易于使用的界面。它拥有记录功能&#xff0c;能够记录用户执行的操作&#xff0c;并可以导出为可重复使用的脚本。如果没有编程经验&#xff0c;也可…