神经网络系列---权重初始化方法

news2024/9/22 19:27:41

文章目录

    • 权重初始化方法
      • Xavier初始化(Xavier initialization)
      • Kaiming初始化,也称为He初始化
      • LeCun 初始化
      • 正态分布与均匀分布
      • Orthogonal Initialization
      • Sparse Initialization
      • n_in和n_out
      • 代码实现


权重初始化方法

Xavier初始化(Xavier initialization)

是一种用于初始化神经网络权重的方法,也称为Glorot初始化。更有效地传播信号并减少梯度消失或梯度爆炸的问题。适用于激活函数为tanh或sigmoid的情况。

Xavier初始化的计算方法如下:

  1. Glorot(或 Xavier)初始化
    • 适用于激活函数如sigmoid和tanh。
    • 初始化公式: σ = 2 n in + n out \sigma = \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}} σ=nin+nout2
      其中, n in n_{\text{in}} nin 是输入单元数, n out n_{\text{out}} nout 是输出单元数。

对于单个神经元的权重w,从均匀分布正态分布中随机采样,具体取决于所选择的激活函数:

在这里插入图片描述

  1. 如果使用tanh激活函数,从均匀分布采样:
    • 采样范围:[-sqrt(6 / (n_in + n_out)), sqrt(6 / (n_in + n_out))]
    • 其中n_in是上一层的输入节点数量,n_out是当前层的输出节点数量。

在这里插入图片描述

  1. 如果使用sigmoid激活函数,从正态分布采样:
    • 均值:0
    • 方差:sqrt(2 / (n_in + n_out))
    • 其中n_in是上一层的输入节点数量,n_out是当前层的输出节点数量。

Kaiming初始化,也称为He初始化

  1. He 初始化
    • 适用于ReLU及其变种(如LeakyReLU)激活函数。
    • 初始化公式: σ = 2 n in \sigma = \sqrt{\frac{2}{n_{\text{in}}}} σ=nin2

这种初始化方法主要用于修正线性单元(Rectified Linear Units,ReLU)激活函数的神经网络。

与Xavier初始化适用于tanh和sigmoid等S型激活函数不同,Kaiming初始化专门针对ReLU激活函数的特性进行优化。ReLU是一个常用的非线性激活函数,它在输入大于零时保持不变,在输入小于等于零时输出为零。

Kaiming初始化的计算方法如下:

对于单个神经元的权重w,从均匀分布或正态分布中随机采样,具体取决于所选择的激活函数:

  1. 如果使用ReLU激活函数,从正态分布采样:

    • 均值:0
    • 方差:sqrt(2 / n_in)
    • 其中n_in是上一层的输入节点数量。
  2. 对于带有ReLU激活的卷积层,可以使用相同的初始化方法,只是需要考虑卷积层的输入通道数量(即n_in)。

LeCun 初始化

  • 适用于Sigmoid激活函数。
  • 初始化公式: σ = 1 n in \sigma = \sqrt{\frac{1}{n_{\text{in}}}} σ=nin1

正态分布与均匀分布

  • 使用较小的标准差(如0.01)从正态分布中采样权重。
  • 使用较小的范围(如-0.01到0.01)从均匀分布中采样权重。

Orthogonal Initialization

  • 使用正交矩阵初始化权重。这种初始化方法对于某些任务和模型架构可能很有益。

Sparse Initialization

  • 将大部分权重初始化为0,只初始化一小部分非零的权重。

n_in和n_out

n_inn_out分别表示神经网络层的输入节点数量和输出节点数量。这些节点也称为神经元,它们是网络的基本组成部分。

  • n_in:代表上一层(前一层)的节点数量,也就是当前层的输入数量。在神经网络中,每个神经元都会接收来自上一层所有节点的输入,这些输入被加权和后传递给当前神经元的激活函数。因此,n_in指的是上一层与当前层之间的连接数量。

  • n_out:代表当前层的节点数量,也就是当前层的输出数量。每个神经元会将经过激活函数处理后的结果传递给下一层所有节点,形成下一层的输入。因此,n_out指的是当前层与下一层之间的连接数量。

代码实现

#include <iostream>
#include <Eigen/Dense>
#include <random>
#include <cmath>

