PyTorch AMP 混合精度中grad_scaler.py的scale函数解析

news2025/1/5 12:34:28

PyTorch AMP 混合精度中的 scale 函数解析

混合精度训练(AMP, Automatic Mixed Precision)是深度学习中常用的技术,用于提升训练效率并减少显存占用。在 PyTorch 的 AMP 模块中,GradScaler 类负责动态调整和管理损失缩放因子,以解决 FP16 运算中的数值精度问题。而 scale 函数是 GradScaler 的一个重要方法,用于将输出的张量按当前缩放因子进行缩放。

本文将详细解析 scale 函数的作用、代码逻辑,以及 apply_scale 子函数的递归作用。


函数代码回顾

以下是 scale 函数的完整代码:
Source: anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/grad_scaler.py

torch 2.4.0+cu121版本

def scale(
    self,
    outputs: Union[torch.Tensor, Iterable[torch.Tensor]],
) -> Union[torch.Tensor, Iterable[torch.Tensor]]:
    """
    Multiplies ('scales') a tensor or list of tensors by the scale factor.

    Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned
    unmodified.

    Args:
        outputs (Tensor or iterable of Tensors):  Outputs to scale.
    """
    if not self._enabled:
        return outputs

    # Short-circuit for the common case.
    if isinstance(outputs, torch.Tensor):
        if self._scale is None:
            self._lazy_init_scale_growth_tracker(outputs.device)
        assert self._scale is not None
        return outputs * self._scale.to(device=outputs.device, non_blocking=True)

    # Invoke the more complex machinery only if we're treating multiple outputs.
    stash: List[
        _MultiDeviceReplicator
    ] = []  # holds a reference that can be overwritten by apply_scale

    def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
        if isinstance(val, torch.Tensor):
            if len(stash) == 0:
                if self._scale is None:
                    self._lazy_init_scale_growth_tracker(val.device)
                assert self._scale is not None
                stash.append(_MultiDeviceReplicator(self._scale))
            return val * stash[0].get(val.device)
        if isinstance(val, abc.Iterable):
            iterable = map(apply_scale, val)
            if isinstance(val, (list, tuple)):
                return type(val)(iterable)
            return iterable
        raise ValueError("outputs must be a Tensor or an iterable of Tensors")

    return apply_scale(outputs)

1. 函数作用

scale 函数的主要作用是将输出张量(outputs)按当前的缩放因子(self._scale)进行缩放。它支持以下两种输入:

  1. 单个张量:直接将缩放因子乘以张量。
  2. 张量的可迭代对象(如列表或元组):递归地对每个张量进行缩放。

当 AMP 功能未启用时(即 self._enabledFalse),scale 函数会直接返回原始的 outputs,不执行任何缩放操作。

使用场景

  • 放大梯度:在反向传播之前,放大输出张量的数值,以减少数值舍入误差对 FP16 计算的影响。
  • 支持多设备:通过 _MultiDeviceReplicator 支持张量分布在多个设备(如多 GPU)的场景。

2. 核心代码解析

(1) 短路处理单个张量

当输入 outputs 是单个张量(torch.Tensor)时,函数直接对其进行缩放:

if isinstance(outputs, torch.Tensor):
    if self._scale is None:
        self._lazy_init_scale_growth_tracker(outputs.device)
    assert self._scale is not None
    return outputs * self._scale.to(device=outputs.device, non_blocking=True)
逻辑解析:
  1. 如果缩放因子 self._scale 尚未初始化,则调用 _lazy_init_scale_growth_tracker 方法在指定设备上初始化缩放因子。
  2. 使用 outputs * self._scale 对张量进行缩放。这里使用了 to(device=outputs.device) 确保缩放因子与张量在同一设备上。

这是单个张量输入的快速路径处理。


(2) 多张量递归处理逻辑

当输入为张量的可迭代对象(如列表或元组)时,函数调用子函数 apply_scale 进行递归缩放:

stash: List[_MultiDeviceReplicator] = []  # 用于存储缩放因子对象

def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
    if isinstance(val, torch.Tensor):
        if len(stash) == 0:
            if self._scale is None:
                self._lazy_init_scale_growth_tracker(val.device)
            assert self._scale is not None
            stash.append(_MultiDeviceReplicator(self._scale))
        return val * stash[0].get(val.device)
    if isinstance(val, abc.Iterable):
        iterable = map(apply_scale, val)
        if isinstance(val, (list, tuple)):
            return type(val)(iterable)
        return iterable
    raise ValueError("outputs must be a Tensor or an iterable of Tensors")

