pytorch笔记:自动混合精度(AMP)

news2024/10/5 15:26:18

1 理论部分

1.1 FP16 VS FP32

  • FP32具有八个指数位和23个小数位,而FP16具有五个指数位和十个小数位
  • Tensor内核支持混合精度数学,即输入为半精度(FP16),输出为全精度(FP32)

1.1.1 使用FP16的优缺点

  • 优点
    • FP16需要较少的内存,因此更易于训练和部署大型神经网络,同时还减少了数据移动(同时可以使用更大的batch)
    • 数学运算的运行速度大大降低了
      • NVIDIA提供的Volta GPU的确切数量是:FP16中为125 TFlops,而FP32中为15.7 TFlops(加速8倍)
  • 缺点:
    • 从FP32转到FP16时,必然会降低精度
      • 但有的时候,这个精度的降低可以忽略不计
      • FP16实际上可以很好地表示大多数权重和渐变。
      • ——>拥有存储和使用FP32所需的所有这些额外位只是浪费。
    • 溢出错误
      • 由于FP16的动态范围比FP32位的狭窄很多,因此,在计算过程中很容易出现上溢出和下溢出
      • 溢出之后就会出现"NaN"的问题

1.2 解决上述FP16的问题

1.2.1 混合精度训练

  • 用FP16做储存和乘法,而用FP32做累加避免舍入误差
  • ——>混合精度训练的策略有效地缓解了舍入误差的问题

1.2.2 损失放大(Loss scaling)

  • 即使使用了混合精度训练,还是存在无法收敛的情况
    • 原因是激活梯度的值太小,造成了溢出。
  • ——>通过使用torch.cuda.amp.GradScaler,通过放大loss的值来防止梯度的下溢出
    • 只在BP时传递梯度信息使用,真正更新权重时还是要把放大的梯度再unscale回去
      • 反向传播前,将损失变化手动增大2^k倍

        • 因此反向传播时得到的中间变量(激活函数梯度)不会溢出;

      • 反向传播后,将权重梯度缩小2^k倍,恢复正常值。

2 torch.cuda.amp

  • AMP(自动混合精度)的关键词有两个:
    • 自动
      • Tensor的dtype类型会自动变化,框架按需自动调整tensor的dtype,当然有些地方还需手动干预
    • 混合精度
      • 采用不止一种精度的Tensor,torch.FloatTensor和torch.HalfTensor

2.1 Pytorch中不同类型的tensor

类型名称位数
torch.DoubleTensor64bit
torch.LongTensor64bit
torch.FloatTensor(默认)32bit
torch.IntTensor32bit
torch.HalfTensor16bit
torch.BFloat16Tensor16bit
torch.ShortTensor16bit
torch.ByteTensor(无符号)8bit
torch.CharTensor8bit
torch.BoolTensorBoolean

2.2 在AMP上下文中,被自动转化为半精度浮点型的参数:

__matmul__
addbmm
addmm
addmv
addr
baddbmm
bmm
chain_matmul
conv1d
conv2d
conv3d
conv_transpose1d
conv_transpose2d
conv_transpose3d
linear
matmul
mm
mv
prelu

2.3 autocast

from torch.cuda.amp import autocast as autocast


model = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensor

optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)
    '''
    自动混合精度环境

    包含了前向过程(模型的输出)和loss的计算

    把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算

    进入autocast的上下文时,tensor可以是任何类型
        不需要在model或者input上手工调用.half() ,框架会自动做
    '''

    
    loss.backward()
    optimizer.step()
    # 反向传播在autocast上下文之外

 2.4 GradScaler

在2.3的基础上增加,反向传播时增加梯度,以防止下溢出

from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler


model = Net().cuda()
#首先初始化一个网络模型Net(),并使用.cuda()方法将模型移至GPU上以利用GPU加速
#Net中的参数默认是torch.FloatTensor

optimizer = optim.SGD(model.parameters(), ...)


