PyTorch: 权值初始化

news2024/9/22 1:41:34

文章和代码已经归档至【Github仓库:https://github.com/timerring/dive-into-AI 】或者公众号【AIShareLab】回复 pytorch教程 也可获取。

文章目录

  • Pytorch:权值初始化
    • 梯度消失与梯度爆炸
  • Xavier 方法与 Kaiming 方法
    • Xavier 方法
    • nn.init.calculate_gain()
    • Kaiming 方法
  • 常用初始化方法

Pytorch:权值初始化

在搭建好网络模型之后,首先需要对网络模型中的权值进行初始化。权值初始化的作用有很多,通常,一个好的权值初始化将会加快模型的收敛,而比较差的权值初始化将会引发梯度爆炸或者梯度消失。下面将具体解释其中的原因:

梯度消失与梯度爆炸

考虑一个 3 层的全连接网络。

H 1 = X × W 1 H_{1}=X \times W_{1} H1=X×W1 H 2 = H 1 × W 2 H_{2}=H_{1} \times W_{2} H2=H1×W2 O u t = H 2 × W 3 Out=H_{2} \times W_{3} Out=H2×W3,如下图所示,

image-20221128213547257

其中第 2 层的权重梯度如下:

H 2 = H 1 ∗   W 2 Δ W 2 = ∂  Loss  ∂ W 2 = ∂ L o s s ∂  out  ∗ ∂  out  ∂ H 2 ∗ ∂ H 2 ∂ W 2 = ∂  Loss  ∂  out  ∗ ∂  out  ∂ H 2 ∗ H 1 \begin{array}{l} \mathrm{H}_{2}=\mathrm{H}_{1} * \mathrm{~W}_{\mathbf{2}} \\ \Delta \mathrm{W}_{\mathbf{2}}=\frac{\partial \text { Loss }}{\partial \mathrm{W}_{2}}=\frac{\partial \mathrm{Loss}}{\partial \text { out }} * \frac{\partial \text { out }}{\partial \mathrm{H}_{2}} * \frac{\partial \mathrm{H}_{2}}{\partial \mathrm{W}_{2}} \\ =\frac{\partial \text { Loss }}{\partial \text { out }} * \frac{\partial \text { out }}{\partial \mathrm{H}_{2}} * \mathrm{H}_{1} \\ \end{array} H2=H1 W2ΔW2=W2 Loss = out LossH2 out W2H2= out  Loss H2 out H1

由上式化简可知,如果H_1发生以下变化,那么对应的梯度也就会发生变化:

  • 梯度消失: H 1 → 0 ⇒ Δ W 2 → 0 \mathrm{H}_{1} \rightarrow 0 \Rightarrow \Delta \mathrm{W}_{2} \rightarrow 0 H10ΔW20
  • 梯度爆炸: $\mathrm{H}{1} \rightarrow \infty \Rightarrow \Delta \mathrm{W}{2} \rightarrow \infty $

因此,为了避免以上两种情况,就必须严格控制网络层输出的数值范围。

具体可以通过构建 100 层全连接网络,先不使用非线性激活函数,每层的权重初始化为服从 N ( 0 , 1 ) N(0,1) N(0,1) 的正态分布,输出数据使用随机初始化的数据,这样的例子来直观地感受影响:

import torch
import torch.nn as nn
from common_tools import set_seed

set_seed(1)  # 设置随机种子

class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
        return x

    def initialize(self):
        for m in self.modules():
            # 判断这一层是否为线性层,如果为线性层则初始化权值
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data)    # normal: mean=0, std=1
# 网络的层数
layer_nums = 100
# 神经元的个数
neural_nums = 256
batch_size = 16

net = MLP(neural_nums, layer_nums)
net.initialize()
# 设置随机初始化的输入
inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

output = net(inputs)
print(output)

输出为:

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<MmBackward>)

通过输出可知,输出值均为nan,即非数字类型,原因可能是数据太大(梯度爆炸)或者太小(梯度消失)。

为了具体知道是在哪一层开始出现nan的,我们可以在forward函数中添加判断得知,查看每一次前向转播的标准差是否是nan,若是,则停止前向传播并输出。

这里判断是否为nan时采用了 torch.isnan 函数

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)

            print("layer:{}, std:{}".format(i, x.std()))
            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

        return x

输出如下:

