几种常用的权重初始化方法

news2024/9/22 7:34:45

来源:投稿 作者:175
编辑:学姐

在深度学习中,权重的初始值非常重要,权重初始化方法甚至关系到模型能否收敛。本文主要介绍两种权重初始化方法。

为什么需要随机初始值

我们知道,神经网络一般在初始化权重时都是采用随机值。如果不用随机值,全部设成一样的值会发生什么呢?

极端情况,假设全部设成0。显然,如果某层的权重全部初始化为0,类似该层的神经元全部被丢弃(dropout)了,就不会有信息传播到下一层。

如果全部设成同样的非零值,那么在反向传播中,所有的权重都会进行相同的更新,权重被更新为相同的值,并拥有了对称(重复)的值。不管怎样进行迭代(sgd),都不会打破这种对称性,隐藏层好像只有一个神经元,我们无法实现神经网络的表达能力。只有我们前面介绍的Dropout可以打破这种对称性。

为了打破权重的对称结构,必须随机生成初始值。

隐藏层激活值的分布

观察隐藏层激活值的分布,可以获得一些启发。

这里通过一个实验来看权重初始值是如何影响隐藏层的激活值分布的。

向一个5层神经网络(激活函数使用sigmoid函数)传入随机生成的输入数据,用直方图绘制各层激活值的数据分布。

# coding: utf-8
import numpy as np
import matplotlib.pyplot as plt


def sigmoid(x):
    return 1 / (1 + np.exp(-x))
    
def ReLU(x):
    return np.maximum(0, x)
    
def tanh(x):
    return np.tanh(x)
    
input_data = np.random.randn(1000, 100)  # 1000个数据
node_num = 100  # 各隐藏层的节点(神经元)数
hidden_layer_size = 5  # 隐藏层有5层
activations = {}  # 激活值的结果保存在这里

x = input_data

for i in range(hidden_layer_size):
    if i != 0:
        x = activations[i-1]

    # 改变初始值进行实验!
    w = np.random.randn(node_num, node_num) * 1
    # w = np.random.randn(node_num, node_num) * 0.01
    # w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
    # w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)

    a = np.dot(x, w)
    # 将激活函数的种类也改变,来进行实验!
    z = sigmoid(a)
    # z = ReLU(a)
    # z = tanh(a)

    activations[i] = z

# 绘制直方图
for i, a in activations.items():
    plt.subplot(1, len(activations), i+1)
    plt.title(str(i+1) + "-layer")
    if i != 0: plt.yticks([], [])
    plt.hist(a.flatten(), 30, range=(0,1))
plt.show()

这里假设神经网络有5层,每层有100个单元。然后,用高斯分布随机生成1000个数据作为输入数据,并把它们传给5层神经网络。这里权重的初始化也通过均值为0方差为1的高斯分布。

各层的激活值呈偏向0和1的分布。这里使用的sigmoid函数是S型函数,随着输出不断地靠近0(或者靠近1,在S线的两端),它的梯度逐渐接近0。因此,偏向于0或1的数据分布会造成反向传播中梯度的值不断变小,最后消失。

我们知道,在L2正则化话会使权重参数变小,那么我们这里在初始参数的时候,直接设定一个较小的值会不会好一点。我们只要改下上面代码的27/28行。

# w = np.random.randn(node_num, node_num) * 1
w = np.random.randn(node_num, node_num) * 0.01

这次虽然没有偏向0和1,不会发生梯度消失的问题。但是激活值的分布有所偏向,这里集中于0.5附近。这样模型的表现力会大打折扣。

下面我们来了解比较常用的Xavier初始值和He初始值,看它们会对激活值的分布产生什么影响。

Xavier初始化

Xavier初始化的思想很简单,即尽可能保持所有层之间输入输出的方差一致。

结论是在初始化时从正态分布中随机采样来构成初始权重,其中和分别代表输入和输出的维度。

直接给出最终结果很简单,但是它是如何推导出来的呢?

下面就来推导看看,我们来看第l层的公式,假设激活函数是tanh:

对于两独立随机变量有:

如果同时X,Y的均值为零,有:

基于以上条件,那么:

其中,基于假设1有:

类似地,有:

和:

基于以上,我们有:

其中可以看成是输入信号,是输出信号。

也就是说,输入信号的方差经过该层后放大或缩小了倍。为了使得在经过多层网络后,输入信号不被过分放大或过分减弱,我们尽可能保持每个神经元的输入和输出方差一致,这样,需要有,即

如果我们考虑整个网络,并用L代表输出层的话。那么输出层的方差与输入层方差的关系为:

从这可以看出,我们输出和输入的方差变化取决于:

上面是正向传播过程,下面我们考虑反向传播过程。

网上大多数只有正向传播的证明,反向传播稍微复杂一点,但也不是无法证明的。

为了简化表示,我们引入一个记号:

其中C为损失。

先看第l层:

我们知道第层有个神经元,如下简单神经网络示意图所示,第二层的第j 个神经元影响了下一层的所有神经元,在计算反向传播时,需要进行梯度累积。

即为了让和上的梯度方差保持一致,需要有

该层的输出数量。

如果我们考虑整个网络,并用L代表输出层,x代表输入向量。那么输出的梯度和输入的梯度的关系为:

证明完毕。

为了简单(公式不好敲),后面用和分别表示某层的输入和输出大小。

若从均匀分布中生成权重参数,那么这里的

因为均匀分布的方差为:

令方差等于上面的调和平均数有:

虽然在上面的推理中,我们假设激活函数为恒等函数(“不存在非线性”)在神经网络中很容易被违反, 但Xavier初始化方法在实践中被证明是有效的。

继续上面的实验,我们采用Xavier初始化方法,这里输入和输出大小一致,因此取就可以了:

w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)

可以看到,输出值在5层之后依然保持良好的分布。我们这里使用的激活函数为sigmoid,那如果换成ReLU会怎样呢?

 z = ReLU(a)

前面几层看起来还可以,随着层数的加深,偏向一点点变大。当层加深后,激活值的偏向变大,就容易出现梯度消失的问题。

那么怎么办呢?Kaiming初始化的提出就是为了解决这个问题。

Kaiming初始化

Kaiming初始化是由何凯明大神提出的,又称为He初始化。主要针对ReLU激活函数:

基于上面公式(4)(6),有:

再根据方差的公式:

公式(26)可以转换为:

上式最后几步基于W的均值为零,所以。

所以,由公式(27)有。

把x,y用原来的式子表示,并将(29)代入式(28)得:

为了让和的方差一致,需要有:

类似的,计算反向传播(注意要考虑ReLU的导数)可以得到

但是Kaiming初始化没有像Xaiver初始化那样取两者的调和平均数,而是根据需要任取一个即可,就像Pytorch的实现那样根据需要取输入还是输出大小。

同理如果采用均匀分布的话,那么,这里n要么是输入大小,要么是输出大小。

继续上面的实验,采用He初始化方法:

w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)

而当初始值为He初始值时,各层中分布的广度相同。由于即便层加深,数据的广度也能保持不变,因此逆向传播时,也会传递合适的值。

代码实现

代码实现就很简单了,代码地址:

👉 https://github.com/nlp-greyfoss/metagrad

class Linear(Module):
    r"""
         对给定的输入进行线性变换: :math:`y=xA^T + b`

        Args:
            in_features: 每个输入样本的大小
            out_features: 每个输出样本的大小
            bias: 是否含有偏置,默认 ``True``
        Shape:
            - Input: `(*, H_in)` 其中 `*` 表示任意维度,包括none,这里 `H_{in} = in_features`
            - Output: :math:`(*, H_out)` 除了最后一个维度外,所有维度的形状都与输入相同,这里H_out = out_features`
        Attributes:
            weight: 可学习的权重,形状为 `(out_features, in_features)`.
            bias:   可学习的偏置,形状 `(out_features)`.
        """

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight = Parameter(Tensor.empty((out_features, in_features)))
        if bias:
            self.bias = Parameter(Tensor.zeros(out_features))
        else:
            self.bias = None
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.kaiming_normal_(self.weight)  # 默认采用kaiming初始化

    def forward(self, input: Tensor) -> Tensor:
        x = input @ self.weight.T
        if self.bias is not None:
            x = x + self.bias

        return x