scaler = GradScaler()
# 在训练最开始之前实例化一个GradScaler对象

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()


        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        '''
        自动混合精度环境

        包含了前向过程(模型的输出)和loss的计算

        把支持参数对应tensor的dtype转换为半精度浮点型,从而在不损失训练精度的情况下加快运算

        进入autocast的上下文时,tensor可以是任何类型
            不需要在model或者input上手工调用.half() ,框架会自动做
        '''

        
        scaler.scale(loss).backward()
        # Scales loss. 为了梯度放大,防止下溢出
        # 代替原来的loss.backward()
        
        scaler.step(optimizer)
        '''
        scaler.step() 首先把梯度的值unscale回来.
        
        如果梯度的值不是 infs 或者 NaNs, 那么调用optimizer.step()来更新权重,
            否则,忽略step调用,从而保证权重不更新(不被破坏)
        '''

        
        scaler.update()
        '''
        准备着,看是否要增大scaler

        '''
  •  scaler的大小在每次迭代中动态的估计
    • 为了尽可能的减少梯度underflow,scaler应该更大
    • 但是如果太大的话,半精度浮点型的tensor又容易overflow(变成inf或者NaN)。
  • ——>动态估计的原理就是在不出现inf或者NaN梯度值的情况下尽可能的增大scaler的值

3 一些tips

  • 为了保证计算不溢出,首先保证人工设定的常数不溢出。如epsilon,INF等
  • Dimension最好是8的倍数:维度是8的倍数,性能最好
  • 涉及sum的操作要小心,容易溢出
    • 比如softmax操作,建议用官方API,并定义成layer写在模型初始化里
  • 如果遇到以下的报错:
    • RuntimeError: expected scalar type float but found c10::Half
    • 需要手动在tensor上调用.float()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1794061.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AR开发-开源框架】使用Sceneform-EQR快速开发AR应用,当前接入了AREngine、ORB-SLAM,可快速地适配不同的安卓设备

Sceneform-EQR Sceneform 概览 Sceneform是一个3D框架,具有基于物理的渲染器,针对移动设备进行了优化,使您可以轻松构建增强现实应用程序,而无需OpenGL。 借助 Sceneform,您可以轻松地在 AR 应用和非 AR 应用中渲染…

【C语言】10.C语言指针(4)

文章目录 1.回调函数是什么?2.qsort 使⽤举例2.1 使⽤qsort函数排序整型数据2.2 使⽤qsort排序结构数据 3.qsort函数的模拟实现 1.回调函数是什么? 回调函数就是一个通过函数指针调用的函数。 如果你把函数的指针(地址)作为参数…

如何为律师制作专业的商务名片?含电子名片二维码

律师关注细节,律师名片也不例外。它们不仅仅是身份的象征,更是律师专业形象的代表,传递专业知识和信任。今天就来和我们一起来看看制作律师商务名片的注意事项,以及如何制作商务名片上的电子名片二维码? 一、名片的主…

kafka学习笔记 @by_TWJ

