pytorch 笔记:torch.nn.init

news2024/11/17 20:40:51

这个模块中的所有函数都是用来初始化神经网络参数的,所以它们都在torch.no_grad()模式下运行,不会被autograd所考虑

1 计算gain value

1.1 介绍

这个在后面的一些nn.init初始化中会用到

 1.2 用法

torch.nn.init.calculate_gain(nonlinearity, param=None)
import torch
torch.nn.init.calculate_gain('sigmoid')
#1

torch.nn.init.calculate_gain('tanh')
#1.6666666666666667

torch.nn.init.calculate_gain('leaky_relu',0.1)
#1.4071950894605838

torch.nn.init.calculate_gain('conv3d')
#1

2 初始化汇总

2.1 均匀分布

以均匀分布U(a,b)填充tensor

torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.uniform_(a,3,5)
a
'''
tensor([[3.2886, 3.5971, 3.3080, 4.5271, 4.3113],
        [4.3634, 4.1311, 3.4466, 3.3745, 3.9957],
        [4.7776, 4.4654, 4.7397, 3.5465, 4.5716]])
'''

2.2 正态分布

N(mean,std^2)初始化tensor

torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.normal_(a,0,5)
a
'''
tensor([[-9.6473, -0.8678, -7.0850, -1.3568, -6.1306],
        [-5.5031, -1.6662,  9.8144, -6.5255, -6.2179],
        [-0.6455, -1.7757,  7.7232, -1.2374, -1.2551]])
'''

2.3 定值

以定值初始化

torch.nn.init.constant_(tensor, val)
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.constant_(a,5)
a
'''
tensor([[5., 5., 5., 5., 5.],
        [5., 5., 5., 5., 5.],
        [5., 5., 5., 5., 5.]])
'''

 2.4 填充1

用定值1初始化

torch.nn.init.ones_(tensor)
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.ones_(a)
a
'''
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
'''

2.5 填充0

用定值0初始化

torch.nn.init.zeros_(tensor)
​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.zeros_(a)
a
'''
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
'''

​

2.6 使用单位矩阵进行初始化

torch.nn.init.eye_(tensor)
​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.eye_(a)
a
'''
tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])
'''

​

2.7 Xavier 均匀初始化

torch.nn.init.xavier_uniform_(tensor, gain=1.0)

根据《Understanding the difficulty of training deep feedforward neural networks》,使用U(-a,a)进行初始化,其中

这里的gain就是 torch.nn.init.calculate_gain输出的内容

​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.xavier_uniform_(a,
                              gain=torch.nn.init.calculate_gain('relu'))
a
'''
tensor([[-1.0399, -0.5018,  0.2838,  1.1071,  0.0897],
        [-0.9356,  0.9661, -0.6718, -1.0132,  0.9140],
        [ 0.9704,  0.8222,  0.2229, -1.1519,  0.4566]])
'''

2.8 Xavier 正态初始化

torch.nn.init.xavier_normal_(tensor, gain=1.0)

根据《Understanding the difficulty of training deep feedforward neural networks》,使用N(0,std^2)进行初始化,其中

这里的gain就是 torch.nn.init.calculate_gain输出的内容

​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.xavier_uniform_(a,
                              gain=torch.nn.init.calculate_gain('relu'))
a
'''
tensor([[-1.0399, -0.5018,  0.2838,  1.1071,  0.0897],
        [-0.9356,  0.9661, -0.6718, -1.0132,  0.9140],
        [ 0.9704,  0.8222,  0.2229, -1.1519,  0.4566]])
'''

 2.9 Kaiming 均匀

根据《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》,使用U(-bound,bound)

其中

torch.nn.init.kaiming_uniform_(tensor, 
                        a=0, 
                        mode='fan_in',
                        nonlinearity='leaky_relu')

只有当nonlinearity为leaky_relu的时候,a有意义(表示负的那一部分的斜率)

