模型初始化

news2024/11/25 10:51:06

在深度学习模型训练中,权重初始值极为重要,一个好的初始值会使得模型收敛速度提高,使模型准确率更准确,一般情况下,我们不使用全零初始值训练网络,为了利于训练和减少收敛时间,我们需要对模型进行合理的初始化, P y t o r c h Pytorch Pytorch也在 t o r c h . n n . i n i t torch.nn.init torch.nn.init中为我们提供了常用的初始化方法,通过本章学习,你将学习到以下内容。

  • 常见的初始化函数。
  • 初始化函数的使用。

torch.nn.init的内容

我们发现初始化模块提供了以下的初始化方法:

  • torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
  • torch.nn.init.normal_(tensor, mean=0.0, std=1.0) 3 . *
  • torch.nn.init.constant_(tensor, val) 4 .
  • torch.nn.init.ones_(tensor) 5
  • torch.nn.init.zeros_(tensor)
  • torch.nn.init.eye_(tensor)
  • torch.nn.init.dirac_(tensor, groups=1)
  • torch.nn.init.xavier_uniform_(tensor, gain=1.0)
  • torch.nn.init.xavier_normal_(tensor, gain=1.0)
  • torch.nn.init.kaiming_uniform_(tensor, a=0, mode=‘fan__in’, nonlinearity=‘leaky_relu’)
  • torch.nn.init.kaiming_normal_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)
  • torch.nn.init.orthogonal_(tensor, gain=1)
  • torch.nn.init.sparse_(tensor, sparsity, std=0.01)
  • torch.nn.init.calculate_gain(nonlinearity, param=None)

具体分布解释

均匀分布

  • torch.nn.init.uniform_(tensor, a=0.0, b=1.0)
    • tensor – an n-dimensional torch.Tensor
    • a – the lower bound of the uniform distribution
    • b – the upper bound of the uniform distribution

高斯分布

  • torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
    * tensor – an n-dimensional torch.Tensor
    * mean – the mean of the normal distribution
    * std – the standard deviation of the normal distribution

初始化为常数

  • torch.nn.init.constant_(tensor, val)
    • tensor – an n-dimensional torch.Tensor
    • val – the value to fill the tensor with

初始化全为1

  • torch.nn.init.ones_(tensor)
    • tensor – an n-dimensional torch.Tensor

初始化全为0

  • torch.nn.init.zeros_(tensor)
    • tensor – an n-dimensional torch.Tensor

初始化为对角单位矩阵

  • torch.nn.init.eye_(tensor)
    • tensor – a 2-dimensional torch.Tensor

.Xavier 均匀分布

  • torch.nn.init.xavier_uniform_(tensor, gain=1.0)
    • tensor – an n-dimensional torch.Tensor
    • gain – an optional scaling factor

.Xavier 高斯分布

在这里插入图片描述

  • torch.nn.init.xavier_normal_(tensor, gain=1.0)
    • tensor – an n-dimensional torch.Tensor
    • gain – an optional scaling factor

He 均匀分布

在这里插入图片描述

  • torch.nn.init.kaiming_uniform_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)
    • tensor – an n-dimensional torch.Tensor
    • a – the negative slope of the rectifier used after this layer (only used with ‘leaky_relu’
      在这里插入图片描述

He 高斯分布

  • torch.nn.init.kaiming_normal_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)

初始化函数的使用

初始化函数的封装

def initialize_weights(self):
	for m in self.modules():
		# 判断是否属于Conv2d
		if isinstance(m, nn.Conv2d):
			torch.nn.init.xavier_normal_(m.weight.data)
			# 判断是否有偏置
			if m.bias is not None:
				torch.nn.init.constant_(m.bias.data,0.3)
		elif isinstance(m, nn.Linear):
			torch.nn.init.normal_(m.weight.data, 0.1)
			if m.bias is not None:
				torch.nn.init.zeros_(m.bias.data)
		elif isinstance(m, nn.BatchNorm2d):
			m.weight.data.fill_(1) 		 
			m.bias.data.zeros_()	

模型定义,调用初始化函数。

