深度学习笔记:神经网络权重确定初始值方法

news2024/11/17 21:49:01

神经网络权重不可为相同的值,比如都为0,因为如果这样网络正向传播输出和反向传播结果对于各权重都完全一样,导致设置多个权重和设一个权重毫无区别。我们需要使用随机数作为网络权重

实验程序

在以下实验中,我们使用5层神经网络,每层神经元个数100,使用sigmoid作为激活函数,向网络传入1000个正态分布随机数,测试使用不同的随机数对网络权重的影响。

# coding: utf-8
import numpy as np
import matplotlib.pyplot as plt


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


def ReLU(x):
    return np.maximum(0, x)


def tanh(x):
    return np.tanh(x)
    
input_data = np.random.randn(1000, 100)  # 1000个数据
node_num = 100  # 各隐藏层的节点(神经元)数
hidden_layer_size = 5  # 隐藏层有5层
activations = {}  # 激活值的结果保存在这里

x = input_data

for i in range(hidden_layer_size):
    if i != 0:
        x = activations[i-1]

    # 改变初始值进行实验!
    w = np.random.randn(node_num, node_num) * 1
    # w = np.random.randn(node_num, node_num) * 0.01
    # w = np.random.randn(node_num, node_num) * np.sqrt(1.0 / node_num)
    # w = np.random.randn(node_num, node_num) * np.sqrt(2.0 / node_num)


    a = np.dot(x, w)


    # 将激活函数的种类也改变,来进行实验!
    z = sigmoid(a)
    # z = ReLU(a)
    # z = tanh(a)

    activations[i] = z

# 绘制直方图
for i, a in activations.items():
    plt.subplot(1, len(activations), i+1)
    plt.title(str(i+1) + "-layer")
    if i != 0: plt.yticks([], [])
    # plt.xlim(0.1, 1)
    # plt.ylim(0, 7000)
    plt.hist(a.flatten(), 30, range=(0,1))
plt.show()

1 标准差为1随机正态
在这里插入图片描述
在这一情况下,权重值主要集中于0和1.由于sigmoid在接近0和1时导数趋于0,这一数据分别会导致反向传播中梯度逐渐减小,这一现象称为梯度消失

2 标准差为0.01随机正态
在这里插入图片描述
这时神经网络权重集中在0.5附近,此时不会出现梯度消失,但是由于值集中在同一区间,多个神经网络会输出几乎相同的值,使得神经网络表现能力受限(如开头所说)

3 使用Xavier初始值

Xavier初始值为保证各层权重值具有足够广度设计。其推导出的最优初始值为每一层初始权重值是1/√N,其中N为上一层权重个数

使用sigmoid激活函数和Xavier初始值结果:
在这里插入图片描述
可以看到此时权重初始值的值域明显大于了之前的取值。Xavier初始值是基于激活函数为线性函数的假设推导出的。sigmoid函数关于(0, 0.5)对称,其在原点附近还不是完美的线性。而tanh函数关于原点对称,在原点附近可以基本近似于直线,其使用Xavier应该会产生更理想的参数值

使用tanh激活函数和Xavier初始值:
在这里插入图片描述
ReLU函数的权重设置

ReLU函数有自己独特的默认权重设置,称为He初始值,其公式为2/√N标准差的随机数,N为上一次神经元个数。
在这里插入图片描述
在该分布中,各层广度分布基本相同,这使得即使层数加深,也不容易出现梯度消失问题

使用mnist数据集对不同初始化权重方法进行测试:

该程序使用0.01随机正态,Xavier + sigmoid,He + ReLU进行2000轮反向传播,并绘制总损失关于迭代次数图象

# coding: utf-8
import os
import sys

sys.path.append("D:\AI learning source code")  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.util import smooth_curve
from common.multi_layer_net import MultiLayerNet
from common.optimizer import SGD


# 0:读入MNIST数据==========
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000


# 1:进行实验的设置==========
weight_init_types = {'std=0.01': 0.01, 'Xavier': 'sigmoid', 'He': 'relu'}
optimizer = SGD(lr=0.01)

networks = {}
train_loss = {}
for key, weight_type in weight_init_types.items():
    networks[key] = MultiLayerNet(input_size=784, hidden_size_list=[100, 100, 100, 100],
                                  output_size=10, weight_init_std=weight_type)
    train_loss[key] = []


