PyTorch 源码解读之 torch.cuda.amp: 自动混合精度详解

news2025/1/11 5:54:35

PyTorch 源码解读之 torch.cuda.amp: 自动混合精度详解

Nvidia 在 Volta 架构中引入 Tensor Core 单元,来支持 FP32 和 FP16 混合精度计算。也在 2018 年提出一个 PyTorch 拓展 apex,来支持模型参数自动混合精度训练。自动混合精度(Automatic Mixed Precision, AMP)训练,是在训练一个数值精度 FP32 的模型,一部分算子的操作时,数值精度为 FP16,其余算子的操作精度是 FP32,而具体哪些算子用 FP16,哪些用 FP32,不需要用户关心,amp 自动给它们都安排好了。这样在不改变模型、不降低模型训练精度的前提下,可以缩短训练时间,降低存储需求,因而能支持更多的 batch size、更大模型和尺寸更大的输入进行训练。PyTorch 从 1.6 以后(在此之前 OpenMMLab 已经支持混合精度训练,即 Fp16OptimizerHook),开始原生支持 amp,即torch.cuda.amp module。2020 ECCV,英伟达官方做了一个 tutorial 推广 amp。从官方各种文档网页 claim 的结果来看,amp 在分类、检测、图像生成、3D CNNs、LSTM,以及 NLP 中机器翻译、语义识别等应用中,都在没有降低模型训练精度都前提下,加速了模型的训练速度。

本文是对torch.cuda.amp工作机制,和 module 中接口使用方法介绍,以及在算法角度上对 amp 不掉点原因进行分析,最后补充一点对 amp 存储消耗的解释。

1. 混合精度训练机制

torch.cuda.amp 给用户提供了较为方便的混合精度训练机制,“方便”体现在两个方面:

用户不需要手动对模型参数 dtype 转换,amp 会自动为算子选择合适的数值精度
对于反向传播的时候,FP16 的梯度数值溢出的问题,amp 提供了梯度 scaling 操作,而且在优化器更新参数前,会自动对梯度 unscaling,所以,对用于模型优化的超参数不会有任何影响
以上两点,分别是通过使用amp.autocast和amp.GradScaler来实现的。

autocast可以作为 Python 上下文管理器和装饰器来使用,用来指定脚本中某个区域、或者某些函数,按照自动混合精度来运行。混合精度在操作的时候,是先将 FP32 的模型的参数拷贝一份,拷贝的参数转换成 FP16,而 amp 规定了的 FP16 的算子(例如卷积、全连接),对 FP16 的数值进行操作;FP32 的算子(例如涉及 reduction 的算子,BatchNormalize,softmax…),输入和输出是 FP16,计算的精度是 FP32。在反向传播时,依然是混合精度计算,得到数值精度为 FP16 的梯度。最后,由于 GPU 中的 Tensor Core 天然支持 FP16 乘积的结果与 FP32 的累加(Tensor Core math),优化器的操作是利用 FP16 的梯度对 FP32 的参数进行更新。

在这里插入图片描述
对于 FP16 不可避免的问题就是:表示的范围较窄,如下图所示,大量非 0 梯度会遇到溢出问题。解决办法是:对梯度乘一个 2**N 的系数,称为 scale factor,把梯度 shift 到 FP16 的表示范围。

在这里插入图片描述
GradScaler的工作就是在反向传播前给 loss 乘一个 scale factor,所以之后反向传播得到的梯度都乘了相同的 scale factor。并且为了不影响学习率,在梯度更新前将梯度unscale。总结amp的基本训练流程:

  1. 维护一个 FP32 数值精度模型的副本;

  2. 在每个iteration。

    1. 拷贝并且转换成 FP16 模型;
    2. 前向传播(FP16 的模型参数);
    3. loss 乘 scale factor s;
    4. 反向传播(FP16 的模型参数和参数梯度);
    5. 参数梯度乘 1/s;
    6. 利用 FP16 的梯度更新 FP32 的模型参数。
      但是,这里会有一个问题,scale factor 应该如何选取?选一个常量显然是不合适的,因为 loss 和梯度的数值在变,scale factor 需要跟随 loss 动态变化。健康的 loss 是振荡中下降,因此GradScaler设计的 scale factor 每隔 N 个 iteration 乘一个大于 1 的系数,再 scale loss;并且每次更新前检查溢出问题(检查梯度中有没有inf和nan),如果有,scale factor 乘一个小于 1 的系数并跳过该 iteration 的参数更新环节,如果没有,就正常更新参数。动态更新 scale factor 是 amp 实际操作中的流程。总结 amp 动态 scale factor 的训练流程:
  3. 维护一个 FP32 数值精度模型的副本;

  4. 初始化 s;

  5. 在每个 iteration + a 拷贝并且转换成FP16模型 + b 前向传播(FP16 的模型参数) + c loss 乘 scale factor s + d 反向传播(FP16 的模型参数和参数梯度) + e 检查有没有inf或者nan的参数梯度 + 如果有:降低 s,回到步骤a + f 参数梯度乘 1/s + g 利用 FP16 的梯度更新 FP32 的模型参数。