我们通过调用实现的kaiming_normal_就可以采用Kaiming初始化了。

References

https://www.deeplearning.ai/ai-notes/initialization/

关注下方《学姐带你玩AI》🚀🚀🚀

深度学习220+篇必读论文免费领取

码字不易,欢迎大家点赞评论收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/171969.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【EasyExcel】在Java中操作Excel 完成数据的导入导出

快速入门 引入依赖 构建实体类 数据导出 参数 WriteWorkbook WriteSheet WriteTable 测试 数据导入 测试 EasyExcel是阿里巴巴开源的一个excel处理框架,以使用简单、节省内存著称。EasyExcel能大大减少占用内存的主要原因是在解析Excel时没有将文件数据一…

【Ajax】form表单

一、form表单的基本使用什么是表单表单在网页中主要负责数据采集功能。HTML中的<form>标签&#xff0c;就是用于采集用户输入的信息&#xff0c;并通过<form>标签的提交操作&#xff0c;把采集到的信息提交到服务器端进行处理。2. 表单的组成部分<!-- 表单标签 …

Android 深入系统完全讲解(27)

讲完了这块&#xff0c;我们来说下相机相关的&#xff0c;再说之前一定记得&#xff0c;先要有框架思维&#xff0c;这点一直是我 强调的。 相机是什么&#xff0c;硬件采集数据上来&#xff0c;解析完成&#xff0c;上层绘制&#xff0c;在绘制的时候&#xff0c;同步可以做特…

iOS 国际化(多语言)

一、应用程序国际化 包括app名称和各种权限的提示文字。 1.1 创建工程&#xff0c;再在“PROJECT”的“Info”里面&#xff0c;添加所需语言。 1.2 从代码中分离出文本 创建一个 “.strings” 扩展名的文件 来本地化字符串&#xff0c;需要把这些字符串全部放在一个单独的文…

【医学数据融合文本方向 思路整理】

Scalable and accurate deep learning for electronic health records【2018】 本论文在于介绍 Google Medical Brain 项目的目标、方法和规划。 思路&#xff1a; 用病情描述&#xff0c;预测疾病诊断&#xff0c;预测死亡率 用病情描述加治疗方案&#xff0c;预测复诊率和住院…

Elasticsearch7.8.0版本高级查询—— 查询所有文档

目录一、初始化文档数据二、查询所有文档示例一、初始化文档数据 在 Postman 中&#xff0c;向 ES 服务器发 POST 请求 &#xff1a;http://localhost:9200/user/_doc/1&#xff0c;请求体内容为&#xff1a; {"name":"张三","age":22,"sex…

Zookeeper 【下载与安装,基本使用】

目录 1. 什么是zookeeper 2. zookeeper下载与安装 3. Zookeeper 测试 1. 什么是zookeeper zookeeper实际上是yahoo开发的&#xff0c;用于分布式中一致性处理的框架。最初其作为研发Hadoop时的副产品。 由于分布式系统中一致性处理较为困难&#xff0c;其他的分布式系统没有…

SAP MTO/MTS操作步骤及月末结算

一、MTO/MTS操作步骤 【MTO核算方式】 是以销售订单触发生产的方式。 创建销售订单 VA01 运行物料需求计划 MD01 查询物料需求 MD04 计划订单转化为生产订单 MD04/CO01 生产订单成本计算以及下达 CO02 生产订单发料 MB1A 生产报工 CO11N 完成品入库 MB31 非限制库存转移到销售…

C# opencv多模板匹配实战应用例程

C# 多模板匹配例程 最近在做项目的时候为了检测某一种物品的齐套性&#xff0c;以及为了和写c#的软件负责人配合自己研究了一下opnencv C# 版的模板匹配&#xff0c;对基础的例程做了一下改进&#xff0c;留一份例程。 因为工作性质原因不能直接放项目的实际图片我用visio简单…

我的个人微信也迅速接入了 ChatGPT

本文主要来聊聊如何快速使用个人微信接入 ChatGPT&#xff0c;欢迎 xdm 尝试起来&#xff0c;仅供学习参考&#xff0c;切莫用于做不正当的事情 关于 ChatGPT 我们每个人都可以简单的使用上&#xff0c;不需要你有很强的技术背景&#xff0c;不需要你有公众号&#xff0c;甚至…