layer:0, std:15.959932327270508
layer:1, std:256.6237487792969
layer:2, std:4107.24560546875
.
.
.
layer:29, std:1.322983152787379e+36
layer:30, std:2.0786820453988485e+37
layer:31, std:nan
output is nan in 31 layers

可见,之际上输出的标准差是逐层递增的,具体为什么会导致这种情况:

  • E ( X × Y ) = E ( X ) × E ( Y ) E(X \times Y)=E(X) \times E(Y) E(X×Y)=E(X)×E(Y):两个相互独立的随机变量的乘积的期望等于它们的期望的乘积。
  • D ( X ) = E ( X 2 ) − [ E ( X ) ] 2 D(X)=E(X^{2}) - [E(X)]^{2} D(X)=E(X2)[E(X)]2:一个随机变量的方差等于它的平方的期望减去期望的平方
  • D ( X + Y ) = D ( X ) + D ( Y ) D(X+Y)=D(X)+D(Y) D(X+Y)=D(X)+D(Y):两个相互独立的随机变量之和的方差等于它们的方差的和。

可以推导出两个随机变量的乘积的方差如下:

D ( X × Y ) = E [ ( X Y ) 2 ] − [ E ( X Y ) ] 2 = D ( X ) × D ( Y ) + D ( X ) × [ E ( Y ) ] 2 + D ( Y ) × [ E ( X ) ] 2 D(X \times Y)=E[(XY)^{2}] - [E(XY)]^{2}=D(X) \times D(Y) + D(X) \times [E(Y)]^{2} + D(Y) \times [E(X)]^{2} D(X×Y)=E[(XY)2][E(XY)]2=D(X)×D(Y)+D(X)×[E(Y)]2+D(Y)×[E(X)]2

又由于输入变量是符合标准的正态分布的,因此 E ( X ) = 0 E(X)=0 E(X)=0 E ( Y ) = 0 E(Y)=0 E(Y)=0,可知 D ( X × Y ) = D ( X ) × D ( Y ) D(X \times Y)=D(X) \times D(Y) D(X×Y)=D(X)×D(Y)

我们以输入层第一个神经元为例:

H 11 = ∑ i = 0 n X i ∗ W 1 i \mathrm{H}_{11}=\sum_{i=0}^{n} X_{i} * W_{1 i} H11=i=0nXiW1i

其中输入 X 和权值 W 都是服从 N ( 0 , 1 ) N(0,1) N(0,1) 的正态分布,且由公式 D ( X × Y ) = D ( X ) × D ( Y ) D(X \times Y)=D(X) \times D(Y) D(X×Y)=D(X)×D(Y), 因此这个神经元的方差为:

D ( H 11 ) = ∑ i = 0 n D ( X i ) ∗ D ( W 1 i ) = n ∗ ( 1 ∗ 1 ) = n \begin{aligned} \mathbf{D}\left(\mathrm{H}_{11}\right) &=\sum_{i=0}^{n} D\left(X_{i}\right) * D\left(W_{1 i}\right) \\ &=n *(1 * 1) \\ &=n \end{aligned} D(H11)=i=0nD(Xi)D(W1i)=n(11)=n

可以求其标准差: std ⁡ ( H 11 ) = D ( H 11 ) = n \operatorname{std}\left(\mathrm{H}_{11}\right)=\sqrt{\mathrm{D}\left(\mathrm{H}_{11}\right)}=\sqrt{n} std(H11)=D(H11) =n

可见,经过第一层网络,方差就会扩大 n 倍,标准差就扩大 n \sqrt{n} n 倍,n 为每层神经元个数,直到超出数值表示范围。

从前面的输出中也可以看出来,n = 256,因此每一层的标准差输出都是16倍。再由公式可知,每一层网络输出的方差与神经元个数、输入数据的方差、权值方差有关(见上式),通过观察可知,比较好改变的是权值的方差 D ( W ) D(W) D(W),要控制每一层输出的方差仍然为1左右,因此需要 D ( W ) = 1 n D(W)= \frac{1}{n} D(W)=n1,可知标准差为 s t d ( W ) = 1 n std(W)=\sqrt\frac{1}{n} std(W)=n1 。因此修改权值初始化代码为nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))

再次输出时,结果如下:

layer:0, std:0.9974957704544067
layer:1, std:1.0024365186691284
layer:2, std:1.002745509147644
.
.
.
layer:94, std:1.031973123550415
layer:95, std:1.0413124561309814
layer:96, std:1.0817031860351562

