大模型训练显存优化推理加速方案

news2025/1/10 20:33:35

当前的深度学习框架大都采用的都是fp32来进行权重参数的存储,比如Python float的类型为双精度浮点数fp64,pytorch Tensor的默认类型为单精度浮点数fp32。随着模型越来越大,加速训练模型的需求就产生了。在深度学习模型中使用fp32主要存在几个问题,第一模型尺寸大,训练的时候对显卡的显存要求高;第二模型训练速度慢;第三模型推理速度慢。其解决方案就是使用低精度计算对模型进行优化。本文主要讲解几种优化显存存储的方法。

1. fp32、fp16、bf16混合精度训练

  • FP32 是单精度浮点数,1位符号位,8位指数,23位表示小数,总共32位
  • BF16 是对FP32单精度浮点数截断数据,即用8bit 表示指数,7bit 表示小数
  • FP16 半精度浮点数,用5bit 表示指数,10bit 表示小数;
    请添加图片描述
    与32位相比,采用BF16/FP16吞吐量可以翻倍,内存需求可以减半。但是这两者精度上差异不一样,BF16 可表示的整数范围更广泛,但是尾数精度较小;FP16 表示整数范围较小,但是尾数精度较高。

1.1 混合精度训练

直接使用半精度进行计算会导致的两个问题的处理:舍入误差(Rounding Error)和溢出错误(Grad Overflow / Underflow)

  • 舍入误差
    float16 的最大舍入误差约为   2 − 10 ~2 ^{-10}  210,比 float32 的最大舍入误差   2 − 23 ~2 ^{-23}  223 要大不少。 对足够小的浮点数执行的任何操作都会将该值四舍五入到零。在反向传播中很多梯度更新值都非常小,但不为零,在反向传播中舍入误差累积可以把这些数字变成0或者nan, 这会导致不准确的梯度更新,影响网络的收敛

  • 溢出错误
    由于 float16 的有效的动态范围(正数部分,负数部分与正数对应)约为 5.96 × 1 0 − 8 ∼ 6.55 × 10 4 5.96\times10^{-8} \sim 6.55\times10{^4} 5.96×1086.55×104,比单精度的 float32 的动态范围 1.4 × 1 0 − 45 ∼ 1.7 × 1 0 38 1.4\times10^{-45} \sim 1.7 \times10^{38} 1.4×10451.7×1038要狭窄很多,精度下降会导致得到的值大于或者小于fp16的有效动态范围,也就是上溢出或者下溢出。在深度学习中,由于激活函数的的梯度往往要比权重梯度小,更易出现下溢出的情况

针对以上两种情况的解决方法是混合精度训练(Mixed Precision)和损失缩放(Loss Scaling)

  • 混合精度训练
    混合精度训练是一种通过在FP16上执行尽可能多的操作来大幅度减少神经网络训练时间的技术,在像线性层或是卷积操作上,FP16运算较快,但像Reduction运算又需要 FP32的动态范围。通过混合精度训练的方式,便可以在部分运算操作使用FP16,另一部分则使用 FP32,混合精度功能会尝试为每个运算使用相匹配的数据类型,在内存中用FP16做储存和乘法从而加速计算,用FP32做累加避免舍入误差。这样在权重更新的时候就不会出现舍入误差导致更新失败,混合精度训练的策略有效地缓解了舍入误差的问题

  • 损失缩放
    尽管使用了混合精度训练,还是会存在无法收敛的情况,原因是激活梯度的值太小,造成了下溢出。损失缩放是指在执行反向传播之前,将损失函数的输出乘以某个标量数(论文建议从8开始)。 乘性增加的损失值产生乘性增加的梯度更新值,提升许多梯度更新值到超过FP16的安全阈值2^-24。 只要确保在应用梯度更新之前撤消缩放,并且不要选择一个太大的缩放以至于产生inf权重更新(上溢出) ,从而导致网络向相反的方向发散

bf16/fp32 混合训练因为两种格式在范围上对齐了,并且 bf16 比 fp16 的范围更大,所以要比 fp16/fp32 混合训练稳定性更高

