6.5 Batch Normalization

news2024/11/26 8:21:27

在训练神经网络时,往往需要标准化(normalization)输入数据,使得网络的训练更加快速和有效。

然而SGD(随机梯度下降)等学习算法会在训练中不断改变网络的参数,隐藏层的激活值的分布会因此发生变化,而这一种变化就称为内协变量偏移(Internal Covariate Shift,ICS)。

为了解决ICS问题,批标准化(Batch Normalization)固定激活函数的输入变量的均值和方差,使得网络的训练更快。

除了加速训练这一优势,Batch Normalization还具备其他功能:

①应用了Batch Normalization的神经网络在反向传播中有着非常好的梯度流

这样,神经网络对权重的初值和尺度依赖减少,能够使用更高的学习率,还降低了不收敛的风险。

②Batch Normalization还具有正则化的作用,Dropout也就不再需要了。

③Batch Normalization让深度神经网络使用饱和非线性函数成为可能。

一、Batch Normalization的实现方式

Batch Normalization在训练时,用当前训练批次的数据单独的估计每一激活值 x⁽ᴷ⁾  的均值和方差。为了方便,我们接下来只关注某一个激活值 x⁽ᴷ⁾ ,并将 k 省略掉,现定义当前批次为具有 m 个激活值的 β:

β = Xi (i=1,...,m)

首先,计算当前批次激活值的均值和方差:

然后用计算好的均值 \mu _{\beta } 和 方差 δ_β ² 标准化这一批次的激活值 x_{i},得到\hat{x}_{i} ,为了避免除0,\epsilon 被设置为一个非常小的数字,在PyTorch中,默认设置为 le - 5:

这样,我们就固定了当前批次 β 的分布,使得其服从均值为0、方差为1的高斯分布

但是标准化有可能会降低模型的表达能力,因为网络中的某些隐藏层很有可能就是需要输入数据是非标准化分布的,所以Batch Normalization对标准化的变量 x_{i} 加了一步防射变化

  y _{i} = y\hat{x}_{i} + β

添加的两个参数 \gamma 和 β 用于恢复网络的表示能力,它们和网络原本的权重一起训练。

在PyTorch中,β 初始化为0,而 \gamma 则从均匀分布随机采样。当时,标准化的激活值则完全恢复成原始值,这完全由训练中的网络自行决定。

训练完毕后,\gamma 和 β 作为中间状态被保存下来。

在PyTorch的实现中,Batch Normalization在训练时还会计算移动平均化的均值和方差

running_mean = (1 - momentum)· running_mean + momentum · \mu _{\beta }

running_var = (1-momentum)· running_var + momentum ·  δ_β ²

momentum默认为0.1,running_mean 和 running_var在训练完毕后保留,用于模型验证。

Batch Normalization在训练完毕后,保留了两个参数 β 和 \gamma,以及两个变量running_mean和running_var。

在模型做验证时,做如下变换:

二、Batch Normalization的使用方法

在PyTorch中,nn.BatchNorm1d 提供了Batch Normalization的实现,同样地,它也被当作神经网络中的层使用。

它有两个十分关键的参数

num_features确定特征的数量

affine决定Batch Normalization是否使用仿射映射

1、代码示例:

(1)实例化一个BatchNorm1d对象,它接收特征数量 num_features = 5 的数据,所以模型的两个中间变量 running_mean 和 running_var 就会被初始化为5维的向量,用于统计移动平均化的均值和方差。

(2)输出这两个变量的数据,可以很直观地看到它们的初始化方式

(3)从标准高斯分布采样了一些数据然后提供给Batch Normalization层。

(4)输出变化后的 running_mean 和 running_var,可以发现它们的数值发生了一些变化但是基本维持了标准高斯分布的均值方差数值。

(5)验证了如果我们将模型设置为eval模式,这两个变量不会发生任何变化

输出:

上面代码(1)设置了affine = False,也就是不对标准化后的数据采用仿射变化,关于仿射变换的两个参数  β 和 y 在 BatchNorm1d 中称为 weight 和 bias。

代码(2)输出了这两个变量,显然因为我们关闭了仿射变化,所以这两个变量被设置为None