2. amp模块的API

用户使用混合精度训练基本操作:

#amp依赖Tensor core架构,所以model参数必须是cuda tensor类型
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
#GradScaler对象用来自动做梯度缩放
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        # 在autocast enable 区域运行forward
        with autocast():
            # model做一个FP16的副本,forward
            output = model(input)
            loss = loss_fn(output, target)
        # 用scaler,scale loss(FP16),backward得到scaled的梯度(FP16)
        scaler.scale(loss).backward()
        # scaler 更新参数,会先自动unscale梯度
        # 如果有nan或inf,自动跳过
        scaler.step(optimizer)
        # scaler factor更新
        scaler.update()

2.1 autocast类

``autocast(enable=True)`` 可以作为上下文管理器和装饰器来使用,给算子自动安排按照 FP16 或者 FP32 的数值精度来操作。

2.1.1 autocast算子

PyTorch中,只有 CUDA 算子有资格被 autocast,而且只有 “out-of-place” 才可以被 autocast,例如:a.addmm(b, c)是可以被 autocast,但是a.addmm_(b, c)和a.addmm(b, c, out=d)不可以 autocast。amp autocast 成 FP16 的算子有:
在这里插入图片描述
autocast 成 FP32 的算子:
剩下没有列出的算子,像dot,add,cat…都是按数据中较大的数值精度,进行操作,即有 FP32 参与计算,就按 FP32,全是 FP16 参与计算,就是 FP16。

2.1.2 MisMatch error

作为上下文管理器使用时,混合精度计算 enable 区域得到的 FP16 数值精度的变量在 enable 区域外需要显式的转成 FP32:

# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")

with autocast():
    # torch.mm is on autocast's list of ops that should run in float16.
    # eable 区域内 fp32 自动转化成 fp16
    e_float16 = torch.mm(a_float32, b_float32)   
    # Also handles mixed input types
    f_float16 = torch.mm(d_float32, e_float16)

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float()) # eable 区域外 fp16 显示转化成 fp32
2.1.3 autocast 嵌套使用
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with autocast():
    e_float16 = torch.mm(a_float32, b_float32)
    with autocast(enabled=False): # enabled=False 关闭fp16
        f_float32 = torch.mm(c_float32, e_float16.float()) # 需要使用fp32进行计算
    g_float16 = torch.mm(d_float32, f_float32)
2.1.4 autocast 作为装饰器

这种情况一般用于 data parallel 的模型的,autocast 设计为 “thread local” 的,所以只在 main thread 上设 autocast 区域是不 work 的:

model = MyModel() 
dp_model = nn.DataParallel(model)

with autocast():     # dp_model's internal threads won't autocast.
     #The main thread's autocast state has no effect.     
     output = dp_model(input)     # loss_fn still autocasts, but it's too late...
     loss = loss_fn(output) 

正确姿势是对 forward 装饰:

MyModel(nn.Module):
    ...
    @autocast()
    def forward(self, input):
       ...

另一个正确姿势是在 forward 的里面设 autocast 区域:

MyModel(nn.Module):
    ...
    def forward(self, input):
        with autocast():
            ...

forward 函数处理之后,在 main thread 里 autocast:

model = MyModel()
dp_model = nn.DataParallel(model)

with autocast():
    output = dp_model(input)
    loss = loss_fn(output)
2.1.5 autocast 自定义函数

对于用户自定义的 autograd 函数,需要用amp.custom_fwd装饰 forward 函数,amp.custom_bwd装饰 backward 函数:

class MyMM(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        return a.mm(b)
    @staticmethod
    @custom_bwd
    def backward(ctx, grad):
        a, b = ctx.saved_tensors
        return grad.mm(b.t()), a.t().mm(grad)

调用时再 autocast

mymm = MyMM.apply
with autocast():
    output = mymm(input1, input2)
2.1.6 源码分析

autocast主要实现接口有:
A. enter

def __enter__(self):
    self.prev = torch.is_autocast_enabled()
    torch.set_autocast_enabled(self._enabled)
    torch.autocast_increment_nesting()

B. exit

def __exit__(self, *args):

    if torch.autocast_decrement_nesting() == 0:
        torch.clear_autocast_cache()
    torch.set_autocast_enabled(self.prev)
    return False

C. call

def __call__(self, func):
    @functools.wraps(func)
    def decorate_autocast(*args, **kwargs):
        with self:
            return func(*args, **kwargs)
    return decorate_autocast

其中torch.autocast函数是在 pytorch/aten/src/ATen/autocast_mode.cpp 里实现。PyTorch ATen 是 A TENsor library for C++11,ATen 部分有大量的代码是来声明和定义 Tensor 运算相关的逻辑的。autocast_mode.cpp 实现策略是 “ cache fp16 casts of fp32 model weights”。

2.2 GradScaler 类torch.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)用于动态 scale 梯度

  1. init_scale: scale factor 的初始值;
  2. growth_factor: 每次 scale factor 的增长系数;
  3. backoff_factor: scale factor 下降系数;
  4. growth_interval: 每隔多个 interval 增长 scale factor;
  5. enabled: 是否做 scale。
2.2.1 scale(output)方法

对outputs乘 scale factor,并返回,如果enabled=False就原样返回。

2.2.3 step(optimizer, *args, **kwargs)方法

step 方法在做两件事情:

  1. 对梯度 unscale,如果之前没有手动调用unscale方法的话;
  2. 检查梯度溢出,如果没有nan/inf,就执行 optimizer 的 step,如果有就跳过
    注意:GradScaler的step不支持传 closure。
2.2.4 update(new_scale=None)方法

update方法在每个 iteration 结束前都需要调用,如果参数更新跳过,会给 scale factor 乘backoff_factor,或者到了该增长的 iteration,就给 scale factor 乘growth_factor。也可以用new_scale直接更新 scale factor。

2.3 举例

2.3.1 Gradient clipping
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        scaler.scale(loss).backward()
        # unscale 梯度,可以不影响clip的threshold
        scaler.unscale_(optimizer)
        # clip梯度
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

        # unscale_()已经被显式调用了,scaler正常执行step更新参数,有nan/inf也会跳过
        scaler.step(optimizer)
        scaler.update()
2.3.2 Gradient accumulation
scaler = GradScaler()

for epoch in epochs:
    for i, (input, target) in enumerate(data):
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
            # loss 根据 累加的次数归一一下
            loss = loss / iters_to_accumulate

        # scale 归一的loss 并backward  
        scaler.scale(loss).backward()

        if (i + 1) % iters_to_accumulate == 0:
            # may unscale_ here if desired 
            # (e.g., to allow clipping unscaled gradients)

            # step() and update() proceed as usual.
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
2.3.3. Gradient penalty
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        # 防止溢出,在不是autocast 区域,先用scaled loss 得到 scaled 梯度
        scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
                                                 inputs=model.parameters(),
                                                 create_graph=True)
        # 梯度unscale
        inv_scale = 1./scaler.get_scale()
        grad_params = [p * inv_scale for p in scaled_grad_params]
        # 在autocast 区域,loss 加上梯度惩罚项
        with autocast():
            grad_norm = 0
            for grad in grad_params:
                grad_norm += grad.pow(2).sum()
            grad_norm = grad_norm.sqrt()
            loss = loss + grad_norm

        scaler.scale(loss).backward()

        # may unscale_ here if desired 
        # (e.g., to allow clipping unscaled gradients)

        # step() and update() proceed as usual.
        scaler.step(optimizer)
        scaler.update()
2.3.4. Multiple models

scaler 一个就够,但 scale(loss) 和 step(optimizer) 要分别执行