修改之后,没有出现梯度消失或者梯度爆炸的情况,每层神经元输出的方差均在 1 左右。通过恰当的权值初始化,可以保持权值在更新过程中维持在一定范围之内

但是上述的实验前提为未使用非线性函数的前提下,如果在forward()中添加非线性变换例如tanh,每一层的输出方差会越来越小,会导致梯度消失。

为了解决这个问题,进一步有了著名的 Xavier 初始化与 Kaiming 初始化。

Xavier 方法与 Kaiming 方法

Xavier 方法

Xavier 是 2010 年提出的,针对有非线性激活函数时的权值初始化方法。

  • 目标是保持数据的方差维持在 1 左右
  • 针对饱和激活函数如 sigmoid 和 tanh 等。

同时考虑前向传播和反向传播,需要满足两个等式

n i ∗ D ( W ) = 1 n i + 1 ∗ D ( W ) = 1 \begin{array}{l} \boldsymbol{n}_{\boldsymbol{i}} * \boldsymbol{D}(\boldsymbol{W})=\mathbf{1} \\ \boldsymbol{n}_{\boldsymbol{i}+\mathbf{1}} * \boldsymbol{D}(\boldsymbol{W})=\mathbf{1} \\ \end{array} niD(W)=1ni+1D(W)=1

通过计算可知: D ( W ) = 2 n i + n i + 1 D(W)=\frac{2}{n_{i}+n_{i+1}} D(W)=ni+ni+12

为了使 Xavier 方法初始化的权值服从均匀分布,假设 W W W 服从均匀分布 U [ − a , a ] U[-a, a] U[a,a],那么方差 D ( W ) = ( − a − a ) 2 12 = ( 2 a ) 2 12 = a 2 3 D(W)=\frac{(-a-a)^{2}}{12}=\frac{(2 a){2}}{12}=\frac{a{2}}{3} D(W)=12(aa)2=12(2a)2=3a2,令 2 n i + n i + 1 = a 2 3 \frac{2}{n_{i}+n_{i+1}}=\frac{a^{2}}{3} ni+ni+12=3a2,解得: a = 6 n i + n i + 1 \boldsymbol{a}=\frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}} a=ni+ni+1 6 ,所以 W W W 服从分布 U [ − 6 n i + n i + 1 , 6 n i + n i + 1 ] U\left[-\frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}}, \frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}}\right] U[ni+ni+1 6 ,ni+ni+1 6 ]

所以初始化方法改为:

a = np.sqrt(6 / (self.neural_num + self.neural_num))
# 把 a 变换到 tanh,计算增益
tanh_gain = nn.init.calculate_gain('tanh')
a *= tanh_gain

nn.init.uniform_(m.weight.data, -a, a)

并且每一层的激活函数都使用 tanh,输出如下:

layer:0, std:0.7571136355400085
layer:1, std:0.6924336552619934
layer:2, std:0.6677976846694946
.
.
.
layer:97, std:0.6426210403442383
layer:98, std:0.6407480835914612
layer:99, std:0.6442216038703918

可以看到每层输出的方差都维持在 0.6 左右。

也可以直接调用PyTorch 中 Xavier 初始化方法:

tanh_gain = nn.init.calculate_gain('tanh')
nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

nn.init.calculate_gain()

这里重点介绍一下nn.init.calculate_gain(nonlinearity,param=**None**)方法。

主要功能是经过一个分布的方差经过激活函数后的变化尺度,主要有两个参数:

  • nonlinearity:激活函数名称
  • param:激活函数的参数,如 Leaky ReLU 的 negative_slop等等。

下面是计算标准差经过激活函数的变化尺度的代码。

x = torch.randn(10000)
out = torch.tanh(x)
# 计算变化尺度(也可以称为变化倍数)
gain = x.std() / out.std()
print('gain:{}'.format(gain))

tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in PyTorch:', tanh_gain)

输出如下:

gain:1.5982500314712524
tanh_gain in PyTorch: 1.6666666666666667

结果表示,原有数据分布的方差经过 tanh 之后,标准差会变小 1.6 倍左右。

Kaiming 方法

