pytorch06:权重初始化

news2024/9/30 1:32:11

在这里插入图片描述

目录

  • 一、梯度消失和梯度爆炸
    • 1.1相关概念
    • 1.2 代码实现
    • 1.3 实验结果
    • 1.4 方差计算
    • 1.5 标准差计算
    • 1.6 控制网络层输出标准差为1
    • 1.7 带有激活函数的权重初始化
  • 二、Xavier方法与Kaiming方法
    • 2.1 Xavier初始化
    • 2.2 Kaiming初始化
    • 2.3 常见的初始化方法
  • 三、nn.init.calculate_gain

一、梯度消失和梯度爆炸

1.1相关概念

一个简易三层全连接神经网络图和神经元计算如下:
在这里插入图片描述
观察第二个隐藏层的权值的梯度是如何求取的,根据链式法则,可以得到如下计算公式,会发现w2的梯度依赖上一层的输出值H1;
在这里插入图片描述
当H1趋近于0的时候,W2的梯度也趋近于0;—>梯度消失
当H1趋近于无穷的时候,W2的梯度也趋近于无穷;—>梯度爆炸
在这里插入图片描述
一旦出现梯度消失或者梯度爆炸就会导致模型无法训练;

1.2 代码实现

import os
import torch
import random
import numpy as np
import torch.nn as nn
from common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
            # x = torch.relu(x)
            # x = torch.tanh(x)

            print("layer:{}, std:{}".format(i, x.std()))  # 打印当前值的标准差
            if torch.isnan(x.std()):  # 判断是什么时候标准差为nan
                print("output is nan in {} layers".format(i))
                break

        return x

    # 权值初始化函数
    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):  # 判断当前网络层是否是线性层,如果是就进行权值初始化
                nn.init.normal_(m.weight.data)  # normal: mean=0, 控制标准差std在1左右
                # nn.init.normal_(m.weight.data, std=np.sqrt(1 / self.neural_num))

                # =======这段代码的目的是通过均匀分布初始化并结合tanh激活函数的特性,为神经网络的某一层(线性层)初始化合适的权重
                # a = np.sqrt(6 / (self.neural_num + self.neural_num))
                # tanh_gain = nn.init.calculate_gain('tanh')
                # a *= tanh_gain
                # nn.init.uniform_(m.weight.data, -a, a)
                # 将权重矩阵的值初始化为在 [-a, a] 范围内均匀分布的随机数。这个范围是通过之前的计算和调整得到的,目的是使得权重初始化在一个合适的范围内

                # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

                # ================凯明初始化方法================
                # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
                # nn.init.kaiming_normal_(m.weight.data)


# flag = 0
flag = 1

if flag:
    layer_nums = 100  # 100层线性层
    neural_nums = 256  # 每增加一层网络 标准差扩大根号n倍
    batch_size = 16

    net = MLP(neural_nums, layer_nums)
    print(net)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

    output = net(inputs)
    print(output)

1.3 实验结果

这里的初始化使用的是标准正态分布normal: mean=0, 控制标准差std在1左右的方法;
在这里插入图片描述
当输出层达到33层后就会出现梯度爆炸,超出了数据精度可以表示的范围。

1.4 方差计算

在这里插入图片描述
1.期望的计算公式
2,3.是方差的计算公式
根据1,2,3,可以得出,x,y的方差计算公式,当x,y的期望值都为0的时候,x,y的方差等于x的方差乘以y的方差。

1.5 标准差计算

在这里插入图片描述
通过计算可以得出每增加一层网络,标准差增加 n \sqrt{n} n ,n也就是神经元的个数;
代码展示:

if flag:
    layer_nums = 100  # 100层线性层
    neural_nums = 256  # 神经元个数 每增加一层网络 标准差扩大根号n倍
    batch_size = 16

执行结果:
可以看出第一层标准差是15.95,第二次标准差在上一层的基础上再乘以 256 \sqrt{256} 256
在这里插入图片描述

1.6 控制网络层输出标准差为1

从1.5可以看出D(H)的大小有三个因素决定,分别是n、D(X)、D(w),所以只要保证这三者乘积为1,就可以保证D(H)的值为1;
在这里插入图片描述
当我们权值的标准差为 1 / n \sqrt{1/n} 1/n ,那么就能保证网络层每一层的输出标准差都为1;

代码实现:
在这里插入图片描述

输出结果:
在这里插入图片描述
通过输出结果可以发现,几乎每一层网络输出的标准差都为1.

1.7 带有激活函数的权重初始化

在forward函数里面添加tanh激活函数
在这里插入图片描述
执行结果:
增加tanh激活函数之后,随着网络层的增加,标准差越来越小,从而会导致梯度消失的现象,下面将说明Xavier方法与Kaiming方法是如何解决该问题。
在这里插入图片描述

二、Xavier方法与Kaiming方法

