深度学习模型数值稳定性——梯度衰减和梯度爆炸的说明

news2024/11/17 16:40:19

文章目录

      • 0. 前言
      • 1. 为什么会出现梯度衰减和梯度爆炸?
      • 2. 如何提高数值稳定性?
        • 2.1 随机初始化模型参数
        • 2.2 梯度裁剪(Gradient Clipping)
        • 2.3 正则化
        • 2.4 Batch Normalization
        • 2.5 LSTM?Short Cut!

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文的主旨是说明深度学习网络模型中关于数值稳定性的常见问题:梯度衰减(vanishing)和爆炸(explosion),以及常见的解决方法。

本文的部分内容、观点及配图借鉴了多伦多大学计算机科学学院讲座——Lecture 15: Exploding and Vanishing Gradients内容,以及Dive into deep learning第3.15章节《数值稳定性和模型初始化》。

1. 为什么会出现梯度衰减和梯度爆炸?

用下面简化的全连接神经元网络讲解,这个全连接神经元网络每层只有一个神经元,可以看作是一串神经元连接而成的网络。

在这里插入图片描述

在前向传播中,由于数值的传递需要经过非线性的激活函数 σ ( ) \sigma() σ()(例如Sigmoid、Tanh函数),其数值大小被限制住了,因此前向传播一般不存在数值稳定性的问题

