PyTorch深度学习实战(8)——批归一化

news2024/9/22 21:21:37

PyTorch深度学习实战(8)——批归一化

    • 0. 前言
    • 1. 批归一化原理
    • 2. 批归一化优势
    • 3. 批归一化对模型训练的影响
      • 3.1 未使用批归一化,且输入值较小
      • 3.2 使用批归一化,且输入值较小
      • 3.3 使用批归一化,且输入值较大
    • 小结
    • 系列链接

0. 前言

批归一化( Batch Normalization )是一种常用的神经网络优化技术,用于在神经网络的训练过程中对每批输入进行归一化操作。它的主要目的是缓解梯度消失或梯度爆炸的问题,并且加速模型的收敛。在本节中,首先介绍批归一化的基本原理,然后通过实验观察其在网络训练过程中的重要作用。

1. 批归一化原理

我们已经了解到,如果不缩放输入数据,则权重优化的速度很慢。这是由于当面临以下情况时,隐藏层的值可能会很高:

  • 输入数据值高
  • 权重值高
  • 权重和输入的乘积很高

任何一种情况都可能导致隐藏层具有较大输出值。隐藏层可以视为输出层的输入层。因此,当隐藏层值也很大时,同样会导致网络优化缓慢。接下来,我们考虑当输入值非常小,Sigmoid 输出随权重的变化情况:

输入权重Sigmoid 输出
0.010.000010.500
0.010.00010.500
0.010.0010.500
0.010.010.500
0.010.10.500
0.010.20.500
0.010.30.501
0.010.40.501
0.010.50.501
0.010.60.501
0.010.70.502
0.010.80.502
0.010.90.502
0.0110.502

当输入值非常小时,Sigmoid 输出的变化幅度较小,从而会对权重值产生较大变化。此外,我们已经看到较大的输入值会对训练准确率有负面影响,这表明输入值既不能过小,也不能过大值。
除了输入值外,当前网络层可以视为下一网络层的输入层,因此同样可能会出现过大或过小的情况,从而导致网络优化缓慢。我们已经了解到,当输入值很高时,我们将执行缩放以减小输入值。批归一化是用于在神经网络中对每个批次的输入数据进行归一化处理,它可以加速模型的训练,提高模型的稳定性和泛化能力。批归一化的算法流程如下:

  • 对于每个批次的输入数据,计算其均值和方差。可以通过求取批次内样本的均值和方差的无偏估计来计算
  • 对于每个特征,将其均值归一化为 0,方差归一化为 1。这可以通过减去均值并除以标准差来完成
  • 引入可学习的参数 γ \gamma γ β \beta β,用于缩放和平移归一化后的数据。这两个参数允许模型自动学习适当的尺度和偏移,以便更好地拟合数据

批归一化使用以下公式缩放每个批次的输入数据值:
B a t c h   m e a n   μ B = 1 m ∑ i = 1 m x i B a t c h   V a r i a n c e   σ 2 B = 1 m ∑ i = 1 m ( x i − μ B ) 2 N o r m a l i z e d   i n p u t   x ‾ i = ( x i − μ B ) σ B 2 + ϵ B a t c h   n o r m a l i z e d   i n p u t = γ x ‾ i + β Batch\ mean\ \mu_B=\frac 1 m\sum_{i=1}^mx_i \\ Batch\ Variance\ \sigma_2^B=\frac 1m\sum_{i=1}^m(x_i-\mu_B)^2 \\ Normalized\ input\ \overline x_i=\frac {(x_i-\mu_B)}{\sqrt {\sigma_B^2+\epsilon}}\\ Batch\ normalized\ input=\gamma \overline x_i+\beta Batch mean μB=m1i=1mxiBatch Variance σ2B=m1i=1m(xiμB)2Normalized input xi=σB2+ϵ (xiμB)Batch normalized input=γxi+β
通过每个输入数据减去批数据输入的平均值,然后将其除以批数据方差,可以将一个节点处所有批数据点归一化到一个固定范围,通过引入 γ γ γ β β β 参数,可以让网络识别最佳归一化参数。
批归一化在训练和预测阶段的计算方式是不同的,在训练阶段,批归一化使用当前批次的均值和方差进行归一化;而在预测阶段,使用训练过程中计算得到的整体均值和方差来归一化输入数据。

2. 批归一化优势

批归一化的具有以下优势:

  • 加速模型收敛:通过减少内部协变量偏移( Internal Covariate Shift),即每层输入分布的变化,批归一化有助于加快模型的收敛速度。这意味着我们可以使用更大的学习率,并在更短的时间内达到更好的性能。
  • 提高模型稳定性:批归一化可以减少网络对输入数据中小的批量统计变化的敏感性,从而使得模型更加稳定。这有助于减轻梯度消失或梯度爆炸等训练过程中的问题。
  • 增强模型泛化能力:批归一化具有正则化的效果,类似于 Dropout 等常用的正则化技术。它可以稍微减少对其他正则化方法的依赖,并提高模型的泛化能力。