scaler = torch.cuda.amp.GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer0.zero_grad()
        optimizer1.zero_grad()
        with autocast():
            output0 = model0(input)
            output1 = model1(input)
            loss0 = loss_fn(2 * output0 + 3 * output1, target)
            loss1 = loss_fn(3 * output0 - 5 * output1, target)

        # (retain_graph here is unrelated to amp, it's present because in this
        # example, both backward() calls share some sections of graph.)
        scaler.scale(loss0).backward(retain_graph=True)
        scaler.scale(loss1).backward()

        # You can choose which optimizers receive explicit unscaling, if you
        # want to inspect or modify the gradients of the params they own.
        scaler.unscale_(optimizer0)

        scaler.step(optimizer0)
        scaler.step(optimizer1)

        scaler.update()
2.3.5. Multiple GPUs

torch DDP 和 torch DP model 的处理方式一样

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1508538.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024.03.11作业

1. 提示并输入一个字符串&#xff0c;统计该字符串中大写小写字母个数&#xff0c;数字个数&#xff0c;空格个数以及其他字符个数&#xff0c;要求使用c风格字符串完成 #include <iostream> #include <string>using namespace std;int main() {cout << &qu…

蓝桥杯2023年第十四届Java省赛真题-矩形总面积

题目描述 平面上有个两个矩形 R1 和 R2&#xff0c;它们各边都与坐标轴平行。设 (x1, y1) 和(x2, y2) 依次是 R1 的左下角和右上角坐标&#xff0c;(x3, y3) 和 (x4, y4) 依次是 R2 的左下角和右上角坐标&#xff0c;请你计算 R1 和 R2 的总面积是多少&#xff1f; 注意&…

设计模式深度解析:工厂方法模式与抽象工厂模式的深度对比

​&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 探索设计模式的魅力&#xff1a;工厂方法模式文章浏览阅读17k次&#xff0c;点赞105次&#xff0…

根据xlsx文件第一列的网址爬虫(selenium)

seleniumXpath 在与该ipynb文件同文件下新增一个111.xlsx&#xff0c;第一列放一堆需要爬虫的同样式网页 然后使用seleniumXpath爬虫 from selenium import webdriver from selenium.webdriver.common.by import By import openpyxl import timedef crawl_data(driver, url)…

2024年零基础自学网络安全/Web安全,看这一篇就够了

作为一个安全从业人员&#xff0c;我自知web安全的概念太过于宽泛&#xff0c;我本人了解的也并不够精深&#xff0c;还需要继续学习。 但又不想新入行的人走弯路&#xff0c;所以今天随手写写关于web安全的内容&#xff0c;希望对初次遇到web安全问题的同学提供帮助&#xff…

334.递增的三元子序列

题目&#xff1a;给你一个整数数组 nums &#xff0c;判断这个数组中是否存在长度为 3 的递增子序列。 如果存在这样的三元组下标 (i, j, k) 且满足 i < j < k &#xff0c;使得 nums[i] < nums[j] < nums[k] &#xff0c;返回 true &#xff1b;否则&#xff0c;…

Nginx+keepalived实现七层的负载均衡的高可用

目录 Nginxkeepalived实现七层的负载均衡的高可用 一、准备服务器 1、主机清单 2、配置安装nginx 所有的机器&#xff0c;关闭防火墙和selinux 3.安装nginx&#xff0c; 全部4台 二、部署负载均衡 1、修改nginx的配置文件&#xff0c;添加以下内容&#xff0c; 2、重启n…

APP自动化测试-Appium Inspector入门操作指南

上一篇博客APP自动化测试-入门示例-CSDN博客介绍了APP自动化测试的入门示例,下面详细介绍下Appium 实现的页面元素查看器工具:Appium Inspector的使用方法。 Appium Inspector简介 Appium Inspector 是 Appium 测试框架中的一个工具,用于可视化和调试移动应用程序的 UI 结…

污水处理厂重金属废水深度处理CH-90树脂处理系统

项目名称 广东某工业污水处理厂重金属废水深度处理工程项目 工艺选择 科海思重金属深度处理工艺 工艺原理 离子交换吸附 项目背景 随着环保要求不断提高&#xff0c;工业废水处理已成为众多企业的必修课。然而在工业生产中&#xff0c;如何有效处理含有重金属的废水成为…

结构化思维助力Prompt创作:专业化技术讲解和实践案例

