深度学习:混合精度训练

news2024/11/24 14:02:18

深度学习:混合精度训练

  • 前言
  • 混合精度训练
    • 核心技术
      • 权重备份
      • 损失缩放
      • 梯度裁剪
      • 动态调整学习率
    • 优势与弊端
    • 代码示例
  • 参考文献

前言

浮点数据类型主要分为双精度Double(FP64)、单精度Float(FP32)和半精度Half(FP16)。FP64浮点数据采用8个字节共64位,来进行的编码存储的一种数据类型;FP32浮点数据采用4个字节共32位来表示;FP16浮点数据则采用2个字节共16位来表示。
在这里插入图片描述

默认情况下,大多数深度学习框架(比如Pytorch)都采用32位浮点算法进行训练,混合精度训练可以在神经网络训练过程中,针对不同的层,采用不同的数据精度(比如半精度16位)进行计算,从而实现降低显存和加快速度的目的。

混合精度训练

在此主要介绍了混合精度训练的核心技术、优缺点和代码示例。

核心技术

混合精度训练过程中,主要使用了权重备份和损失缩放两大方法,此外还可以引入梯度裁剪和动态调整学习率等相关技术,提高训练稳定性,从而发挥混合精度训练的优势,并尽可能避免混合精度训练的弊端。

在这里插入图片描述

权重备份

如果直接全部采用FP16的精度进行模型训练,由于梯度幅值本身非常小,参数更新的时候乘上学习率就更小了,容易导致 NAN 或者参数更新失败(无法更新)的问题,故模型参数更新需要采用 FP32 格式上操作,因此需要维护一份 FP32 的模型参数副本,并利用FP16的梯度更新FP32模型的参数。值得注意的是,这里备份的模型副本增加的主要是静态内存。只要动态内存的值基本都是使用FP16来进行存储,则最终模型与整网络使用FP32进行训练相比起来, 内存占用也基本能够减半。

损失缩放

使用损失缩放的原因是 FP16 的梯度表示范围比较窄,如果不做处理,大量非零梯度会遇到溢出问题,那么即使后续是采用 FP32 参数更新也是没用的。故需要设置一个缩放系数(loss scale),将前向传播得到的Loss进行放大,放大到 16 精度可表示范围,但是需要注意在反向传播后需要除以缩放系数,将权重缩小后更新模型的参数。

动态损失缩放:上面提到的损失缩放都是使用一个默认值对损失值进行缩放,为了充分利用FP16的动态范围,可以更好地缓解舍入误差,尽量使用比较大的放大倍数。总结动态损失缩放算法,就是每当梯度溢出时候减少损失缩放规模,并且间歇性地尝试增加损失规模,从而实现在不引起溢出的情况下使用最高损失缩放因子,更好地恢复精度。

梯度裁剪

由于半精度浮点数表示的梯度较小,容易出现数值溢出或数值过小的问题。为了解决这个问题,采用梯度裁剪的方法,限制梯度的范围,防止梯度消失或爆炸。

动态调整学习率

随着训练的进行,动态地调整学习率以适应使用半精度浮点数时可能出现的数值不稳定性。这有助于提高训练的稳定性和收敛速度。

优势与弊端

混合精度训练的核心思想是将神经网络中的参数和梯度使用更低位数的浮点数表示,通常是16位半精度浮点数。混合精度的优势在于主要在于减小显存占用和加快训练速度方面。

  • 减少显存占用:FP16的位宽是FP32的一半,因此权重等参数所占用的内存也是原来的一半,从而可以使用更大的模型或更多的数据进行训练。
  • 加快通讯效率:对于分布式训练,特别是在大模型训练的过程中,通讯的开销往往会增大训练时间,使用低精度的数据,由于较小的位宽可以提高通讯效率,从而加快模型训练。
  • 计算效率更高:使用低精度的数据,执行运算性能也更高,从而加快模型训练,特别是在支持混合精度的硬件上(如NVIDIA的Volta架构及以后的GPU)。

弊端在于:

  • 数据溢出:FP16数据类型的有效数据范围比FP32数据类型的有效数据范围小,使用FP16替换FP32就会出现上溢(Overflow)或下溢(Underflow),从而容易出现数值不稳定性的问题,需要采用一些技术手段来处理。
  • 精度损失:FP16和FP32的最小间隔(精确度)不同,从FP32转换到FP16就会出现强制舍入,从而带来一定的精度损失。
  • 训练不稳定:使用混合精度训练容易出现NAN和参数无法更新的问题,需要精心设计超参数,以提高训练的稳定性。
  • 硬件依赖: 混合精度训练的效果受到硬件支持的限制,只有支持半精度浮点数运算的硬件才能发挥其优势。