# 2:开始训练==========
for i in range(max_iterations):
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    for key in weight_init_types.keys():
        grads = networks[key].gradient(x_batch, t_batch)
        optimizer.update(networks[key].params, grads)
    
        loss = networks[key].loss(x_batch, t_batch)
        train_loss[key].append(loss)
    
    if i % 100 == 0:
        print("===========" + "iteration:" + str(i) + "===========")
        for key in weight_init_types.keys():
            loss = networks[key].loss(x_batch, t_batch)
            print(key + ":" + str(loss))


# 3.绘制图形==========
markers = {'std=0.01': 'o', 'Xavier': 's', 'He': 'D'}
x = np.arange(max_iterations)
for key in weight_init_types.keys():
    plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 2.5)
plt.legend()
plt.show()

在这里插入图片描述
在该图象中可以看到,0.01随机正态由于梯度丢失问题,权重更新速率极慢,在2000次迭代中总损失基本没有变化。Xavier和He都正常进行了反向传播得到了更准确的网络参数,其中He似乎学习速率更快一些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/380908.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计模式-服务定位器模式

设计模式-服务定位器模式一、背景1.1 服务定位模式1.2 策略模式二、代码实战2.1 服务定位器2.2 配置ServiceLocatorFactoryBean2.3 定义一个支付的接口2.4 根据不同类型处理Bean2.5 controller层三、项目结构及测试结果3.1 测试结果3.2 项目结构及源码(欢迎star)四、参考资料一…

进程管理(进程概念、查看进程、关闭进程)

1.进程 程序运行在操作系统中,是被操作系统所管理的。 为管理运行的程序,每一个程序运行的时候,便被操作系统注册为系统中的一个:进程 并会为每一个进程都分配一个独有的:进程ID(进程号) 2. 查…

Paddle OCR Win 11下的安装和简单使用教程

Paddle OCR Win 11下的安装和简单使用教程 对于中文的识别,可以考虑直接使用Paddle OCR,识别准确率和部署都相对比较方便。 环境搭建 目前PaddlePaddle 发布到v2.4,先下载paddlepaddle,再下载paddleocr。根据自己设备操作系统进…

代码随想录算法训练营第四十天 | 343. 整数拆分,96.不同的二叉搜索树

一、参考资料整数拆分https://programmercarl.com/0343.%E6%95%B4%E6%95%B0%E6%8B%86%E5%88%86.html 视频讲解:https://www.bilibili.com/video/BV1Mg411q7YJ不同的二叉搜索树https://programmercarl.com/0096.%E4%B8%8D%E5%90%8C%E7%9A%84%E4%BA%8C%E5%8F%89%E6%90…

Win10任务栏卡死的几个处理方法 附小工具

问题:Win10任务栏卡死 最近经常碰到用户系统任务栏卡死的现象,桌面上的图标可以正常打开,点击任务栏就疯狂的转圈,感觉像死机状态,等半天都没啥用,一般只能强制关机重启,我不建议这样操作&…

机器学习算法原理——感知机