3. 批归一化对模型训练的影响

为了了解批归一化对模型训练的影响,观察使用以下设定时,训练和验证数据集的损失和准确率值,以及隐藏层值的分布:

  • 未使用批归一化,且输入值较小
  • 使用批归一化,且输入值较小

3.1 未使用批归一化,且输入值较小

我们通常将输入数据缩放到 01,在本节中,我们将更进一步将其缩放到 00.0001 之间,以便了解缩放数据的影响。我们已经知道,即使权重值变化很大,小的输入值也无法改变 Sigmoid 值。
为了缩放输入数据集,我们通常在 FMNISTDataset 类中执行缩放操作,通过将输入像素值除以 (255*10000) 来缩小输入像素值的范围,将输入值的范围缩放至 00.0001

class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/(255.*10000)
        x = x.view(-1,28*28)
        self.x, self.y = x, y 
    def __getitem__(self, ix):
        x, y = self.x[ix], self.y[ix] 
        return x.to(device), y.to(device)
    def __len__(self): 
        return len(self.x)

重新定义 get_model() 函数,以便获取模型的预测及隐藏层的值:

def get_model():
    class neuralnet(nn.Module):
        def __init__(self):
            super().__init__()
            self.input_to_hidden_layer = nn.Linear(28*28,1000)
            self.hidden_layer_activation = nn.ReLU()
            self.hidden_to_output_layer = nn.Linear(1000,10)
        def forward(self, x):
            x = self.input_to_hidden_layer(x)
            x1 = self.hidden_layer_activation(x)
            x2= self.hidden_to_output_layer(x1)
            return x2, x1
    model = neuralnet().to(device)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-3)
    return model, loss_fn, optimizer

在以上代码中,定义了神经网络类,它返回输出层值 (x2) 和隐藏层的激活值 (x1)。
由于修改后的 get_model() 会返回两个输出,我们同样需要修改 train_batch()val_loss() 函数,在这两个函数中我们只需要获取输出层的值,而无需隐藏层值。由于输出层值位于模型返回的第 0 个索引中,我们需要修改函数使其仅获取第 0 个预测索引:

def train_batch(x, y, model, optimizer, loss_fn):
    prediction = model(x)[0]
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

def accuracy(x, y, model):
    with torch.no_grad():
        prediction = model(x)[0]
    max_values, argmaxes = prediction.max(-1)
    is_correct = argmaxes == y
    return is_correct.cpu().numpy().tolist()

训练模型后,得到训练和验证数据集中的准确率和损失值在训练过程中的变化:

准确率和损失值变化
可以看到,即使在 100epoch 之后,模型也没有获得优异性能(只有大约 85% 的验证准确率),而在以上部分中,模型在 10 个 epoch 内的验证数据集上可以获得大约 90% 的准确率。

通过探索隐藏值的分布以及参数分布来了解当输入值范围较小时,模型性能不佳的原因:

隐藏值以及参数分布
第一个分布表示隐藏层中值的分布(可以看到这些值的范围非常小),此外,由于输入和隐藏层值的范围都非常小,权重(包括将输入连接到隐藏层的权重和将隐藏层连接到输出层的权重)必须有大幅度的变化。
我们已经了解了当输入值的范围非常小时,网络就不能很好地训练,接下来我们将学习批归一化如何帮助增大隐藏层内的值范围。

3.2 使用批归一化,且输入值较小

在本节中,我们在上一小节的代码基础上增加批归一化,修改后的 get_model() 函数如下:

from torch.optim import SGD, Adam
def get_model():
    class neuralnet(nn.Module):
        def __init__(self):
            super().__init__()
            self.input_to_hidden_layer = nn.Linear(784,1000)
            self.batch_norm = nn.BatchNorm1d(1000)
            self.hidden_layer_activation = nn.ReLU()
            self.hidden_to_output_layer = nn.Linear(1000,10)
        def forward(self, x):
            x = self.input_to_hidden_layer(x)
            x0 = self.batch_norm(x)
            x1 = self.hidden_layer_activation(x0)
            x2= self.hidden_to_output_layer(x1)
            return x2, x1
    model = neuralnet().to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=1e-3)
    return model, loss_fn, optimizer

在以上代码中,我们声明了一个执行批归一化 (nn.BatchNorm1d) 的变量 (batch_norm),由于隐藏层中每个图像的输出维度为 1,000,因此执行 nn.BatchNorm1d(1000)。此外,在前向传播方法 forward() 中,在 ReLU 激活之前通过批归一化传递隐藏层值的输出。
训练和验证数据集的准确率和损失随时间的变化如下:

准确率和损失变化

可以看到模型训练过程与输入值范围不是很小时的训练过程非常相似。观察隐藏层值的分布和权重分布:

请添加图片描述

可以看到,进行批归一化时,隐藏层值的分布更大,而连接隐藏层和输出层的权重分布更小,模型的训练结果与较优。批归一化在训练深度神经网络时极为有效,它可以帮助模型避免因梯度变得太小而无法更新权重的问题。

3.3 使用批归一化,且输入值较大

为了进一步了解批归一化,在上一小节的代码基础上修改输入数据范围,修改后如下:

class FMNISTDataset(Dataset):
    def __init__(self, x, y):
        x = x.float()/(255.)
        x = x.view(-1,28*28)
        self.x, self.y = x, y 
    def __getitem__(self, ix):
        x, y = self.x[ix], self.y[ix]        
        return x.to(device), y.to(device)
    def __len__(self): 
        return len(self.x)

训练模型后,得到训练和验证数据集中的准确率和损失值在训练过程中的变化:

准确率和损失值变化
观察隐藏层值的分布和权重分布:

隐藏层值的分布和权重分布

小结

批归一化是一种通过标准化神经网络层的输入数据,加速模型训练并提高泛化能力的技术。它在深度学习中广泛应用,是构建高效、稳定的神经网络模型的重要方法。本节介绍了批归一化的基本概念及其优点,并通过实战了解了批归一化对模型训练的影响。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/823871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis—环境搭建

Redis—环境搭建 🔎Centos 安装 Redis5创建符号链接修改配置文件启动 Redis停止 Redis 🔎Centos 安装 Redis5 Centos8 安装 Redis5 yum install -y redisCentos7 安装 Redis5 Centos7 中 yum 源提供的 Redis 版本是 Redis3(有点老), 因此先安装 scl 源 …

算法综合篇专题二:滑动窗口

“在混沌想法中&#xff0c;最不可理喻念头。” 1、长度最小的子数组 (1) 题目解析 (2) 算法原理 class Solution { public:int minSubArrayLen(int target, vector<int>& nums) {int n nums.size();int sum 0;int len INT_MAX;for(int left0,r…

mysql进阶-用户的创建_修改_删除

1. 使用mysql单次查询 [rootVM-4-6-centos /]# mysql -h localhost -P 3306 -p mytest -e "select * from book1"; Enter password: ------------------------------------------- | id | category_id | book_name | num | ----------------------------…

数据结构 | 基本数据结构——队列

目录 一、何谓队列 二、队列抽象数据类型 三、用Python实现队列 四、模拟&#xff1a;传土豆 五、模拟&#xff1a;打印任务 5.1 主要模拟步骤 5.2 Python实现 一、何谓队列 队列是有序集合&#xff0c;添加操作发生在“尾部”&#xff0c;移除操作则发生在“头部”。新…

【Javascript】基础知识

文章目录 01 变量的声明02 数据类型字符串型boolean类型undefined null类型symbol类型超大整数 bigint数组类型普通对象 01 变量的声明 02 数据类型 复习: 声明 ​ 声明变量关键词 ​ let ​ const ​ 变量名 >变量命名规范 ​ 英文 数字 _ $不要以数字开头 ​ 见名知意 ​…

深度学习之tensorboard可视化工具

(1)什么是tensorboard tensorboard是TensorFlow 的一个可视化工具包&#xff0c;提供机器学习实验所需的可视化和工具&#xff0c;该工具的功能如下&#xff1a; 跟踪和可视化指标&#xff0c;例如损失和精度可视化模型图&#xff08;操作和层&#xff09;查看权重、偏差或其…

【Java多线程学习4】volatile关键字及其作用

说说对于volatile关键字的理解&#xff0c;及的作用 概述 1、我们知道要想线程安全&#xff0c;就需要保证三大特性&#xff1a;原子性&#xff0c;有序性&#xff0c;可见性。 2、被volatile关键字修饰的变量&#xff0c;可以保证其可见性和有序性&#xff0c;但是volatile…

uniApp 对接安卓平板刷卡器, 读取串口数据

背景: 设备: 鸿合 电子班牌 刷卡对接 WS-B22CS, 安卓11; 需求: 将刷卡器的数据传递到自己的App中, 作为上下岗信息使用, 以完成业务; 对接方式: 1. 厂家技术首先推荐使用 接收自定义广播的方式来获取, 参考代码如下 对应到uniApp 中的实现如下 <template><view c…

python数据可视化Matplotlib

