Pytorch自动混合精度的计算:torch.cuda.amp.autocast

news2025/1/12 6:15:12

1 autocast介绍

1.1 什么是AMP?

默认情况下,大多数深度学习框架都采用32位浮点算法进行训练。2017年,NVIDIA研究了一种用于混合精度训练的方法,该方法在训练网络时将单精度(FP32)与半精度(FP16)结合在一起,并使用相同的超参数实现了与FP32几乎相同的精度。

FP16也即半精度是一种计算机使用的二进制浮点数据类型,使用2字节存储。而FLOAT就是FP32。

1.2 autocast作用

torch.cuda.amp.autocast是PyTorch中一种混合精度的技术(仅在GPU上训练时可使用),可在保持数值精度的情况下提高训练速度和减少显存占用。

    def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.float16, cache_enabled : bool = True):

它是一个自动类型转换器,可以根据输入数据的类型自动选择合适的精度进行计算,从而使得计算速度更快,同时也能够节省显存的使用。使用autocast可以避免在模型训练过程中手动进行类型转换,减少了代码实现的复杂性。

在深度学习中,通常会使用浮点数进行计算,但是浮点数需要占用更多的显存,而低精度数值可以在减少精度的同时,减少缓存使用量。因此,对于正向传播和反向传播中的大多数计算,可以使用低精度型的数值,提高内存使用效率,进而提高模型的训练速度。

1.3 autocast原理

autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:

  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

核心代码:autocast_mode.cpp

2 autocast优缺点

PyTorch中的autocast功能是一个性能优化工具,它可以自动调整某些操作的数据类型以提高效率。具体来说,它允许自动将数据类型从32位浮点(float32)转换为16位浮点(float16),这通常在使用深度学习模型进行训练时使用。

2.1 autocast优点

  • 提高性能:使用16位浮点数(half precision)进行计算可以在支持的硬件上显著提高性能,特别是在最新的GPU上。

  • 减少内存占用:16位浮点数占用的内存比32位少,这意味着在相同的内存限制下可以训练更大的模型或使用更大的批量大小。

  • 自动管理autocast能够自动管理何时使用16位浮点数,何时使用32位浮点数,这降低了手动管理数据类型的复杂性。

  • 保持精度:尽管使用了较低的精度,但autocast通常能够维持足够的数值精度,对最终模型的准确度影响不大。

2.2 autocast缺点

  • 硬件要求:并非所有的GPU都支持16位浮点数的高效运算。在不支持或优化不足的硬件上,使用autocast可能不会带来性能提升。

  • 精度问题:虽然在大多数情况下精度损失不显著,但在某些应用中,尤其是涉及到小数值或非常大的数值范围时,降低精度可能会导致问题。

  • 调试复杂性:由于autocast在模型的不同部分自动切换数据类型,这可能会在调试时增加额外的复杂性。

  • 算法限制:某些特定的算法或操作可能不适合在16位精度下运行,或者在半精度下的实现可能还不成熟。

  • 兼容性问题:某些PyTorch的特性或第三方库可能还不完全支持半精度运算。

在实际应用中,是否使用autocast通常取决于特定任务的需求、所使用的硬件以及对性能和精度的权衡。通常,对于大多数现代深度学习应用,特别是在使用最新的GPU时,使用autocast可以带来显著的性能优势。

3 使用示例

3.1 autocast混合精度计算

with autocast(): 语句块内的代码会自动进行混合精度计算,也就是根据输入数据的类型自动选择合适的精度进行计算,并且这里使用了GPU进行加速。使用示例如下:

# 导入相关库
import torch
from torch.cuda.amp import autocast

# 定义一个模型
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)

    def forward(self, x):
        with autocast():
            x = self.linear(x)
        return x

# 初始化数据和模型
x = torch.randn(1, 10).cuda()
model = MyModel().cuda()

# 进行前向传播
with autocast():
    output = model(x)

# 计算损失
loss = output.sum()

# 反向传播
loss.backward()

3.2 autocast与GradScaler一起使用

因为autocast会损失部分精度,从而导致梯度消失的问题,并且经过中间层时可能计算得到inf导致最终loss出现nan。所以我们通常将GradScaler与autocast配合使用来对梯度值进行一些放缩,来缓解上述的一些问题。

