权值初始化

news2025/3/1 23:08:24

一、梯度消失与爆炸

在神经网络中,梯度消失和梯度爆炸是训练过程中常见的问题。

梯度消失指的是在反向传播过程中,梯度逐渐变小,导致较远处的层对参数的更新影响较小甚至无法更新。这通常发生在深层网络中,特别是使用某些激活函数(如sigmoid函数)时。当梯度消失发生时,较浅层的权重更新较大,而较深层的权重更新较小,使得深层网络的训练变得困难。

梯度爆炸指的是在反向传播过程中,梯度逐渐变大,导致权重更新过大,网络无法收敛。这通常发生在网络层数较多,权重初始化过大,或者激活函数的导数值较大时。

为了解决梯度消失和梯度爆炸问题,可以采取以下方法:

  • 权重初始化:合适的权重初始化可以缓解梯度消失和梯度爆炸问题。常用的方法包括Xavier初始化和He初始化。
  • 使用恰当的激活函数:某些激活函数(如ReLU、LeakyReLU)可以缓解梯度消失问题,因为它们在正半轴具有非零导数。
  • 批归一化(Batch Normalization):通过在每个批次的输入上进行归一化,可以加速网络的收敛,并减少梯度消失和梯度爆炸的问题。
  • 梯度裁剪(Gradient Clipping):设置梯度的上限,防止梯度爆炸。
  • 减少网络深度:减少网络的层数,可以降低梯度消失和梯度爆炸的风险。

综上所述,梯度消失和梯度爆炸是神经网络中常见的问题,可以通过合适的权重初始化、激活函数选择、批归一化、梯度裁剪和减少网络深度等方法来缓解这些问题。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

二、Xavier初始化

对于具有饱和函数(如Sigmoid、Tanh)的激活函数和方差一致性的要求,可以推导出权重矩阵的初始化范围。
假设输入的维度为 n_in,权重矩阵为 W,我们希望满足方差一致性的要求:
在这里插入图片描述

方差一致性:

保持数据尺度维持在恰当范围,通常方差为1

激活函数:ReLU及其变种
在这里插入图片描述


三、十种初始化方法

以下是常用的权重初始化方法:

  1. Xavier均匀分布(Xavier Uniform Distribution):根据输入和输出的维度,从均匀分布中采样权重,范围为 [-a, a],其中 a = sqrt(6 / (n_in + n_out))。适用于具有饱和函数(如Sigmoid、Tanh)的激活函数。
  2. Xavier正态分布(Xavier Normal Distribution):根据输入和输出的维度,从正态分布中采样权重,均值为 0,标准差为 sqrt(2 / (n_in + n_out))。适用于具有饱和函数的激活函数。
  3. Kaiming均匀分布(Kaiming Uniform Distribution):根据输入维度,从均匀分布中采样权重,范围为 [-a, a],其中 a = sqrt(6 / n_in)。适用于具有ReLU激活函数的网络。
  4. Kaiming正态分布(Kaiming Normal Distribution):根据输入维度,从正态分布中采样权重,均值为 0,标准差为 sqrt(2 / n_in)。适用于具有ReLU激活函数的网络。
  5. 均匀分布(Uniform Distribution):从均匀分布中采样权重,范围为 [-a, a],其中 a 是一个常数。
  6. 正态分布(Normal Distribution):从正态分布中采样权重,均值为 0,标准差为 std。
  7. 常数分布(Constant Distribution):将权重初始化为常数。
  8. 正交矩阵初始化(Orthogonal Matrix Initialization):通过QR分解或SVD分解等方法,初始化权重为正交矩阵。
  9. 单位矩阵初始化(Identity Matrix Initialization):将权重初始化为单位矩阵。
  10. 稀疏矩阵初始化(Sparse Matrix Initialization):将权重初始化为稀疏矩阵,其中只有少数非零元素。

不同的初始化方法适用于不同的网络结构和激活函数,选择合适的初始化方法可以帮助网络更好地进行训练和收敛。

nn.init.calculate_gain

nn.init.calculate_gain 是 PyTorch 中用于计算激活函数的方差变化尺度的函数。方差变化尺度是指激活函数输出值方差相对于输入值方差的比例。这个比例对于初始化神经网络的权重非常重要,可以影响网络的训练和性能。

主要参数如下:

  • nonlinearity:激活函数的名称,用字符串表示,比如 ‘relu’、‘leaky_relu’、‘tanh’ 等。
  • param:激活函数的参数,这是一个可选参数,用于指定激活函数的特定参数,比如 Leaky ReLU 的 negative_slope