1.绘制简单的折线图 # -*- coding: utf-8 -*- import matplotlib.pyplot as pltinput_values [1, 2, 3, 4, 5] squares [1, 4, 9, 16, 25] plt.style.use(seaborn) fig, ax plt.subplots() ax.plot(input_values, squares, linewidth3) # 线条粗细# 设置图表标题并给坐标…

2023年第四届“华数杯”数学建模思路 - 复盘:光照强度计算的优化模型

文章目录 0 赛题思路1 问题要求2 假设约定3 符号约定4 建立模型5 模型求解6 实现代码 0 赛题思路 &#xff08;赛题出来以后第一时间在CSDN分享&#xff09; https://blog.csdn.net/dc_sinor?typeblog 1 问题要求 现在已知一个教室长为15米&#xff0c;宽为12米&#xff0…

less的使用

less的介绍&#xff1a; less使用 1、 less使用的第一种用法&#xff0c;起变量名&#xff0c;变量名区分大小写&#xff1a; 这里我们定义一个粉色变量 我想使用直接把变量拿过来就行 2、vscode使用插件&#xff0c;直接将Css文件转换less文件&#xff1a; 3、第二种用法&…

8.泛型

目录 1 基本使用 2 多个泛型 3 泛型约束 3.1 数组 3.2 extends约束 3.3 用泛型约束泛型 4 泛型接口 5 ts中的数组用的就是泛型 6 泛型类 7 常用泛型工具类型 7.1 让所有属性变为可选属性 Partial 7.2 将所有属性都变为只读属性 Readonly 7.3 从指定类…

git-版本控制器

集中式版本控制工具&#xff08;不常用&#xff09; 版本库集中于中央服务器&#xff0c;team要联网才能工作&#xff08;下载代码&#xff09; SVN CVS 分布式版本控制工具 每个电脑上都有一个完整的版本库&#xff0c;工作时无需联网&#xff0c;可以把修改推送给其他人来…

ThreadLocal有内存泄漏问题吗

对于ThreadLocal的原理不了解或者连Java中的引用类型都不了解的可以看一下我的之前的一篇文章Java中的引用和ThreadLocal_鱼跃鹰飞的博客-CSDN博客 我这里也简单总结一下: 1. 每个Thread里都存储着一个成员变量&#xff0c;ThreadLocalMap 2. ThreadLocal本身不存储数据&…

python爬虫(四)_urllib2库的基本使用

本篇我们将开始学习如何进行网页抓取&#xff0c;更多内容请参考:python学习指南 urllib2库的基本使用 所谓网页抓取&#xff0c;就是把URL地址中指定的网络资源从网络流中读取出来&#xff0c;保存到本地。在Python中有很多库可以用来抓取网页&#xff0c;我们先学习urllib2。…

docker minio安装

1.介绍 Minio是一款开源的对象存储服务&#xff0c;它可以在任何硬件或云平台上提供高性能、高可用性和高安全性的存储解决方案。Minio最新版是2021年11月发布的RELEASE.2021-11-24T23-19-33Z&#xff0c;它带来了以下几个方面的改进和新特性&#xff1a; - 支持S3 Select AP…

Allegro选择暗显模式仍然无法实现暗显模式的解决办法

Allegro选择暗显模式仍然无法实现暗显模式的解决办法 用Allegro进行PCB设计的时候,时常需要使用到暗显模式,让视图中未被高亮的图形暗显下去,如下图 左边是未高亮的网络,右边是已高亮的 但是有时候因为一些原因,导致无法暗显,如下图 下面介绍如何解决这个问题,具体操作…

CSPM认证的价值?

最近 CSPM 证书很热门&#xff0c;含金量高&#xff0c;CSPM证书虽然发起的时间不长&#xff0c;但获取 CSPM 证书也是目前发展的一个趋势。如果打算在项目管理领域发展的强烈建议尽快获取 CSPM&#xff0c;提前为自己积攒一些资本。 一、什么是 CSPM证书&#xff1f;跟PMP是什…

Java-API简析_java.io.FileWriter类(基于 Latest JDK)(浅析源码)

【版权声明】未经博主同意&#xff0c;谢绝转载&#xff01;&#xff08;请尊重原创&#xff0c;博主保留追究权&#xff09; https://blog.csdn.net/m0_69908381/article/details/132038909 出自【进步*于辰的博客】 因为我发现目前&#xff0c;我对Java-API的学习意识比较薄弱…

elasticsearch 将时间类型为时间戳保存格式的时间字段格式化返回

dsl查询用法如下&#xff1a; GET /your_index/_search {"_source": {"includes": ["timestamp", // Include the timestamp field in the search results// Other fields you want to include],"excludes": []},"query": …