from torch.cuda.amp import autocast, GradScaler

dataloader = ...
model = Model.cuda(0)
optimizer = ...
scheduler = ...
scaler = GradScaler()  # 新建GradScale对象,用于放缩
for epoch_idx in range(epochs):
    for batch_idx, (dataset) in enumerate(dataloader):
        optimizer.zero_grad()
        dataset = dataset.cuda(0)
        with autocast():  # 自动混精度
            logits = model(dataset)
            loss = ...
        scaler.scale(loss).backward()  # scaler实现的反向误差传播
        scaler.step(optimizer)  # 优化器中的值也需要放缩
        scaler.update()  # 更新scaler
  scheduler.step()
...

4 可能出现的问题

使用autocast技术进行混精度训练时loss经常会出现'nan',有以下三种可能原因:

  • 精度损失,有效位数减少,导致输出时数据末位的值被省去,最终出现nan的现象。该情况可以使用GradScaler(上文所示)来解决。
  • 损失函数中使用了log等形式的函数,或是变量出现在了分母中,并且训练时,该数值变得非常小时,混精度可能会让该值更接近0或是等于0,导致了数学上的log(0)或是x/0的情况出现,从而出现'inf'或'nan'的问题。这种时候需要针对该问题设置一个确定值。例如:当log(x)出现-inf的时候,我们直接将输出中该位置的-inf设置为-100,即可解决这一问题。
  • 模型内部存在的问题,比如模型过深,本身梯度回传时值已经非常小。这种问题难以解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1208005.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

致刘家窑中医院龚洪海医生:患者的感谢与敬意

你们好!我曾经是咱们这的一名患者,我叫李刚,今年45岁,不知道你们还有印象吗?我曾去过一些医院进行就诊,但都没有得到恰当的治疗,症状一直没有消失。得了这个病之后对我的生活以及工作打击都十分的大。经朋友介绍说刘家…

Realistic fault detection of li-ion battery via dynamical deep learning

昇科能源、清华大学欧阳明高院士团队等的最新研究成果《动态深度学习实现锂离子电池异常检测》,用已经处理的整车充电段数据,分析车辆当前或近期是否存在故障。 思想步骤: 用正常电池的充电片段数据构造训练集,用如下的方式构造…

SwiftUI - 界面布局知识点

前言 SwiftUI采用的布局方式是和Flutter一样是弹性布局,而不是iOS之前的坐标轴的方式布局,不用准确的设置出位置大小,只需要设置当前视图大小及视图间排布的方式。灵活性增强,布局操作简便,SwiftUI与Flutter布局原理一…

基于SSM的供电所档案管理系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

交换机基础知识之安全配置

交换机在网络基础设施中扮演着重要角色,它促进了设备之间数据包的流动。正因此,采取适当的安全措施来保护网络免受未经授权的访问和潜在攻击至关重要。本文将全面解读交换机基础安全配置知识,并提供实践方案,以保证安全的网络环境…

Istio学习笔记-体验istio

参考Istio 入门(三):体验 Istio、微服务部署、可观测性 - 痴者工良 - 博客园 (cnblogs.com) 在本章中,我们将会学习到如何部署一套微服务、如何使用 Istio 暴露服务到集群外,并且如何使用可观测性组件监测流量和系统指标。 本章教程示例使用…

Vue CLI脚手架安装、搭建、配置 和 CLI项目分析

目录 一、CLI快速入门 1. 官方介绍 : 2.安装Vue CLI : 3.搭建Vue CLI : 4.IDEA配置Vue CLI : 二、Vue CLI项目分析 1.结构分析 : 1.1 config 1.2 node_modules 1.3 src 1.4 static 2.流程分析 : 2.1 main.js 2.2 router/index.js 2.3 components/HelloWorld.vue 2.4 A…

Vue3源码reactive和readonly对象嵌套转换,及实现shallowReadonly

前言 官方文档中对reactive的描述: 响应式转换是“深层”的:它会影响到所有嵌套的属性。一个响应式对象也将深层地解包任何 ref 属性,同时保持响应性。 官方文档中对readonly的描述: 只读代理是深层的:对任何嵌套属性的访问都将是…

