Pytorch深度学习实战2-1:详细推导Xavier参数初始化(附Python实现)

news2024/12/25 0:09:16

目录

  • 1 参数初始化
  • 2 Xavier参数初始化原理
    • 2.1 前向传播阶段
    • 2.2 反向传播阶段
    • 2.3 可视化思考
  • 3 Python实现

1 参数初始化

参数初始化在深度学习中起着重要的作用。在神经网络中,参数初始化是指为模型中的权重和偏置项设置初始值的过程。合适的参数初始化可以帮助模型更好地学习和收敛到最优解。参数初始化的目标是使模型具有良好的初始状态,以便在训练过程中快速且稳定地收敛。错误的参数初始化可能导致模型无法正常学习,梯度消失或梯度爆炸等问题。

常见的参数初始化方法包括随机初始化、零初始化、正态分布初始化和均匀分布初始化等。这些方法根据不同的分布特性和模型结构选择合适的初始值。在某些情况下,不同层或不同类型的参数可能需要不同的初始化策略。例如使用预训练模型时,可以采用迁移学习的方法,将预训练模型的参数作为初始值,从而加速收敛并提高性能。

除了设置初始值外,参数初始化还可以与其他优化技术相结合,如学习率调整、正则化和批归一化等,以进一步提高模型的性能和稳定性

举例而言,如图所示是在 t a n h ( ⋅ ) \rm{tanh(\cdot)} tanh()下九层神经网络各层激活输出,可以看到在网络深层激活输出逐渐衰减或保持不变

在这里插入图片描述

2 Xavier参数初始化原理

Xavier初始化的核心原理是保证各层网络的前向传播激活值和反向传播梯度值方差保持一致。Xavier初始化基于如下假设:

  • 输入样本独立同分布采样,且各个特征维度方差相等;
  • 激活函数 σ ( ⋅ ) \sigma \left( \cdot \right) σ()对称且近似线性区间满足 σ ( z ) ≈ z ⇔ σ ′ ( z ) ≈ 1 \sigma \left( \boldsymbol{z} \right) \approx \boldsymbol{z}\Leftrightarrow \sigma '\left( \boldsymbol{z} \right) \approx 1 σ(z)zσ(z)1
  • 激活输入 z \boldsymbol{z} z处于激活函数的线性区间

2.1 前向传播阶段

根据

a l = σ ( z l ) = σ ( W l a l − 1 − b l ) \boldsymbol{a}^l=\sigma \left( \boldsymbol{z}^l \right) =\sigma \left( \boldsymbol{W}^l\boldsymbol{a}^{l-1}-\boldsymbol{b}^l \right) al=σ(zl)=σ(Wlal1bl)

可得

v a r [ a l ] ≈ v a r [ z l ] = v a r [ W l a l − 1 − b l ] \mathrm{var}\left[ \boldsymbol{a}^l \right] \approx \mathrm{var}\left[ \boldsymbol{z}^l \right] =\mathrm{var}\left[ \boldsymbol{W}^l\boldsymbol{a}^{l-1}-\boldsymbol{b}^l \right] var[al]var[zl]=var[Wlal1bl]

初始阶段第 l l l层的网络权重 W l \boldsymbol{W}^l Wl的各个元素独立采样自某个分布 P P P,即

[ z 1 l z 2 l ⋮ z n l l ] = [ w 1 , 1 l w 1 , 2 l ⋯ w 1 , n l − 1 l w 2 , 1 l w 2 , 2 l ⋯ w 2 , n l − 1 l ⋮ ⋮ ⋱ ⋮ w n l , 1 l w n l , 2 l ⋯ w n l , n l − 1 l ] [ a 1 l − 1 a 2 l − 1 ⋮ a n l − 1 l − 1 ] ⇒ v a r [ z i l ] = v a r [ ∑ k = 1 n l − 1 w 1 , k l a k l − 1 ] \left[ \begin{array}{c} z_{1}^{l}\\ z_{2}^{l}\\ \vdots\\ z_{n_l}^{l}\\\end{array} \right] =\left[ \begin{matrix} w_{1,1}^{l}& w_{1,2}^{l}& \cdots& w_{1,n_{l-1}}^{l}\\ w_{2,1}^{l}& w_{2,2}^{l}& \cdots& w_{2,n_{l-1}}^{l}\\ \vdots& \vdots& \ddots& \vdots\\ w_{n_l,1}^{l}& w_{n_l,2}^{l}& \cdots& w_{n_l,n_{l-1}}^{l}\\\end{matrix} \right] \left[ \begin{array}{c} a_{1}^{l-1}\\ a_{2}^{l-1}\\ \vdots\\ a_{n_{l-1}}^{l-1}\\\end{array} \right] \Rightarrow \mathrm{var}\left[ z_{i}^{l} \right] =\mathrm{var}\left[ \sum_{k=1}^{n_{l-1}}{w_{1,k}^{l}a_{k}^{l-1}} \right] z1lz2lznll = w1,1lw2,1lwnl,1lw1,2lw2,2lwnl,2lw1,nl1lw2,nl1lwnl,nl1l a1l1a2l1anl1l1 var[zil]=var[k=1nl1w1,klakl1]