2. gradient checkpointing

gradient checkpointing(梯度检查点)的工作原理是在反向传播时重新计算深度神经网络的中间值(通常情况是在前向传播时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)

3. Xformers

Xformers 应该是目前社区知名度最高的优化加速方案了,所谓 Xformers 指的是该库将各种transformer 架构的模型囊括其中。注意,该库仅适用于N卡,特点是加速图片生成并降低显存占用,代价是输出图像不稳定,有可能比不开Xformers略差。各种transformer变体可以参考 A Survey of Transformers.

参考

  • 彻底搞懂float16与float32的计算方式
  • pytorch模型训练之fp16、apm、多GPU模型、梯度检查点(gradient checkpointing)显存优化等
  • facebookresearch/xformers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1032788.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

R语言贝叶斯MCMC:GLM逻辑回归、Rstan线性回归、Metropolis Hastings与Gibbs采样算法实例...

原文链接:http://tecdat.cn/?p23236 在频率学派中,观察样本是随机的,而参数是固定的、未知的数量(点击文末“阅读原文”获取完整代码数据)。 相关视频 什么是频率学派? 概率被解释为一个随机过程的许多观测…

Spark SQL【电商购买数据分析】

Spark 数据分析 (Scala) import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.{SparkConf, SparkContext}import java.io.{File, PrintWriter}object Taobao {case class Info(userId: Lo…

最该考的高含金量计算机证书盘点(文末领资料)

谈到大学规划,不少过来人都会建议萌新们在课余时间多多考证,俗话说的好“证多不压身”,今天我们就来聊一聊,计算机相关专业的大学生,有哪些证书可以考? 首先,不得不提的就是全国计算机二级考试…

web:[ACTF2020 新生赛]Exec

背景知识 命令执行漏洞 linux命令 题目 打开题目,页面显示的是一个ping 尝试一下 查看源代码发现 尝试ping一下百度 由题目名可知这道题关于exec(命令执行),这里需要联想到可以多条命令执行 输入baidu.com;ls 尝试;号是否能够…

从统计语言模型到预训练语言模型---预训练语言模型(Transformer)

预训练模型的概念在计算机视觉领域并不陌生, 通常我们可以在大规模图像数据集上预先训练出一个通用 模型, 之后再迁移到类似的具体任务上去, 这样在减少对图像样本需求的同时, 也加速了模型的开发速度。计 算机视觉领域采用 Image…

互联网医院系统|互联网医院软件功能与广阔应用领域

随着科技的不断进步和人们对健康需求的提高,互联网医院已经成为当今医疗领域的热点话题。作为一种融合了互联网和医疗服务的创新模式,互联网医院带来了许多便利和改变。本文将详细介绍互联网医院的软件功能、应用范围以及未来的发展趋势。 互联网医院通过…

【计算机毕业设计】基于SpringBoot+Vue电影在线订票系统的开发与实现

博主主页:一季春秋博主简介:专注Java技术领域和毕业设计项目实战、Java、微信小程序、安卓等技术开发,远程调试部署、代码讲解、文档指导、ppt制作等技术指导。主要内容:毕业设计(Java项目、小程序等)、简历模板、学习资料、面试题…

机器学习笔记:概念对比——损失函数,代价函数,目标函数

损失函数 Loss Function 通常是针对单个训练样本而言 给定一个模型输出 和一个真实值y ,损失函数是 代价函数 Cost Function 通常是针对整个训练集(或者在使用 mini-batch gradient descent 时一个 mini-batch)的总损失 目标函数 Objec…

备考cisp拿证,收藏这一篇就够了

为什么要考CISP 认证机构:中国信息安全测评中心,是中央批准成立的国家权威信息安全测评机构,CISP是当之无愧的国家级认证,是国内对信息安全从业人员资质能力的最高认可。 持证人数:在信息安全行业,持有CI…

多维数据可视化技术,Radviz可视化原理,向量化的 Radviz(vectorized Radviz,简称 VRV)

目录 多维数据可视化技术 Radviz可视化原理 向量化的 Radviz(vectorized Radviz,简称 VRV) 多维数据可视化技术 多维和高维数据普遍存在于我们的日常生活和科学研究中 . 比如 , 手机就包括品牌、型号、尺寸、重量、 生产日期、屏幕尺寸和电池容量等几十个属性; 又如 , 生物…

Pygame中Sprite类的使用3

在Pygame中Sprite类的使用2_棉猴的博客-CSDN博客中提到了通过派生自pygame.sprite.Sprite类的自定义类Zombie,可以实现一个僵尸的移动。可以通过pygame.sprite.Group类实现对多个Zombie类实例的管理,即可以实现多个僵尸的移动。 1 pygame.sprite.Group类…

一文彻底理解synchronized(通俗易懂的synchronized)

目录 一、什么是synchronized 二、synchronized的四种用法 2.1、修饰一个代码块 2.2、修饰一个方法 2.3、修饰一个静态的方法 2.4、修饰一个类 三、使用案例分析 3.1、修饰一个代码块 3.2、修饰一个方法 3.3、修饰一个静态的方法 3.4、修饰一个类 3.5 经典用法&…

#循循渐进学51单片机#UART串口通信#not.10

1、能够理解UART串口通信的基本原理和通信过程。 1)串行通信的初步认识 并行通信:通信时数据的各个位同时传送,可以实现字节为单位通信,但是通信线占用资源太多,成本高。 串行通信:一次只能发送一位&…

debian终端快捷键设置

为了方便使用图形化debian,快捷调出shell终端是提升工作学习效率的最重要的一步。 1.首先点击右上角,选择设置 2.点击键盘,选择快捷键,并创建自定义快捷键 3.点击添加快捷键 4.根据图中提示创建快捷键 Name: Terminal Command…

软考网络工程师华为配置考点总结

华为交换机配置基础 1.vlan的配置 华为设备中划分VLAN的方式有: 静态的划分:基于接口动态划分:基于MAC地址、基于IP子网、基于协议、基于策略(MAC地址、Ip地址)。 其中基于接口划分VLAN,是最简单&#x…

Arduino程序设计(十一)8×8 共阳极LED点阵显示(74HC595)

88 共阳极LED点阵显示 前言一、74HC595点阵模块1、74HC595介绍2、74HC595工作原理3、1088BS介绍4、74HC595点阵模块 二、点阵显示实验1、点阵显示初探2、点阵显示进阶3、点阵显示高阶3.1 点阵显示汉字(方法1)3.2 点阵显示汉字(方法2&#xff…

不用addEventListener(‘resize‘, this.resize),用新的Web API ResizeObserver监听DIV元素尺寸的变化

响应式设计指的是根据屏幕视口尺寸的不同,对 Web 页面的布局、外观进行调整,以便更加有效地进行信息的展示。我们日常生活中接触的很多应用都遵循响应式的设计。 响应式设计如今也成为 web 应用的基本需求,而现在很多 web 应用都已经组件化&a…

华为云云耀云服务器L实例评测 |云服务器选购

华为云耀云服务器 L 实例是一款轻量级云服务器,开通选择实例即可立刻使用,不需要用户再对服务器进行基础配置。新用户还有专享优惠,2 核心 2G 内存 3M 带宽的服务器只要 89 元/年,可以点击华为云云耀云服务器 L 实例购买地址去购买…

如何在新浪、搜狐、腾讯、网易、人民网等知名媒体网站上投稿

网络通稿成本低、投入小,软文宣传成为了众多企业的宣传选择,一篇优质的稿件更是能带来惊人的效果。越知名的网站传播效果越好,像新浪、搜狐、腾讯、网易、人民网等,那么如果找到这些网站投稿呢,本期盒子分享&#xff0…

[Go疑难杂症]为什么nil不等于nil

现象 在日常开发中,可能一不小心就会掉进 Go 语言的某些陷阱里,而本文要介绍的 nil ≠ nil 问题,便是其中一个,初看起来会让人觉得很诡异,摸不着头脑。 先来看个例子: type CustomizedError struct {Err…