BatchNormalization和Layer Normalization解析

news2024/12/26 10:50:34

Batch Normalization

是google团队2015年提出的,能够加速网络的收敛并提升准确率

1.Batch Normalization原理

图像预处理过程中通常会对图像进行标准化处理,能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应的feature map的数据要满足分布规律)。而我们BN的目的就是使feature map满足均值为0,方差为1的分布规律。

对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,其中x^1就代表我们的R通道所对应的特征矩阵,依次类推。标准化处理也就是分别对R通道,G通道,B通道进行处理。

让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后再进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的BN,也就是计算一个Batch数据的feature map然后进行标准化(batch越大越接近整个数据集的分布,效果越好)。

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么x^1代表batch的所有的feature的channel1的数据。然后分别计算x^1和x^2的均值和方差。然后再根据标准差计算公式分别计算每个channel 的值(\varepsilon是很小的常量,放置分母为0的情况)。在训练过程中要去不断地计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。然后再我们的验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理。

\gamma是用来调整数值分布的方差大小,默认为1,\beta是用来调节数值均值的位置,默认值为0。这两个参数实在反向传播过程中学习到的。

2.使用Pytorch进行实验

在训练过程中,均值和方差是同通过计算当前批次数据得到的记录为\mu _{now},\delta_{now} ^{2},而我们的验证以及预测过程中使用的均值方差是一个统计量为\mu _{statistic},\delta _{statistic}^{2}。具体更新策略如下,其中momentum默认取0.1:

\mu _{statistic+1} = 0.9*\mu _{statistic}+0.1*\mu _{now}\\ \delta _{statistic+1}^{2} = 0.9*\delta _{statistic}^{2}+0.1*\delta _{now}^{2}

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化\beta=0,\gamma=1。

import numpy as np
import torch.nn as nn
import torch

def bn_process(feature, mean, var):
    feature_shape = feature.shape
    for i in range(feature_shape[1]):
        # [batch,channel, height, weight]
        feature_t = feature[:, i, :, :]
        mean_t = feature_t.mean()
        #总体标准差
        std_t1 = feature_t.std()
        #样本标准差
        std_t2 = feature_t.std(ddof = 1)

        #bn process
        #这里记得加上eps和pytorch保持一致
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2+ 1e-5)
        #更新计算均值
        mean[i]  = mean[i]*0.9 + mean_t * 0.1
        var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
    print(feature)

#随机生成一个batch为2,channel为2,height=width=2的特征向量
#[batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
#初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
#print(feature1.numpy())

#注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)

bn = nn.BatchNorm2d(2, eps =  1e-5)
output = bn(feature1)
print(output)

 

3.使用BN时需要注意的问题

(1)训练时要将training采纳数设置为True,在验证时将training参数设置为False。在Pytorch中了可以通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层和激活层之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,及时使用了偏置bias求出的结果也是一样的。

 


Layer Normalization

Layer Normalization针对自然语言处理提出的,为什么不用BN呢,因为在RNN这类时序网络中,时序的长度并不是一个定值(网络深度不一定相同),比如每句话的长短都不一定相同,所以很难去使用BN,所以作者提出了Layer Normalization(图像处理领域BN比LN更有效),但现在很多人将自然语言领域的模型用来处理图像,比如Vision Transformer,此时会涉及到LN。

直接看Pytorch 官方给出的关于LayerNorm 的介绍。不同的是,BN是对一个batch数据的每个channel进行Norm处理,一个for循环,但LN是对单个数据的制定维度进行Norm处理与batch无关而且BN中训练时是需要累计moving_mean和moving_var两个变量的(所以BN中有4个参数moving_mean,moving_var,\beta ,\gamma),但LN不需要累计只有\beta ,\gamma两个参数。

在Pytorch的LayerNorm类中有个normalized_shape参数,可以指定要Norm的维度(注意,函数说明中the last certain number of dimensions,指定的维度必须是从最后一维开始)。比如我们的数据shape是[4,2,3],那么normalized_shape可以是[3](最后一维进行Norm处理),也可以是[2,3](Norm最后两个维度),也可以是整个维度[4,2,3],但不能是[2]或者[4,2],否则会报错。

y = \frac{x-E[X]}{\sqrt{Var[x]+\varepsilon}}*\gamma +\beta

import torch
import torch.nn as nn

def layer_norm_process(feature:torch.Tensor, beta=0.,gamma = 1.,eps=1e-5):
    var_mean = torch.var_mean(feature, dim = -1, unbiased = False)
    #均值
    mean = var_mean[1]
    #方差
    var = var_mean[0]

    #layer norm process
    feature  = (feature - mean[..., None]) / torch.sqrt(var[..., None] + eps)
    feature = feature*gamma+beta

    return feature

