(十一)数据归一化方法BN/LN/GN/IN

news2024/10/6 6:03:49

文章目录

    • 0. Introduction
    • 1.Batch Normalization
    • 3.Layer Normalization
    • 4.Group Normalization
    • 6.Instance Normalization
    • 参考资料


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


0. Introduction

在神经网络的训练过程中,网络的收敛情况非常依赖于参数的初始化情况,使用Normalization的方法可以增强模型训练过程中的鲁棒性。目前常用的Normalization方法有Batch NormalizationLayer NormalizationGroup NormalizationInstance Normalization四种方法,具体分别是指在一个batch的数据上分别在不同维度上做Normalization。如下图:

在这里插入图片描述

图中N表示一个Batch的大小,WH表示特征图宽高方向resize到一起后的维度方向,C表示不同的特征通道,G表示在通道方向做Group Normalization时每组包含的通道数的大小。

1.Batch Normalization

Batch Normalization是谷歌的Sergey Ioffe等于2015年03月份提交的论文Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift中提出的。

在这里插入图片描述

其中 x i x_i xi是维度为 C C C的数据,分别求每个维度在 b a t c h batch batch方向的均值和方差,然后进行归一化。值得注意的是方程

y i ← γ x i ^ + β y_i \leftarrow\gamma \hat{x_i}+\beta yiγxi^+β

相当于对归一化后的数据做了线性变换,这里 γ \gamma γ β \beta β都是在网络训练过程中需要学习的参数。根据上述BN的计算方式可求得反向传播的链路图:

在这里插入图片描述

由此使用Batch Normalization Layer时,其对应的反向和前向推理代码为,参考自CS231N homework2:


## Forward
def batchnorm_forward(x, gamma, beta, bn_param):
    """Forward pass for batch normalization.

    During training the sample mean and (uncorrected) sample variance are
    computed from minibatch statistics and used to normalize the incoming data.
    During training we also keep an exponentially decaying running mean of the
    mean and variance of each feature, and these averages are used to normalize
    data at test-time.

    At each timestep we update the running averages for mean and variance using
    an exponential decay based on the momentum parameter:

    running_mean = momentum * running_mean + (1 - momentum) * sample_mean
    running_var = momentum * running_var + (1 - momentum) * sample_var

    Note that the batch normalization paper suggests a different test-time
    behavior: they compute sample mean and variance for each feature using a
    large number of training images rather than using a running average. For
    this implementation we have chosen to use running averages instead since
    they do not require an additional estimation step; the torch7
    implementation of batch normalization also uses running averages.

    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - bn_param: Dictionary with the following keys:
      - mode: 'train' or 'test'; required
      - eps: Constant for numeric stability
      - momentum: Constant for running mean / variance.
      - running_mean: Array of shape (D,) giving running mean of features
      - running_var Array of shape (D,) giving running variance of features

    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    mode = bn_param["mode"]
    eps = bn_param.get("eps", 1e-5)
    momentum = bn_param.get("momentum", 0.9)

    N, D = x.shape
    running_mean = bn_param.get("running_mean", np.zeros(D, dtype=x.dtype))
    running_var = bn_param.get("running_var", np.zeros(D, dtype=x.dtype))

    out, cache = None, None
    if mode == "train":

        avg = x.mean(axis=0)
        var = x.var(axis=0)
        std = np.sqrt(var)
        x_hat = avg / (std + eps)
        out = x_hat * gamma + beta
        
        shape = bn_param.get("shape", (N, D))
        axis = bn_param.get("axis", 0)
        cache = x, avg, var, std, gamma, x_hat, beta, shape, axis

        if axis == 0:
          running_mean = running_mean * momentum + (1 - momentum) * avg
          running_var = running_var * momentum + (1 - momentum) * var
    elif mode == "test":

        x_hat = (x - running_mean) / (np.sqrt(running_var) + eps)
        out = x_hat * gamma + beta

    else:
        raise ValueError('Invalid forward batchnorm mode "%s"' % mode)

    # Store the updated running means back into bn_param
    bn_param["running_mean"] = running_mean
    bn_param["running_var"] = running_var

    return out, cache

## Backward
def batchnorm_backward_alt(dout, cache):
    """Alternative backward pass for batch normalization.

    For this implementation you should work out the derivatives for the batch
    normalizaton backward pass on paper and simplify as much as possible. You
    should be able to derive a simple expression for the backward pass.
    See the jupyter notebook for more hints.

    Note: This implementation should expect to receive the same cache variable
    as batchnorm_backward, but might not use all of the values in the cache.

    Inputs / outputs: Same as batchnorm_backward
    """
    dx, dgamma, dbeta = None, None, None
    _, _, _, std, gamma, x_hat, _, shape, axis = cache # expand cache
    S = lambda x: x.sum(axis=0)                     # helper function
    
    dbeta = dout.reshape(shape, order='F').sum(axis)            # derivative w.r.t. beta
    dgamma = (dout * x_hat).reshape(shape, order='F').sum(axis) # derivative w.r.t. gamma
    
    dx = dout * gamma / (len(dout) * std)          # temporarily initialize scale value
    dx = len(dout)*dx  - S(dx*x_hat)*x_hat - S(dx) # derivative w.r.t. unnormalized x

    return dx, dgamma, dbeta

在以上代码中,BatchNorm层在训练结束推理时使用的是训练时得到的running averagerunning variance,在反向传播梯度时是根据链式法则求出BN层整体的梯度公式来计算梯度,可以减少中间变量的存储和计算,减少运算量和内存占用。

3.Layer Normalization

Batch Normalization在使用过程中依赖batch size的大小,当模型比较复杂,占用内存过多时很难使用大的batch size进行网络训练,这时BN的效果会受到限制,2016Hinton等提出的LayerNormalization克服了这些问题,可以作为batch size 较小时Batch Normalization的一种替代方案。

在这里插入图片描述

其中, H H H表示当前层隐层单元的数量,当使用的是卷积神经网络时,Layer Normalization是作用在卷积核作用在输入上得到的输出的每个通道上,输出的每个通道算做一层,在该层上做Normalization

代码实现:

def layernorm_forward(x, gamma, beta, ln_param):
    """Forward pass for layer normalization.
    During both training and test-time, the incoming data is normalized per data-point,
    before being scaled by gamma and beta parameters identical to that of batch normalization.
    Note that in contrast to batch normalization, the behavior during train and test-time for
    layer normalization are identical, and we do not need to keep track of running averages
    of any sort.
    Input:
    - x: Data of shape (N, D)
    - gamma: Scale parameter of shape (D,)
    - beta: Shift paremeter of shape (D,)
    - ln_param: Dictionary with the following keys:
        - eps: Constant for numeric stability
    Returns a tuple of:
    - out: of shape (N, D)
    - cache: A tuple of values needed in the backward pass
    """
    out, cache = None, None
    eps = ln_param.get("eps", 1e-5)
    ln_param.setdefault('mode', 'train')       # same as batchnorm in train mode
    ln_param.setdefault('axis', 1)             # over which axis to sum for grad
    [gamma, beta] = np.atleast_2d(gamma, beta) # assure 2D to perform transpose

    out, cache = batchnorm_forward(x.T, gamma.T, beta.T, ln_param) # same as batchnorm
    out = out.T                                                    # transpose back
    return out, cache


def layernorm_backward(dout, cache):
    """Backward pass for layer normalization.
    For this implementation, you can heavily rely on the work you've done already
    for batch normalization.
    Inputs:
    - dout: Upstream derivatives, of shape (N, D)
    - cache: Variable of intermediates from layernorm_forward.
    Returns a tuple of:
    - dx: Gradient with respect to inputs x, of shape (N, D)
    - dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
    - dbeta: Gradient with respect to shift parameter beta, of shape (D,)
    """
    dx, dgamma, dbeta = None, None, None
    dx, dgamma, dbeta = batchnorm_backward_alt(dout.T, cache) # same as batchnorm backprop
    dx = dx.T # transpose back dx
    return dx, dgamma, dbeta

从上面代码可以看到,Layer Normalization是在每个样本的每层输出上实现的,因此可以复用Batch Normalization的实现。

4.Group Normalization

Group Normalization是2018年06月份HeKaiMing等提出的论文中发表的方法,作为Batch Normalization的另一种替代。

在这里插入图片描述

## pytorch example
import torch
x = torch.randn(1, 4, 2, 2)
m = torch.nn.GroupNorm(2, 4)
output = m(x)
print(output)

# equal to 
gx1 = x[:, :2, :, :]
gx2 = x[:, 2:, :, :]
mu1 = torch.mean(gx1)
mu2 = torch.mean(gx2)
std1 = torch.sqrt(torch.var(gx1))
std2 = torch.sqrt(torch.var(gx2))
x[:, :2, :, :] = (x[:, :2, :, :] - mu1) / (std1  + 1e-05)
x[:, 2:, :, :] = (x[:, 2:, :, :] - mu2) / (std2 + 1e-05)
print(x)

6.Instance Normalization

Instance Normalization 是2017年1月份Dmitry Ulyanov等发表的论文Improved Texture Networks: Maximizing Quality and Diversity in Feed-forward Stylization and Texture Synthesis中的提出的方法,其作用在单个样本的一个通道上,相当于num_groups=1Group Normalization

在这里插入图片描述


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


参考资料

  • 1.https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
  • 2.https://github.com/mantasu/cs231n/blob/master/assignment2/cs231n/layers.py
  • 3.Group Normalization in Pytorch (With Examples)
  • 4.GROUPNORM

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/84384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习分类算法之逻辑回归

1、基础知识: 逻辑回归:logistic regression二分类:binary classification 类别一类别二noysefalsetrue01negative classpositive class 线性回归模型用于分类,效果一般;逻辑回归是最广泛使用的分类算法;…

main入口函数分析

在开始讲解之前,分享一些阅读 项目代码的经验。无论学习哪方面的知识,都是需要正反馈才能继续学下去。在学习开源项目的时候,如果不掌握一些比较好的方法,会比较难拿到正反馈,或者要坚持学习很久才能拿到正反馈。 我个…

JAVA毕业设计——基于Springboot的动漫论坛系统(源代码+数据库+ppt文档)

github代码地址 https://github.com/ynwynw/cartoonForum-public 毕业设计所有选题地址 https://github.com/ynwynw/allProject #动漫论坛系统 #java web #java #毕业设计 #课程设计 #JPa #Springboot #mysql #源代码 基于Springboot的动漫论坛系统(源代码数据库ppt文档)040 …

Python文件操作注意事项

今天继续给大家介绍Python相关知识,本文主要内容是Python文件操作注意事项。 一、文件操作流程注意事项 在文章Python文件操作详解(一)中,我们讲解过,文件操作的流程是打开文件——操作文件——关闭文件。如果我们在…

设计用于汽车和车身SPC58NH92C3RMI0X\SPC560B50L1B4E0X微控制器

SPC560B50x系列 32 位微控制器是集成汽车应用控制器的最新成就。它属于一个不断扩大的以汽车为中心的产品家族,旨在解决下一波汽车内部的车身电子应用。该汽车控制器系列的先进且经济高效的主机处理器核心符合 Power Architecture 嵌入式类别,仅实现 VLE…

Qt扫盲-QRadioButton理论总结

QRadioButton理论总结1. 简介2. 自动排外3. 信号槽4. 外观&快捷键1. 简介 QRadioButton是一个选项按钮,可以打开(选中)或关闭(未选中)。单选按钮通常为用户提供”众多”选项之一。在一组单选按钮中,一…

黑客隔空盗密码,你的账户安全吗?

一、NFC卡防互动,怎样才能更安全? 想知道黑客如何破解你的设备,盗取你的信息吗?这一黑科技设备将向你展示黑客是如何隔空盗取你银行卡的账号密码的。 模拟黑客使用一张RFID读卡器,近距离靠近你的银行卡时,…

java毕设_第172期ssm高校毕业生就业满意度调查统计系统_计算机毕业设计

java毕设_第172期ssm高校毕业生就业满意度调查统计系统_计算机毕业设计 【源码请到下载专栏下载】 今天分享的项目是《ssm高校毕业生就业满意度调查统计系统》 该项目分为2个角色,管理员和用户。 用户可以浏览前台,包含功能有:进行问卷提交、 就业咨询、试题列表进行…

Mentor-dft 学习笔记 day40-Saving Timing Patterns(1)

Timeplate Examples 例如,移位周期40ns,占空比为50%,timeplate所示: timeplate tp_shift force_pi 0; measure_po 5; pulse_clock 10 20; period 40; end;拉伸的timeplate可用于将时钟脉冲延迟40ns,同时保持相同的20…

计算机研究生就业方向之运营商(移动,联通,电信)

我一直跟学生们说你考计算机的研究生之前一定要想好你想干什么,如果你只是转码,那么你不一定要考研,至少以下几个职位研究生是没有啥优势的: 1,软件测试工程师(培训一下就行) 2,前…

[附源码]Nodejs计算机毕业设计基于Java网络游戏后台管理系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

HackTheBox Precious CVE-2022-25765利用,YAML反序列化攻击提权

靶机网址: https://app.hackthebox.com/machines/Precious枚举 使用nmap枚举靶机 nmap -sC -sV 10.10.11.189发现域名,我们本地DNS解析一下 echo "10.10.11.189 precious.htb" >> /etc/hosts然后访问网站 CVE-2022-25765利用 他的功…

网络流量分析帮助企业提升OA应用性能(一)

需求简介 某外高桥公司的OA系统是其重要的业务系统,OA系统负责人表示,部分用户反馈,访问OA系统时比较慢。需要通过分析系统看一下实际情况。 信息部已对企业领导定义了独立的组,本次要主动分析OA使用体验快慢。如果OA系统存在访…

mybatis入门02:Mybatis核心文件配置

目录 2.1 MyBatis核心配置文件层级关系 2.2MyBatis常用配置解析 1.environments标签 2.mapper标签 3.properties标签 4.typeAlisases标签 2.3 Mybatis相应的API 1.SqlSessionFactory工厂构造器SqlSessionFactoryBuilder 2.SqlSession工厂对象SqlSessionFactory 3.SqlSe…

实时监控网络流量,精准辨别网络性能瓶颈

网络流量反映网络运作状态,是辨别网络运行是否正常的关键指标,通过对网络流量进行监测不仅能反映交换机、路由器等设备的工作状态,更能体现整个网络资源的运行性能。同时,用户在网络中的行为可以通过其承载的流量动态来展现&#…

使用个从版gitee时向远程库push修改后内容时报remote: error: File: xxx 129.03 MB, exceeds 100.00 MB

1、报错时截图如下(我以下所有命令都是在Git Bash中执行的): 这是先前git push报的提示,明显说LFS(即large file system,此处大文件应该就是指过超过100M的单一文件)仅仅针对企业版gitee用户才…

.Net 7 CLR和ILC编译函数过程

楔子 由于甲方的需求,随着研究深入,发现CLR编译函数与ILC编译是两种不同的截然方式,除了JIT部分编译一样,其它部分貌似完全不一。 本篇来梳理这些东西。QQ:676817308。wx公众号:江湖评谈 示例: 作为例子…

OH----基于RK3568的AB分区功能,bsp部分

1、背景: OH master 主线 ,RK3568平台添加AB分区功能,uboot部分完成对ab分区标志位的读取解析,并加载和进入对应的分区,如: kernel_a 或者 kernel_b 2、环境: rk3568 Uboot代码下载&#xff…

股票L2接口和L1接口有什么差距?

股票L2数据的主要特点是能看到资金流向和十档买卖盘,比L1数据更加清晰和全面。 但是就现在的股票市场而言,也不能全部听信L2数据。 很多数据也是庄家做出来的,就是为了给散户看,所以全面分析基本面和技术面才是最重要的。 而且…

[附源码]计算机毕业设计个人博客系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis MavenVue等等组成,B/S模式…