考虑到 w i , j l w_{i,j}^{l} wi,jl与前一层激活值 a l − 1 \boldsymbol{a}^{l-1} al1独立,所以

v a r [ z i l ] = v a r [ ∑ k = 1 n l − 1 w i , k l a k l − 1 ] = ∑ k = 1 n l − 1 v a r [ w i , k l a k l − 1 ] = ∑ k = 1 n l − 1 ( v a r [ w i , k l ] v a r [ a k l − 1 ] + v a r [ w i , k l ] E 2 [ a k l − 1 ] + v a r [ a k l − 1 ] E 2 [ w i , k l ] ) \begin{aligned}\mathrm{var}\left[ z_{i}^{l} \right] &=\mathrm{var}\left[ \sum_{k=1}^{n_{l-1}}{w_{i,k}^{l}a_{k}^{l-1}} \right]\\& =\sum_{k=1}^{n_{l-1}}{\mathrm{var}\left[ w_{i,k}^{l}a_{k}^{l-1} \right]}\\&=\sum_{k=1}^{n_{l-1}}{\left( \mathrm{var}\left[ w_{i,k}^{l} \right] \mathrm{var}\left[ a_{k}^{l-1} \right] +\mathrm{var}\left[ w_{i,k}^{l} \right] \mathbb{E} ^2\left[ a_{k}^{l-1} \right] +\mathrm{var}\left[ a_{k}^{l-1} \right] \mathbb{E} ^2\left[ w_{i,k}^{l} \right] \right)}\end{aligned} var[zil]=var[k=1nl1wi,klakl1]=k=1nl1var[wi,klakl1]=k=1nl1(var[wi,kl]var[akl1]+var[wi,kl]E2[akl1]+var[akl1]E2[wi,kl])

根据激活函数对称性,可令 W l \boldsymbol{W}^l Wl a l − 1 \boldsymbol{a}^{l-1} al1均值为0,根据假设中的方差关系