这个函数的返回值是一个标量,表示激活函数的方差变化尺度。在初始化网络权重时,可以使用这个尺度来缩放权重,以确保网络在训练过程中具有良好的数值稳定性。

例如,可以在初始化网络权重时使用 nn.init.xavier_uniform_nn.init.xavier_normal_,并通过 calculate_gain 函数计算激活函数的方差变化尺度,将其作为相应初始化方法的参数。这样可以根据激活函数的特性来调整权重的初始化范围,有助于更好地训练神经网络。

小案例

import os
import torch
import random
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子


class MLP(nn.Module):
    def __init__(self, neural_num, layers):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):
        for (i, linear) in enumerate(self.linears):
            x = linear(x)
            x = torch.relu(x)

            print("layer:{}, std:{}".format(i, x.std()))
            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num))    # normal: mean=0, std=1

                # a = np.sqrt(6 / (self.neural_num + self.neural_num))
                #
                # tanh_gain = nn.init.calculate_gain('tanh')
                # a *= tanh_gain
                #
                # nn.init.uniform_(m.weight.data, -a, a)

                # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)

                # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
                nn.init.kaiming_normal_(m.weight.data)

flag = 0
# flag = 1

if flag:
    layer_nums = 100
    neural_nums = 256
    batch_size = 16

    net = MLP(neural_nums, layer_nums)
    net.initialize()

    inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

    output = net(inputs)
    print(output)

# ======================================= calculate gain =======================================

# flag = 0
flag = 1

if flag:
    # 生成随机张量并通过tanh激活函数计算输出
    x = torch.randn(10000)
    out = torch.tanh(x)

    # 计算激活函数增益
    gain = x.std() / out.std()
    print('gain:{}'.format(gain))

    # 使用PyTorch提供的calculate_gain函数计算tanh激活函数的增益
    tanh_gain = nn.init.calculate_gain('tanh')
    print('tanh_gain in PyTorch:', tanh_gain)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1381099.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于SpringBoot+Thymeleaf的医院挂号管理系统(有文档、Java毕业设计)

大家好,我是DeBug,很高兴你能来阅读!作为一名热爱编程的程序员,我希望通过这些教学笔记与大家分享我的编程经验和知识。在这里,我将会结合实际项目经验,分享编程技巧、最佳实践以及解决问题的方法。无论你是…

【React】TS项目配置Redux

前提条件 在React中使用Redux,官方要求安装两个插件,Redux Toolkit 和 react-redux Redux Toolkit(RTK): 官方推荐编写Redux逻辑的方式,是一套工具的集合集,简化书写方式。 简化 store 的配置方…

第133期 为什么一些场景下Oracle很难被替换掉(20240113)

数据库管理133期 2024-01-13 第133期 为什么一些场景下Oracle很难被替换掉(20240113)1 数据量2 架构3 应用改造4 Exadata和融合数据库总结 第133期 为什么一些场景下Oracle很难被替换掉(20240113) 今天在薛首席的群里&#xff0c…

Jmeter 性能-监控服务器

Jmeter监控Linux需要三个文件 JMeterPlugins-Extras.jar (包:JMeterPlugins-Extras-1.4.0.zip) JMeterPlugins-Standard.jar (包:JMeterPlugins-Standard-1.4.0.zip) ServerAgent-2.2.3.zip 1、Jemter 安装插件 在插件管理中心的搜索Servers Perform…

day17 平衡二叉树 二叉树的所有路径 左叶子之和

题目1:110 平衡二叉树 题目链接:110 平衡二叉树 题意 判断二叉树是否为平衡二叉树(每个节点的左右两个子树的高度差绝对值不超过1) 递归遍历 递归三部曲 1)确定递归函数的参数和返回值 2)确定终止条…

基于ubuntu2204使用kubeadm部署k8s集群

部署k8s集群 基础环境配置安装container安装runc安装CNI插件部署1.24版本k8s集群(flannel)安装crictl使用kubeadm部署集群节点加入集群部署flannel网络配置dashboard 本集群基于ubuntu2204系统使用kubeadm工具部署1.24版本k8s,容器运行时使用…

倍福PLC控制器开发环境介绍

倍福PLC控制器是一款功能强大、易于使用的可编程逻辑控制器,广泛应用于各种工业自动化控制系统中。为了充分发挥倍福PLC控制器的功能,需要使用合适的开发环境。下面将介绍倍福PLC控制器的开发环境,主要包括软件安装与配置、工程创建与管理、编…

matlab中any()函数用法