Eigen::MatrixXd glorotInitialize(int rows, int cols);
Eigen::MatrixXd heInitialize(int rows, int cols);
Eigen::MatrixXd lecunInitialize(int rows, int cols);
Eigen::MatrixXd normalDistributionInitialize(int rows, int cols, double std_dev=0.01);
Eigen::MatrixXd uniformDistributionInitialize(int rows, int cols, double limit=0.01);
Eigen::MatrixXd orthogonalInitialize(int rows, int cols);
// Sparse Initialization需要额外参数来确定稀疏度,这里我们使用一个简化版本,指定一个非零的权重数。
Eigen::MatrixXd sparseInitialize(int rows, int cols, int nonZeroCount);



//1. **Glorot (Xavier) Initialization**:

Eigen::MatrixXd glorotInitialize(int rows, int cols) {
    std::random_device rd;
    std::mt19937 gen(rd());
    double limit = sqrt(6.0 / (rows + cols));
    std::uniform_real_distribution<> dis(-limit, limit);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//**He Initialization**:

Eigen::MatrixXd heInitialize(int rows, int cols) {
    std::random_device rd;
    std::mt19937 gen(rd());
    double std_dev = sqrt(2.0 / rows);
    std::normal_distribution<> dis(0, std_dev);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//3. **LeCun Initialization**:

Eigen::MatrixXd lecunInitialize(int rows, int cols) {
    std::random_device rd;
    std::mt19937 gen(rd());
    double std_dev = sqrt(1.0 / rows);
    std::normal_distribution<> dis(0, std_dev);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//4. **Normal Distribution Initialization**:

Eigen::MatrixXd normalDistributionInitialize(int rows, int cols, double std_dev) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<> dis(0, std_dev);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}

//5. **Uniform Distribution Initialization**:

Eigen::MatrixXd uniformDistributionInitialize(int rows, int cols, double limit) {
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(-limit, limit);

    Eigen::MatrixXd matrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            matrix(i, j) = dis(gen);
        }
    }
    return matrix;
}


//6. **Orthogonal Initialization**:
Eigen::MatrixXd orthogonalInitialize(int rows, int cols) {
    // 创建一个随机矩阵
    std::random_device rd;
    std::mt19937 gen(rd());
    std::normal_distribution<> dis(0, 1);

    Eigen::MatrixXd randomMatrix(rows, cols);
    for(int i = 0; i < rows; i++) {
        for(int j = 0; j < cols; j++) {
            randomMatrix(i, j) = dis(gen);
        }
    }

    // 使用QR分解获得正交矩阵
    Eigen::HouseholderQR<Eigen::MatrixXd> qr(randomMatrix);
    Eigen::MatrixXd orthogonalMatrix = qr.householderQ();

    // 如果您需要一个具有特定维度的正交矩阵(例如rows != cols),您可以选择一个子矩阵
    return orthogonalMatrix.block(0, 0, rows, cols);
}


//7. **Sparse Initialization**:

Eigen::MatrixXd sparseInitialize(int rows, int cols, int nonZeroCount) {
    Eigen::MatrixXd matrix = Eigen::MatrixXd::Zero(rows, cols);
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<> dis(-1, 1);

    for(int i = 0; i < nonZeroCount; i++) {
        int r = rand() % rows;
        int c = rand() % cols;
        matrix(r, c) = dis(gen);
    }

    return matrix;
}
int main() {
    int rows = 5;
    int cols = 5;

    // Glorot Initialization
    Eigen::MatrixXd weights_glorot = glorotInitialize(rows, cols);
    std::cout << "Glorot Initialized Weights:" << std::endl << weights_glorot << std::endl << std::endl;

    // He Initialization
    Eigen::MatrixXd weights_he = heInitialize(rows, cols);
    std::cout << "He Initialized Weights:" << std::endl << weights_he << std::endl << std::endl;

    // LeCun Initialization
    Eigen::MatrixXd weights_lecun = lecunInitialize(rows, cols);
    std::cout << "LeCun Initialized Weights:" << std::endl << weights_lecun << std::endl << std::endl;

    // Normal Distribution Initialization
    Eigen::MatrixXd weights_normal = normalDistributionInitialize(rows, cols);
    std::cout << "Normal Distribution Initialized Weights:" << std::endl << weights_normal << std::endl << std::endl;

    // Uniform Distribution Initialization
    Eigen::MatrixXd weights_uniform = uniformDistributionInitialize(rows, cols);
    std::cout << "Uniform Distribution Initialized Weights:" << std::endl << weights_uniform << std::endl << std::endl;

    // Sparse Initialization
    int nonZeroCount = 10; // As an example, set 10 weights to non-zero values
    Eigen::MatrixXd weights_sparse = sparseInitialize(rows, cols, nonZeroCount);
    std::cout << "Sparse Initialized Weights with " << nonZeroCount << " non-zero values:" << std::endl << weights_sparse << std::endl;

    return 0;
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1468190.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java面试题之SpringMVC篇

Spring MVC的工作原理 Spring MVC的工作原理如下&#xff1a; DispatcherServlet 接收用户的请求找到用于处理request的 handler 和 Interceptors&#xff0c;构造成 HandlerExecutionChain 执行链找到 handler 相对应的 HandlerAdapter执行所有注册拦截器的preHandler方法调…

Java知识点一

hello&#xff0c;大家好&#xff01;我们今天开启Java语言的学习之路&#xff0c;与C语言的学习内容有些许异同&#xff0c;今天我们来简单了解一下Java的基础知识。 一、数据类型 分两种&#xff1a;基本数据类型 引用数据类型 &#xff08;1&#xff09;整型 八种基本数…

计算机网络-网络互联与互联网(一)

1.常用网络互联设备&#xff1a; 1层物理层&#xff1a;中继器、集线器2层链路层&#xff1a;网桥、交换机3层网络层&#xff1a;路由器、三层交换机4层以上高层&#xff1a;网关 2.网络互联设备&#xff1a; 中继器Repeater、集线器Hub&#xff08;又叫多端口中继器&#xf…

一分钟 由浅入深 学会Navigation

目录 1.官网正式概念 1.1 初认知 2.导入依赖 2.1 使用navigation 2.2 safe Args插件-> 传递数据时用 3.使用Navigation 3.1 搭建初始框架 3.2 确定action箭头的属性 3.3 为Activity添加NavHostFragment控件 3.4 NavController 管理应用导航的对象 3.5 数据传递(单…

leetcode单调栈

739. 每日温度 请根据每日 气温 列表&#xff0c;重新生成一个列表。对应位置的输出为&#xff1a;要想观测到更高的气温&#xff0c;至少需要等待的天数。如果气温在这之后都不会升高&#xff0c;请在该位置用 0 来代替。 例如&#xff0c;给定一个列表 temperatures [73, …

基于SpringBoot的停车场管理系统

基于SpringBootVue的停车场管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBootMyBatis工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 前台首页 停车位 个人中心 管理员界面 摘要 摘要&#xff1a;随着城市化进程的…

基于django的购物商城系统

摘要 本文介绍了基于Django框架开发的购物商城系统。随着电子商务的兴起&#xff0c;购物商城系统成为了许多企业和个人创业者的首选。Django作为一个高效、稳定且易于扩展的Python web框架&#xff0c;为开发者提供了便捷的开发环境和丰富的功能模块&#xff0c;使得开发购物商…

Java零基础 - 条件运算符

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一个人虽可以走的更快&#xff0c;但一群人可以走的更远。 我是一名后…

Vue(学习笔记)

什么是Vue Vue是一套构建用户界面的渐进式框架 构建用户界面&#xff1a; 基于数据渲染出用户可以看到的界面 渐进式&#xff1a; 所谓渐进式就是循序渐进&#xff0c;不一定非得把Vue中的所有API都学完才能开发Vue&#xff0c;可以学一点开发一点 创建Vue实例 比如就上面…

k8s学习笔记-基础概念

&#xff08;作者&#xff1a;陈玓玏&#xff09; deployment特别的地方在于replica和selector&#xff0c;docker根据镜像起容器&#xff0c;pod控制容器&#xff0c;job、cronjob、deployment控制pod&#xff0c;job做离线任务&#xff0c;pod大多一次性的&#xff0c;cronj…

汽车常识网:电脑主机如何算功率的计算方法?

今天汽车知识网就给大家讲解一下如何计算一台主机的功率。 它还会解释如何计算计算机主机所需的功率&#xff1f; &#xff1f; &#xff08;如何计算电脑主机所需的功率&#xff09;进行说明。 如果它恰好解决了您现在面临的问题&#xff0c;请不要忘记关注本站。 让我们现在就…

vue3 vite 经纬度逆地址解析

在web端测试经纬度逆地址解析有2中方式&#xff0c;先准备好两个应用key 第一种&#xff0c;使用“浏览器端”应用类型 const address ref() const latitude ref() // 经度 const longitude ref() // 纬度 const ak 你的key // 浏览器端 function getAddressWeb() {// 创建…

【读博杂记】:近期日常240223

近期日常 最近莫名其妙&#xff0c;小导悄悄卷起来&#xff0c;说要早上八点半开始打卡&#xff0c;我感觉这是要针对我们在学校住的&#xff0c;想让我们自己妥协来这边租房子住&#xff0c;但我感觉这是在逼我养成规律作息啊&#xff01;现在基本上就是6~7点撤退&#xff0c;…

【Spring】 AOP面向切面编程

文章目录 AOP是什么&#xff1f;一、AOP术语名词介绍二、Spring AOP框架介绍和关系梳理三、Spring AOP基于注解方式实现和细节3.1 Spring AOP底层技术组成3.2 初步实现3.3 获取通知细节信息3.4 切点表达式语法3.5 重用&#xff08;提取&#xff09;切点表达式3.6 环绕通知3.7 切…

R语言入门笔记2.6

描述统计 分类数据与顺序数据的图表展示 为了下面代码便于看出颜色参数所对应的值&#xff0c;在这里先集中介绍&#xff0c; col1是黑色&#xff0c;2是粉红&#xff0c;3是绿色&#xff0c;4是天蓝&#xff0c;5是浅蓝&#xff0c;6是紫红&#xff0c;7是黄色&#xff0c;…

前沿科技速递——YOLOv9

随着YOLO系列的不断迭代更新&#xff0c;前几天&#xff0c;YOLO系列也迎来了第九个大型号的更新&#xff01;YOLOv9正式推出了&#xff01;附上原论文链接。 arxiv.org/pdf/2402.13616.pdf 同样是使用MS COCO数据集进行对比比较&#xff0c;通过折线图可看出AP曲线在全方面都…

一、系统架构师考试介绍

一、系统架构设计师介绍 系统架构设计师在软考体系中&#xff0c;属于高级资格。(不需要先考中级可以直接报考高级&#xff0c;我之前不知道还考了软件设计师T.T不如当初直接考系统架构师) 考试时间: 每年11月份的第二个周六 报名方式: 网上报名 报名网址 http://wwwruankao.…

C++常见问题

C常见问题 引用模板STLvector原理移动语义与右值引用New delete与malloc freeinlineconststaticexplicit 的作用lambda 表达式友元public、protected、private的区别封装继承多态虚函数重载、重写、隐藏的区别智能指针C 11新特性深拷贝与浅拷贝虚拟内存内存对齐及内存泄漏C内存…

解决ubuntu系统cannot find -lc++abi: No such file or directory

随着CentOS的没落&#xff0c;使用ubuntu的越来越多&#xff0c;而且国外貌似也比较流行使用ubuntu&#xff0c;像LLVM/Clang就有专门针对ubuntu编译二进制发布文件&#xff1a; ubuntu本身也可以直接通过apt install命令来安装编译好的clang编译器。不过目前22.04版本下最高…

【尚硅谷】MybatisPlus 学习笔记(下)

目录 六、插件 6.1、分页插件 6.1.1、添加配置类 6.1.2、测试 6.2、xml自定义分页 6.2.1、UserMapper中定义接口方法 6.2.2、UserMapper.xml中编写SQL 6.2.3、测试 6.3、乐观锁 6.3.1、场景 6.3.2、乐观锁与悲观锁 6.3.3、模拟修改冲突 数据库中增加商品表 添加数…