目录 1. 消息重复消费怎么解决1.1. 确保相同的消息不会被重复发送(消费幂等性)1.2. 消息去重1.3. 消息重试机制1.4. kafka怎么保证消息的顺序性1.4.1. 利用分区的特征:1.4.2. 解决办法:1.4.3. 分区分配策略1.4.3.1. RangeAssignor (每组(Topi…

Windows下安装和配置Redis

目录 1、下载redis压缩包 2、解压redis文件 3、启动redis临时服务 4、打开Redis客户端进行连接 5、使用一些基础操作来测试 5.1、输入ping命令来检测redis服务器与redis客户端的连通性 5.2、使用set和get命令测试redis数据库进行数据存储和获取 5.3、在命令中通过shut…

vs中C++项目中没有QT(.pro)文件怎么生成翻译ts文件

目录 使用 CMake 生成翻译文件 1.创建 CMakeLists.txt 文件 2.添加翻译生成规则 3.运行 CMake 4.生成翻译文件 使用命令行工具生成翻译文件 1.运行 lupdate 2.编辑 .ts 文件 3.运行 lrelease 网络上说的情况都是一个qt程序在VS中打开,拥有.pro文件的情况&a…

解决远程服务器连接报错

最近使用服务器进行数据库连接和使用的时候出现了一个报错: Error response from daemon: Conflict. The container name “/mysql” is already in use by container “1bd3733123219372ea7c9377913da661bb621156d518b0306df93cdcceabb8c4”. You have to remove …

什么是WEB应用防火墙,云服务器有带吗

伴随着Web软件、SaaS应用及API交互的使用普及,网络安全威胁也随之而来。网络攻击者一直不断改进他们的方法,使用自动爬虫程序、僵尸网络和漏洞扫描器来发动多媒介攻击,试图瘫痪应用。而WAF防火墙可以帮助保护Web应用程序免受这些攻击&#xf…

10.爬虫---XPath插件安装并解析爬取数据

10.XPath插件安装并解析爬取数据 1.XPath简介2.XPath helper安装3.XPath 常用规则4.实例引入4.1 //匹配所有节点4.2 / 或 // 匹配子节点或子孙节点4.3 ..或 parent::匹配父节点4.4 匹配属性4.5 text()文本获取4.6 属性获取4.7 属性多值匹配 1.XPath简介 XPath是一门在XML文档中…

让你工作效率飞起的五款软件

🌟 No.1:亿可达 作为一款自动化工具,亿可达被誉为国内版的免费Zaiper。它允许用户无需编程知识即可将不同软件连接起来,构建自动化的工作流程。其界面设计清新且直观,描述语言简洁易懂,使得用户可以轻松上…

一个可以自动生成随机区组试验的excel VBA小程序

在作物品种区域试验时,通常会采用随机区组试验设计,特制作了一个可以自动生成随机区组试验的小程序。excel参数界面如下: 参数含义如下: 1、生成新表的名称:程序将新建表格,用于生成随机区组试验。若此处为…

探索通信技术的未来:2024中国通信技术和智能装备产业博览会

探索通信技术的未来:2024通信技术产业专场 随着信息技术的飞速发展,通信技术已成为现代社会不可或缺的基础设施。2024年10月11日至13日,青岛将迎来一场通信技术的盛会——2024中国军民两用智能装备与通信技术产业博览会。本次博览会不仅将展…

调试线上资源文件失效问题

之前的老项目,突然报红,为了定位问题,使用注入和文件替换的方式进行问题定位! 1.使用注入 但是刷新后就没有了,不是特别好用! const jqScript document.createElement(script); jqScript.src https://…

视频推广短信:新时代的营销利器(视频短信XML接口示例)

随着移动互联网的普及,短信已经不再是简单的文字信息传递工具,而是逐渐演变为一种有效的推广手段。特别是当视频与短信结合时,它所带来的营销效率更是令人瞩目。 一、视频推广短信的特点 1.直观性:与传统的文字短信相比&#xf…

MySQL存储引擎的区别和比较

MyISAM存储引擎 MyISAM基于ISAM存储引擎,并对其进行扩展。它是在Web、数据仓储和其他应用环境下最常使用的存储引擎之一。MyISAM拥有较高的插入、查询速度,但不支持事务。 MyISAM主要特性有: 1、大文件(达到63位文件长度&#x…

图片像素缩放,支持个性化自定义与精准比例调整,让图像处理更轻松便捷!

图片已经成为我们生活中不可或缺的一部分。无论是社交媒体的分享,还是工作文档的编辑,图片都扮演着至关重要的角色。然而,你是否曾经遇到过这样的问题:一张高清大图在上传时却受限于平台的大小要求,或者一张小图需要放…

PPT视频如何16倍速或者加速播放

有两种方式,一种是修改PPT本身,这种方式非常繁琐,不太推荐,还有一种就是修改视频本身,直接让视频是16倍速的视频即可。 如何让视频16倍速,我建议人生苦短,我用Python,几行代码&…

Kafka生产者消息异步发送并返回发送信息api编写教程

1.引入依赖&#xff08;pox.xml文件&#xff09; <dependencies> <dependency> <groupId>org.apache.kafka</groupId> <artifactId>kafka-clients</artifactId> <version>3.6.2</version> </dependency> </depende…

0基础学习区块链技术——推演猜想

大纲 去中心预防篡改付出代价方便存储 在《0基础学习区块链技术——入门》一文中&#xff0c;我们结合可视化工具&#xff0c;直观地感受了下区块的结构&#xff0c;以及链式的前后关系。 本文我们将抛弃之前的知识&#xff0c;从0开始思考和推演&#xff0c;区块链技术可能是如…

【必会面试题】快照读的原理

目录 前言知识点一个例子 前言 快照读&#xff08;Snapshot Read&#xff09;是数据库管理系统中一种特殊的读取机制&#xff0c;主要用于实现多版本并发控制&#xff08;MVCC, Multi-Version Concurrency Control&#xff09;策略&#xff0c;尤其是在MySQL的InnoDB存储引擎中…