2、代码示例

下面设置 affine = True,然后输出 m_affine.weight、m_affine.bias,可以看到,y 从均匀分布      U(0,1)随机采样,而 β 被初始化为0。

输出:

应当注意,m_affine.weight 和 m_affine.bias的类型均为Parameter

也就是说它们和线性模型的权重是一种类型,参与模型的训练,而running_mean 和 running_var 的类型为Tensor,这样的变量在PyTorch中称为buffer。

buffer不影响模型的训练,仅作为中间变量更新和保存。

四、代码

import torch
from torch import nn

m = nn.BatchNorm1d(num_features=5,affine=False)
print("BEFORE:")
print("running_mean:",m.running_mean)
print("running_var:",m.running_var)

for _ in range(100):
    input = torch.randn(20,5)
    output= m(input)

print("AFTER:")
print("running_mean:",m.running_mean)
print("running_var:",m.running_var)

m.eval()
for _ in range(100):
    input = torch.randn(20,5)
    output = m(input)

print("EVAL:")
print("running_mean:",m.running_mean)
print("running_var:",m.running_var)

#########################################

print("no affine,gamma:",m.weight)
print("no affine,beta:",m.bias)

m_affine = nn.BatchNorm1d(num_features=5,affine=True)
print("")
print("with affine,gamma:",m_affine.weight,type(m_affine.weight))
print("with affine,beta:",m_affine.bias,type(m_affine.bias))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1561261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL故障排查与生产环境优化

一、MySQL单实例常见故障 1.逻辑架构图 MySQL逻辑架构图客户端和连接服务核心服务功能存储引擎层数据存储层 2.故障一 故障现象 ERROR 2002 (HY000): Cant connect to local MySQL server through socket/data/mysql/mysql.sock(2) 问题分析 数据库未启动或者数据库端口…

淘宝优惠券去哪里领取隐藏内部券?

淘宝优惠券去哪里领取隐藏内部券? 1、手机安装「草柴」APP,打开手机淘宝挑选要购买的商品,并点击分享复制链接; 2、将复制的商品链接,粘贴到草柴APP,并点击立即查询该商品的优惠券和约返利; 3、…

java题目15:从键盘输入n个数,求这n个数中的最大数与最小数并输出(MaxAndMin15)

每日小语 你是否有资格摆脱身上的枷锁呢?有许多人一旦获得解放,他的最后一点价值也就会跟着丧失。 ——尼采 自己敲写 它不按我想的来。。。 //从键盘输入n个数,求这n个数中的最大数与最小数并输出 import java.util.Scanner; public clas…

软件测评师教程之软件测试基础

🔥 交流讨论:欢迎加入我们一起学习! 🔥 资源分享:耗时200小时精选的「软件测试」资料包 🔥 教程推荐:火遍全网的《软件测试》教程 📢欢迎点赞 👍 收藏 ⭐留言 &#x1…

Mysql数据库:故障分析与配置优化

目录 前言 一、Mysql逻辑架构图 二、Mysql单实例常见故障 1、无法通过套接字连接到本地MySQL服务器 2、用户rootlocalhost访问被拒绝 3、远程连接数据库时连接很慢 4、无法打开以MYI结尾的索引文件 5、超出最大连接错误数量限制 6、连接过多 7、配置文件/etc/my.cnf权…

全栈开发与测试定向培养班

Python全栈开发与测试 什么是软件测试? 对于测试行业来说,行业普遍会把职位分为测试工程师和测试开发工程师两个岗位。软件测试工程师就是常规意义上了解到的功能测试岗位,以功能测试为主,会有少量的自动化测试。测试能力要求:熟…

键盘输入与屏幕输出——单个字符的输入和输出

