【深度学习】5-3 与学习相关的技巧 - Batch Normalization

news2024/12/26 21:05:39

如果为了使各层拥有适当的广度,“强制性”地调整激活值的分布会怎样呢?实际上,Batch Normalization 方法就是基于这个想法而产生的

为什么Batch Norm这么惹人注目呢?因为Batch Norm有以下优点:

  • 可以使学习快速进行(可以增大学习率)。
  • 不那么依赖初始值(对于初始值不用那么神经质) 。
  • 抑制过拟合(降低Dropout等的必要性)。

Batch Norm的思路是调整各层的激活值分布使其拥有适当的广度。为此,要向神经网络中插入对数据分布进行正规化的层,即Batch Normalization层(下文简称Batch Norm层)
在这里插入图片描述
Batch Norm,顾名思义,以进行学习时的mini-batch为单位,按mini-batch进行正规化。具体而言,就是进行使数据分布的均值为0、方差为1的正规化。用数学式表示的话,如下:
在这里插入图片描述

这里对mini-batch的m个输人数据的集合B求均值方差。然后,对输人数据进行均值为0、方差为1(合适的分布)的正规化。
这个式子所做的是将mini-batch的输人数据变换为均值为0,方差为1的数据。通过将这个处理插入到激活函数的前面(或者后面),可以减少数据分布的偏向
接着,Batch Norm层会对正规化后的数据进行缩放和平移的变换用数学式可以如下表示。
在这里插入图片描述

这里,γ和β是参数。一开始γ=1,β=0,然后再通过学习调整到合适的值。
上面就是Batch Norm的算法。这个算法是神经网络上的正向传播。

用计算图表示如下:
在这里插入图片描述

Batch Norm的反向传播
Batch Norm实现类

class BatchNormalization:
    """
    http://arxiv.org/abs/1502.03167
    """
    def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None):
        self.gamma = gamma
        self.beta = beta
        self.momentum = momentum
        self.input_shape = None # Conv层的情况下为4维,全连接层的情况下为2维  

        # 测试时使用的平均值和方差
        self.running_mean = running_mean
        self.running_var = running_var  
        
        # backward时使用的中间数据
        self.batch_size = None
        self.xc = None
        self.std = None
        self.dgamma = None
        self.dbeta = None

    def forward(self, x, train_flg=True):
        self.input_shape = x.shape
        if x.ndim != 2:
            N, C, H, W = x.shape
            x = x.reshape(N, -1)

        out = self.__forward(x, train_flg)
        
        return out.reshape(*self.input_shape)
            
    def __forward(self, x, train_flg):
        if self.running_mean is None:
            N, D = x.shape
            self.running_mean = np.zeros(D)
            self.running_var = np.zeros(D)
                        
        if train_flg:
            mu = x.mean(axis=0)
            xc = x - mu
            var = np.mean(xc**2, axis=0)
            std = np.sqrt(var + 10e-7)
            xn = xc / std
            
            self.batch_size = x.shape[0]
            self.xc = xc
            self.xn = xn
            self.std = std
            self.running_mean = self.momentum * self.running_mean + (1-self.momentum) * mu
            self.running_var = self.momentum * self.running_var + (1-self.momentum) * var            
        else:
        	# 算法实现
            xc = x - self.running_mean
            xn = xc / ((np.sqrt(self.running_var + 10e-7)))
            
        out = self.gamma * xn + self.beta 
        return out

    def backward(self, dout):
        if dout.ndim != 2:
            N, C, H, W = dout.shape
            dout = dout.reshape(N, -1)

        dx = self.__backward(dout)

        dx = dx.reshape(*self.input_shape)
        return dx

	# 反向传播
    def __backward(self, dout):
        dbeta = dout.sum(axis=0)
        dgamma = np.sum(self.xn * dout, axis=0)
        dxn = self.gamma * dout
        dxc = dxn / self.std
        dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0)
        dvar = 0.5 * dstd / self.std
        dxc += (2.0 / self.batch_size) * self.xc * dvar
        dmu = np.sum(dxc, axis=0)
        dx = dxc - dmu / self.batch_size
        
        self.dgamma = dgamma
        self.dbeta = dbeta
        
        return dx

Batch Normalization的评估

现在我们使用Batch Norm层进行实验。首先,使用MNIST数据集,观察使用Batch Norm层和不使用Batch Norm层时学习的过程会如何变化,
代码如下:

# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.multi_layer_net_extend import MultiLayerNetExtend
from common.optimizer import SGD, Adam

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

# 减少学习数据
x_train = x_train[:1000]
t_train = t_train[:1000]

max_epochs = 20
train_size = x_train.shape[0]
batch_size = 100
learning_rate = 0.01