虽然 Xavier 方法提出了针对饱和激活函数的权值初始化方法,但是 AlexNet 出现后,大量网络开始使用非饱和的激活函数如 ReLU 等,这时 Xavier 方法不再适用。2015 年针对 ReLU 及其变种等激活函数提出了 Kaiming 初始化方法。

针对 ReLU,方差应该满足: D ( W ) = 2 n i \mathrm{D}(W)=\frac{2}{n_{i}} D(W)=ni2

针对 ReLu 的变种,方差应该满足: D ( W ) = 2 ( 1 + a 2 ) ∗ n i D(W)=\frac{2}{\left(1+\mathrm{a}^{2}\right) * n_{i}} D(W)=(1+a2)ni2,a 表示负半轴的斜率,如 PReLU 方法,标准差满足 std ⁡ ( W ) = 2 ( 1 + a 2 ) ∗ n i \operatorname{std}(W)=\sqrt{\frac{2}{\left(1+a^{2}\right) * n_{i}}} std(W)=(1+a2)ni2

代码如下:nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num)),或者使用 PyTorch 提供的初始化方法:nn.init.kaiming_normal_(m.weight.data)

常用初始化方法

PyTorch 中提供了 10 中初始化方法

  1. Xavier 均匀分布
  2. Xavier 正态分布
  3. Kaiming 均匀分布
  4. Kaiming 正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

综上, 常用初始化的目标就是要保证每一层输出的方差不能太大,也不能太小,维持在一个稳定的范围内。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/759537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spark3新特性

Spark AQE 自适应查询优化&#xff1a; 实现运行时优化&#xff0c;纠正因统计信息不准确导致生成的逻辑计划不完善或有误的问题 动态调整JOIN策略&#xff1a;类似于mapjoin优化&#xff0c;将sortMergejoin转换成broadcasthashjoin&#xff0c;也就是将小表当作广播变量分发…

基于深度学习的高精度线路板瑕疵目标检测系统(PyTorch+Pyside6+YOLOv5模型)

摘要&#xff1a;基于深度学习的高精度线路板瑕疵目标检测系统可用于日常生活中来检测与定位线路板瑕疵目标&#xff0c;利用深度学习算法可实现图片、视频、摄像头等方式的线路板瑕疵目标检测识别&#xff0c;另外支持结果可视化与图片或视频检测结果的导出。本系统采用YOLOv5…

SpringBoot+Vue实现的高校图书馆管理系统

项目描述&#xff1a;这是一个基于SpringBootVue框架开发的高校图书馆管理系统。首先&#xff0c;这是一个前后端分离的项目&#xff0c;代码简洁规范&#xff0c;注释说明详细&#xff0c;易于理解和学习。其次&#xff0c;这项目功能丰富&#xff0c;具有一个高校图书馆管理系…

外包软件定制开发中知识保护和安全性问题及解决方案

引言 外包软件定制开发在当今的商业环境中越来越普遍&#xff0c;它为企业提供了灵活性和成本效益。然而&#xff0c;与外包合作也带来了一些风险&#xff0c;其中最重要的就是知识保护和安全性问题。在外包软件定制开发过程中&#xff0c;共享敏感信息和知识产权是不可避免的…

redis基本操作

string数据类型的命令操作 设置键值 使用append 命令设置键值&#xff0c;后面跟键的名字&#xff0c;可以先判断该建是否存在&#xff0c;存在将值追加在后面&#xff0c;不存在自动添加该建 append mykey hello读取键值 get mykey数值类型自减1 数值类型自加1 查看值的…

记录C#知识点(二)21-40

目录 21.性能优化 22.动态dynamic使用 23.中文乱码 24.启动项目之前&#xff0c;执行文件 25.深拷贝-反射实现 26.丢弃运算符 _ 27.winform程序使用管理员运行 28.wpf程序使用管理员运行 21.性能优化 1.检查空字符串&#xff1a;使用string.Empty 2.更改类型转换&…

Java设计模式之行为型-访问者模式(UML类图+案例分析)

目录 一、基础概念 二、UML类图 三、角色设计 四、案例分析 五、总结 一、基础概念 访问者模式是一种对象行为型设计模式&#xff0c;它能够在不修改已有对象结构的前提下&#xff0c;为对象结构中的每个对象提供新的操作。 访问者模式的主要作用是把对元素对象的操作抽…

进程通信与信号

