梯度裁剪:torch.nn.utils.clip_grad_norm_详解

news2024/11/26 19:42:59

梯度裁剪是为了防止梯度爆炸。在训练FCOS算法时,因为训练过程出现了损失为NaN的情况,在github issue有很多都是这种训练过程出现loss为NaN,使用torch.nn.utils.clip_grad_norm_梯度裁剪函数,可以有效预防梯度爆炸的情况发生。

1 clip_grad_norm_介绍

1.1 函数原型

def clip_grad_norm_(
        parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
        error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
  • parameters:需要进行梯度裁剪的参数列表。通常是模型的参数列表,即model.parameters()
  • max_norm:可以理解为梯度(默认是L2 范数)范数的最大阈值
  • norm_type:可以理解为指定范数的类型,比如norm_type=1 表示使用L1 范数,norm_type=2 表示使用L2 范数。

这个梯度裁剪函数一般来说只需要调整max_normnorm_type这两个参数。clip_grad_norm_最后就是对所有的梯度乘以一个clip_coef,而且乘的前提是clip_coef一定是小于1的,所以,按照这个情况:clip_grad_norm只解决梯度爆炸问题,不解决梯度消失问题

torch.nn.utils.clip_grad_norm_和torch.nn.utils.clip_grad_norm(已弃用)的区别就是前者是直接修改原Tensor,后者不会。在Pytorch中有很多这样的函数对均是如此,在函数最后多了下划线一般都是表示直接在原Tensor上进行操作。

1.2 参数的选择

clip_coef的公式为:

max_norm的取值:

假定忽略clip_coef > 1的情况,则可以根据公式推断出:

  • clip_coef越小,则对梯度的裁剪越厉害,即,使梯度的值缩小的越多
  • max_norm越小,clip_coef越小,所以,max_norm越大,对于梯度爆炸的解决越柔和,max_norm越小,对梯度爆炸的解决越狠

    max_norm可以取小数

total_norm受梯度大小和norm_type的影响:

  • 梯度越大,total_norm值越大,进而导致clip_coef的值越小,最终也会导致对梯度的裁剪越厉害,很合理
  • norm_type不管取多少,对于total_norm的影响不是太大,所以可以直接取默认值2
  • norm_type越大,total_norm越小

1.3 函数执行的操作

  • 对所有需要进行梯度计算的参数,收集所有参数的梯度的指定范数(通过参数norm_type进行设置,1表示绝对值,2表示二阶范数也就是平方和开根号)
  • 计算所有参数的梯度范数总和(一个标量)和设定的max_norm的比值。如果max_norm/total_norm>1, 所有参数的梯度不变,可以直接反向传播。如果比值小于1,说明参数梯度需要被缩减,缩减比率为rate= max_norm/total_norm,所有反向传播的梯度变为原本的rate倍。

这样就避免权重梯度爆炸导致模型训练困难,对于大梯度的缩小,小梯度的不变。

1.4 梯度裁剪存在的问题

  • 参数原本的分布很不均匀,有的梯度大有的梯度小;
  • 而梯度的总体范数值大于阈值,那么所有的梯度都会被同比例缩小。

2 clip_grad_norm_使用

import torch

# 构造两个Tensor
x = torch.tensor([102.0, 155.0], requires_grad=True)
y = torch.tensor([201.0, 221.0], requires_grad=True)

# 模拟网络计算过程
z = x ** 3 + y ** 4
z = z.sum()

# 反向传播
z.backward()

# 得到梯度
print(f"gradient of x is:{x.grad}")
print(f"gradient of y is:{y.grad}")

# 梯度裁剪
torch.nn.utils.clip_grad_norm_([x, y], max_norm=200, norm_type=2)


# 再次打印裁剪后的梯度
# 直接修改了原x.grad的值
print("---clip_grad---")
print(f"clip_grad of x is:{x.grad}")
print(f"clip_grad of y is:{y.grad}")

运行结果显示如下:

gradient of x is:tensor([31212., 72075.])
gradient of y is:tensor([32482404., 43175444.])
---clip_grad---
clip_grad of x is:tensor([0.1155, 0.2668])
clip_grad of y is:tensor([120.2386, 159.8205])

上例中可以看出,裁剪后的梯度远小于原来的梯度。一开始变量x的梯度是tensor([31212., 72075.]),就是求zx的偏导,变量y同理。裁剪后的梯度远小于原来的梯度,所以可以缓解梯度爆炸的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1070586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OpenCV4(C++)—— 图像阈值化

文章目录 前言一、固定阈值化 —— threshold二、自适应阈值化 —— adaptiveThreshold三、LUT查找表 前言 图像阈值化在许多计算机视觉和图像处理任务中都是一个重要的预处理步骤。在边缘检测过程中,通过将图像转换为二值图像,可以突出图像中的边缘信息…

Qt:多语言支持,构建全面应用程序“

Qt:强大API、简化框架、多语言支持,构建全面应用程序" 强大的API:Qt提供了丰富的API,包括250多个C类,基于模板的集合、序列化、文件操作、IO设备、目录管理、日期/时间等功能。还包括正则表达式处理和支持2D/3D…

制作长图海报的详细指南,制作长图海报的5个步骤

制作长图海报是宣传活动、产品或服务的重要方式之一。乔拓云后台提供了丰富的海报模板,让你轻松制作出专业级的长图海报。下面将介绍如何使用乔拓云后台制作长图海报的技巧。 一、选择模板 首先,注册并登录乔拓云后台,进入云设计页面。在选择…

C语言学生成绩录入系统

一、系统概述 该系统是一个由链表创建主菜单的框架,旨在快速创建学生成绩录入系统的主菜单结构。其主要任务包括: 实现链表的创建、插入和遍历功能,用于存储和展示学生成绩录入系统各个模块的菜单项。 2. 提供用户友好的主菜单界面&#xf…

证件照换底色详细教程

说到证件照的底色更改,我想对大部分朋友来说是蛮头疼的事情,由于我们不论是在生活还是学习中,有时候总会要上传一些证件照,而当你手上有证件照准备上传时,发现底色不对,是不是很抓狂,现在&#…

SpringCloud学习笔记-Eureka的服务拉取

假设是OrderService里面拉取Eureka的服务之一User Service 1.依然需要在该服务里面引入依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-netflix-eureka-client</artifactId> </dependenc…

水波纹文字效果动画

效果展示 CSS 知识点 text-shadow 属性绘制立体文字clip-path 属性来绘制水波纹 工具网站 CSS clip-path maker 效果编辑器 页面整体结构实现 使用多个 H2 标签来实现水波纹的效果实现&#xff0c;然后使用clip-path结合动画属性一起来进行波浪的起伏动画实现。 <div …

分析性质+dp计数:1007T4

http://cplusoj.com/d/senior/p/SS231007D 分析题目性质&#xff0c;有&#xff1a; 按编号顺序最短路必然为连续段边只会在连续段内和相邻连续段之间连 i i i 段 连 i 1 i1 i1 段&#xff0c; i 1 i1 i1 段中每个点恰有1条来自 i i i 的边 然后肯定是考虑 f ( l , r )…

Edge浏览器下载文件被保存为 .crdownload 文件的问题小记

问题 近期使用Edge浏览器下载文件时&#xff0c;文件都被保存为 .crdownload 格式的文件了&#xff0c;不确定从哪个版本开始的。除非下载未完成导致文件不完整&#xff0c;否则不会被保存为 .crdownload 格式的文件&#xff1b;实际上文件已完成了下载&#xff0c;且手工修改…

Day4:Linux系统编程1-60P

我的学习方法是&#xff1a;Linux系统编程&#xff08;看pdf笔记&#xff09; Linux网络编程 WebServer 01P-17P Linux相关命令及操作 cp -a dirname1 dirname2 复制目录 cp -r dirname1 dirname2 递归复制目录 1 到目录 2 这里-a 和-r 的差别在于&#xff0c;-a 是完全复制…

深入了解 GMP

视频链接地址:Golang深入理解GPM模型_哔哩哔哩_bilibili 一、Golang“调度器”的由来? (1) 单进程时代不需要调度器 我们知道,一切的软件都是跑在操作系统上,真正用来干活(计算)的是CPU。早期的操作系统每个程序就是一个进程,直到一个程序运行完,才能进行下一个进程,就是…

大数据Doris(七):Doris安装与部署规划

文章目录 Doris安装与部署规划 一、软硬件需求 二、​​​​​​​资源规划

[Java] 服务端消息推送汇总

前言&#xff1a;当构建实时消息推送功能时&#xff0c;选择适合的方案对于开发高效的实时应用至关重要。消息的推送无非就推、拉两种数据模型。本文将介绍四种常见的消息实时推送方案&#xff1a;短轮询&#xff08;拉&#xff09;、长轮训&#xff08;拉&#xff09;、SSE&am…

详解C语言指针(二)

文章目录 1. 字符指针2. 指针数组3. 数组指针3.1 什么是数组指针&#xff1f;3.2 &数组名 VS 数组名 4. 数组参数4.1 一维数组传参4.2 二维数组传参 5. 函数指针6. 函数指针数组7. 指向函数指针数组的指针8. 回调函数 1. 字符指针 字符指针是指针类型的变量&#xff0c;其…

文本自动输入/删除的加载动画效果

效果展示 CSS 知识点 绕矩形四周跑的光柱动画实现animation 属性的 steps 属性值运用 页面基础结构实现 <div class"loader"><!-- span 标签是围绕矩形四周的光柱 --><span></span><span></span><span></span>&l…

Git 学习笔记 | 使用码云

Git 学习笔记 | 使用码云 Git 学习笔记 | 使用码云注册登录码云&#xff0c;完善个人信息设置本机绑定SSH公钥&#xff0c;实现免密码登录创建远程仓库 Git 学习笔记 | 使用码云 注册登录码云&#xff0c;完善个人信息 网址&#xff1a;https://gitee.com/ 可以使用微信&…

SpringBoot结合dev-tool 实现IDEA项目热部署

什么是热部署&#xff1f; 应用正在运行的时候升级功能, 不需要重新启动应用对于Java应用程序来说, 热部署就是在运行时更新Java类文件 通俗的来讲&#xff0c;应用在运行状态下&#xff0c;修改项目源码后&#xff0c;不用重启应用&#xff0c;会把编译的内容部署到服务器上…

【Acwing1010】拦截导弹(LIS+贪心)题解

题目描述 思路分析 本题有两问&#xff0c;第一问直接用lis的模板即可&#xff0c;下面重点看第二问 思路是贪心&#xff1a; 贪心流程&#xff1a; 从前往后扫描每一个数&#xff0c;对于每个数&#xff1a; 情况一&#xff1a;如果现有的子序列的结尾都小于当前的数&…

stm32的GPIO寄存器操作以及GPIO外部中断,串口中断

一、学习参考资料 &#xff08;1&#xff09;正点原子的寄存器源码。 &#xff08;2&#xff09;STM32F103最小系统板开发指南-寄存器版本_V1.1&#xff08;正点&#xff09; &#xff08;3&#xff09;STM32F103最小系统板开发指南-库函数版本_V1.1&#xff08;正点&a…

【重拾C语言】七、指针(一)指针与变量、指针操作、指向指针的指针

目录 前言 七、指针 7.1 指针与变量 7.1.1 指针类型和指针变量 7.1.2 指针所指变量 7.1.3 空指针、无效指针 7.2 指针操作 7.2.1 指针的算术运算 7.2.2 指针的比较 7.2.3 指针的递增和递减 7.3 指向指针的指针 前言 指针是C语言中一个重要的概念正确灵活运用指针 可…