def main():
    t = torch.randn(4, 2, 3)
    print(t)
    #仅在最后一个维度上做norm处理
    norm = nn.LayerNorm(normalized_shape= t.shape[-1], eps = 1e-5)
    #官方layer norm处理
    t1 = norm(t)
    #自己实现的layer norm处理
    t2 = layer_norm_process(t, eps = 1e-5)
    print("t1:\n",t1)
    print("t2:\n",t2)

if __name__ == '__main__':
    main()
tensor([[[ 0.8512,  0.4201, -0.3457],
         [ 0.4701, -0.0647,  0.0733]],

        [[-0.9950, -0.4634,  0.0540],
         [ 0.4096,  0.4037, -0.0914]],

        [[-2.3165,  1.3059,  0.3183],
         [-0.9716,  0.4956,  0.4524]],

        [[-0.6209, -0.5958,  0.3212],
         [-0.8762,  0.3176, -0.5427]]])
t1:
 tensor([[[ 1.0963,  0.2254, -1.3218],
         [ 1.3697, -0.9893, -0.3804]],

        [[-1.2302,  0.0110,  1.2192],
         [ 0.7198,  0.6942, -1.4140]],

        [[-1.3642,  1.0050,  0.3591],
         [-1.4137,  0.7385,  0.6752]],

        [[-0.7355, -0.6783,  1.4138],
         [-1.0123,  1.3614, -0.3490]]], grad_fn=<NativeLayerNormBackward0>)