1.管道 匿名管道&#xff1a;匿名管道用于进程间通信&#xff0c;且仅限于本地父子进程之间的通信 管道符号 | 进程间通信的本质就是&#xff0c;让不同的进程看到同一份资源&#xff0c;使用匿名管道实现父子进程间通信的原理就是&#xff0c;让两个父子进程先看到同一份被打…

【云原生】Docker跨主机网络Overlay与Macvlan的区别

跨主机网络通信解决方案 docker原生的overlay和macvlan 第三方的flannel&#xff0c;weave&#xff0c;calico 1.overlay网络 在Docker中&#xff0c;Overlay网络是一种容器网络驱动程序&#xff0c;它允许在多个Docker主机上创建一个虚拟网络&#xff0c;使得容器可以通过这…

Python 最优传输工具箱(Python Optimal Transport)

最近在研究最优传输的相关理论&#xff0c;博主使用的是python编程语言&#xff0c;在这里给大家推荐一个Python最优传输工具箱&#xff1a;Python Optimal Transport&#xff08;pot)与geomloss 其中geomloss是针对pytorch张量的&#xff0c;ot是针对numpy数组的&#xff1b;g…

装饰器模式揭秘:我用装饰器给手机集成了ChatGPT

在平时的开发过程中&#xff0c;我们经常会遇到需要给一个类增加额外功能的需求&#xff0c;但又不想破坏类的原有结构。这时候&#xff0c;装饰器模式就能大显神威了&#xff01;接下来&#xff0c;我将带你深入了解装饰器模式的原理、优缺点、适用场景以及如何在实际开发中巧…

无法找到docker.sock

os环境&#xff1a;麒麟v10(申威) 问题描述&#xff1a; systemctl start docker 然后无法使用docker [rootnode2 ~]# systemctl restart docker [rootnode2 ~]# docker ps Cannot connect to the Docker daemon at unix:///var/run/docker.sock. Is the docker daemon r…

4、应用层https27

https协议加密流程&#xff1a;使用ssl加密。 一、HTTPS协议 对HTTP协议进行加密后的一个新的协议。 1、加密概念 单说数据加密过去狭义&#xff0c;更多的是防止数据被监听劫持。 加密包含俩个方面&#xff1a;身份验证&#xff0c;加密传输。 1.1身份验证 验证对端的身…

四、传播

文章目录 1、草药迷阵问题2、时序回溯搜索3、传播搜索THE END 1、草药迷阵问题 \qquad 有一个10*10的百草药柜&#xff0c;每一个抽屉里都有5种不同属性的草药&#xff0c;依次打开抽屉来长出草药迷阵&#xff0c;要求寻找一种神奇的药方&#xff0c;满足&#xff1a; 横行&am…

数据结构——C++无锁队列

数据结构——C无锁队列 贺志国 2023.7.11 上一篇博客给出了最简单的C数据结构——堆栈的几种无锁实现方法。队列的挑战与栈的有些不同&#xff0c;因为Push()和Pop()函数在队列中操作的不是同一个地方。因此同步的需求就不一样。需要保证对一端的修改是正确的&#xff0c;且对…

(中等)LeetCode 3. 无重复字符到的最长子串 Java

滑动窗口 以示例一为例&#xff0c;找出从每一个字符开始的&#xff0c;不包含重复字符的最长子串&#xff0c;那么&#xff0c;其中最长的那个字符串即为答案。 当我们一次递增地枚举子串的起始位置&#xff0c;会发现子串的结束位置也是递增的&#xff0c;原因在于&#xf…

Django项目创建

Django项目创建 文章目录 Django项目创建&#x1f468;‍&#x1f3eb;方式一&#xff1a;终端命令行方式&#x1f468;‍&#x1f52c;方式二&#xff1a;Pycharm创建 &#x1f468;‍&#x1f3eb;方式一&#xff1a;终端命令行方式 1️⃣cmd打开终端&#xff0c;切换到指定目…

WebSell管理工具--中国蚁剑安装教程以及初始化

简介&#xff1a;中国蚁剑是一款开源的跨平台WebShell网站管理工具 蚁剑的下载安装&#xff1a; GitHub项目地址&#xff1a;https://github.com/AntSwordProject/ Windows下载安装&#xff1a; 百度网盘下载链接&#xff1a;链接&#xff1a;https://pan.baidu.com/s/1A5wK…

超细整理,性能测试-性能指标监控命令详细实战,一篇速通

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 性能监控命令&…