AMP 混合精度训练中的动态缩放机制: grad_scaler.py函数解析( torch._amp_update_scale_)

news2025/1/7 18:36:11

AMP 混合精度训练中的动态缩放机制

在深度学习中,混合精度训练(AMP, Automatic Mixed Precision)是一种常用的技术,它利用半精度浮点(FP16)计算来加速训练,同时使用单精度浮点(FP32)来保持数值稳定性。为了在混合精度训练中避免数值溢出,PyTorch 提供了一种动态缩放机制来调整 “loss scale”(损失缩放值)。本文将详细解析动态缩放机制的实现原理,并通过代码展示其内部逻辑。


动态缩放机制简介

动态缩放机制的核心思想是通过一个可动态调整的缩放因子(scale factor)放大 FP16 的梯度,从而降低舍入误差对训练的影响。当检测到数值不稳定(例如 NaN 或无穷大)时,缩放因子会被降低;当连续多步未检测到数值问题时,缩放因子会被提高。其调整策略基于以下两个参数:

  • growth_factor: 连续成功步骤后用于增加缩放因子的乘数(通常大于 1,如 2.0)。
  • backoff_factor: 检测到数值溢出时用于减少缩放因子的乘数(通常小于 1,如 0.5)。

此外,动态缩放还使用 growth_interval 参数控制连续成功步骤的计数阈值。当达到这个阈值时,缩放因子才会增加。


AMP 缩放更新核心代码解析

PyTorch 实现了一个用于更新缩放因子的 CUDA 核函数以及相关的 Python 包装函数。以下是核心代码解析:

CUDA 核函数实现

// amp_update_scale_cuda_kernel 核函数实现
__global__ void amp_update_scale_cuda_kernel(float* current_scale,
                                             int* growth_tracker,
                                             const float* found_inf,
                                             double growth_factor,
                                             double backoff_factor,
                                             int growth_interval) {
  if (*found_inf) {
    // 如果发现梯度中存在 NaN 或 Inf,缩放因子乘以 backoff_factor,并重置 growth_tracker。
    *current_scale = (*current_scale) * backoff_factor;
    *growth_tracker = 0;
  } else {
    // 未发现数值问题,增加 growth_tracker 的计数。
    auto successful = (*growth_tracker) + 1;
    if (successful == growth_interval) {
      // 当 growth_tracker 达到 growth_interval,尝试增长缩放因子。
      auto new_scale = static_cast<float>((*current_scale) * growth_factor);
      if (isfinite_ensure_cuda_math(new_scale)) {
        *current_scale = new_scale;
      }
      *growth_tracker = 0;
    } else {
      *growth_tracker = successful;
    }
  }
}
核函数逻辑
  1. 发现数值溢出(found_inf > 0):

    • 缩放因子 current_scale 乘以 backoff_factor
    • 重置成功计数器 growth_tracker 为 0。
  2. 未发现数值溢出:

    • 增加成功计数器 growth_tracker
    • 如果 growth_tracker 达到 growth_interval,则将缩放因子乘以 growth_factor
    • 保证缩放因子不会超过 FP32 的数值上限。

C++ 包装函数实现

在 PyTorch 中,这一 CUDA 核函数通过 C++ 包装函数 _amp_update_scale_cuda_ 被调用。以下是实现代码:

Tensor& _amp_update_scale_cuda_(Tensor& current_scale,
                                Tensor& growth_tracker,
                                const Tensor& found_inf,
                                double growth_factor,
                                double backoff_factor,
                                int64_t growth_interval) {
  TORCH_CHECK(growth_tracker.is_cuda(), "growth_tracker must be a CUDA tensor.");
  TORCH_CHECK(current_scale.is_cuda(), "current_scale must be a CUDA tensor.");
  TORCH_CHECK(found_inf.is_cuda(), "found_inf must be a CUDA tensor.");
  
  // 核函数调用
  amp_update_scale_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
    current_scale.mutable_data_ptr<float>(),
    growth_tracker.mutable_data_ptr<int>(),
    found_inf.const_data_ptr<float>(),
    growth_factor,
    backoff_factor,
    growth_interval);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  return current_scale;
}

Python 调用入口

AMP 的 GradScaler 类通过 _amp_update_scale_ 函数更新缩放因子,以下是相关代码:
代码来源:anaconda3/envs/xxxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

具体调用过程可以参考笔者的另一篇博文:PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例