目录 字符常量 字符型变量 单个字符的输入输出 两种输入输出方法的比较 字符常量 字符常量是用单引号括起来的一个符号 *’3’表示一个数字字符,而3则表示一个整数数值 转义字符(Escape Character) *一些特殊字符(无法从键盘…

Dimitra:基于区块链、AI 等前沿技术重塑传统农业

根据 2023 年联合国粮食及农业组织(FAO)、国际农业发展基金(IFAD)等组织联合发布的《世界粮食安全和营养状况》报告显示,目前全球约有 7.35 亿饥饿人口,远高于 2019 年的 6.13 亿,这意味着农业仍…

【Linux C | 多线程编程】线程的连接、分离,资源销毁情况

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 ⏰发布时间⏰:2024-04-01 1…

讲解pwngdb的用法,以csapp的bomb lab phase_1为例

参考资料 Guide to Faster, Less Frustrating Debugging 什么情况下会使用gbd 需要逆向ELF文件时(掌握gdb的使用,是二进制安全的基本功)开发程序时,程序执行结果不符合预期 动态调试ELF文件可以使用另外一种方法:IDA的远程linux动态调试。个…

探索 Redis 数据库:一款高性能的内存键值存储系统

目录 引言 一、非关系型数据库 (一)什么是非关系型数据库 (二)非关系型数据库的主要特征 (三)关系数据库与非关系型数据库的区别 二、Redis 简介 (一)基本信息 (…

栈————顺序栈和链式栈

目录 栈 顺序栈 1、初始化顺序栈 2、判栈空 3、进栈 4、出栈 5、读栈顶元素 6、遍历 链式栈 1、初始化链式栈 2、断链式栈是否为空判 3、入栈(插入) ​编辑​编辑 4、出栈(删除) 5、读取栈顶元素 6、输出链式栈中各个节点的值(遍历) 栈 …

LeetCode-240. 搜索二维矩阵 II【数组 二分查找 分治 矩阵】

LeetCode-240. 搜索二维矩阵 II【数组 二分查找 分治 矩阵】 题目描述:解题思路一:从左下角或者右上角元素出发,来寻找target。解题思路二:右上角元素,代码解题思路三:暴力也能过解题思路四:二分…

【小呆的力学笔记】弹塑性力学的初步认知六:后继屈服条件

文章目录 4. 后继屈服条件4.1 后继屈服条件4.2 强化模型4.2.1 等向强化模型4.2.2 随动强化模型4.2.3 两种强化模型的讨论 4. 后继屈服条件 4.1 后继屈服条件 上一章节的屈服条件是在当材料未经受任何塑性变形时且在载荷作用下材料第一次进入屈服应该满足的条件(也…

Vscode + PlatformIO + Arduino 搭建EPS32开发环境

Vscode PlatformIO Arduino 搭建EPS32开发环境 文章目录 Vscode PlatformIO Arduino 搭建EPS32开发环境1. Vscode插件安装2. 使用PlatformIO新建工程3.工程文件的基本结构4.一个基本的测试用例Reference 1. Vscode插件安装 如何下载vscode这里不再赘述,完成基本…

LeetCode-热题100:160. 相交链表

给你两个单链表的头节点 headA 和 headB ,请你找出并返回两个单链表相交的起始节点。如果两个链表不存在相交节点,返回 null 。 图示两个链表在节点 c1 开始相交: 题目数据 保证 整个链式结构中不存在环。 注意,函数返回结果后&…

异常,Lambda表达式

文章目录 异常介绍存在形式程序中异常发生后的第一反应体系JVM的默认处理方案处理方式声明 throws概述格式抛出 throw格式注意意义 throws和throw的区别 捕获 try,catch介绍格式执行方式多异常捕获处理意义 如何选择用哪个 Throwable类介绍常用方法 自定义异常概述实现步骤范例…

2_3.Linux系统中的日志管理

# 1.journald # 服务名称:systemd-journald.service journalctl 默认日志存放路径: /run/log (1) journalctl命令的用法 journalctl -n 3 ##日志的最新3条--since "2020-05-01 11:00:00" ##显示11:00后的日…

Mysql的高级语句3

目录 一、子查询 注意:子语句可以与主语句所查询的表相同,但是也可以是不同表。 1、select in 1.1 相同表查询 1.2 多表查询 2、not in 取反,就是将子查询结果,进行取反处理 3、insert into in 4、update…

LeetCode226:反转二叉树

题目描述 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 解题思想 使用前序遍历和后序遍历比较方便 代码 class Solution { public:TreeNode* invertTree(TreeNode* root) {if (root nullptr) return root;swap(root->left, root…