代码示例

为了演示混合精度训练的流程,下面是Pytorch官方代码示例,供参考:

# Creates model and optimizer in default precision 
model = Net().cuda() 
optimizer = optim.SGD(model.parameters(), ...) 
 
# Creates a GradScaler once at the beginning of training. 
scaler = GradScaler() 
 
for epoch in epochs: 
 for input, target in data: 
        optimizer.zero_grad() 
 
 # Runs the forward pass with autocasting. 
 with autocast(): 
            output = model(input) 
            loss = loss_fn(output, target) 
 
 # Scales loss.  Calls backward() on scaled loss to create scaled gradients. 
 # Backward passes under autocast are not recommended. 
 # Backward ops run in the same dtype autocast chose for corresponding forward ops. 
        scaler.scale(loss).backward() 
 
 # scaler.step() first unscales the gradients of the optimizer's assigned params. 
 # If these gradients do not contain infs or NaNs, optimizer.step() is then called, 
 # otherwise, optimizer.step() is skipped. 
        scaler.step(optimizer) 
 
 # Updates the scale for next iteration. 
        scaler.update() 

参考文献

  1. https://blog.51cto.com/u_16099268/6696537
  2. https://zhuanlan.zhihu.com/p/375224982

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1316792.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在页面中加入百度地图

官方文档&#xff1a;jspopularGL | 百度地图API SDK (baidu.com) 添加一下代码就可以实现 <!DOCTYPE html> <html> <head><meta name"viewport" content"initial-scale1.0, user-scalableno"/><meta http-equiv"Conten…

基于Springboot的高校教学评价系统的设计与实现(源码+调试)

项目描述 临近学期结束&#xff0c;还是毕业设计&#xff0c;你还在做java程序网络编程&#xff0c;期末作业&#xff0c;老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。今天给大家介绍一篇基于Springboot的高校教…

idea2023解决右键没有Servlet的问题

复制Servlet Class.java中的文件。 回到文件&#xff0c;然后点击小加号 然后输入刚刚复制的东西&#xff1a; 3. 此时右键有servlet。 4. 然后他让你输入下面两个框&#xff1a; JAVAEE TYPE中输入Servlet Class Name 表示你要创建的Servlet类的名称是什么。自己起名字。然后…

EIS(防抖):meshflow算法

视频防抖的应用 对视频防抖的需求在许多领域都有。 这在消费者和专业摄像中是极其重要的。因此&#xff0c;存在许多不同的机械、光学和算法解决方案。即使在静态图像拍摄中&#xff0c;防抖技术也可以帮助拍摄长时间曝光的手持照片。 在内窥镜和结肠镜等医疗诊断应用中&…

用代码写uml并在线生成uml图

可以用PlantUml写uml,并在线生成uml图。 startuml start:登录系统; if (用户名和密码正确?) then (yes):进入系统首页;:展示主菜单; else (no):显示登录错误;stop endif:选择模块; partition "课程信息" {:查看课程列表;:查看课程详情; } partition "课程签到…

uniapp的uni-im 即时通信使用教程【用户与商家对话、聊天 / 最新 / 最全 / 带源码 / 教程】

目录 使用场景用户图片商家图片 官方文档官方文档地址插件地址 项目创建uniCloud开发环境申请开发环境申请完后 概括开始使用步骤1App.vue 步骤2找到软件登录图片找到软件登录接口登录源码如下 步骤3找到软件注册图片注册源码如下 步骤4找到index.vue首页图片 index.vue源码如下…

[robot_state_publisher-3] Error: Error document empty.

出现这个问题&#xff0c;我这里遇到的是&#xff1a;指定的urdf文件路径无效&#xff0c;而产生这个的根本原因是没有在CMakelists.txt中添加如下代码&#xff1a; install( DIRECTORY urdf DESTINATION share/${PROJECT_NAME} )把urdf文件夹添加到指定的share/${PROJEC…

第15章 《乐趣》Page305~311, 代码精简以后,讨论一下引用含义的问题

将Page305~311的代码精简了一下&#xff0c;讨论一下引用含义的问题&#xff0c;精简之后的代码如下&#xff1a; #include <iostream> #include <SDL2/SDL.h>using namespace std;namespace sdl2 {char const* last_error() {return SDL_GetError(); }struct Ini…

贪心算法:买卖股票的最佳时机II 跳跃游戏 跳跃游戏II