def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None:
    """更新缩放因子"""
    if not self._enabled:
        return

    _scale, _growth_tracker = self._check_scale_growth_tracker("update")

    if new_scale is not None:
        # 设置用户定义的新缩放因子。
        self._scale.fill_(new_scale)
    else:
        # 收集所有优化器中的 found_inf 数据。
        found_infs = [
            found_inf.to(device=_scale.device, non_blocking=True)
            for state in self._per_optimizer_states.values()
            for found_inf in state["found_inf_per_device"].values()
        ]

        found_inf_combined = found_infs[0]
        if len(found_infs) > 1:
            for i in range(1, len(found_infs)):
                found_inf_combined += found_infs[i]

        # 更新缩放因子。
        torch._amp_update_scale_(
            _scale,
            _growth_tracker,
            found_inf_combined,
            self._growth_factor,
            self._backoff_factor,
            self._growth_interval,
        )

总结

PyTorch 的动态缩放机制通过 CUDA 核函数和 Python 包装函数协作完成。其核心逻辑是:

  1. 检测数值不稳定(如 NaN 或 Inf),通过缩小缩放因子提高数值稳定性。
  2. 当连续多次未出现数值不稳定时,逐步增大缩放因子以充分利用 FP16 的动态范围。
  3. 所有更新操作都在 GPU 上异步完成,最大限度地减少同步开销。

通过动态调整缩放因子,AMP 有效地加速了深度学习模型的训练,同时避免了梯度溢出等数值问题。


推荐阅读

  • PyTorch 官方文档
  • 混合精度训练介绍

后记

2025年1月2日15点38分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2272080.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【python因果库实战16】双重稳健模型1

这里写目录标题 双重稳健模型数据简单双重稳健模型双重稳健 IP 特征模型 双重稳健模型 基本上&#xff0c;这些是利用加权模型增强结果模型的不同集合模型。 本笔记展示了不同的结果模型和倾向性模型组合方式&#xff0c; 但由于可能的组合非常多&#xff0c;本笔记并不打算展…

如何恢复已删除的 Telegram 消息 [iOSamp;Android]

Telegram 是一款功能强大的消息应用程序&#xff0c;因其易用性、隐私保护和众多炫酷功能而深受用户喜爱。然而&#xff0c;有时我们会不小心删除重要的消息。在这种情况下你应该做什么&#xff1f; 本文将为您提供简单有效的解决方案来恢复 Telegram 上已删除的消息&#xff…

第431场周赛:最长乘积等价子数组、计算字符串的镜像分数、收集连续 K 个袋子可以获得的最多硬币数量、不重叠区间的最大得分

Q1、最长乘积等价子数组 1、题目描述 给你一个由 正整数 组成的数组 nums。 如果一个数组 arr 满足 prod(arr) lcm(arr) * gcd(arr)&#xff0c;则称其为 乘积等价数组 &#xff0c;其中&#xff1a; prod(arr) 表示 arr 中所有元素的乘积。gcd(arr) 表示 arr 中所有元素的…

【微服务】2、网关

Spring Cloud微服务网关技术介绍 单体项目拆分微服务后的问题 服务地址问题&#xff1a;单体项目端口固定&#xff08;如黑马商城为8080&#xff09;&#xff0c;拆分微服务后端口各异&#xff08;如购物车808、商品8081、支付8086等&#xff09;且可能变化&#xff0c;前端难…

使用JMeter玩转tidb压测

作者&#xff1a; du拉松 原文来源&#xff1a; https://tidb.net/blog/3f1ada39 一、前言 tidb是mysql协议的&#xff0c;所以在使用过程中使用tidb的相关工具连接即可。因为jmeter是java开发的相关工具&#xff0c;直接使用mysql的jdbc驱动包即可。 二、linux下安装jmet…

2024网络安全运营方案概述(附实践资料合集)

以下是网络安全运营方案的详细内容&#xff1a; 一、目标与原则 目标&#xff1a;建立一套安全高效、灵活性强的网络安全运营体系&#xff0c;实现对网络安全的全面监控、防护和应急响应。原则&#xff1a; 全员参与&#xff1a;网络安全是全员共同的责任&#xff0c;所有员工…

使用Python进行图像裁剪和直方图分析

一、简介 在数字图像处理领域&#xff0c;裁剪和分析图像的直方图是两个非常基本且重要的操作。本文将通过一个简单的Python项目&#xff0c;展示如何使用skimage和matplotlib库来裁剪图像并分析其RGB通道的直方图。 二、环境准备 在开始之前&#xff0c;请确保你已经安装了以…

vue3-dom-diff算法

vue3diff算法 什么是vue3diff算法 Vue3中的diff算法是一种用于比较虚拟DOM树之间差异的算法&#xff0c;其目的是为了高效地更新真实DOM&#xff0c;减少不必要的重渲染 主要过程 整个过程主要分为以下五步 前置预处理后置预处理仅处理新增仅处理后置处理包含新增、卸载、…

【U8+】用友U8软件中,出入库流水输出excel的时候提示报表输出引擎错误。