感知机 输入空间:X⊆Rn\mathcal X\subseteq{\bf R^n}X⊆Rn ;输入:x(x(1),x(2),⋅⋅⋅,x(n))T∈Xx\left(x^{(1)},x^{(2)},\cdot\cdot\cdot,x^{(n)}\right)^{T}\in{\mathcal{X}}x(x(1),x(2),⋅⋅⋅,x(n))T∈X 输出空间:Y{1,−1}{\…

杂谈:created中两次数据修改,会触发几次页面更新?

面试题&#xff1a;created生命周期中两次修改数据&#xff0c;会触发几次页面更新&#xff1f; 一、同步的 先举个简单的同步的例子&#xff1a; new Vue({el: "#app",template: <div><div>{{count}}</div></div>,data() {return {count…

[架构之路-124]-《软考-系统架构设计师》-操作系统-3-操作系统原理 - IO设备、微内核、嵌入式系统

第11章 操作系统第5节 设备管理/文件管理&#xff1a;IO5.1 文件管理5.2 IO设备管理&#xff08;内存与IO设备之间&#xff09;数据传输控制是指如何在内存和IO硬件设备之间传输数据&#xff0c;即&#xff1a;设备何时空闲&#xff1f;设备何时完成数据的传输&#xff1f;SPOO…

vue实战-深入响应式数据原理

本文将带大家快速过一遍Vue数据响应式原理&#xff0c;解析源码&#xff0c;学习设计思路&#xff0c;循序渐进。 数据初始化 _init 在我们执行new Vue创建实例时&#xff0c;会调用如下构造函数&#xff0c;在该函数内部调用this._init(options)。 import { initMixin } f…

代码随想录算法训练营第一天| 704. 二分查找、27. 移除元素

Leetcode 704 二分查找题目链接&#xff1a;704二分查找介绍给定一个 n 个元素有序的&#xff08;升序&#xff09;整型数组 nums 和一个目标值 target &#xff0c;写一个函数搜索 nums 中的 target&#xff0c;如果目标值存在返回下标&#xff0c;否则返回 -1。思路先看看一个…

MyBatis源码分析(三)SqlSession的执行主流程

文章目录一、熟悉主要接口二、SqlSession的获取1、通过数据源获取SqlSession三、Mapper的获取与代理1、从SqlSession获取Mapper2、执行Mapper方法前准备逻辑3、SqlCommand的创建4、构造MethodSignature四、执行Mapper的核心方法1、执行Mapper的方法逻辑五、简单SELECT处理过程1…

【蓝桥杯试题】 递归实现指数型枚举例题

&#x1f483;&#x1f3fc; 本人简介&#xff1a;男 &#x1f476;&#x1f3fc; 年龄&#xff1a;18 &#x1f91e; 作者&#xff1a;那就叫我亮亮叭 &#x1f4d5; 专栏&#xff1a;蓝桥杯试题 文章目录1. 题目描述2. 思路解释2.1 时间复杂度2.2 递归3. 代码展示最后&#x…

超级简单又功能强大还免费的电路仿真软件

设计电路的时候经常需要进行一些电路仿真。常用的仿真软件很多&#xff0c;由于大学里经常使用Multisim作为教学软件&#xff0c;所以基本上所有从事硬件开发的人都听过或者用过Multisim这个软件。这个软件最大的好处就是简单直观&#xff0c;可以在自己的PC上搭建电路并使用软…

gdb常用命令详解

gdb常用调试命令概览和说明 run命令 在默认情况下&#xff0c;gdbfilename只是attach到一个调试文件&#xff0c;并没有启动这个程序&#xff0c;我们需要输入run命令启动这个程序&#xff08;run命令被简写成r&#xff09;。如果程序已经启动&#xff0c;则再次输入 run 命令…

从面试官角度告诉你高级性能测试工程师面试必问的十大问题

目录 1、介绍下最近做过的项目&#xff0c;背景、预期指标、系统架构、场景设计及遇到的性能问题&#xff0c;定位分析及优化&#xff1b; 2、项目处于什么阶段适合性能测试介入&#xff0c;原因是什么&#xff1f; 3、性能测试场景设计要考虑哪些因素&#xff1f; 4、对于一…

SAP MM学习笔记4-在库类型都有哪些,在库类型有哪些控制点

SAP MM模块中的在库类型有3种&#xff1a; 1&#xff0c;利用可能在库 (非限制使用库存) 2&#xff0c;品质检查中在库 &#xff08;质检库存&#xff09; 3&#xff0c;保留在库&#xff08;已冻结库存&#xff09; 这3种在库标识该物料的状态&#xff0c;是否可用。 这3种…

bugku 安全加固1

js劫持 根据题目所给出的ip访问原本应该进入一个学院的二手交易网站 但是实际进入了一个博客 flag需要去除最后的斜杆 黑客首次webshell密码 利用所给的账户密码进行登录进入www目录并且进行备份 #我们对网站进行备份 cd /var/www && tar -czvf /tmp/html.tgz html …

Kubernetes之存储管理(上)

数据持久化的主要方式简介 pod是临时的&#xff0c;pod中的数据随着pod生命周期的结束也会被一起删除。 pod想实现数据持久化主要有以下几种方式&#xff1a; emptyDir&#xff1a;类似于docker run –v /xx&#xff0c;在物理机里随机产生一个目录(这个目录其实挂载的是物理…

墨天轮2022年度数据库获奖名单

2022年&#xff0c;国家相继从高位部署、省级试点布局、地市重点深入三个维度&#xff0c;颁布了多项中国数据库行业发展的利好政策。但是我们也能清晰地看到&#xff0c;中国数据库行业发展之路道阻且长&#xff0c;而道路上的“拦路虎”之一则是生态。中国数据库的发展需要多…

如何创建发布新品上市新闻稿

推出新产品对任何企业来说都是一个激动人心的时刻&#xff0c;但向潜在客户宣传并围绕您的新产品引起轰动也可能是一个挑战。最有效的方法之一就是通过发布新品上市新闻稿。精心制作的新闻稿可以帮助我们通过媒体报道、吸引并在目标受众中引起关注。下面&#xff0c;我们将讲述…