Pytroch 模型权重初始化

news2024/11/24 17:07:46

目录

1 概念

2 权值初始化方法

2.1 常数初始化

2.2 均匀分布初始化

2.3 正态分布初始化

2.4 Xavier 均匀分布

2.5 Xavier 正态分布

2.6 kaiming 均匀分布

2.7 kaiming 正态分布

2.8 单位矩阵初始化

2.9 正交初始化

2.10 稀疏初始化

2.11 狄拉克δ函数初始化

3 python示例


1 概念

        权值初始化是指在网络模型训练之前,对各节点的权值和偏置初始化的过程,正确的初始化会加快模型的收敛,从而加快模型的训练速度,而不恰当的初始化可能会导致梯度消失或梯度爆炸,最终导致模型无法训练。

        如上图所示的一个基本的CNN网络结构,数据在网络结构中流动时,会有如下的公式(默认没有偏置):

 

        在反向传播的过程中,由于是复合函数的求导,根据链式求导法则,会有两组导数,一个是损失函数Cost对Z的导数,一个是损失函数对W的导数,

1、损失函数关于状态Z的梯度:

2、损失函数关于W的梯度:

        可以看出,在网络结构中,一个参数的初始化关系到网络能否训练出好的结果或者以多快的速度收敛。所以对权值初始化有如下的要求:

  1. 参数不能全部初始化为0或全为1,也不能全部初始化同一个值;(反向传播时梯度的更新值一样)
  2. 参数的初始化值不能太大;(针对激活函数是sigmoid和tanh,参数太大,梯度会消失)
  3. 参数的初始化值也不能太小;(针对激活函数是relu和sigmoid,参数太小,梯度会消失)

一、所有的参数初始化为0或者相同的数

        最简单的初始化方法是将所有的参数初始化为0或者一个常数,但是使用这种特征会使网络中的所有神经元学习到的是相同的特征。
假设神经网络中只有一个有2个神经元的隐藏层,现在将偏置参数初始化为:bias=0,权值矩阵初始化为一个常数α。 网络的输入为(x1,x2),隐藏层使用的激活函数为ReLU,则隐藏层的每个神经元的输出都是relu(αx1+αx2)。 这就导致,对于loss function的值来说,两个神经元的影响是一样的,在反向传播的过程中对应参数的梯度值也是一样,也就说在训练的过程中,两个神经元的参数一直保持一致,其学习到的特征也就一样,相当于整个网络只有一个神经元。

二、过大或过小的初始化

        如果权值的初始值过大,则会导致梯度爆炸,使得网络不收敛;过小的权值初始值,则会导致梯度消失,会导致网络收敛缓慢或者收敛到局部极小值。如果权值的初始值过大,则loss function相对于权值参数的梯度值很大,每次利用梯度下降更新参数时,参数更新的幅度也会很大,这就导致loss function的值在其最小值附近震荡。而过小的初始值则相反,loss关于权值参数的梯度很小,每次更新参数时,更新的幅度也很小,着就会导致loss的收敛很缓慢,或者在收敛到最小值前在某个局部的极小值收敛了。

2 权值初始化方法

2.1 常数初始化

        用val的值填充输入的张量或变量

torch.nn.init.constant_(tensor, val)

参数:

tensor – n维的torch.Tensor或autograd.Variable
val – 用来填充张量的值

使用:

w = torch.empty(3, 5)
nn.init.constant_(w, 0.3)

2.2 均匀分布初始化

        从均匀分布U(a, b)中生成值,填充输入的张量或变量

torch.nn.init.uniform_(tensor, a=0, b=1)

参数:

tensor - n维的torch.Tensor
a - 均匀分布的下界
b - 均匀分布的上界

2.3 正态分布初始化

        从给定均值和标准差的正态分布N(mean, std)中生成值,填充输入的张量或变量

torch.nn.init.normal_(tensor, mean=0, std=1)

参数:

tensor – n维的torch.Tensor
mean – 正态分布的均值
std – 正态分布的标准差