def __train(weight_init_std):
    bn_network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10, 
                                    weight_init_std=weight_init_std, use_batchnorm=True)
    network = MultiLayerNetExtend(input_size=784, hidden_size_list=[100, 100, 100, 100, 100], output_size=10,
                                weight_init_std=weight_init_std)
    optimizer = SGD(lr=learning_rate)
    
    train_acc_list = []
    bn_train_acc_list = []
    
    iter_per_epoch = max(train_size / batch_size, 1)
    epoch_cnt = 0
    
    for i in range(1000000000):
        batch_mask = np.random.choice(train_size, batch_size)
        x_batch = x_train[batch_mask]
        t_batch = t_train[batch_mask]
    
        for _network in (bn_network, network):
            grads = _network.gradient(x_batch, t_batch)
            optimizer.update(_network.params, grads)
    
        if i % iter_per_epoch == 0:
            train_acc = network.accuracy(x_train, t_train)
            bn_train_acc = bn_network.accuracy(x_train, t_train)
            train_acc_list.append(train_acc)
            bn_train_acc_list.append(bn_train_acc)
    
            print("epoch:" + str(epoch_cnt) + " | " + str(train_acc) + " - " + str(bn_train_acc))
    
            epoch_cnt += 1
            if epoch_cnt >= max_epochs:
                break
                
    return train_acc_list, bn_train_acc_list


# 3.绘制图形==========
weight_scale_list = np.logspace(0, -4, num=16)
x = np.arange(max_epochs)

for i, w in enumerate(weight_scale_list):
    print( "============== " + str(i+1) + "/16" + " ==============")
    train_acc_list, bn_train_acc_list = __train(w)
    
    plt.subplot(4,4,i+1)
    plt.title("W:" + str(w))
    if i == 15:
        plt.plot(x, bn_train_acc_list, label='Batch Normalization', markevery=2)
        plt.plot(x, train_acc_list, linestyle = "--", label='Normal(without BatchNorm)', markevery=2)
    else:
        plt.plot(x, bn_train_acc_list, markevery=2)
        plt.plot(x, train_acc_list, linestyle="--", markevery=2)

    plt.ylim(0, 1.0)
    if i % 4:
        plt.yticks([])
    else:
        plt.ylabel("accuracy")
    if i < 12:
        plt.xticks([])
    else:
        plt.xlabel("epochs")
    plt.legend(loc='lower right')
    
plt.show()

运行结果如下:
在这里插入图片描述

从运行结果可以看到使用Batch Norm后,学习进行得更快了。
综上,通过使用Batch Norm,可以推动学习的进行。并且,对权重初始值变得健壮(表示不那么依初始值) Batch Norm具备如此优良的性质,一定能应用在更多场合中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/670551.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

广工赢清华,炸裂!

去年2022年广工对阵清华&#xff0c;我在知乎写了文章 清华赢球靠的是广东第一高中生邹阳和2022届CBA状元王岚嵚。 比分焦灼的第四节关键时刻&#xff0c;邹阳在左角底线持球高高举起篮球&#xff0c;那个球的弧度非常高&#xff0c;皮球以稳稳的抛物线弧度掉入篮筐。 之后&…

Python基础(21)——Python函数实战、递归、lambda、高阶函数

Python基础&#xff08;21&#xff09;——Python函数实战、递归、lambda、高阶函数 文章目录 Python基础&#xff08;21&#xff09;——Python函数实战、递归、lambda、高阶函数目标一. 应用&#xff1a;学员管理系统1.1 系统简介1.2 步骤分析1.3 需求实现1.3.1 显示功能界面…

Streamlit基础教程

streamlit是什么 streamlit是一个开源的python库&#xff0c;它能够快速的帮助我们创建定制化的web应用&#xff0c;而且还非常便于和他人分享&#xff0c;特别是在机器学习和数据科学领域。整个过程不需要你了解任何前端的知识&#xff0c;包括html、css、javascript等&#x…

Vue3 计算属性和侦听器实战(computed、watch)——简易点餐页面

文章目录 &#x1f4cb;前言&#x1f3af;项目介绍&#x1f3af;项目创建&#x1f3af;代码分析&#x1f3af;完整代码&#xff08;含 CSS 代码&#xff09;&#x1f4dd;最后 &#x1f4cb;前言 这篇文章记录一下 Vue3 计算属性和侦听器 &#xff08;computed、watch&#xf…

网络安全自学能学会吗?网络安全如何学习

网络安全是近年来的热门工作&#xff0c;吸引了许多小伙伴开始学习网络安全知识。那么我们应该如何学习网络安全呢&#xff1f;这是一个很多人都在考虑的问题。网络安全可以自学吗&#xff1f;自学网络安全能不能学会&#xff1f; 无论什么知识都是自学的&#xff0c;只是说每…

数学物理学家心中的十大最美方程

“你认为最美的数学、物理方程是什么&#xff1f;”当代十位大数学家、物理学家给出了他们自己的回答。这些回答构成了大雅之美&#xff08;The Concinitas Project&#xff09;的十篇文章。我们为读者带来这些大师对自己眼中最美方程的精彩解读。 1.指标定理 撰文 阿蒂亚爵士…

机器翻译与自动文摘评价指标 BLEU 和 ROUGE

