【Pytorch】梯度裁剪——torch.nn.utils.clip_grad_norm_的原理及计算过程

news2025/1/13 19:47:18

文章目录

  • 一、torch.nn.utils.clip_grad_norm_
  • 二、计算过程
  • 三、确定max_norm


众所周知,梯度裁剪是为了防止梯度爆炸。在训练FCOS算法时,因为训练过程出现了损失为NaN的情况,在github issue有很多都是这种训练过程出现loss为NaN,作者也提出要调整梯度裁剪的超参数,于是理了理梯度裁剪函数torch.nn.utils.clip_grad_norm_ 的计算过程,方便调参。


一、torch.nn.utils.clip_grad_norm_

torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type) ,这个梯度裁剪函数一般来说只需要调整max_normnorm_type这两个参数。
parameters参数是需要进行梯度裁剪的参数列表。通常是模型的参数列表,即model.parameters()
max_norm参数可以理解为梯度(默认是L2 范数)范数的最大阈值
norm_type参数可以理解为指定范数的类型,比如norm_type=1 表示使用L1 范数,norm_type=2 表示使用L2 范数。
同时,torch.nn.utils.clip_grad_norm_torch.nn.utils.clip_grad_norm(该函数已被弃用)的区别就是前者是直接修改原Tensor,后者不会(在Pytorch中有很多这样的函数对均是如此,在函数最后多了下划线一般都是表示直接在原Tensor上进行操作)。

import torch

# 构造两个Tensor
x = torch.tensor([99.0, 108.0], requires_grad=True)
y = torch.tensor([45.0, 75.0], requires_grad=True)

# 模拟网络计算过程
z = x ** 2 + y ** 3
z = z.sum()

# 反向传播
z.backward()

# 得到梯度
print(f"gradient of x is:{x.grad}") 
print(f"gradient of y is:{y.grad}") 

# 梯度裁剪
torch.nn.utils.clip_grad_norm_([x, y], max_norm=100, norm_type=2)


# 再次打印裁剪后的梯度
# 直接修改了原x.grad的值
print("---clip_grad---")
print(f"clip_grad of x is:{x.grad}") 
print(f"clip_grad of y is:{y.grad}") 


# 输出如下
"""
gradient of x is:tensor([198., 216.])
gradient of y is:tensor([ 6075., 16875.])
---clip_grad---
clip_grad of x is:tensor([1.1038, 1.2042])
clip_grad of y is:tensor([33.8674, 94.0762])
"""

上例中可以看出,裁剪后的梯度远小于原来的梯度。一开始变量x的梯度是tensor([198., 216.]),这个很好计算,就是求zx的偏导,也就是2*x 。变量y同理。裁剪后的梯度远小于原来的梯度,所以可以缓解梯度爆炸的问题。

二、计算过程

梯度裁剪的计算过程参考源码是不难的,
SOURCE CODE FOR TORCH.NN.UTILS.CLIP_GRAD
结合代码转换成数学公式,计算过程如下:
第一步:依然以上面的代码为例,构造Tensor反向传播,得到参数xy 的梯度,也就是torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type) 中的parameters参数。
在这里插入图片描述

import torch

# 构造两个Tensor
x = torch.tensor([99.0, 108.0], requires_grad=True)
y = torch.tensor([45.0, 75.0], requires_grad=True)

# 模拟网络计算过程
z = x ** 2 + y ** 3
z = z.sum()

# 反向传播
z.backward()

# 得到梯度
print(f"gradient of x is:{x.grad}") 
print(f"gradient of y is:{y.grad}") 

# 输出
"""
gradient of x is:tensor([198., 216.])
gradient of y is:tensor([ 6075., 16875.])
"""

第二步:计算每个变量梯度的L2 范数(以L2 范数为例)
在这里插入图片描述

这段代码的意思就是定义一个空列表norms ,用来存储每个参数梯度的L2 范数。随后利用torch.stacknorms中的所有Tensor(计算所得的L2 范数)合并成一个Tensor,最后又再求合并后Tensor的L2 范数得到total_norm(总范数)。

norm_type是inf即无穷范数时,total_norm会直接取参数梯度最大的那一个

# 相当于把x_L2norm、y_L2norm放入代码中的norms空列表

# x的梯度的L2 范数
x_L2norm = torch.sum(x.grad ** 2) ** 0.5
# y的梯度的L2 范数
y_L2norm = torch.sum(y.grad ** 2) ** 0.5

# 相当于遍历norms列表合并成一个Tensor
total_norm = torch.sum(torch.stack([x_L2norm, y_L2norm]) ** 2) ** 0.5

"""
等价过程
x_L2norm = sum([198 ** 2, 216 ** 2]) ** 0.5
y_L2norm = sum([6075 ** 2, 16875 ** 2]) ** 0.5
total_norm = sum([x_L2norm ** 2, y_L2norm ** 2]) ** 0.5
"""

第三步:计算梯度裁剪系数
在这里插入图片描述