return apply_scale(outputs)
apply_scale 子函数的作用
  1. 张量处理

    • 如果 val 是单个张量,检查 stash 是否为空。
    • 如果为空,初始化缩放因子对象 _MultiDeviceReplicator,并存储在 stash 中。
    • 使用 stash[0].get(val.device) 获取对应设备上的缩放因子,并对张量进行缩放。
  2. 递归处理可迭代对象

    • 如果 val 是一个可迭代对象,调用 map(apply_scale, val),对其中的每个元素递归地调用 apply_scale
    • 如果输入是 listtuple,则保持其原始类型。
  3. 类型检查

    • 如果 val 既不是张量也不是可迭代对象,抛出错误。

3. apply_scale 是递归函数吗?

是的,apply_scale 是一个递归函数。

递归逻辑

  • 当输入为嵌套结构(如张量的列表或列表中的列表)时,apply_scale 会递归调用自身,将缩放因子应用到最底层的张量。
  • 递归的终止条件是 val 为单个张量(torch.Tensor)。
示例:

假设输入为嵌套张量列表:

outputs = [torch.tensor([1.0, 2.0]), [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]]
scaled_outputs = scaler.scale(outputs)

递归处理过程如下:

  1. outputs 调用 apply_scale

    • 第一个元素是张量 torch.tensor([1.0, 2.0]),直接缩放。
    • 第二个元素是列表,递归调用 apply_scale
  2. 进入嵌套列表 [torch.tensor([3.0]), torch.tensor([4.0, 5.0])]

    • 第一个元素是张量 torch.tensor([3.0]),缩放。
    • 第二个元素是张量 torch.tensor([4.0, 5.0]),缩放。

4. _MultiDeviceReplicator 的作用

_MultiDeviceReplicator 是一个工具类,用于在多设备场景下管理缩放因子对象的复用。它根据张量所在的设备返回正确的缩放因子。

  • 当张量分布在多个设备(如 GPU)时,_MultiDeviceReplicator 可以高效地为每个设备提供所需的缩放因子,避免重复初始化。

总结

scale 函数是 AMP 混合精度训练中用于梯度缩放的重要方法,其作用是将输出张量按当前缩放因子进行缩放。通过递归函数 apply_scale,该函数能够处理嵌套的张量结构,同时支持多设备场景。

关键点总结:

  1. 快速路径:单张量输入的情况下,直接进行缩放。
  2. 递归处理:对于张量的嵌套结构,递归地对每个张量进行缩放。
  3. 设备管理:通过 _MultiDeviceReplicator 支持多设备场景。

通过 scale 函数,PyTorch 的 AMP 模块能够高效地调整梯度数值范围,提升混合精度训练的稳定性和效率。

后记