一、帮助文档中的介绍 B any(A) 沿着大小不等于 1 的数组 A 的第一维测试所有元素为非零数字还是逻辑值 1 (true)。实际上,any 是逻辑 OR 运算符的原生扩展。 二、解读 分两步走: ①确定维度;②确定运算规则 以下面二维数组为例 >>…

使用组合框QComboBox模拟购物车

1.组合框: QComboBox 组合框:QComboBox 用于存放一些列表项 实例化 //实例化QComboBox* comboBox new QComboBox(this);1.1 代码实现 1.1.1 组合框的基本函数 QComboBox dialog.cpp #include "dialog.h" #include "ui_dialog.h"Dialog::Dialog…

Qt QListWidget列表框控件

文章目录 1 属性和方法1.1 外观1.2 添加条目1.3 删除条目1.4 信号和槽 2 实例2.1 布局2.2 代码实现 Qt中的列表框控件,对应的类是QListWidget 它用于显示多个列表项,列表项对应的类是QListWidgetitem 1 属性和方法 QListWidget有很多属性和方法&#xf…

linux磁盘清理_docker/overlay2爆满

问题:无意间发现linux服务器登陆有问题,使用df命令发现目录满了。 1. 确定哪里占用了大量内存。 cd / du -sh * | sort -rh经过一段时间后,显示如下: // 474G home // 230G var // 40G usr // 10G snap // --- 根据实际情…

Java内存模型之可见性

文章目录 1.什么是可见性问题2.为什么会有可见性问题3.JMM的抽象:主内存和本地内存3.1 什么是主内存和本地内存3.2 主内存和本地内存的关系 4.Happens-Before原则4.1 什么是Happens-Before4.2 什么不是Happens-Before4.3 Happens-Before规则有哪些4.4 演示&#xff…

Kafka的核心原理

Topic的分区和副本机制 分区有什么用呢? 作用: 1- 避免单台服务器容量的限制: 每台服务器的磁盘存储空间是有上限。Topic分成多个Partition分区,可以避免单个Partition的数据大小过大,导致服务器无法存储。利用多台服务器的存储能力&#…

时序预测 | Matlab实现EEMD-SSA-BiLSTM、EEMD-BiLSTM、SSA-BiLSTM、BiLSTM时序预测对比

时序预测 | Matlab实现EEMD-SSA-BiLSTM、EEMD-BiLSTM、SSA-BiLSTM、BiLSTM时间序列预测对比 目录 时序预测 | Matlab实现EEMD-SSA-BiLSTM、EEMD-BiLSTM、SSA-BiLSTM、BiLSTM时间序列预测对比预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.Matlab实现EEMD-SSA-BiLSTM、…

HTML5+CSS3+JS小实例:音频可视化

实例:音频可视化 技术栈:HTML+CSS+JS 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"> <head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible" content="IE=edge"><m…

Rust-模式解构

match 首先&#xff0c;我们看看使用match的最简单的示例&#xff1a; exhaustive 有些时候我们不想把每种情况一一列出&#xff0c;可以用一个下划线来表达“除了列出来的那些之外的其他情况”&#xff1a; 下划线 下划线还能用在模式匹配的各种地方&#xff0c;用来表示…

【STM32】STM32学习笔记-USART串口手法HEX和文本数据包(29)

00. 目录 文章目录 00. 目录01. 串口简介02. 串口收发HEX数据包接线图03. 串口收发HEX数据包示例104. 串口收发HEX数据包示例205. 串口收发文本数据包接线图06. 串口收发文本数据包示例07. 程序示例下载08. 附录 01. 串口简介 串口通讯(Serial Communication)是一种设备间非常…

高精度彩色3D相机:开启崭新的彩色3D成像时代

3D成像的新时代 近年来&#xff0c;机器人技术的快速发展促使对3D相机技术的需求不断增加&#xff0c;原因在于&#xff0c;相机在提高机器人的性能和实现多种功能方面发挥了决定性作用。然而&#xff0c;其中许多应用所需的解决方案更复杂&#xff0c;仅提供环境的深度信息是…

【LeetCode】142. 环形链表 II(中等)——代码随想录算法训练营Day04

题目链接&#xff1a;142. 环形链表 II 题目描述 给定一个链表的头节点 head &#xff0c;返回链表开始入环的第一个节点。 如果链表无环&#xff0c;则返回 null。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了…

Docker 介绍 及 支持的操作系统

Docker组成&#xff1a; Docker主机(Host)&#xff1a; 一个物理机或虚拟机, 用于运行Docker服务进程和容器, 也成为宿主机, node节点。 Docker服务器端(Server)&#xff1a; Docker守护进程, 运行Docker容器。 Docker客户端(Client)&#xff1a; 客户端使用docker命令或其他工…