2.4 Xavier 均匀分布

        用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自U(-a, a),其中a= gain * sqrt( 2/(fan_in + fan_out))* sqrt(3). 该方法也被称为Glorot initialisation

torch.nn.init.xavier_uniform_(tensor, gain=1)

参数:

tensor – n维的torch.Tensor
gain - 可选的缩放因子

2.5 Xavier 正态分布

        用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自均值为0,标准差为gain * sqrt(2/(fan_in + fan_out))的正态分布。也被称为Glorot initialisation.

torch.nn.init.xavier_normal_(tensor, gain=1)

参数:

tensor – n维的torch.Tensor
gain - 可选的缩放因子

2.6 kaiming 均匀分布

        用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自U(-bound, bound),其中

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数:

tensor – n维的torch.Tensor
a -这层之后使用的rectifier的负斜率系数(ReLU的默认值为0)
mode - fan_in 保留前向传播时权值方差的大小,fan_out 保留反向传播时的大小。默认:fan_in
nonlinearity –非线性函数,推荐使用relu和leaky_relu,默认leaky_relu

2.7 kaiming 正态分布

        用一个正态分布生成值,填充输入的张量或变量。结果张量中的值采样自的正态分布。

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数:

tensor – n维的torch.Tensor
a -这层之后使用的rectifier的负斜率系数(ReLU的默认值为0)
mode - fan_in 保留前向传播时权值方差的大小,fan_out 保留反向传播时的大小。默认:fan_in
nonlinearity –非线性函数,推荐使用relu和leaky_relu,默认leaky_relu

2.8 单位矩阵初始化

        用单位矩阵来填充2维输入张量或变量。在线性层尽可能多的保存输入特性。

torch.nn.init.eye_(tensor)

参数:

tensor – 2维的torch.Tensor

2.9 正交初始化

        用(半)正交矩阵填充输入的张量或变量。

torch.nn.init.orthogonal_(tensor, gain=1)

参数:

tensor – n维的torch.Tensor或 autograd.Variable,其中n>=2
gain -可选

2.10 稀疏初始化

        将2维的输入张量或变量当做稀疏矩阵填充,其中非零元素根据一个均值为0,标准差为std的正态分布生成。

torch.nn.init.sparse_(tensor, sparsity, std=0.01)

参数:

tensor – n维的torch.Tensor或autograd.Variable
sparsity - 每列中需要被设置成零的元素比例
std - 用于生成非零值的正态分布的标准差

2.11 狄拉克δ函数初始化

使用狄拉克δ函数填充输入的torch.Tensor。

torch.nn.init.dirac_(tensor)

参数:

tensor – {3, 4, 5}维的torch.Tensor

3 python示例

import torch.nn as nn
import torch

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 3, (3, 3), stride=(1, 1), padding=1)
        self.bn1 = nn.BatchNorm2d(3)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))

        return x
# init.uniform  a, b均值分布的上下限
class model_param_init(nn.Module):
    def __init__(self, model):
        super().__init__()
        assert isinstance(model, nn.Module), 'model not a class nn.Module'
        self.net = model
        self.initParam()
    def initParam(self):
        for param in self.net.parameters():
            # nn.init.zeros_(param)
            # nn.init.ones_(param)
            # nn.init.normal_(param, mean=0, std=1)
            # nn.init.uniform_(param, a=0, b=1)
            # nn.init.constant_(param, val=1)   # 将所有权重初始化为1
            # nn.init.eye_(param)  # 只能将二维的tensor初始化为单位矩阵
            # nn.init.xavier_uniform_(param, gain=1)  # Glorot初始化  得到的张量是从-a——a中采用的
            # nn.init.xavier_normal_(param, gain=1)   # 得到的张量是从0-std采样的
            nn.init.kaiming_normal_(param, a=0, mode='fan_in', nonlinearity='leaky_relu') # he初始化方法
            # nn.init.kaiming_uniform_(param)