在反向传播中,例如求解输出 y y y对权重 w 1 w_1 w1的偏导为:
∂ y ∂ w 1 = σ ′ ( z n ) w n ⋅ σ ′ ( z n − 1 ) w n − 1 ⋅ ⋅ ⋅ σ ′ ( z 1 ) x \frac{\partial y}{\partial w_1}=\sigma'(z_n)w_n · \sigma'(z_{n-1})w_{n-1} ··· \sigma'(z_{1})x w1y=σ(zn)wnσ(zn1)wn1⋅⋅⋅σ(z1)x
z n = { w n ⋅ h n − 1 + b n , n > 1 w 1 ⋅ x + b 1 , n = 1 z_n= \left \{\begin{array}{cc} w_n·h_{n-1}+b_n, & n>1\\ w_1·x+b_1, & n=1 \end{array} \right. zn={wnhn1+bn,w1x+b1,n>1n=1
这里就可以看出,如果权重 w n w_n wn的初始选择不合理,或者 w n w_n wn在逐渐优化过程中,出现导致 σ ′ ( z n ) w n \sigma'(z_n)w_n σ(zn)wn大部分或全部大于1或者小于1的情况,且网络足够深,就会导致反向传播的偏导出现数值不稳定——梯度衰减或者梯度爆炸。

再简化点理解,假设 σ ′ ( z n ) w n = 0.8 \sigma'(z_n)w_n=0.8 σ(zn)wn=0.8,有50层网络深度, 0. 8 50 = 0.000014 0.8^{50}=0.000014 0.850=0.000014;假设 σ ′ ( z n ) w n = 1.2 \sigma'(z_n)w_n=1.2 σ(zn)wn=1.2,有50层网络深度, 1. 2 50 = 9100 1.2^{50}=9100 1.250=9100

参考Lecture 15: Exploding and Vanishing Gradients的另一种解释数值稳定性的方法是:深度学习网络类似于非线性方程的迭代使用,例如 f ( x ) = 3.5 x ( 1 − x ) f(x)=3.5x(1-x) f(x)=3.5x(1x)经过多次迭代 y = f ( f ( ⋅ ⋅ ⋅ f ( x ) ) ) y=f(f(···f(x))) y=f(f(⋅⋅⋅f(x)))后的情况如下图:
在这里插入图片描述
可见,非线性函数再经历多次迭代后会呈现复杂且混沌的表现,在这个实例中仅经历6次迭代后就出现了偏导很大的情况(对应梯度爆炸)。

我们也应该注意到经历6次迭代后也出现了 ∂ y ∂ x ≈ 0 \frac{\partial y}{\partial x}≈0 xy0的区域(对应梯度衰减)。

2. 如何提高数值稳定性?

2.1 随机初始化模型参数

这是最简单、最常用的对抗梯度衰减和梯度爆炸的方法。上文已经说明: σ ′ ( z n ) w n \sigma'(z_n)w_n σ(zn)wn大部分或全部大于1或者小于1的情况,且网络足够深,就容易发生数值不稳定的情况。如果随机初始化模型参数,就会很大程度上避免因为 w n w_n wn的初始选择不合理导致的梯度衰减或爆炸。

Xavier随机初始化是一种常用的方法:假设某隐藏层输入个数为 a a a,输出个数为 b b b,Xavier随机初始化会将该层中的权重参数随机采样于 ( − 6 a + b , 6 a + b ) (-\sqrt{\frac{6}{a+b}},\sqrt{\frac{6}{a+b}}) (a+b6 ,a+b6 )

2.2 梯度裁剪(Gradient Clipping)

这是一种人为限制梯度过大或过小的方法,其思路是给原本的梯度 g g g加上一个系数,在 g g g的绝对值过大时对其进行缩小,反之亦然。这个系数为:
η ∣ ∣ g ∣ ∣ \frac{\eta}{||g||} ∣∣g∣∣η

其中 η \eta η为超参数, ∣ ∣ g ∣ ∣ ||g|| ∣∣g∣∣为梯度的二范数。

增加这个系数后虽然会导致这个结果并非是真正的损失函数对于权重的偏导数,但是能够维持数值稳定性。

2.3 正则化

这是一种抑制梯度爆炸的方法。我之前介绍过正则化方法:基于PyTorch实战权重衰减——L2范数正则化方法(附代码),其思想是在损失函数中增加权重的范数作为惩罚项:
l o s s = 1 n Σ ( y − y ^ ) 2 + λ 2 n ∣ ∣ w ∣ ∣ 2 loss = \dfrac{1}{n} \Sigma (y - \widehat{y})^2+ \dfrac{\lambda}{2n}||w||^2 loss=n1Σ(yy )2+2nλ∣∣w2
在深度学习模型不断地迭代(学习)过程中, l o s s loss loss越来越小导致权重的范数也越来越小,也就抑制了梯度爆炸。

2.4 Batch Normalization

Batch Normalization(批标准化)是基于Normalization(归一化)增加scaling和shifting的一种数据标准化处理方式,其具体作用原理可以参考:关于Batch Normalization的说明。

Batch Normalization能维持数值稳定性的基本原理与梯度裁剪类似:都是对数值人为增加缩放,维持数值保持在一个不大不小的合理范围内。两者的区别是梯度裁剪在反向传播过程中直接作用于损失函数对权重的偏导数;而Batch Normalization在正向传播中对某层的输出进行标准化处理,间接维持对权重偏导的稳定性。

这里需要指出的是:由于输入 x x x也参与了偏导的计算,如果 x x x是一个高维向量,那对于输入 x x x的Batch Normalization处理也是必要的。

2.5 LSTM?Short Cut!

很多文章说明LSTM(长短周期记忆)网络有助于维持数值稳定性,我最初看到这些文章时大为不解——因为我们是需要通用的方法来改进提高现有模型的数值稳定性,而不是直接替换成LSTM网络模型,况且LSTM也不是万能的深度学习模型,不可能遇到梯度衰减或者梯度爆炸就把模型替换成LSTM。

如果不知道LSTM是什么可以看下:LSTM(长短期记忆)网络的算法介绍及数学推导

后来我看到Lecture 15: Exploding and Vanishing Gradients明白了其中的误解:这篇文章通篇都在用RNN为例来说明数值稳定性。对于RNN来说,LSTM确实是一个改进的模型,因为其内部维持“长期记忆”的“门”结构确实有助于提升数值稳定性。

我想大部分把LSTM单列出来说明可以提升数值稳定性的文章都误会了。

而Short Cut这种结构才是提升数值稳定性的普适规则,LSTM仅是改善RNN的一个特例而已。
在这里插入图片描述

Short Cut的具体作用机理可以参考He Kaiming的原文:Deep Residual Learning for Image Recognition

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/933185.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

读取SD卡图片bin文件显示LCD上

读取SD卡bin文件显示图片 Coding 环境搭建: 硬件平台:STM32H750XBH6开发环境:STM32CubeMX V6.8.1KEIL V5.28.0.0STM32H750固件版本:package V1.11.0仿真下载驱动:ST-Link 前言:STM32H750XBH6 的flash只…

零基础学习正演的数值模拟(含代码)

摘要: 本贴从零开始学习正演的数值模拟方法. 包括相应的偏微分基础、声波方程、雷克子波、均匀速度场的模拟、一般速度场的模拟. 1. 偏微分基础 本小节仅涉及高等数学相关知识, 与领域无关. 1.1 导数 引例: 物体从一维坐标的原点开始移动, 在 t t t 时刻, 它在坐标轴的位置…

汤普森采样(Thompson sampling): 理论支持

目录 一、UCB与TS算法数学原理1、Upper Confidence Bounds 数学原理2、Thompson sampling 数学原理a、TS 基本数据原理1. beta 分布2. 共轭分布与共轭先验3. 采样的编程实现 b、TS 算法流程1. TS算法基础版本2. Batched Thompson Sampling 二、UCB与TS算法的优缺点1、TS算法的优…

Ubuntu释放VMware虚拟磁盘未使用空间

By: Ailson Jack Date: 2023.08.26 个人博客:http://www.only2fire.com/ 本文在我博客的地址是:http://www.only2fire.com/archives/152.html,排版更好,便于学习,也可以去我博客逛逛,兴许有你想要的内容呢。…

基于Java+SpringBoot+Vue前后端分离医院后台管理系统设计和实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

Spring为什么要专门定义BeanDefinition ,有Class不行吗?

前言 创建一个Java Bean,大概是下面这个流程: 我们写的Java文件,会编译为Class文件,运行程序,类加载器会加载Class文件,放入JVM的方法区,我们就可以愉快的new对象了。 创建一个Spring Bean&am…

项目总结知识点记录(二)

1.拦截器实现验证用户是否登录: 拦截器类:实现HandlerInterception package com.yx.interceptor;import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView;import javax.servlet.http.HttpS…

Mybatis+MybatisPlus拦截器实战之数据的加解密和脱敏

文章目录 一、前言二、拦截器简介三、代码目录结构简介四、核心代码讲解4.1 application.yml文件4.2 自定义注解4.2.1 SensitiveEntity4.2.2 SensitiveData4.2.3 MaskedEntity4.2.4 MaskedField4.2.5 MaskedMethod 4.3 Mybatis-Plus 拦截器数据自动加密4.4 Mybatis 打印完整sql…

7年经验之谈 —— 如何实现高效的Web自动化测试?

随着互联网的快速发展,Web应用程序的重要性也日益凸显。为了保证Web应用程序的质量和稳定性,Web自动化测试成为必不可少的一环。然而,如何实现高效的Web自动化测试却是一个值得探讨的课题。 首先,选择合适的测试工具是关键。市面…

低通滤波器和高通滤波器

应用于图像低通滤波器和高通滤波器的实现 需要用到傅里叶变换 #include <opencv2/opencv.hpp> #include <Eigen> #include <iostream> #include <vector> #include <cmath> #include <complex>#define M_PI 3.14159265358979323846…

五、多表查询-3.4连接查询-联合查询union

一、概述 二、演示 【例】将薪资低于5000的员工&#xff0c;和 年龄大于50岁的 员工全部查询出来 1、查询薪资低于5000的员工 2、查询年龄大于50岁的员工 3、将薪资低于5000的员工&#xff0c;和 年龄大于50岁的 员工全部查询出来&#xff08;把上面两部分的结果集直接合并起…

最新敏感信息和目录收集技术

敏感信息和目录收集 目标域名可能存在较多的敏感目录和文件&#xff0c;这些敏感信息很可能存在目录穿越漏洞、文件上传漏洞&#xff0c;攻击者能通过这些漏洞直接下载网站源码。搜集这些信息对之后的渗透环节有帮助。通常&#xff0c;扫描检测方法有手动搜寻和自动工具查找两…

requestAnimationFrame(RAF)

1、RAF介绍 requestAnimateFrame&#xff08;RAF&#xff09;动画帧&#xff0c;是一个做动画的API。 如果想要一个动画流畅&#xff0c;就需要以60帧/s来更新视图&#xff0c;那么一次视图的更新就是16.67ms。 想要达到上述目标&#xff0c;可以通过setTimeout定时器来手动控…

JSON文件教程之【jsoncpp源码编译】

目录 1 数据下载(jsoncpp源码)2 文件编译内容: JSON文件的读取与保存可以使用jsoncpp库来实现,这里介绍该库的下载及编译方法。 1 数据下载(jsoncpp源码) 数据下载:Github地址 图1 github源码示意图 2 文件编译 2.1 点击Download ZIP,下载源码。 图2 压缩包数据 2.2 将压…

在 macOS 中安装 TensorFlow 1g

tensorflow 需要多大空间 pip install tensorflow pip install tensorflow Looking in indexes: https://pypi.douban.com/simple/ Collecting tensorflowDownloading https://pypi.doubanio.com/packages/1a/c1/9c14df0625836af8ba6628585c6d3c3bf8f1e1101cafa2435eb28a7764…

2022年06月 C/C++(四级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题&#xff1a;公共子序列 我们称序列Z < z1, z2, …, zk >是序列X < x1, x2, …, xm >的子序列当且仅当存在 严格上升 的序列< i1, i2, …, ik >&#xff0c;使得对j 1, 2, … ,k, 有xij zj。比如Z < a, b, f, c > 是X < a, b, c, f, b, …

软考:中级软件设计师:关系代数:中级软件设计师:关系代数,规范化理论函数依赖,它的价值和用途,键,范式,模式分解

软考&#xff1a;中级软件设计师:关系代数 提示&#xff1a;系列被面试官问的问题&#xff0c;我自己当时不会&#xff0c;所以下来自己复盘一下&#xff0c;认真学习和总结&#xff0c;以应对未来更多的可能性 关于互联网大厂的笔试面试&#xff0c;都是需要细心准备的 &…

一篇文章带你彻底了解Java常用的设计模式

文章目录 前言1. 工厂模式使用示例代码优势 2. 单例模式说明使用示例代码优势 3. 原型模式使用示例代码优势 4. 适配器模式使用示例代码优势 5. 观察者模式使用示例代码优势 6. 策略模式使用示例代码优势 7. 装饰者模式使用示例代码优势 8. 模板方法模式使用示例代码优势 总结 …

python-数据可视化-下载数据-CSV文件格式

数据以两种常见格式存储&#xff1a;CSV和JSON CSV文件格式 comma-separated values import csv filename sitka_weather_07-2018_simple.csv with open(filename) as f:reader csv.reader(f)header_row next(reader)print(header_row) # [USW00025333, SITKA AIRPORT, A…

YOLO目标检测——皮肤检测数据集下载分享

数据集点击下载&#xff1a;YOLO皮肤检测数据集Face-Dataset.rar