2025年1月2日15点47分于上海,在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2270752.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java项目实战II基于微信小程序的家庭大厨(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、核心代码 五、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 在快节奏的生活中,家庭聚餐成为了连接亲情…

Github拉取项目报错解决

前言 昨天在拉取github上面的项目报错了,有好几个月没用github了,命令如下: git clone gitgithub.com:zhszstudy/git-test.git报错信息: ssh: connect to host github.com port 22: Connection timed out fatal: Could not rea…

学AI编程的Prompt工程,豆包Marscode

学习链接:Datawhale-AI活动https://www.datawhale.cn/activity/116/23/95?rankingPage1 目录 一、如何使用 二、编写游戏 2.1 创意输入与代码生成 2.2 项目初始化与应用 2.3 创意优化与迭代 三、效果展示 一、如何使用 建议在在vscode上安装marscode插件&a…

NLP CH3复习

CH3 3.1 几种损失函数 3.2 激活函数性质 3.3 哪几种激活函数会发生梯度消失 3.4 为什么会梯度消失 3.5 如何解决梯度消失和过拟合 3.6 梯度下降的区别 3.6.1 梯度下降(GD) 全批量:在每次迭代中使用全部数据来计算损失函数的梯度。计算成本…

计算机网络 (19)扩展的以太网

前言 以太网(Ethernet)是一种局域网(LAN)技术,它规定了包括物理层的连线、电子信号和介质访问层协议的内容。以太网技术不断演进,从最初的10Mbps到如今的10Gbps、25Gbps、40Gbps、100Gbps等,已成…

JavaVue-Get请求 数组参数(qs格式化前端数据)

前言 现在管理系统,像若依,表格查询一般会用Get请求,把页面的查询条件传递给后台。其中大部分页面会有日期时间范围查询这时候,为了解决请求参数中的数组文件,前台就会在请求前拦截参数中的日期数组数据,然…

.e01, ..., .e0n的分卷压缩包怎么解压

用BandiZip,这些分卷压缩中还有一个.exe的文件,这个不是可执行文件,是一个解压缩的开头。 安装好bandiZip后,右键这个.exe文件 点击打开就是开始解压了: 最后解压后是这些。然后一个个再次解压.

库伦值自动化功耗测试工具

1. 功能介绍 PlatformPower工具可以自动化测试不同场景的功耗电流,并可导出为excel文件便于测试结果分析查看。测试同时便于后续根据需求拓展其他自动化测试用例。 主要原理:基于文件节点 coulomb_count 实现,计算公式:电流&…

大模型 LangChain 开发框架:Runable 与 LCEL 初探

大模型 LangChain 开发框架:Runable 与 LCEL 初探 一、引言 在大模型开发领域,LangChain 作为一款强大的开发框架,为开发者提供了丰富的工具和功能。其中,Runnable 接口和 LangChain 表达式语言(LCEL)是构…

【Jboss/Windows】Tomcat 8 + JDK 8 升级为 Jboss eap 7 + JDK8

文章目录 下载Jboss eap 7安装包执行standalone.bat修改jdk8不兼容的一些内存空间参数查看端口是否被占用解决端口占用环境变量配置修改项目中的pom文件配置Jboos启动项本地localhost启动测试 更多相关内容可查看 下载Jboss eap 7安装包 Jboss EAP:JBoss Enterpris…

aardio —— 改变按钮文本颜色

import win.ui; /*DSG{{*/ var winform win.form(text"改变按钮颜色示例";right279;bottom239;composited1) winform.add( button{cls"button";text"点这里1";left16;top104;right261;bottom159;fontLOGFONT(h-14);z1}; button2{cls"butto…

Elasticsearch操作笔记版

文章目录 1.ES索引库操作(CRUD)1.mapping常见属性(前提)2.创建索引库3.查询,删除索引库4.修改索引库 2.ES文档操作(CRUD)1.新增文档2.查询、删除文档查询返回的数据解读: 3.修改文档 3.RestClient操作(索引库/文档)(CRUD)1.什么是RestClient2.需要考虑前…

【狂热算法篇】解锁数据潜能:探秘前沿 LIS 算法

嘿,各位编程爱好者们!今天带来的 LIS 算法简直太赞啦 无论你是刚入门的小白,还是经验丰富的大神,都能从这里找到算法的奇妙之处哦!这里不仅有清晰易懂的 C 代码实现,还有超详细的算法讲解,让你轻…

【漫话机器学习系列】033.决策树回归(Decision Tree Regression)

决策树回归(Decision Tree Regression) 决策树回归是一种基于树状结构进行回归分析的监督学习方法。它将输入空间递归地划分为多个区域,并在每个区域内拟合一个简单的常数值,从而对目标变量进行预测。 决策树回归的原理 树的构建…

Vue3中使用 Vue Flow 流程图方法

效果图: 最近项目开发时有一个流程图的功能,需要做流程节点的展示,就搜到了 Vue Flow 这个插件,这个插件总得来说还可以,简单已使用,下边就总结一下使用的方法: Vue Flow官网:https…

ArcGIS JSAPI 高级教程 - 通过RenderNode实现视频融合效果(不借助三方工具)

ArcGIS JSAPI 高级教程 - 通过RenderNode实现视频融合效果(不借助三方工具) 核心代码完整代码在线示例地球中展示视频可以通过替换纹理的方式实现,但是随着摄像头和无人机的流行,需要视频和场景深度融合,简单的实现方式则不能满足需求。 三维视频融合技术将视频资源与三维…

Appllo学习

补充学习: Apollo管理多环境下的配置和踩坑实践 - 简书 Apollo-阿波罗配置中心超详细教程_apllo-CSDN博客 springboot本地local配置覆盖远程Apollo配置(含Apollo配置加载顺序说明)_本地覆盖apollo配置-CSDN博客 Apollo 配置中心详细教程 - 简书 (包含…

React18路由和Vue3路由进行对比

本文将深入比较 React 18 和 Vue 3 路由的不同之处,帮助你更好地理解如何在这两个框架中进行路由管理。希望能对于从 Vue 3 迁移到 React 的开发者,理解这些差异,帮助你更高效地切换框架和构建应用。 1. 路由配置 React 18 的路由配置 Rea…

Windows系统下载、部署Node.js与npm环境的方法

本文介绍在Windows电脑中,下载、安装并配置Node.js环境与npm包管理工具的方法。 Node.js是一个基于Chrome V8引擎的JavaScript运行时环境,其允许开发者使用JavaScript编写命令行工具和服务器端脚本。而npm(Node Package Manager)则…

浏览器选中文字样式

效果 学习 Chrome: 支持 ::selection。Firefox: 支持 :-moz-selection 和 ::selection。Safari: 支持 ::selection。Internet Explorer: 支持 :-ms-selection。Microsoft Edge: 支持 ::-ms-selection 和 ::selection。 代码 <!DOCTYPE html> <html lang"en&qu…