# 1e-6防止分母为0
# clip_coef = max_norm / (total_norm + 1e-6)
max_norm = 100
clip_coef = max_norm / total_norm 

第四步:将原始梯度乘以梯度裁剪系数得到裁剪后的梯度,这与函数计算的结果是一致的。
在这里插入图片描述

print(f"clip_grad of x is: is {x.grad * clip_coef }")
print(f"clip_grad of x is: is {y.grad * clip_coef }")

# 输出
"""
clip_grad of x is: is tensor([1.1038, 1.2042])
clip_grad of x is: is tensor([33.8674, 94.0762])
"""

整合一下代码:

import torch

# 构造两个Tensor
x = torch.tensor([99.0, 108.0], requires_grad=True)
y = torch.tensor([45.0, 75.0], requires_grad=True)

# 模拟网络计算过程
z = x ** 2 + y ** 3
z = z.sum()

# 反向传播
z.backward()

# 得到梯度
print(f"gradient of x is:{x.grad}") 
print(f"gradient of y is:{y.grad}") 

x_L2norm = torch.sum(x.grad ** 2) ** 0.5
y_L2norm = torch.sum(y.grad ** 2) ** 0.5
total_norm = torch.sum(torch.stack([x_L2norm, y_L2norm]) ** 2) ** 0.5

max_norm = 100
clip_coef = max_norm / total_norm 

print(f"clip_grad of x is: is {x.grad * clip_coef }")
print(f"clip_grad of x is: is {y.grad * clip_coef }")

# 输出如下
"""
gradient of x is:tensor([198., 216.])
gradient of y is:tensor([ 6075., 16875.])
clip_grad of x is:tensor([1.1038, 1.2042])
clip_grad of y is:tensor([33.8674, 94.0762])
"""

三、确定max_norm

根据上述计算过程,梯度裁剪最主要的就是计算出裁剪系数得出裁剪后的梯度。clip_coef = max_norm / total_norm 公式中,clip_coef 越小,裁剪的梯度越大。
max_norm越小,裁剪的梯度越大,得到的梯度就越小,防止梯度爆炸的效果越明显。

在训练模型时,我们可以根据total_norm的值大概确定max_norm的一个取值范围。调用torch.nn.utils.clip_grad_norm_([x, y], max_norm=100, norm_type=2)函数时,该函数会返回total_norm的值。比如最开始参数x、y,这时的total_norm17937.5879,值非常大,那么为了防止梯度爆炸我们就可以把max_norm设置得稍微小一些。

# 以最开始的x、y为例
total_norm = torch.nn.utils.clip_grad_norm_([x, y], max_norm=100, norm_type=2)
print(total_norm)

#输出
# tensor(17937.5879) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/687328.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RISCV Reader笔记_3 RISCV汇编

RISC-V 汇编语言 函数调用的步骤在计算机组成与设计中也有过涉及: 指定寄存器存入参数;跳转到函数开始位置(jal);在callee中按需保存寄存器;执行函数;恢复保存的寄存器;把返回值存入…

使用传统图像处理算法+机器学习进行shadow detection

前言 阴影是图像中常见的现象,它们对于场景理解和分析非常重要。由于阴影区域通常比较暗淡,而且与周围物体区别较大,因此在图像处理和计算机视觉领域中,阴影检测是一个重要的研究方向。传统的阴影检测算法通常基于阈值或边缘检测…

深入理解 kernel panic 的流程

我们在项目开发过程中,很多时候会出现由于某种原因经常会导致手机系统死机重启的情况(重启分Android重启跟kernel重启,而我们这里只讨论kernel重启也就是 kernel panic 的情况),死机重启基本算是影响最严重的系统问题了…

180_Power BI 新卡片图计算组与同环比应用

180_Power BI 新卡片图计算组与同环比应用 一、背景 在 2023 年 6 月,Power BI 更新了新的视觉对象:Card(new) 。 当前还需要在预览功能中将其打开。 我们在实际的应用中将新卡片图做了一些应用,先来看看具体效果。 Power BI 公共 web 效果…

安全区域内活动UWB标签,高精度UWB定位监测,室内厘米级测距应用

随着人们对于室内安全和定位需求的增加,相应的技术应运而生,超宽带(UWB)标签定位技术应用于室内定位领域,并获得了快速的发展和应用。 UWB技术是一种基于极窄脉冲的无线技术,它的主要特点是无载波&#xf…

软件测试技能,JMeter压力测试教程,setUp线程组批量登录(九)

前言 前面一篇已经实现了在 setUp 线程组实现单个用户先登录后提取token给其它线程组使用,在压测的时候,单个用户登录很显然不能满足我们的压测需求 我们在压测接口的时候,需批量获取多个用户登录后返回的token值,那么在setUp 线…

RabbitMQ消息队列高级特性