if __name__ == '__main__':
    net = Net()
    net = model_param_init(net)
    for param in net.parameters():
        print(param)
    # 按照参数更改权重
    for name, param in net.named_parameters():
        if name == net.conv1.weight:  # 指定更改某层的权重值
            nn.init.dirac_(param, groups=1) # 保留卷积层通道的值
        print(param)

PyTorch学习之十一种权重初始化方法_51CTO博客_pytorch权重初始化

 Pytroch进行模型权重初始化_pytorch权重初始化_像风一样自由的小周的博客-CSDN博客

 pytorch学习笔记九:权值初始化_Dear_林的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/539544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

STC15通过内部BandGap电压值测量ADC外部输入电压

STC15通过内部BandGap参考电压值测量ADC通道外部输入电压 📜内部 BandGap参考电压值获取方式: 🎬通过VOFA图形化显示ADC值 🔧vofa+工具下载地址:https://www.vofa.plus/🌿验证对象:IAP15F2K61S2🌿时钟频率:11.0592MHz🌿波特率:115200🔖在通过STC-ISP烧录程序…

XSS攻击以及java应对措施

文章目录 一. XSS攻击介绍1. 前端安全2. xss攻击简介3. xss的攻击方式 二. java应对xss攻击的解决方案1. 强制修改html敏感标签内容2. 利用过滤器过滤非法html标签 一. XSS攻击介绍 1. 前端安全 随着互联网的高速发展,信息安全问题已经成为企业最为关注的焦点之一…

大脑的故事

婴⼉的神经元是相互独⽴的、未连接的。在⼈⽣的头两年, 随着⼤脑细胞接收感觉信息,它们异常迅速地连接起来。 每⼀秒就有多达 200万个新连接(突触)在婴⼉的⼤脑⾥形成。两岁时,⼩孩⼦拥有超过 100万亿个突触&#xff…

版本发布 | 科东软件Intewell-Win V2.1.0 release版本正式发布

Intewell是由科东软件自主研发的工业嵌入式实时操作系统,具有高实时,确定性、高安全、高可靠、虚拟化等特点。Intewell系统源自于1990年诞生的“道”操作系统,至今已有30多年历史,已在多种严苛环境下运行检验,广泛商用…

springboot+java超市收银管理系统idea

考虑到实际生活中在超市 POS 收银管理方面的需要以及对该系统认真的分析,将系统权限按管理员和员工这两类涉及用户划分。 Spring Boot 是 Spring 家族中的一个全新的框架,它用来简化Spring应用程序的创建和开发过程。也可以说 Spring Boot 能简化我们之…

mmFormer:用于脑肿瘤分割的不完全多模态学习的多模态医学Transformer

文章目录 mmFormer: Multimodal Medical Transformer for Incomplete Multimodal Learning of Brain Tumor Segmentation摘要本文方法Hybrid Modality-Specific EncoderModality-Correlated EncoderConvolutional DecoderAuxiliary Regularizer 实验结果 mmFormer: Multimodal …

Kali-linux使用假冒令牌

使用假冒令牌可以假冒一个网络中的另一个用户进行各种操作,如提升用户权限、创建用户和组等。令牌包括登录会话的安全信息,如用户身份识别、用户组和用户权限。当一个用户登录Windows系统时,它被给定一个访问令牌作为它认证会话的一部分。例如…

师从英国两院院士|生物医学科研人员获CSC资助赴剑桥大学访学

L老师拟申报CSC公派访问学者项目,希望到欧洲TOP学校,师从知名教授,在自己的研究基础上取得进一步的进展和突破。最终我们获得世界名校剑桥大学的邀请函,导师是英国皇家科学学会及英国医学科学院两院院士,凭借这份硬气十…

【SAP Abap】X-DOC:SE18/19 - SAP第四代增强概念理解

【SAP Abap】X-DOC:SE18/19 - SAP第四代增强概念理解 1、Tcode2、概念3、增强选项类型4、增强实现类型5、增强操作方式6、增强选项与增强实现关系7、增强实施建议 1、Tcode SE18:Business Add-Ins: Definitions(增强点定义/查看)…