{ ∀ i    v a r [ a i l ] = v a r [ a l ] ∀ i , j    v a r [ w i , j l ] = v a r [ W l ] \begin{cases} \forall i\,\,\mathrm{var}\left[ a_{i}^{l} \right] =\mathrm{var}\left[ \boldsymbol{a}^l \right]\\ \forall i,j\,\,\mathrm{var}\left[ w_{i,j}^{l} \right] =\mathrm{var}\left[ \boldsymbol{W}^l \right]\\\end{cases} {ivar[ail]=var[al]i,jvar[wi,jl]=var[Wl]

上式可简化为 v a r [ z i l ] = n l − 1 v a r [ w i , 1 l ] v a r [ a 1 l − 1 ] \mathrm{var}\left[ z_{i}^{l} \right] =n_{l-1}\mathrm{var}\left[ w_{i,1}^{l} \right] \mathrm{var}\left[ a_{1}^{l-1} \right] var[zil]=nl1var[wi,1l]var[a1l1],改写成矩阵形式

v a r [ a l ] ≈ n l − 1 v a r [ W l ] v a r [ a l − 1 ] \mathrm{var}\left[ \boldsymbol{a}^l \right] \approx n_{l-1}\mathrm{var}\left[ \boldsymbol{W}^l \right] \mathrm{var}\left[ \boldsymbol{a}^{l-1} \right] var[al]nl1var[Wl]var[al1]

结合 a 0 = x \boldsymbol{a}^0=\boldsymbol{x} a0=x可递推得到

v a r [ a l ] ≈ v a r [ x ] ∏ k = 1 l n k − 1 v a r [ W k ] {\mathrm{var}\left[ \boldsymbol{a}^l \right] \approx \mathrm{var}\left[ \boldsymbol{x} \right] \prod_{k=1}^l{n_{k-1}\mathrm{var}\left[ \boldsymbol{W}^k \right]}} var[al]var[x]k=1lnk1var[Wk]

2.2 反向传播阶段

根据 δ l = ( W l + 1 ) T δ l + 1 ⊙ σ ′ ( z l ) \boldsymbol{\delta }^l=\left( \boldsymbol{W}^{l+1} \right) ^T\boldsymbol{\delta }^{l+1}\odot \sigma '\left( \boldsymbol{z}^l \right) δl=(Wl+1)Tδl+1σ(zl)可得

v a r [ δ l ] ≈ n l + 1 v a r [ W l + 1 ] v a r [ δ l + 1 ] \mathrm{var}\left[ \boldsymbol{\delta }^l \right] \approx n_{l+1}\mathrm{var}\left[ \boldsymbol{W}^{l+1} \right] \mathrm{var}\left[ \boldsymbol{\delta }^{l+1} \right] var[δl]nl+1var[Wl+1]var[δl+1]

结合 δ L = ∇ y ~ E ⊙ σ ′ ( z L ) ≈ ∇ y ~ E \boldsymbol{\delta }^L=\nabla _{\boldsymbol{\tilde{y}}}E\odot \sigma '\left( \boldsymbol{z}^L \right) \approx \nabla _{\boldsymbol{\tilde{y}}}E δL=y~Eσ(zL)y~E可递推得到

v a r [ δ l ] ≈ ∇ y ~ E ∏ k = l + 1 L n k v a r [ W k ] {\mathrm{var}\left[ \boldsymbol{\delta }^l \right] \approx \nabla _{\boldsymbol{\tilde{y}}}E\prod_{k=l+1}^L{n_k\mathrm{var}\left[ \boldsymbol{W}^k \right]}} var[δl]y~Ek=l+1Lnkvar[Wk]

为保证前向传播激活和反向传播梯度在网络中顺利流动,应保持各层参数方差相等,即满足

{ n l v a r [ W l ] = 1 n l − 1 v a r [ W l ] = 1 \begin{cases} n_l\mathrm{var}\left[ \boldsymbol{W}^l \right] =1\\ n_{l-1}\mathrm{var}\left[ \boldsymbol{W}^l \right] =1\\\end{cases} nlvar[Wl]=1nl1var[Wl]=1

由于第 l l l层的输入神经元个数 n l − 1 n_{l-1} nl1和输出神经元个数 n l n_l nl一般不相等,故取折中

v a r [ W l ] = 2 n l − 1 + n l \mathrm{var}\left[ \boldsymbol{W}^l \right] =\frac{2}{n_{l-1}+n_l} var[Wl]=nl1+nl2

所以网络连接权采样自服从方差满足上式的分布即可,例如

W ∼ N ( 0 , 2 n l − 1 + n l )    o r W ∼ U ( − 6 n l − 1 + n l , 6 n l − 1 + n l ) \boldsymbol{W}\sim \mathcal{N} \left( 0,\frac{2}{n_{l-1}+n_l} \right) \,\, \mathrm{or} \boldsymbol{W}\sim U\left( -\sqrt{\frac{6}{n_{l-1}+n_l}},\sqrt{\frac{6}{n_{l-1}+n_l}} \right) WN(0,nl1+nl2)orWU(nl1+nl6 ,nl1+nl6 )

2.3 可视化思考

如图所示,经过Xavier初始化后网络各层前向和反向传播时的方差保持一致

在这里插入图片描述

如图所示,经过Xavier初始化后的测试误差通常更小

在这里插入图片描述

Xavier进一步指出:观察层与层之间传播的激活值和梯度有利于理解深层网络的训练复杂度;保持层与层之间激活值和梯度的良好流动对学习效果非常重要。尽管在Xavier初始化做出了比较苛刻的假设,且在工程上很容易被违反,但其在实践中被证明是有效的,已经成为很多深度学习框架的默认初始化方法之一。

3 Python实现

简单实现一下Xavier初始化

def initialize_parameters_xavier(layers_dims):
    parameters = {}
    L = len(layers_dims)
    for l in range(1, L):
        mu = 0
        sigma = np.sqrt(2.0 / (layers_dims[l - 1] + layers_dims[l]))
        parameters['W' + str(l)] = np.random.normal(loc=mu, scale=sigma, size=(layers_dims[l], layers_dims[l - 1]))
        parameters['b' + str(l)] = np.zeros((layers_dims[l], 1))
    return parameters

可视化

for l in range(1, num_layers):
	A_pre = A
	W = parameters['W' + str(l)]
	b = parameters['b' + str(l)]
	z = np.dot(W, A_pre) + b # z = Wx + b
	
	A = tanh(z)
	
	print(A)
	plt.subplot(1, 8, l)
	plt.hist(A.flatten(), facecolor='g')
	plt.xlim([-2, 2])
	plt.ylim([0, 1000000])
	plt.yticks([])
plt.show()

如下所示
在这里插入图片描述
可以看出各层输出方差基本一致,实现了良好的初始化效果

完整工程代码请联系下方博主名片获取


🔥 更多精彩专栏

  • 《ROS从入门到精通》
  • 《Pytorch深度学习实战》
  • 《机器学习强基计划》
  • 《运动规划实战精讲》

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1258929.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

阿里云服务器部署node和npm

目录 1.链接服务器2.找到node 下载地址3获取链接地址4下载到linux5.解压6.重命名 解压后的文件7.配置环境变量7.1复制当前的bin目录7.2vim /etc/profile7.3在按下ESC按键 8.重启环境变量9.输入node10.npm配置加速镜像 1.链接服务器 2.找到node 下载地址 https://nodejs.org/d…

C++——解锁string常用接口

本篇的内容是记录使用string接口的测试与使用,方便后续使用时查阅使用 首先介绍 string::npos; size_t(无符号整型)的最大值。NPOS 是一个静态成员常量值,具有 size_t 类型元素的最大可能值。当此值用作字符串成员函数中 len&am…

希宝猫罐头怎么样?专业人士告诉你性价比高的猫罐头推荐

作为一家经营猫咖店已有6年的店长,我在这段时间里接触过不少于30种不同的猫罐头。在猫罐头上我还是有话语权的。通过本文,我将与大家分享值得购买的猫罐头,分享猫罐头喂养的技巧。那么希宝猫罐头表现怎么样呢? 希宝猫罐头可是采用…

Python GUI 图形用户界面程序设计,Python自带 tkinter 库

文章目录 前言GUI介绍简单操作tkinter组件介绍向窗体中添加按钮控件使用文本框控件使用菜单控件使用标签控件使用单选按钮和复选按钮组件使用绘图组件关于Python技术储备一、Python所有方向的学习路线二、Python基础学习视频三、精品Python学习书籍四、Python工具包项目源码合集…

【多线程】-- 04 静态代理模式

多线程 3 静态代理 这里以一个现实生活中的例子来解释并实现所谓的静态代理模式,即结婚者雇用婚庆公司来帮助自己完成整个婚礼过程: package com.duo.lambda;interface Marry {void HappyMarry();//人生四大乐事:久旱逢甘霖;他…

自动化测试|我为什么从Cypress转到了Playwright?

以下为作者观点: 早在2019年,我就开始使用Cypress ,当时我所在的公司决定在新项目中放弃Protractor 。当时,我使用的框架是Angular,并且有机会实施Cypress PoC。最近,我换了工作,现在正在使用R…

springboot自定义更换启动banner动画

springboot自定义更换启动banner动画 文章目录 springboot自定义更换启动banner动画 📕1.新建banner🖥️2.启动项目🔖3.自动生成工具🧣4.彩蛋 🖊️最后总结 📕1.新建banner 在resources中新建banner.txt文…

C#——多线程之异步调用容易出现的问题

C#——多线程之异步调用容易出现的问题 Q1:For中异步调用函数且函数输入具有实时性 Q1:For中异步调用函数且函数输入具有实时性 在项目进行过程中,发现For中用异步调用带有输入参数的函数时,会由于闭包特性,以及Task.…

ssm+vue的公司安全生产考试系统(有报告)。Javaee项目,ssm vue前后端分离项目。

演示视频: ssmvue的公司安全生产考试系统(有报告)。Javaee项目,ssm vue前后端分离项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结…

C#从零搭建微信机器人(二)分词匹配组件【jieba】的使用

上篇文章我们讲解了微信机器人的环境搭建及演示,这期我们来说一下其中在模糊匹配搜索时用到的Segement子项目,也就是其中的中文分词匹配器。 一、原理介绍: 其实这个子项目中的分词插件和solr的IK分词器类似,都是可以支持将一句…

TP4066L是一款完整的单节锂离子电池采用恒定电流/恒定电压线性充电器。

TP4066L 采用ESOP8/DFN2*2-8封装 1A线性锂离子电池充电器 描述: TP4066L是一款完整的单节锂离子电池采用恒定电流/恒定电压线性充电器。其底部带有散热片的ESOP8DFN2*2-8封装与较少的外部元件数目使得TP4066L成为便携式应用的理想选择。TP4066L可以适合USB电源和适…

桐庐县数据资源管理局领导一行莅临美创科技并带来感谢信

11月23日,浙江桐庐县数据资源管理局党组成员、副局长朱勃一行到访美创科技总部参观交流,并带来感谢信,对美创圆满完成护航亚运政务外网数据网站安全保障工作表示充分肯定。美创科技联合创始人、副总裁胡江涛等进行热情接待并开展交流座谈。 图…

微信ipad协议8.0.37/8.0.40新版本

功能如下,如有定制功能请在官网联系我们。 登录 创建新设备 获取登录er维码 执行登录 注销登录 消息 消息回调 消息撤回 发送app类型消息 发送小程序 发送CDN文件 发送CDN图片 发送CDN视频 发送emoji 发送文件 发送图片 发送链接 发送消息 发送视频 发送语音 …

2021年全国a级景区数据,shp+csv数据均有

大家好~这周将和大家分享关于文化旅游和城乡建设相关的数据,希望大家喜欢~ 今天分享的是2021年全国a级景区数据,数据格式有shpcsv,几何类型为点,已经经过清洗加工,可直接使用,以下为部分字段列表&#xff…

无人机遥控器方案定制_MTK平台无人设备手持遥控终端PCB板开发

随着科技的不断发展和无人机技术的逐步成熟,无人机越来越受到人们的关注。作为一种高新技术,无人机的应用范围不断拓展,包括农业、环境监测、城市规划、运输物流等领域。同时,无人机的飞行控制技术也得到了不断的优化和提升。 早…

前端管理制度

数据运营中心的管理形式: 数据运营中心的管理形式 竖向是各小组 横向是项目管理 负责人的定位: 只是工作的内容不同,没有上下级之分 帮助组员找到适合的位置,帮助大家解决问题,给大家提供资源 前端组的工作形式&am…

Doris-Routine Load(二十七)

例行导入(Routine Load)功能为用户提供了一种自动从指定数据源进行数据导入的功能。 适用场景 当前仅支持从 Kafka 系统进行例行导入,使用限制: (1)支持无认证的 Kafka 访问,以及通过 SSL 方…

二维码智慧门牌管理系统:实现高效信息管理

文章目录 前言一、 功能升级优势 前言 随着科技的飞速发展和人们生活节奏的加快,传统的门牌管理系统已经不再适应现代社会的需求。为了解决这一问题,全新的二维码智慧门牌管理系统升级解决方案应运而生,为用户带来前所未有的便捷与高效。 一…

1m照片手机怎么拍?一分钟解决!

我们都知道现在的手机像素特别好,随便拍一张照片都是2-3MB,有时候上课或者会议要拍很多照片,这些照片其实又不需要如此清晰,就会特别占内存,下面就向大家介绍三种好用的办法。 方法一:拍完照后手机截图进行…

【开源】基于JAVA的海南旅游景点推荐系统

项目编号: S 023 ,文末获取源码。 \color{red}{项目编号:S023,文末获取源码。} 项目编号:S023,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 用户端2.2 管理员端 三、系统展示四…