文章目录 1.消息的可靠投递2.ConSumer ACK消费者确认接收消息3.消费者限流4.TTL过期时间5.死信队列6.延迟队列7.日志与监控8.消息追踪 1.消息的可靠投递 在线上生产环境中,RabbitMQ可能会产生消息丢失或者是投递失败的一个场景,RabbitMQ为了避免这种场景…

Redis慢查询分析

慢查询分析 谓慢查询日志就是系统在命令执行前后计算每条命令的执行时间,当超过预设阀值,就将这条命令的相关信息(例如:发生时间,耗时,命令的详细信息)记录下来。 执行一条命令分为如下4个部分…

【每日算法 数据结构(C++)】—— 05 | 判断单链表是否有环(解题思路、流程图、代码片段)

文章目录 01 | 👑 题目描述02 | 🔋 解题思路03 | 🧢 代码片段 The future belongs to those who believe in the beauty of their dreams. 未来属于那些相信梦想之美的人 01 | 👑 题目描述 给你一个单链表,请判断其中是…

Mac(M1)上安装Ubuntu虚拟机

Mac(M1)上安装Ubuntu虚拟机 0.下载资料汇总 VMware Fusionhttps://www.vmware.com/products/fusion/fusion-evaluation.htmlubuntu-desktop-arm64.isohttps://cdimage.ubuntu.com/jammy/daily-live/current/ 1.安装VMware Mac版本的VMware叫 VMware …

黑马程序员前端 Vue3 小兔鲜电商项目——(十一)支付页

文章目录 基础数据渲染封装接口数据渲染 支付功能实现支付携带参数支付宝沙箱账号信息 支付结果页展示模版代码绑定路由渲染数据 倒计时逻辑函数封装 支付页有俩个关键数据,一个是要支付的钱数,一个是倒计时数据(超时不支付商品释放)。 基础…

Tomcat项目更新Tomcat版本,重新配置conf,并在Idea运行项目,服务器替换SSL证书

Tomcat项目更新Tomcat版本,重新配置conf,并在Idea运行项目 1.下载Tomcat包2.Tomcat配置-ssi配置3.Idea运行 Tomcat 项目4.服务器Tomcat替换SSL证书4. Tomcat 项目重启 1.下载Tomcat包 Tomcat 官网 - https://tomcat.apache.org/ jdk1.8.0_191 我选择的…

如何在电脑、手机《酷游链接》录制屏幕?一看就会!也有剪辑录制视频的方法哦!

最近,我的生活中出现了许多需要录制电脑屏幕的场景! 『酷游链kw9㍠N͜E͜T提供』娜娜友善提醒,要自己输入才会显示出来!比如会议,教学等场景。这些场景我们可以通过Windows10的内建软体来解决。另外,也出现…

Python小白应该怎么学习字典

1.Python 字典 字典 字典是一个无序、可变和有索引的集合。在 Python 中,字典用花括号编写,拥有键和值。 例子:创建并打印字典 thisdict {"brand": "Porsche","model": "911","year": 1963 } pr…

MUR20100DC-ASEMI快恢复二极管MUR20100DC

编辑-Z MUR20100DC在TO-263封装里采用的2个芯片,其尺寸都是102MIL,是一款共阴极快恢复对管。MUR20100DC的浪涌电流Ifsm为200A,漏电流(Ir)为10uA,其工作时耐温度范围为-55~150摄氏度。MUR20100DC采用抗冲击硅芯片材质,…

技术管理第二板斧建团队-沟通

一、沟通的核心原则 我认为,沟通是内心想法和思考逻辑的外延,如果你有良好的沟通能力,可以在整个团队中营造公开透明的信任氛围,让信息透明的同时,也让团队成员愿意发出自己的声音。 但实际情况中,很多人…

TypeScript 中对【函数类型】的约束使用解读

概述 函数是JavaScript 中的 一等公民 概念:函数类型的概念是指给函数添加类型注解,本质上就是给函数的参数和返回值添加类型约束 声明式函数: 在 TypeScript 中,一个函数有输入和输出,需要对其进行约束,需要把输入和…

电力载波远程控制系统

随着电力技术的不断发展,电力载波远程控制系统成为了现代电力系统中的重要组成部分。电力载波远程控制系统是一种利用电力载波技术实现远程控制的系统,可以对电力系统中的各种设备进行实时监测、控制和管理,提高电力系统的安全性、可靠性和效…

Efficient Video Transformers with Spatial-Temporal Token Selection阅读笔记

摘要 Video Transformers在主要视频识别基准测试中取得了令人印象深刻的结果,但其计算成本很高。 在本文中,我们提出了 STTS,这是一种令牌选择框架,它根据输入视频样本在时间和空间维度上动态选择一些信息丰富的令牌。 具体来说&…

Qt/C++编写视频监控系统78-视频推流到流媒体服务器

一、前言 视频推流作为独立的模块,目前并没有集成到视频监控系统中,目前是可以搭配监控系统一起使用,一般是将添加好的摄像头通道视频流地址打开后,读取视频流重新推到流媒体服务器,然后第三方可以从流媒体服务器拉取…