最早接触 Prompt engineering 时, 学到的 Prompt 技巧都是: 你是一个 XX 角色… 你是一个有着 X 年经验的 XX 角色… 你会 XX, 不要 YY.. 对于你不会的东西, 不要瞎说!…对比什么技巧都不用, 直接像使用搜索引擎一样提问, 上面的技巧对于回复的效果确实有着 明显提升. 在看了 N…

【CSS面试题】外边距折叠的原因和解决

参考文章 什么时候出现外边距塌陷 外边距塌陷&#xff0c;也叫外边距折叠&#xff0c;在普通文档流中&#xff0c;在垂直方向上的2个或多个相邻的块级元素&#xff08;父子或者兄弟&#xff09;外边距合并成一个外边距的现象&#xff0c;不过只有上下外边距才会有塌陷&#x…

Xinstall CPA结算系统:精准追踪,轻松提升广告ROI

在如今的移动互联网时代&#xff0c;App推广已经成为各大企业获取用户、扩大市场份额的重要手段。然而&#xff0c;随着推广渠道的多样化&#xff0c;如何精准评估各渠道的效果、优化广告投放策略&#xff0c;以及提升用户体验&#xff0c;成为了摆在推广者面前的难题。 这时…

R语言绘制桑基图教程

原文链接&#xff1a;R语言绘制桑基图教程 写在前面 在昨天3月10日&#xff0c;我们在知乎、B站等分享了功能富集桑基气泡图的绘制教程。相关链接&#xff1a;NC|高颜值功能富集桑基气泡图&#xff0c;桑基气泡组合图。 确实&#xff0c;目前这个图在文章中出现的频率相对比较…

YOLOv8模型改进4【增加注意力机制GAM-Attention(超越CBAM,不计成本地提高精度)】

一、GAM-Attention注意力机制简介 GAM全称:Global Attention Mechanism。它被推出的时候有一个响亮的口号叫做:超越CBAM,不计成本地提高精度。由此可见,它的主要作用是为了目标检测精度的提高。 但是,大家都明白,具体效果怎么样,还得看具体的任务,我浅浅地试了一下,…

SpringBoot +WebSocket应用

我们今天不研究原理&#xff0c;只看应用。 什么是WebSocket WebSocket是一种在单个TCP连接上进行全双工通信的协议。WebSocket通信协议于2011年被IETF定为标准RFC 6455&#xff0c;并由RFC7936补充规范。WebSocket API也被W3C定为标准。 WebSocket使得客户端和服务器之间的数…

微信小程序开发系列(二十)·wxml语法·setData()修改对象类型数据、ES6 提供的展开运算符、delete和rest的用法

目录 1. 新增单个、多个属性 1.1 新增单个属性 1.2 新增多个属性 2. 修改单个、多个属性 2.1 修改单个属性 2.2 修改多个属性 3. 优化 3.1 ES6 提供的展开运算符 3.2 Object.assign()将多个对象合并为一个对象 4. 删除单个、多个属性 4.1 删除单个属性 …

Spring揭秘:Environment接口应用场景及实现原理!

内容概要 Environment接口提供了强大且灵活的环境属性管理能力&#xff0c;通过它&#xff0c;开发者能轻松地访问和配置应用程序运行时的各种属性&#xff0c;如系统属性、环境变量等。 同时&#xff0c;Environment接口还支持属性源的定制和扩展&#xff0c;使得开发者能根…

20240309web前端_第一周作业_完成电子汇款单

作业二&#xff1a;完成电子汇款单 成果展示: 完整代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"…

算法-状数组与线段树-1264. 动态求连续区间和

题目 思路 线段树&树状数组 - AcWing算法学习笔记(14): 线段树 - 知乎 (zhihu.com) 代码 Python超时版 def calculate_subarray_sum(nums, a, b):return sum(nums[a-1:b])n, m map(int, input().split()) nums list(map(int, input().split()))for _ in range(m):op,…

社交媒体革新者:揭秘Facebook对在线互动的影响

1. Facebook的兴起与发展 Facebook由马克扎克伯格在哈佛大学宿舍创建&#xff0c;最初只是服务于哈佛大学学生的社交网络。然而&#xff0c;其后快速扩张到其他大学和全球&#xff0c;成为了全球最大的社交媒体平台之一。其发展历程不仅是数字时代的典范&#xff0c;也是创业成…