a=torch.Tensor(3,5)
a
'''
tensor([[9.2755e-39, 8.9082e-39, 9.9184e-39, 8.4490e-39, 9.6429e-39],
        [1.0653e-38, 1.0469e-38, 4.2246e-39, 1.0378e-38, 9.6429e-39],
        [9.2755e-39, 9.7346e-39, 1.0745e-38, 1.0102e-38, 9.9184e-39]])
'''

torch.nn.init.kaiming_uniform_(a,
                              mode='fan_out',
                              nonlinearity='relu') 
a
'''
tensor([[ 0.7745, -1.0520, -0.3770,  0.7101,  0.9383],
        [ 1.0138,  0.6069, -0.5126, -0.3454,  1.2242],
        [ 0.3531,  0.2758,  0.3740, -0.8026,  1.1270]])
'''

2.10 kaiming正态 

根据《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》,使用N(0,std^2)进行初始化,其中

 

​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.kaiming_normal_(a,
                              mode='fan_out',
                              nonlinearity='relu') 
a
'''
tensor([[ 1.1192, -0.6108, -1.2601,  0.4863,  0.4850],
        [ 0.8790, -0.1947,  0.3900, -0.1621,  0.0261],
        [-0.5602, -2.0269,  0.1730, -1.4321,  0.1675]])
'''

2.11 截断正态分布 

torch.nn.init.trunc_normal_(tensor, mean=0.0, std=1.0, a=- 2.0, b=2.0)

如果初始化的某一些元素不在[a,b]之间,那么就重新随机选取这个值 

​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.trunc_normal_(a,
                           a=-0.2,
                           b=0.8) 
a
'''
tensor([[ 0.4685,  0.7272,  0.1331, -0.0746,  0.4909],
        [-0.1088,  0.4126,  0.4549,  0.0990,  0.3314],
        [ 0.4176,  0.0785,  0.3213,  0.5305,  0.5663]])
'''

2.12 初始化稀疏矩阵

torch.nn.init.sparse_(tensor, sparsity, std=0.01)

 sparsity表示每一列多少比例的元素是0

std表示每一列以N(0,std^2)的方式选择非负值

​
a=torch.Tensor(3,5)
a
'''
tensor([[9.8265e-39, 9.4592e-39, 1.0561e-38, 7.3470e-39, 1.0653e-38],
        [1.0194e-38, 1.0929e-38, 1.0102e-38, 1.0561e-38, 1.0561e-38],
        [1.0561e-38, 1.0745e-38, 1.0561e-38, 8.7245e-39, 9.6429e-39]])
'''

torch.nn.init.sparse_(a,sparsity=0.3)
a
'''
tensor([[ 0.0000,  0.0074, -0.0044, -0.0046,  0.0000],
        [-0.0091,  0.0000, -0.0111, -0.0024,  0.0047],
        [-0.0004,  0.0037,  0.0000,  0.0000,  0.0007]])
'''

3 fan_in 与 fan_out

下面是kaiming 初始化中对fan_mode的说法

  • "fan_in"可以保留前向计算中权重方差的大小。
    • Linear的输入维度
    • Conv2d:in\_channel*kernel\_width*kernel\_height
  • "fan_out"将保留后向传播的方差大小。 
    • Linear的输出维度
    • Conv2d:out\_channel*kernel\_width*kernel\_height

3.1 Pytorch的计算方式

Linear:

net=torch.nn.Linear(3,5)
net
#Linear(in_features=3, out_features=5, bias=True)

torch.nn.init._calculate_fan_in_and_fan_out(net.weight)
#(3,5)

torch.nn.init._calculate_correct_fan(net.weight,
                                    mode='fan_in')
#3

torch.nn.init._calculate_correct_fan(net.weight,
                                    mode='fan_out')
#5

Conv2d

net=torch.nn.Conv2d(kernel_size=(3,5),
                    in_channels=2,
                    out_channels=10)
net
#Conv2d(2, 10, kernel_size=(3, 5), stride=(1, 1))

torch.nn.init._calculate_fan_in_and_fan_out(net.weight)
#(30,150)