云平台电子班牌系统源码

越来越多的教育单位加入了数字化校园和智慧校园建设行列。在不断探究、建设和实施的过程中,建立强大的、高扩展性的智慧教育管理平台被众多学校和教育单位所认同。智慧班牌是电子班牌信息发布系统的数据呈现端,也是智慧平台数据的采集工具之一。通过智慧…

EtherCAT运动控制器在数控加工手轮随动中的应用之C++

本文以正运动技术具备专用手轮接口的运动控制器ZMC408CE为例,介绍手轮、手轮的作用及原理、控制器手轮接口接线以及手轮程序配置。 上节讲解了使用正运动basic语言进行手轮应用配置,本节主要讲解C调用API函数库接口实现手轮配置。 01 手轮作用及原理 …

第一个gin程序

一、下载并安装gin go get -u github.com/gin-gonic/gin二、第一个gin程序 package mainimport "github.com/gin-gonic/gin"func sayHello(c *gin.Context) {// 返回给客户端一个JSON格式的数据,其中HTTP状态码为200,表示处理成功c.JSON(200…

成功的产品经理,应该了解一定的开发知识

产品经理在互联网产品开发中扮演着协调和推动的重要角色。然而,由于产品经理没有直接的实际权力,与开发团队合作时可能会遇到各种挑战。当你给开发人员分配任务时,他们可能会找各种借口推脱工作。 在项目开发中,所有成员必须共同…

【C++】详解STL中的list及其与vector的比较

目录 一、list的介绍及其使用1、list的介绍2、list的使用2.1 list的构造2.2 list iterator的使用3、list的元素访问接口4、list的调节器6、list的迭代器失效 二、list的模拟实现及反向迭代器1、模拟实现list2、list的反向迭代器 三、list和vector的比较 一、list的介绍及其使用…

142. 环形链表 II Python

文章目录 一、题目描述示例 1示例 2示例 3 二、代码三、解题思路 一、题目描述 给定一个链表的头节点 head ,返回链表开始入环的第一个节点。 如果链表无环,则返回 null。 如果链表中有某个节点,可以通过连续跟踪 next 指针再次到达&#x…

bug记录:遇到的tinycudann编译的N种错误

1. 编译成功,但是import tinycudann报错找不到DLL 编译成功,但是import tinycudann的时候,报错: 开始打断点,搜索电脑文件,发现_75_c.py应该是存在的,但就是读不到。 发现其所在的文件夹名称…

自定义组件间通信-2

目录 一、 父子组件间通信的3种方式 二、属性绑定,父-> 子 三、事件绑定,子-> 父 四、获取组件实例 一、 父子组件间通信的3种方式 属性绑定:用于父组件向子组件的指定属性设置设置数据,仅能设置JSON兼容的数据事件绑定&…

三分钟挖掘快速软件开发框架提高办公效率的秘诀

在科技日新月异的当今社会,学会利用快速软件开发框架,可以给企业带来更大的便利和市场价值。因为它拥有可视化设计、灵活简便、易操作、易上手等优势特点,在助推企业实现数字化转型的过程中有着举足轻重的作用。那么,快速软件开发…

自媒体品牌宣传策略注意哪些,是怎么种草的

众所周知,小红书平台有着极其强大的种草能力。不论新品牌孵化,还是大品牌扩张,都会将目光投注到这里,那么小红书的品牌宣传策略究竟是怎样的呢。 一、聚焦种草能力 前面已经提到了,小红书平台是一个以“种草”为特色的…

在 Python 中执行逐元素加法

文章目录 Python 中的逐元素加法在 Python 中使用 zip() 函数执行逐元素加法在 Python 中使用 map() 函数执行逐元素加法在 Python 中使用 NumPy 执行逐元素加法 我们将通过示例介绍在 Python 中按元素添加两个列表的不同方法。 Python 中的逐元素加法 在 Python 中使用列表时…