# 模型的定义
class MLP(nn.Module):
  # 声明带有模型参数的层,这里声明了两个全连接层
  def __init__(self, **kwargs):
    # 调用MLP父类Block的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
    super(MLP, self).__init__(**kwargs)
    self.hidden = nn.Conv2d(1,1,3)
    self.act = nn.ReLU()
    self.output = nn.Linear(10,1)
    
   # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
  def forward(self, x):
    o = self.act(self.hidden(x))
    return self.output(o)

mlp = MLP()
print(list(mlp.parameters()))
print("-------初始化-------")

initialize_weights(mlp)
print(list(mlp.parameters()))

总结

慢慢的会自己编写初始化模块,将初始化,全部都将其搞透彻,。
了解一个模型架构需要那些函数,以及那些模块,会自己将其搞清楚。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/116296.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从入门到项目实战 - Vue 计算属性用法解析

Vue 计算属性用法解析上一节:《Vue 监听器用法解析 》| 下一节:《Vue 样式绑定》jcLee95 邮箱 :291148484163.com CSDN 主页:https://blog.csdn.net/qq_28550263?spm1001.2101.3001.5343 本文地址:https://blog.…

衣服、商品、商城网站模板首页,仿U袋网,vue+elementui简洁实现(二)

一.前言 接上一遍博客:《衣服、商品、商城网站模板首页,仿U袋网,vueelementui简洁实现》 在此基础上增加了和完善一些页面: 商品分类筛选页面登录、注册、找回密码共用页面U袋学堂(视频专区,视频播放&am…

编译原理——参数传递—传名、传地址、得结果、传值

1.传名(替换操作) 把这种方式理解为替换操作,把P函数参数X、Y、Z和P函数内部的Y、Z替换为A、B,然后P函数对Y、Z的操作,其实就是对A、B的操作;需要注意这和传地址一样,上面对A造成的变化&#x…

制品仓库 Nexus 安装、配置、备份、使用

目录 1.1 Nexus 优点 1.2 Nexus 仓库类型 2. 安装 Nexus 2.1 设置持久化目录 2.2 拉取 Nexus docker 镜像 2.3 运行并启动 Nexus 3. 系统配置 3.1 配置管理员密码 3.2 配置 LDAP 3.3 配置 Email 服务器 4. 配置 Repository 4.1 添加 Blob Stores 4.2 添加 Reposit…

软考高级考哪个好?

软考高级一共5个科目,含金量都差不多,每个人考证的需求各不相同,合适自己情况的才是最有用的证书。看你自己的工作、专业与哪个更相近,再来深入学习备考的,当然自己也要对考试取证有一定的信心。 高级科目介绍&#x…

【LeetCode每日一题】——剑指 Offer II 072.求平方根

文章目录一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【解题思路】七【题目提示】八【题目注意】九【时间频度】十【代码实现】十一【提交结果】一【题目类别】 二分查找 二【题目难度】 简单 三【题目编号】 剑指 Offer II 072.求平方根 …

《图解TCP/IP》阅读笔记(第七章 7.5)—— OSPF 开放最短路径优先协议

7.5 OSPF OSPF(Open Shortest Path First,开放最短路径优先)是一种链路状态性的路由协议,即使网络中有环路,也可以进行稳定的路由控制。 另外,OSPF支持子网掩码,使得在RIP中无法实现的可变长度…

在简历上写了“精通自动化测试,阿里面试官跟我死磕后就给我发了高薪 offer

事情是这样的 前段时间面试了阿里,大家也都清楚,如果你在简历上面写着你精通 XX 技术,那面试官就会跟你死磕到底。 我就是在自己的简历上写了精通自动化测试,然后就开启了和阿里面试官的死磕之路,结果就是拿到了一份…

【Lilishop商城】No4-2.业务逻辑的代码开发,涉及到:会员B端第三方登录的开发-平台注册会员接口开发

仅涉及后端,全部目录看顶部专栏,代码、文档、接口路径在: 【Lilishop商城】记录一下B2B2C商城系统学习笔记~_清晨敲代码的博客-CSDN博客 全篇会结合业务介绍重点设计逻辑,其中重点包括接口类、业务类,具体的结合源代…

AMQP协议:消费者、生产者与RibbitMQ节点之间的交互流程,RibbitMQ的核心组成部分

原文链接 一、什么是AMQP协议? AMQP全称:Advanced Message Queuing Protocol(高级消息队列协议)。是应用层协议的一个开发标准,为面向消息的中间件设计。 基于此协议的客户端与消息中间件可以传递消息,不受客户端/中间件的不同产品…

小程序之首页搭建——Flex布局

目录 Flex布局简介 什么是flex布局? flex属性 学习地址: 视图层 View WXML 数据绑定 列表渲染 条件渲染 模板 WXSS 尺寸单位 样式导入 内联样式 选择器 全局样式与局部样式 WXS 示例 注意事项 页面渲染 数据处理 会议OA项目-首页 …

Python实现照片、视频一键压缩及备份源代码

代码 完整代码下载地址:Python实现照片、视频一键压缩及备份源代码 第一次运行前先编辑脚本,修改其中的主库位置、随库位置,保存。 此后要更新随库时,只要双击运行脚本即可。 运行结果示例: 主库位置:D…

用了这么多年的 SpringBoot 你知道什么是 SpringBoot 的 Web 类型推断吗?

用了这么多年的 SpringBoot 那么你知道什么是 SpringBoot 的 web 类型推断吗? 估计很多小伙伴都不知道,毕竟平时开发做项目的时候做的都是普通的 web 项目并不需要什么特别的了解,不过抱着学习的心态,阿粉今天带大家看一下什么是 …

jQuery库冲突

文章目录jQuery库冲突原因jQuery.noConflict()如还想使用$可以这么做jQuery库冲突 原因 在某些情况下,可能有必要在同一个页面中使用多个JavaScript库。但是很多库都使用了“$”这个符号(因为它简短方便),这时就需要用一种方式来…

Oracle中Null和无值的区别

以Leetcode第176题“第二高的薪水”为例。 首先说一下这道题最容易理解、最容易实现的解法&#xff0c;就是比最大值小的值里最大的值。 SELECT MAX(salary) AS SecondHighestSalary FROM Employee WHERE salary < (SELECT MAX(salary) FROM Employee)通过这道题目&#…

泛型------数据结构

泛型 问题:下面是一个简单的顺序表&#xff0c;我们在这里面实现的一个顺序表&#xff0c;是存放的数据类型只有int类型&#xff0c;这就会很不通用&#xff0c;如果我们想什么样的类型的数据都想要放进去,就要把这个数组的类型设置成Object类型 能不能啥样的类型都可以存放呢&…

ArcGIS中ArcMap栅格重采样操作与算法选择

本文介绍在ArcMap软件中&#xff0c;实现栅格图像重采样的具体操作&#xff0c;以及不同重采样方法的选择依据。 在文章Python中ArcPy实现栅格图像文件批量掩膜与批量重采样&#xff08;https://blog.csdn.net/zhebushibiaoshifu/article/details/124282764&#xff09;中&…

optimization问题的解决

目录 临界点critical point 基本介绍临界点两种情况的区分 g和H的举例介绍根据H区分Saddle Point和local minima 批次Batch batch大小的比较 时间的开销训练集和测试集的效果 训练集效果测试集效果 动量Momentum 一般的Gradient Descent带有动量的Gradient Descent 2021 -…

异步通信技术AJAX | 原理剖析、发送Ajax请求四步

目录 一&#xff1a;快速搞定AJAX&#xff08;第一篇&#xff09; 1、传统请求及缺点 2、AJAX请求原理剖析 3、AJAX概述 4、XMLHttpRequest对象 5、AJAX GET请求 6、AJAX GET请求提交数据 7、AJAX GET请求的缓存问题 8、AJAX POST请求及模拟表单提交数据 9、经典案例…

C语言基础--数组

文章目录一维数组一、一维数组的创建和初始化&#xff08;1&#xff09;一维数组的创建&#xff08;2&#xff09;一维数组的初始化1&#xff09;整形数组初始化2&#xff09;字符数组初始化3&#xff09;sizeof与strlen4&#xff09;总结二、一维数组的使用三、一维数组在内存…