2.1 Xavier初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:饱和函数,如Sigmoid,Tanh
Xavier初始化公式如下:
在这里插入图片描述

代码实现:
手动代码实现
在这里插入图片描述

直接使用pytorch提供的xavier_uniform_函数方法

nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

执行结果:
在这里插入图片描述
可以看到,每一层的网络输出标准差都在0.6左右

2.2 Kaiming初始化

当我们使用带有权值初始化的relu激活函数时,输出结果如下,会发现标准差随着网络层的增加逐渐减小,Kaiming初始化解决了这一问题。
在这里插入图片描述
在这里插入图片描述

方差一致性:保持数据尺度维持在恰当范围,通常方差为1
激活函数:ReLU及其变种
公式如下:
在这里插入图片描述

代码实现:

# ================凯明初始化方法================
nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))  # 适合relu激活函数初始化 凯明初始化手动计算方法
# nn.init.kaiming_normal_(m.weight.data)  # 使用pytorch自带方法

输出结果:
在这里插入图片描述

2.3 常见的初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化

三、nn.init.calculate_gain

主要功能:计算激活函数的方差变化尺度(也就是输入数据的方差/经过激活函数之后的方差)
主要参数
• nonlinearity: 激活函数名称
• param: 激活函数的参数,如Leaky ReLU的negative_slop

代码实现:

flag = 1

if flag:
    x = torch.randn(10000)
    out = torch.tanh(x)

    gain = x.std() / out.std()  # 手动计算
    print('gain:{}'.format(gain))

    tanh_gain = nn.init.calculate_gain('tanh')  # pytorch自带函数
    print('tanh_gain in PyTorch:', tanh_gain)

输出结果:
在这里插入图片描述
总结:任何数据在经过tanh激活函数之后,方差缩小大约1.6倍。感兴趣的话也可以使用relu进行实验,最后我的到的结果方差尺度大约是1.4左右。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1363257.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多线程高级知识点

多线程高级知识点 1.ThreadLocal 1.1 什么是 ThreadLocal? ​ ThreadLocal 叫做本地线程变量,意思是说,ThreadLocal 中填充的的是当前线程的变量,该变量对其他线程而言是封闭且隔离的,ThreadLocal 为变量在每个线程…

高性能NVMe Host Controller IP

NVMe Host Controller IP 介绍 NVMe Host Controller IP可以连接高速存储PCIe SSD,无需CPU和外部存储器,自动加速处理所有的NVMe协议命令,具备独立的数据写入AXI4-Stream/FIFO接口和数据读取AXI4-Stream/FIFO接口,非常适合于超高…

插槽slot涉及到的样式污染问题

1. 前言 本次我们主要结合一些案例研究一下vue的插槽中样式污染问题。在这篇文章中&#xff0c;我们主要关注以下两点: 父组件的样式是否会影响子组件的样式&#xff1f;子组件的样式是否会影响父组件定义的插槽部分的样式&#xff1f; 2. 准备代码 2.1 父组件代码 <te…

mysql基础-数据操作之增删改

目录 1.新增数据 1.1单条数据新增 1.2多条数据新增 1.3查询数据新增 2.更新 2.1单值更新 2.2多值更新 2.3批量更新 2.3.1 批量-单条件更新 2.3.2批量-多条件更新 2.4 插入或更新 2.5 联表更新 3.删除 本次分享一下数据库的DML操作语言。 操作表的数据结构&#xf…

《计算机科学中的建模技术》复习点

0 考试题型 题型&#xff1a;选择、填空、大题&#xff08;综合题&#xff09; 分值&#xff1a;选择填空30分&#xff0c;综合70分 填空&#xff1a;基本概念题 第 1 章&#xff1a;计算机科学基本问题与数学建模概要 1.1 科学计算的基本概念 科学计算是指利用计算机来完成…

Transformer架构和对照代码详解

1、英文架构图 下面图中展示了Transformer的英文架构&#xff0c;英文架构中的模块名称和具体代码一一对应&#xff0c;方便大家对照代码、理解和使用。 2、编码器 2.1 编码器介绍 从宏观⻆度来看&#xff0c;Transformer的编码器是由多个相同的层叠加⽽ 成的&#xff0c;每个…

【数据结构】二叉树的概念及堆

前言 我们已经学过了顺序表、链表、栈和队列这些属于线性结构的数据结构&#xff0c;那么下面我们就要学习我们第一个非线性结构&#xff0c;非线性结构又有哪些值得我们使用的呢&#xff1f;那么接下来我们就将谈谈树的概念了。 1.树的概念与结构 1.1树的概念 树是一种非线性…

2.C++的编译:命令行、makefile和CMake

1. 命令行编译 命令行编译是指直接在命令行中输入以下指令&#xff1a; 预处理&#xff1a;gcc -E main.c -o main.i 编译&#xff1a;gcc -S main.i -o main.s 汇编&#xff1a;gcc -c main.s -o main.o 链接&#xff1a;gcc main.o -o main 命令汇总&#xff1a;gcc main.c …