torch.nn.init._calculate_correct_fan(net.weight,
                                    mode='fan_in')
#30 (2*3*5)


torch.nn.init._calculate_correct_fan(net.weight,
                                    mode='fan_out')
#150 (10*3*5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/174357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】回溯算法总结

回溯法解决的问题 回溯法模板 返回值:一般为void参数:先写逻辑,用到啥参数,再填啥参数终止条件:到达叶子节点,保存当前结果,返回遍历过程:回溯法一般在集合中递归搜索,集…

使用DiskGenius进行硬盘数据迁移

克隆硬盘 - DiskGenius 1.迁移磁盘 选择自己想要迁移的磁盘,点击工具-克隆磁盘 首先选择源硬盘,点击确定 之后选择想要迁移到的硬盘,点击确定 检查一下原硬盘和目标硬盘是否正确,此外还可以对这个空间进行二次调整。最终如果没有…

Android 中关于 FileObserver类监听文件状态的实践

文章目录需求背景走进源码实现示例参考需求背景 当某一个目录的文件发生变化(创建、修改、删除、移动)时,需要给一个回调事件给其他端调用。 其他场景:阅后即焚等等。 比如在 Android 的 VR 设备中,有一个用于部署的文…

Oracle P6 Professional专业版 22.12 中的热门新功能

目录 并排查看项目 在复制与 WBS 元素的关系时具有更大的灵活性 更轻松地确定要分配的正确基线 复制并粘贴电子表格中的单元格区域 更好地控制导入数据 检查 P6 专业版中提供的时间表报告 在排序对话框中排列字段顺序 创建导入和导出模板的副本 指定完成日期筛选器如何…

光流估计(一) 光流的简介与操作

今天是大年29,明天要贴春联了!算是在年前赶出来一篇文章发(太长时间没发东西了O。o),也算是自己在光流估计深度学习部分研究的开始~ 明年开学就是研二下学期了,时间过得飞快,毕设、实习、工作等…

MyBatis | 使用插件better-mybatis-generator自动生成dao、pojo

0️⃣简介🗼简介在我们编写MyBatis的项目时,常常需要为数据表编写大量的SQL语句以及dao类。better-mybatis-generator作为一款IDEA插件,可以自动为我们生成所需要的pojo类、dao类,并提供相当多的SQL单表查询操作。利用该插件&…

Python小技巧:富比较方法的妙用,__lt__、__le__、__eq__、__ne__、__gt__、__ge__。。。

前言 这里是Python小技巧的系列文章。这是第二篇&#xff0c;富比较方法的妙用。 在 Python中&#xff0c;富比较方法共6个&#xff0c;如下表所示&#xff1a; 见名知意&#xff0c;富比较主要用于比较。 富比较方法使用释义释义object.__lt__(self, other)x.__lt__(y)x<…

Springboot+mybatis使用PageHelper实现vue前端分页

Springbootmybatis使用PageHelper实现vue前端分页1、未分页前的vue前端效果图2、Springbootmybatis使用PageHelper分页逻辑&#xff1a;&#xff08;1&#xff09;Springboot、mybatis、PageHelper的版本&#xff1a;&#xff08;2&#xff09;yml文件配置pagehelper&#xff1…

带你了解docker是什么----初始篇

docker容器docker简介docker、虚拟环境与虚拟机docker 的核心概念Docker 镜像Docker 仓库Docker容器镜像、容器、仓库&#xff0c;三者之间的联系容器 容器一词的英文是container&#xff0c;其实container还有集装箱的意思&#xff0c;集装箱绝对是商业史上了不起的一项发明&…

11.3 关联容器操作

文章目录关联容器迭代器关键字成员不可修改&#xff0c;值可修改关于泛型算法添加元素向set插入元素向map插入数据insert操作总结检测insert的返回值展开递增语句向multiset和multimap添加元素删除元素map下标操作访问元素类型别名&#xff1a;类型别名说明key_type关键字类型&…

第一个Spring、第一个SpringBoot、Spring-Mybatis整合、SpringBoot-Mybatis整合