122.买卖股票的最佳时机II 思路&#xff1a; 想要获得利润&#xff0c;至少要以两天为一个交易单元&#xff0c;因为两天才会有股价差。因此可以将最终利润进行分解&#xff0c;如prices[3] - prices[0] (prices[3] - prices[2]) (prices[2] - prices[1]) (prices[1] - pr…

07-抽象工厂

意图 提供一个创建一系列相关或相互依赖对象的接口&#xff0c;而无需指定它们具体的类。 适用性 在以下的情况可以选择使用抽象工厂模式&#xff1a; 一个系统要独立于它的产品的创建、组合和表示。一个系统要由多个产品系列中的一个来配置。要强调一系列相关的产品对象的…

Elasticsearch优化-04

Elasticsearch优化 1、优化-硬件选择 Elasticsearch 的基础是 Lucene&#xff0c;所有的索引和文档数据是存储在本地的磁盘中&#xff0c;具体的路径可在 ES 的配置文件…/config/elasticsearch.yml中配置&#xff0c;如下&#xff1a; # #Path to directory where to store …

Pytorch:Tensorboard简要学习

目录 一、TensorBoard简介二、TensorBoard的安装与启动Tensorboard的安装Tensorboard的启动 三、TensorBoard的简单使用3.1 SummaryWriter()3.2 add_scalar()和add_scalars()3.3 add_histogram()3.4 模型指标监控 四、总结参考博客 一、TensorBoard简介 TensorBoard 是Google开…

C#中的封装、继承和多态

1.引言 在面向对象的编程中&#xff0c;封装、继承和多态是三个重要的概念。它们是C#语言中的基本特性&#xff0c;用于设计和实现具有高内聚和低耦合的代码。本文将详细介绍C#中的封装、继承和多态的相关知识。 目录 1.引言2. 封装2.1 类2.2 访问修饰符 3. 继承4. 多态4.1 虚方…

36个校招网络原理面试题

1.如何理解 URI&#xff1f; URI, 全称为(Uniform Resource Identifier), 也就是统一资源标识符&#xff0c;它的作用很简单&#xff0c;就是区分互联网上不同的资源。但是&#xff0c;它并不是我们常说的网址, 网址指的是URL, 实际上URI包含了URN和URL两个部分&#xff0c;由…

如何从众多知识付费平台中正确选择属于自己的平台(明理信息科技知识付费平台)

在当今的知识付费市场中&#xff0c;用户面临的选择越来越多&#xff0c;如何从众多知识付费平台中正确选择属于自己的平台呢&#xff1f;下面&#xff0c;我们将为您介绍明理信息科技知识付费平台相比同行的优势&#xff0c;帮助您做出明智的选择。 一、创新的技术架构&#…

全套SpringBoot讲义01

hello&#xff0c;我是小索奇&#xff0c;全套SpringBoot教程~一起来学习叭 文章目录 SpringBoot文档更新日志前言课程内容说明课程前置知识说明 SpringBoot基础篇JC-1.快速上手SpringBootJC-1-1.SpringBoot入门程序制作&#xff08;一&#xff09;JC-1-2.SpringBoot入门程序制…

Qt之QNetworkAccessManager 从本地和内存中上传数据到Http服务器

简述 接连做了好几个服务器的项目&#xff0c;例如文件传输用的Ftp和对象存储服务器(Object Storage Service)&#xff0c;简单的信息传输用的WebServer&#xff0c;之前也有用过HttpServer不过都和WebServer一样简单的调用接口提交数据并没有上传过文件&#xff0c;正好趁这次…

人工智能导论习题集(2)

第三章&#xff1a;确定性推理 题1题2题3题4题5题6题7 题1 题2 题3 题4 题5 题6 题7

设计模式之结构型设计模式(二):工厂模式 抽象工厂模式 建造者模式

工厂模式 Factory 1、什么是工厂模式 工厂模式旨在提供一种统一的接口来创建对象&#xff0c;而将具体的对象实例化的过程延迟到子类或者具体实现中。有助于降低客户端代码与被创建对象之间的耦合度&#xff0c;提高代码的灵活性和可维护性。 定义了一个创建对象的接口&…

spring 笔记九 Spring AOP

Spring 的 AOP 简介 什么是AOP AOP 为Aspect Oriented Programming 的缩写&#xff0c;意思为面向切面编程&#xff0c;是通过预编译方式和运行期动态代理实现程序功能的统一维护的一种技术。 AOP 是OOP 的延续&#xff0c;是软件开发中的一个热点&#xff0c;也是Spring框架…