城市内涝监测仪的作用有哪些?

城市内涝近几年愈发频繁,它的出现不仅仅会导致财产损失,还可能危及公共安全。所以对路面积水进行实时监测刻不容缓。内涝积水监测仪的早期警报系统,有助于提高城市的紧急响应能力。政府远程监控城市路面水位,实现精准的系统化管理…

SPL机制与使用,组件化技术核心点打法

什么是SPI SPI ,全称为 Service Provider Interface,是一种服务发现机制。它通过在ClassPath路径下的META-INF/services文件夹查找文件,自动加载文件里所定义的类。 SPI 的本质是将接口实现类的全限定名配置在文件中,并由服务加…

MySQL索引下推:提升数据库性能的关键优化技术

文章目录 前言索引下推原理MySQL 基础架构传统查询过程ICP 查询过程 使用场景限制参数配置索引下推开启状态查询索引下推开启和关闭 一些问题只有联合索引才能使用索引下推?下面的查询为什么不走索引下推 参考 前言 大家好,我是 Lorin ,今天…

教对象写代码

之前对象工作中需要获取地图上的一些数据, 手工找寻复制 费时费力, 逢此契机, 准备使用代码尽可能简化机械重复操作, 力图一劳永逸. 首选简洁易入门的Python. 下文就是对流程的总结, 及简述每步的意义. 并不Hack,重在感受编程的用途和基本工具的使用. 以百度地图为例,需求如下:…

大模型时代的机器人研究

机器人研究的一个长期目标是开发能够在物理上不同的环境中执行无数任务的“多面手”机器人。对语言和视觉领域而言,大量的原始数据可以训练这些模型,而且有虚拟应用程序可用于应用这些模型。与上述两个领域不同,机器人技术由于被锚定在物理世…

hive更改表结构的时候报错

现象 FAILED: ParseException line 1:48 cannot recognize input near ADD COLUMN compete_company_id in alter table statement 23/11/14 17:59:27 ERROR org.apache.hadoop.hive.ql.Driver: FAILED: ParseException line 1:48 cannot recognize input near ADD COLUMN compe…

身份证照片怎么弄成200k以内?超级好用!

一些网站为了限制大的文件上传,提出了一些大小限制的要求,那么身份证如何弄成200k呢?下面介绍三种方法。 方法一: 使用嗨格式压缩大师 1、在电脑上打开安装好的软件,在首界面中点击“图片压缩”。 2、进入后上传需要…

【MongoDB】索引 – 通配符索引

一、准备工作 这里准备一些数据 db.books.drop();db.books.insert({_id: 1, name: "Java", alias: "java 入门", description: "入门图书" }); db.books.insert({_id: 2, name: "C", alias: "c", description: "C 入…

OpenCV颜色识别及应用

OpenCV是一个开源计算机视觉库,提供了丰富的图像处理和计算机视觉算法,其中包括颜色识别。本文首先介绍了OpenCV库,然后着重描述了颜色识别的基本原理和方法,包括颜色空间的转换、阈值处理、颜色检测等技术。接下来详细探讨了Open…

【ccf-csp题解】第11次csp认证-第三题-Json查询超详细讲解

此题思路来源于acwing ccfcsp认证辅导课 题目描述 思路分析 此题的难点在于对输入的内容进行解析,题目中说除了保证字符串内容不会有空格存在之外,其它的任意地方都可能出现空格,甚至在某些地方还会出现空行,这样的话&#xff0…

spring-cloud-alibaba-nacos

spring cloud nacos 安装和启动nacos # 解压nacos安装包 # tar -zvxf nacos-server-1.4.1.tar.gz# nacos默认是以集群的模式启动,此处先用单机模式 # cd /usr/local/mysoft/nacos/bin # sh startup.sh -m standalone# nacos 日志 # tail -f /usr/local/mysoft/na…

reactive和effect,依赖收集触发依赖

通过上一篇文章已经初始化项目,集成了ts和jest。本篇实现Vue3中响应式模块里的reactive方法。 前置知识要求 如果你熟练掌握Map, Set, Proxy, Reflect,可直接跳过这部分。 Map Map是一种用于存储键值对的集合,并且能够记住键的原始插入顺…