【ZooKeeper高手实战】ZAB协议:ZooKeeper分布式一致性的基石

&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308;&#x1f308; 欢迎关注公众号&#xff08;通过文章导读关注&#xff1a;【11来了】&#xff09;&#xff0c;及时收到 AI 前沿项目工具及新技术 的推送 发送 资料 可领取 深入理…

第17课 为rtsp流加入移动检测功能

在上节课&#xff0c;我们成功拿到了rtsp视频和音频流&#xff0c;在第13课&#xff0c;我们为普通的usb摄像头加上了移动检测功能&#xff0c;那能不能给rtsp摄像头也加上移动检测功能以实现一些好玩的应用呢&#xff1f;答案是肯定的&#xff0c;在usb摄像头检测中&#xff0…

BetaFlight开源代码之电压校准

BetaFlight开源代码之电压校准 1. 源由2. 分析数据流3. 采样电路3. 原理4. 示例5. 实测&转换数据6. 参考资料 1. 源由 既然复杂的BetaFlight开源代码之电流校准都过了一遍&#xff0c;电压相对来说是比较简单的&#xff0c;一起过一下 2. 分析数据流 电源路径1》采样电路…

【Spring实战】24 使用 Spring Boot Admin 管理和监控应用

文章目录 1. 定义2. 使用场景3. 主要功能4. 示例1&#xff09;[服务端] 添加依赖2&#xff09;[服务端] 相关配置3&#xff09;[服务端] 启动类4&#xff09;[服务端] 启动服务5&#xff09;[服务端] 浏览器访问6&#xff09;[客户端] 添加依赖7&#xff09;[客户端] 相关配置8…

双变量probit模型

1. Probit模型 1.1 模型含义 假设个体只有两种选择&#xff0c;y1或y0。影响选择的变量都包括在向量x中。即线性概率模型为 y值服从两点分布 被认为是连接函数&#xff0c;函数选择具有一定的灵活性。如果为标准正态的累积分布函数&#xff0c;则模型成为Probit模型&#xff…

网络嗅探器的设计与实现(2024)-转载

1.题目描述 参照 raw socket 编程例子&#xff0c;设计一个可以监视网络的状态、数据流动情况以及网络上传输 的信息的网络嗅探器。 2.运行结果 3.导入程序需要的库 请参考下面链接: 导入WinPcap到Clion (2024)-CSDN博客 4.参考代码 #define HAVE_REMOTE #define LINE_LEN …

【数据库原理】(11)SQL数据查询功能

基本格式 SELECT [ALL|DISTINCT]<目标列表达式>[,目标列表达式>]... FROM <表名或视图名>[,<表名或视图名>] ... [ WHERE <条件表达式>] [GROUP BY<列名 1>[HAVING <条件表达式>]] [ORDER BY <列名 2>[ASC DESC]];SELECT: 指定要…

WinForms中的UI卡死

WinForms中的UI卡死 WinForms中的UI卡死通常是由于长时间运行的操作阻塞了UI线程所导致的。在UI线程上执行的操作&#xff0c;例如数据访问、计算、文件读写等&#xff0c;如果耗时较长&#xff0c;会使得UI界面失去响应&#xff0c;甚至出现卡死的情况。 解决方法 为了避免…

061:vue中通过map修改一维数组,增加一些变量

第061个 查看专栏目录: VUE ------ element UI 专栏目标 在vue和element UI联合技术栈的操控下&#xff0c;本专栏提供行之有效的源代码示例和信息点介绍&#xff0c;做到灵活运用。 &#xff08;1&#xff09;提供vue2的一些基本操作&#xff1a;安装、引用&#xff0c;模板使…

系列二、GitHub中的Alpha、Beta、RC、GA、Release等各个版本

一、GitHub中的Alpha、Beta、RC、GA 1.1、概述 1.2、参考 https://www.cnblogs.com/huzhengyu/p/13905129.html

Qt——TCP UDP网络编程

目录 前言正文一、TCP二、UDP1、基本流程2、必备知识 三、代码层级1、UDP服务端 END、总结的知识与问题1、如何获取QByteArray中某一字节的数据&#xff0c;并将其转为十进制&#xff1f;2、如何以本年本月本日为基础&#xff0c;获取时间戳&#xff0c;而不以1970为基础&#…

Ps 滤镜:高反差保留

Ps菜单&#xff1a;滤镜/其它/高反差保留 Filter/Others/High Pass 高反差保留 High Pass滤镜常用于锐化、保护纹理、提取线条等图像编辑工作流程中。它的工作原理是&#xff1a;只保留显示图像中的高频信息&#xff08;即图像中的细节和边缘区域&#xff09;&#xff0c;而图像…