【问题现象】 通过天联高级版客户端登录拥有U8后&#xff0c; 将出入库流水输出excel的时候&#xff0c;提示报表输出引擎错误。 进行报表输出时出现错误&#xff0c;错误信息&#xff1a;找不到“fd6eea8b-fb40-4ce4-8ab4-cddbd9462981.htm”。 如果您正试图从最近使用的文件列…

[SMARTFORMS] 创建样式模板

通过事务码SMARTFORMS创建样式模板 选择样式&#xff0c;自定义样式模板名称ZST_DEMO_2025 点击"创建"按钮&#xff0c;跳转至样式模板详情页面&#xff0c;我们可以在该页面上设置SMARTFORMS表单相关的样式 在段落样式处&#xff0c;右键选择创建节点&#xff0c;输…

基于51单片机和DS3231时钟模块、LCD1602(I2C通信)模块的可调时钟+温度测量+计时+闹钟

目录 系列文章目录前言一、效果展示二、原理分析三、各模块代码1、延时函数2、定时器03、定时器14、独立按键5、DS3231时钟模块6、LCD1602模块&#xff08;PCF8574T驱动&#xff09; 四、主函数总结 系列文章目录 前言 之前做过一个类似的&#xff0c;用到了很多外设&#xff…

通义视觉推理大模型QVQ-72B-preview重磅上线

Qwen团队推出了新成员QVQ-72B-preview&#xff0c;这是一个专注于提升视觉推理能力的实验性研究模型。提升了视觉表示的效率和准确性。它在多模态评测集如MMMU、MathVista和MathVision上表现出色&#xff0c;尤其在数学推理任务中取得了显著进步。尽管如此&#xff0c;该模型仍…

企业级Nosql数据库和Redis集群

一、关系数据库和Nosql数据库 关系数据库 定义&#xff1a;关系数据库是建立在关系模型基础上的数据库。它使用表格&#xff08;关系&#xff09;来存储数据&#xff0c;通过行和列的形式组织信息。例如&#xff0c;一个简单的学生信息表可能有 “学号”“姓名”“年龄”“班级…

Ant Design中Flex布局、Grid布局和Layout布局详解

好的&#xff0c;我们来更详细地探讨 Ant Design 中的 Flex布局、Grid布局 和 Layout布局 的特点、用法、适用场景&#xff0c;以及如何灵活运用它们来构建页面。下面将从各个方面进行更深入的分析&#xff0c;并提供具体的实例。 VueFlex布局实现响应式布局 1. Flex布局 概念…

基于FPGA的SNN脉冲神经网络之IM神经元verilog实现,包含testbench

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 vivado2019.2 3.部分核心程序 &#xff08;完整版代码包含详细中文注释和操作步骤视频&#xff0…

健身房管理系统多身份

本文结尾处获取源码。 本文结尾处获取源码。 本文结尾处获取源码。 一、相关技术 后端&#xff1a;Java、JavaWeb / Springboot。前端&#xff1a;Vue、HTML / CSS / Javascript 等。数据库&#xff1a;MySQL 二、相关软件&#xff08;列出的软件其一均可运行&#xff09; I…

三甲医院等级评审八维数据分析应用(四)--数据质量管理篇

一、引言 1.1 研究背景与意义 在医疗卫生领域,医院评审是衡量医院综合实力、保障医疗服务质量的重要手段。其中,三甲评审作为我国医院评审体系中的最高级别,对医院的管理、医疗技术、服务质量等各方面都设定了严格标准。医务科作为医院医疗质量管理的核心部门,肩负着协调…

Solidity合约编写(一)

Solidity IDE地址&#xff1a;Remix - Ethereum IDE 点击进入后在contract文件夹下创建合约 合约代码如下&#xff1a; // SPDX-License-Identifier: MIT pragma solidity ^0.8.26;contract SimpleStorage{bool hasFavorNumtrue;uint256 favorNum5;string favorNums"fiv…

嵌入式系统(将软件嵌入到硬件里面)

目录 Linux起源 查看操作系统的版本 查看内核的版本&#xff1a; 内核系统架构 系统关机或重启命令 关机&#xff1a; 重启&#xff1a; linux下的软件安装 两种软件包管理机制&#xff1a; deb软件包分为两种&#xff1a; 软件包的管理工具&#xff1a;dpkg apt 1…

会员制电商创新:开源 AI 智能名片与 2+1 链动模式的协同赋能

摘要&#xff1a;本文聚焦于电商领域会员制的关键作用&#xff0c;深入探讨在传统交易模式向数字化转型过程中&#xff0c;如何借助开源 AI 智能名片以及 21 链动模式商城小程序&#xff0c;实现对会员数据的精准挖掘与高效利用&#xff0c;进而提升企业的营销效能与客户洞察能…