Chat GPT 创建APP: 开发人员要被替代了吗?

我们又要被人工智能取代了吗&#xff1f;GitHub Copilot 于 2021 年 10 月发布&#xff0c;整个开发社区都为之疯狂。有些人发表言论说我们很快就会失业&#xff0c;而其他人&#xff08;比如我&#xff09;&#xff0c;认为虽然这个工具很有趣&#xff0c;但距离替代人工还很远…

【Django框架】——25 Django视图 07 状态保持Session

文章目录1.session流程图2.session语法与案例3.session配置cookie不安全&#xff0c;会把所有敏感数据放到浏览器保存。 session是把敏感数据存到自己的服务器中给浏览器一把钥匙就行了&#xff08;是基于cookie完成的&#xff09;。 Django 提供对匿名会话(session)的完全支…

Cisco Packet Tracer 8.2.x Crack

Cisco Packet Tracer 是一个网络模拟器。有了这款功能强大的软件&#xff0c;用户可以在模拟和安全的环境中学习所有网络主题&#xff0c;而无需花费很多钱。它是网络主题模拟和培训领域中最受欢迎的应用程序之一&#xff0c;因为它提供了这样做所需的所有功能。Packet Tricer …

Java方法(函数)

文章目录Java方法(函数)一、方法介绍二、方法的定义和调用格式1. 快速入门2. Debug查看方法的执行流程3. 方法调用内存图解4. 带参数方法的定义和调用1&#xff09;定义和调用格式2&#xff09;形参和实参5. 带返回值方法的定义和调用6. 方法通用定义格式三、方法常见问题四、方…

MIPI 摄像头的原理

1. 摄像头sensor 的原理 定时脉冲生成器会生成clock&#xff0c;用于访问image sensor 阵列中的行&#xff0c;预充电&#xff0c;并且按顺序采样像素阵列中的所有行。在一个行的预充电和采样的时间段里&#xff0c;像素的电荷量会随着曝光时间而逐渐减少。这就是快门结构中的曝…

擎创技术流 | ClickHouse实用工具—ckman教程(10)

一、前言 哈喽~友友们&#xff0c;转眼农历新年就在眼前&#xff0c;ckman系列也终于迎来了最后一期&#xff0c;非常感谢大家的喜欢&#xff0c;让up主有动力做完这个系列&#xff0c;也感谢一路走来&#xff0c;大家给予的反馈&#xff0c;让这个系列越做越好。 接下来&…

4-Spring使用

目录 1.存储Bean对象到Spring容器中 1.1.创建Bean 1.2.将Bean注册到Spring容器中 1.2.1.第一次存储Bean&#xff08;可选&#xff0c;如果是第二次及以后&#xff0c;此步骤忽略&#xff09; 1.2.2.添加Bean标签 2.从Spring容器中获取并使用Bean对象 2.1.创建Spring上下…

剑指 Offer 04. 二维数组中的查找 [C语言]

目录题目思路代码结果该文章只是用于记录考研复试刷题题目 在一个 n * m 的二维数组中&#xff0c;每一行都按照从左到右 非递减 的顺序排序&#xff0c;每一列都按照从上到下 非递减 的顺序排序。请完成一个高效的函数&#xff0c;输入这样的一个二维数组和一个整数&#xff…

[leetcode 1723] 完成所有工作的最短时间

题目 题目&#xff1a;https://leetcode.cn/problems/find-minimum-time-to-finish-all-jobs/description/ 该题和 [leetcode 2305] 公平分发饼干 完全相同。 解法 回溯剪枝 感觉和 [leetcode 198] 划分为k个相等的子集 有点相似&#xff0c;这题更像是划分为k个尽量相等的子…

easypoi 模板导入、导出合并excel单元格功能

easypoi 模板导入、导出合并单元格功能 参考&#xff1a; java使用poi读取跨行跨列excel springboot集成easypoi并使用其模板导出功能和遇到的坑 Easypoi Excel模板功能简要说明 easypoi 模板导出兼容合并单元格功能 ExcelUtil package com.yymt.utils;import cn.aftertu…