t2:
 tensor([[[ 1.0963,  0.2254, -1.3218],
         [ 1.3697, -0.9893, -0.3804]],

        [[-1.2302,  0.0110,  1.2192],
         [ 0.7198,  0.6942, -1.4140]],

        [[-1.3642,  1.0050,  0.3591],
         [-1.4137,  0.7385,  0.6752]],

        [[-0.7355, -0.6783,  1.4138],
         [-1.0123,  1.3614, -0.3490]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1832013.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python 数据持久化:使用 SQLite3 进行简单而强大的数据存储

&#x1f340; 前言 博客地址&#xff1a; CSDN&#xff1a;https://blog.csdn.net/powerbiubiu &#x1f44b; 简介 SQLite3是一种轻量级嵌入式数据库引擎&#xff0c;它在Python中被广泛使用。SQLite3通常已经包含在Python标准库中&#xff0c;无需额外安装。你只需导入 s…

antd的表格组件错乱问题

环境 react&#xff1a;17.0.2 antd&#xff1a;3.26.20 问题 表格头列宽度和表格体列宽度不一致&#xff0c;表格错乱 解决 针对这个问题官方github仓库里面有专门的issues https://github.com/ant-design/ant-design/issues/13825 里面给出了几种解决方案&#xff1a…

31、matlab卷积运算:卷积运算、二维卷积、N维卷积

1、conv 卷积和多项式乘法 语法 语法1&#xff1a;w conv(u,v) 返回向量 u 和 v 的卷积。 语法2&#xff1a;w conv(u,v,shape) 返回如 shape 指定的卷积的分段。 参数 u,v — 输入向量 shape — 卷积的分段 full (默认) | same | valid full&#xff1a;全卷积 ‘same…

Compose 可组合项 - DatePicker、DatePickerDialog

一、概念 一般是以对话框的形式呼出&#xff0c;DatePickerDialog 就是对 DatePicker 的一个简单对话框封装。 Composable fun DatePicker( state: DatePickerState, modifier: Modifier Modifier, dateFormatter: DatePickerFormatter remember { DatePickerFor…

15.编写自动化测试(下)

标题 三、控制测试流程3.1 添加测试参数3.2 并行或连续运行测试3.3 显示函数输出3.4 指定/过滤测试用例名称3.5 忽略某些测试用例3.6 只运行被忽略的测试 四、测试的组织结构4.1 概念引入4.2 测试私有函数4.2 单元测试4.3 集成测试4.4 集成测试中的子模块4.5 二进制crate的集成…

【漏洞复现】畅捷通T+ keyEdit SQL注入漏洞

免责声明&#xff1a; 本文内容旨在提供有关特定漏洞或安全漏洞的信息&#xff0c;以帮助用户更好地了解可能存在的风险。公布此类信息的目的在于促进网络安全意识和技术进步&#xff0c;并非出于任何恶意目的。阅读者应该明白&#xff0c;在利用本文提到的漏洞信息或进行相关测…

Confluence安装

Confluence安装 1.安装 #下载confluence版本&#xff08;8.5.11&#xff09; https://www.atlassian.com/software/confluence/download-archives #修改权限 chmod x atlassian-confluence-8.5.11-x64.bin #执行安装 ./atlassian-confluence-8.5.11-x64.bin按照以下提示输入&…

SD-WAN在教育行业的应用及优势解析

随着教育领域的数字化转型&#xff0c;网络技术的需求变得愈发迫切。作为一种前沿的网络解决方案&#xff0c;SD-WAN正在为教育行业提供强有力的支持。本文将详细探讨SD-WAN在教育行业的应用&#xff0c;并分析其为教育行业带来的众多优势。 实现多校区高效互联 教育机构通常拥…

稳了?L3规模化落地在即,激光雷达公司成首批赢家

作者 | 芦苇 编辑 | 德新 在中国&#xff0c;距L3级自动驾驶的规模化落地&#xff0c;又近了一步。 随着国内试点政策刷新&#xff0c;越来越多的车企在部分市域获得了自动驾驶测试牌照&#xff0c;能上路测试的L3级自动驾驶车辆正在快速增加。 其中一个重要节点是&#xf…

Python基础用法 之 转义字符

将两个字符进⾏转义 表示⼀个特殊的字符 \n ---> 换⾏&#xff0c;回⻋ \t ---> 制表符, tab键 注意&#xff1a; print( end\n)&#xff1a; print() 函数中默认有⼀个 end\n, 所以,每个 print 结束之后, 都会输出⼀ 个换行。 未完待续。

Java数据类型及运算符及数组(与C语言对比)

Java和C语言在数据类型大部分相同&#xff0c;但是也有不同 1.新增了byte类型&#xff08;相当于C语言中把char用作整数一样&#xff09; 2.然后就是char类型的大小改为了2字节。 3.布尔型改名为boolean而不是bool,且大小没有明确规定&#xff0c;方便进行不同平台之间的移…

使用dev_dbg调试

首先内核要使能两个配置才可以使用。一般内核都是打开的。 CONFIG_DEBUG_FSy CONFIG_DYNAMIC_DEBUGy 当编译选项CONFIG_DYNAMIC_DEBUG打开的时候&#xff0c;在编译阶段&#xff0c;kernel会把所有使用dev_dbg()的信息记录在一个table中&#xff0c;这些信息我们可以从/sys/k…

在线预览多类型文件_全栈

目录 一、下载运行项目 二、项目功能 三、前端项目引用 四、文件预览样式更改 在做项目时经常用到在线预览文件&#xff0c;给大家介绍一个好用的在线预览文件项目。使用技术是后端Java&#xff0c;前端Freemarker模板。 FreeMarker 特别适应与 MVC 模式的 Web 应用&#x…

从“产品的RFM分析”看如何探索“职业方向”

我们在做产品分析时&#xff0c;经常会用到一种方法“产品的RFM分析”&#xff0c;它是一种客户细分和价值评估的常用方法&#xff0c;广泛应用于电子商务、零售和其他众多行业&#xff0c;它可以帮助企业和产品团队更好地理解用户行为&#xff0c;优化营销策略&#xff0c;提升…

解禁日大涨,爱玛科技的投资前景值得信任吗?

6月17日&#xff0c;爱玛迎来6.28亿股、金额超190亿元的解禁&#xff0c;占总股本72.91%。不过&#xff0c;爱玛股价在巨量解禁中反而迎来涨势&#xff0c;因为这部分股票中&#xff0c;创始人张剑持有的限售股数量几乎就占了爱玛总股本的七成。某种意义上&#xff0c;市场认为…

【产品经理】订单处理4-拆单策略

上次讲解了订单的促销策略&#xff0c;本次讲解下订单处理过程中的拆单策略。 订单拆单策略分为自动拆单、手动拆单&#xff0c;拆单时机也分为订单未被审核前拆单、订单审核后因仓库/快递情况的拆单&#xff0c;本次主要讲解订单未被审核前拆单、订单审核后快递超重的拆单&am…

ollama模型CPU轻量化部署

一、定义 ollama 定义环境部署demo加载本地模型方法基本指令关闭开启ollamaollama 如何同时 运行多个模型, 多进程ollama 如何分配gpu修改模型的存储路径 二、实现 ollama 定义 ollama 是llama-cpp 的进一步封装&#xff0c;更加简单易用&#xff0c;类似于docker. 模型网址…

SFNC —— 标准特征命名约定(一)

系列文章目录 SFNC —— 标准特征命名约定&#xff08;一&#xff09; 文章目录 系列文章目录1、介绍1.1 约定&#xff08;Conventions&#xff09;功能名称和接口&#xff08;Feature Name and Interface&#xff09;功能类别&#xff08;Feature Category&#xff09;功能级别…

菜单栏(骆驼书)

代码如下&#xff1a; 效果图&#xff1a;

使用宝塔面板部署Django应用(不成功Kill Me!)

使用宝塔面板部署Django应用 文章目录 使用宝塔面板部署Django应用 本地操作宝塔面板部署可能部署失败的情况 本地操作 备份数据库 # 备份数据库 mysqldump -u root -p blog > blog.sql创建requirements # 创建requirements.txt pip freeze > requirements.txt将本项目…