目录一、第一个Spring程序二、第一个SpringBoot三、Spring-Mybatis整合四、SpringBoot-Mybatis整合第一个程序一、第一个Spring程序 添加依赖——用以支持spring <dependency><groupId>org.springframework</groupId><artifactId>spring-webmvc</a…

线程池的简单介绍以及实现一个线程池

文章目录1、线程池存在的意义2、什么是线程池&#xff1f;3、线程池的使用2、java标准库中的线程池3、认识一下不同的线程池&#xff1a;4、认识一下线程池里的参数&#xff1a;4、实现一个简单的线程池1、线程池存在的意义 线程存在的意义&#xff1a;使用进程来实现并发编程…

鼠标右键没有git bash here(图文详解)

升级Win11后突然发现右键没有git bash here了解决&#xff1a;1. winr键&#xff0c;打开命令窗口,输入regedit打开注册表2. 在注册表中按照路径打开\HKEY_CLASSES_ROOT\Directory\Background\shell\3. 在shell上右键新建项&#xff0c;取名Git Bash Here&#xff0c;再点击Git…

SpringCloudConsul

上篇文章注册中心选出了Consul 和 K8S&#xff0c;现在我需要把他们集成到SpringCloud里&#xff0c;体验一下他们的服务注册发现、动态配置与权限分配难易 问题&#xff0c;以便选出更适合我们的。SpringCloudConsul首先用Docker搭建出Consul集群&#xff0c;这一步忽略了&…

8、Ubuntu22.4Server安装MariaDB10.10初始化密码Navicat远程登录

安装MariaDB10.10 查找源 apt search mariadb 在Ubuntu系统上从MariaDB存储库安装MariaDB10.10时&#xff0c;需要运行以下命令 sudo apt-get install apt-transport-https curl sudo curl -o /etc/apt/trusted.gpg.d/mariadb_release_signing_key.asc https://mariadb.org…

【微服务】Feign远程调用

本系列介绍的是Spring Cloud中涉及的知识点&#xff0c;如有错误欢迎指出~ 一.引子 我们以前基于RestTemplate发起的http请求远程调用服务&#xff1a; 存在下面的问题&#xff1a; 代码可读性差&#xff0c;编程体验不统一 参数复杂URL难以维护&#xff0c;字符串拼接硬编码…

逆卷积(ConvTranspose2d)是什么?

上图是一个卷积操作&#xff08;蓝色为输入&#xff0c;绿色为输出&#xff09;。 输入的特征图为x&#xff1a;( 4&#xff0c;4 &#xff0c;channels_in&#xff09;其中channels_in表示通道数。 卷积核设置&#xff1a;无padding&#xff0c; kernel size为3*3&#xff0c…

<关键字(1)>——《C语言深度剖析》

目录 关键字 - 第一讲 1.关键字分类 2.定义与声明 2.1 什么是变量(是什么) 2.2如何定义变量(怎么用) 2.3为什么要定义变量(为什么) 2.4 变量定义的本质 2.5 变量声明的本质 3. 最宽宏大量的关键字 - auto 3.1 变量的分类 3.2 变量的作用域 3.3 变量的生命周期 …

汇编语言(第四版)第八章 实验7 习题解答

Power idea 公司从1975年成立一直到1995年的基本情况如下&#xff1a; 下面的程序中&#xff0c;已经定义好了这些数据&#xff1a; assume cs:codesg,ds:datasgdatasg segmentdb 1975,1976,1977,1978,1979,1980,1981,1982,1983db 1984,1985,1986,1987,1988,1989,1990,1991,19…

【12】C语言_几个循环的经典练习

目录 1. 打印n的阶乘; 2、计算 1!2!3!......10! 3、用二分查找在一个有序数组中查找一个数 4、打印如下 5、输入三次密码 6、写一个猜数字游戏 7、如题 8、打印1到100之间 3的倍数 9、给两个数&#xff0c;求出最大公约数 10、找出从1000到2000之间的闰年 11、找出10…