机器翻译与自动文摘评价指标 BLEU 和 ROUGE 在机器翻译任务中&#xff0c;BLEU 和 ROUGE 是两个常用的评价指标&#xff0c;BLEU 根据精确率(Precision)衡量翻译的质量&#xff0c;而 ROUGE 根据召回率(Recall)衡量翻译的质量。 1.机器翻译评价指标 使用机器学习的方法生成文…

安科瑞AWT100无线数据采集通信终端

安科瑞AWT100无线数据采集通信终端 安科瑞 崔丽洁

js \d正则匹配数字失败问题记录

记录一次的正则匹配数字失败的问题 在一次开发中&#xff0c;需要匹配卡号&#xff0c;正则表达式较为复杂&#xff0c;想通过元字符进行简化&#xff0c;便由&#xff1a; new RegExp(^622(12[6-9]|1[3-9][0-9]|[2-8][0-9]{2}|9[01][0-9]|92[0-5])[0-9]{10,}$)变成&#xff…

EMQ的使用和介绍

首先先了解一下底层的协议: 1. MQTT MQTT&#xff08;Message Queuing Telemetry Transport&#xff0c;消息队列遥测传输协议&#xff09;&#xff0c;是一种基于发布/订阅 &#xff08;publish/subscribe&#xff09;模式的"轻量级"通讯协议&#xff0c;该协议构建…

【盘点】百家量子企业正展露头角

光子盒研究院 量子计算是一个可能彻底改变我们在金融、材料科学、密码学和药物发现等领域解决复杂问题的方式。过去十年左右&#xff0c;量子计算初创公司正迅速崛起。 现在&#xff0c;根据光子盒的量子企业数据库&#xff0c;全球大约有一千家公司直接参与到量子技术中&#…

反调试技术

文章目录 前言系统API实现方式IsDebuggerPresent (0x2)NtGlobalFlag&#xff08;0x68&#xff09;Heap flags&#xff08;0x18&#xff09;CheckRemoteDebuggerPresentNtQueryInformationProcessZwSetInformationThread 示例示例1比较明文字符串和输入字符串NtGlobalFlag时间差…

支持向量机SVM的原理和python实现

文章目录 1 SVM概述1.1 概念1.2 SVM的优缺点1.2.1 优点1.2.2 缺点 2 在python中使用SVM2.1 scikit-learn库2.2 SVM在scikit-learn库中的使用2.2.1 安装依赖库2.2.2 svm.SVC2.2.3 应用实例 总结 1 SVM概述 1.1 概念 支持向量机&#xff08;SVM&#xff09;是一类按监督学习方式…

CRM系统如何选择?哪些是必备功能?

CRM系统可以收集、整理并分析客户数据、优化企业销售流程、实现团队协作和共享&#xff0c;提高客户转化率&#xff0c;实现业绩增长。那么&#xff0c;如何选择CRM系统&#xff1f;CRM系统哪家好&#xff1f; 一、明确自己的业务需求 不同行业和规模的企业有不同的业务需求&…

JMU 软件工程经济学 复习总结

文章目录 碎碎念0. 基准收益率 i1. 现金流量图2. 净现值 NPV&#xff0c;内部收益率 IRR3. 单利&#xff0c;复利计算4. 等额年金NAV5. 动态回收期 P t ′ P_t Pt′​6. 固定资产折旧 [书P44]7. 增值税8. 软件行业增值税的即征即退9. 利息备付率 ICR&#xff0c;偿债备付率 DSC…

这6种最佳移动自动化测试工具你知道吗?

最好的移动自动化测试工具 在本文章关于移动应用程序测试的这一部分中&#xff0c;我们将研究 2023 年 6 种最佳移动自动化测试工具。 1、Appium Appium 是一个非常流行的开源自动化测试框架&#xff0c;支持各种操作系统的自动化。它可以与本机、混合和移动 Web 应用程序一…

微机原理基础知识

前言 微机原理期末复习的一些概念性的基础知识总结。 内容 &#xff08;1&#xff09;微处理器、微机与微机系统三者之间有什么异同&#xff1f; &#xff08;1&#xff09;把CPU&#xff08;运算器和控制器&#xff09;用大规模集成电路技术做在一个芯片上&#xff0c;即为微…

Vue实现Base64转png、jpg

method中写两个方法&#xff1a; 根据base64转图片的方法 根据转换出blob格式的文件导出的方法 //base64转pngbase64ImgtoFile(dataurl, filename file) {const arr dataurl.split(,)const mime arr[0].match(/:(.*?);/)[1]const suffix mime.split(/)[1]const bstr a…

Windows安装postgresql数据库图文教程

数据库使用排行榜&#xff1a;https://db-engines.com/en/ranking 目录 一、软件简介 二、软件下载 三、安装教程 四、启动教程 一、软件简介 PostgreSQL是一种特性非常齐全的自由软件的对象-关系型数据库管理系统&#xff08;ORDBMS&#xff09;&#xff0c;是以加州大学计…

Python采集二手房源数据信息并做可视化展示

目录标题 前言环境使用:模块使用:python技术实现: <基本流程步骤>代码展示尾语 前言 嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! 环境使用: Python 3.8 jupyter --> pip install jupyter notebook pycharm 也可以 模块使用: requests >>> pip instal…