神经网络中BN层简介及位置分析

news2024/9/20 20:30:13

1. 简介

Batch Normalization是深度学习中常用的技巧,Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe and Szegedy, 2015) 第一次介绍了这个方法。

这个方法的命名,明明是Standardization, 非要叫Normalization, 把本来就混用、意义不明的两个词更加搅得一团糟。那standardization 和 Normalization有什么区别呢?

一般是下面这样(X是输入数据集):

  • normalization(也叫 min-max scaling),一般译做 “归一化”:

  • standardization,一般译做 “标准化”:

Batch-Norm 是一个网络层,对中间结果作上面说的 standardization 操作。实际上 standardization 也可以叫做 Z-score normalization。所以可以这样理解,standardization 是一种特殊的 normalization。normalization 作为一个 scaling 的大类,包括 min-max scaling,standardization 等。

2. BatchNorm

对输入进行标准化的时候,计算每个特征在样本集合中的均值、方差;然后将每个样本的每个特征减去该特征的均值,并除以它的方差。用数学公式表示,即:

而所谓的BatchNorm, 就是神经网络中间,在一小撮batch样本中进行标准化。具体如下(B: batch size)

注意,BatchNorm作为神经网络的一层,是有两个参数\left ( \gamma, \beta \right )要训练的,分别称为拉伸和偏移参数。可能你会有疑问,既然已经对 u_{b} 作了标准化得到了\hat{u_{b}}  ,为什么还要用 \gamma, \beta将它“还原”呢?

实际上,设置这两个参数是为了给神经网络足够的自由度。如果经过训练\gamma \approx \hat{\sigma} _{batch},\beta \approx \hat{\mu} _{batch}, 说明神经网络认为,不需要进行批标准化即可使loss function最小化,我们也充分“尊重”它的选择。

3. BN 的特点

  • 使用 BatchNorm,我们可以尝试更大的学习率,从而加速收敛,但一般不会改变模型的精度;
  • BatchNorm 的效果依赖于 Batch size;一般需要较大的 Batch size(>16)才能有好的效果
  • 和 Dropout 一样,BatchNorm 在训练和推理时有不同的行为:训练时,它基于每个 batch 计算均值和方差,因此 batch size 必须足够大才能较好反映统计性质;推理时,BatchNorm 则直接用训练集整体的均值和方差进行标准化

训练集整体的均值和方差如何得到?——在每个batch的均值和方差计算中,通过移动平均估算得到。

4. BN 的位置

BatchNorm 究竟应该放在哪,现在还存在争议。很多人说应该放在激活函数之前,但也有声音说应该放在激活函数之后。思考一下,两种说法都有道理。举个简单的例子。

前一种说法是要对 \omega x+b作BatchNorm,这样可以保证 \omega x+b 在0附近, {\sigma}'(\omega x+b) 不至于太小;后一种说法 BatchNorm 的作用对象则直接是x ,这样可以控制梯度 \frac{\partial y}{\partial w} 在合理的范围内,不会因为 x 的极端取值而波动过大。

但现在看来,前一种声音是占上风的:将 BatchNorm 作用在全连接层和卷积层的输出上,激活函数之前。在全连接网络中,顺序是:线性组合+BatchNorm+Activation

对于全连接层,BatchNorm 作用在特征维上。假设输入矩阵大小是 m×n —— m 等于 batch size,即这个小批量中的样本数, n 表示特征数。我们要在每个特征上计算 m 个样本的均值和方差,也就是对每一列做计算。

在卷积神经网络中,顺序是:卷积层+BatchNorm+Activation+池化+全连接。要注意一点是,如果卷积层有K个卷积核(即K个通道),要对每个通道的输出分别做批标准化,且每个通道都拥有独立的拉伸和偏移参数。

对于卷积层,BatchNorm 作用在通道维上。我们先考虑一个 1×1 的卷积层,通道数为 k 。它其实就等价于神经元个数为 k 的全连接层。图片中每个像素点都由一个 k 维的向量表示,可以看作是像素点的 k 个特征。同一批量各个图片的各个像素点,就是不同的样本,共有 m×p×q 个样本, m,p,q 分别为 batch size、高、宽。

类比全连接层 BatchNorm 作用在特征维上,要在每个通道(即每个特征)上计算 m×p×q 个样本的均值和方差

设小批量中有m个样本。在单个通道上,假设卷积计算输出的高和宽分别为p和q。我们需要对该通道中m×p×q个元素同时标准化:对这些元素做标准化计算时,我们使用相同的均值和方差,即该通道中m×p×q个元素的均值和方差。——卷积神经网络之Batch Normalization(一)

5. BN的理解与延伸

BN 效果好是因为 BN 的存在会引入 mini-batch 内其他样本的信息,就会导致预测一个独立样本时,其他样本信息相当于正则项,使得 loss 曲面变得更加平滑,更容易找到最优解。相当于一次独立样本预测可以看多个样本,学到的特征泛化性更强,更加 general

Conv+BN+Relu 是卷积网络的一个常见组合。在模型推理时,BN 层的参数已经固定下来,本质就是一个线性变换。我们可以把 Conv+BN+Relu 进行算子融合,以加速模型推理

除了BN层,还有GN(Group Normalization)、LN(Layer Normalization、IN(Instance Normalization)这些个标准化方法,每个标注化方法都适用于不同的任务。

这个图很好地说明了BatchNorm、LayerNorm、InstanceNorm、GroupNorm的区别。N代表batch size;C代表卷积核个数(通道个数);H,W代表卷积结果的高和宽。

BatchNorm: 计算均值和方差时,考虑N * H * W 个元素;对每个通道分别做标准化

LayerNorm:计算均值和方差时,考虑C * H * W 个元素;对batch中的每个instance分别做标准化

InstanceNorm:计算均值和方差时,考虑H * W 个元素;对每个通道、batch中的每个instance分别做标准化

GroupNorm:介于LayerNorm和InstanceNorm二者之间,将C个通道分组,然后进行标准化。

直觉上来讲,GroupNorm把提取到类似特征的不同卷积核分到同一个group中。对这些卷积核进行标准化,确实make sense. 而且GroupNorm摆脱了对batch size的依赖。

GN在训练集上表现最好,在测试集上稍逊于BN(引自 Group Normalization (Yuxin & Kaiming, 2018))

6. BN vs LN

Transformer模型中用到了LayerNorm,着重对比一下LayerNorm和BatchNorm。

对于一个输入序列 (x1,x2,...,xn) ,每一个 xi 都是 d 维的向量。譬如输入序列是一个句子,每个单词 xi 都用一个 d 维的向量表示。

X轴是序列长度(n),Y轴是特征个数(d),Z轴是Batch size

此时BatchNorm是对图中蓝色框作标准化处理,就像我们上面说的——对每个特征分别做标准化;而LayerNorm针对每一个输入序列,对图中黄色框作标准化处理。总结来说,BatchNorm盯住每一个特征;而LayerNorm盯住的是每一个样本。

那么为什么Transformer模型要用LayerNorm而不是BatchNorm呢?

实际上,序列模型的背景下,BatchNorm有一个天然的硬伤,这使得它在所有序列模型中都不吃香:输入序列的长度(n)可能不一致。一般来说,我们会规定一个最长的序列长度,长度不够的序列用0填充。譬如下图这样,Batch中的序列长短不一。

如果用BatchNorm,以一个feature为例,它的标准化有效范围是蓝色的图,其余用0填充;如果是LayerNorm,对于4个序列,它们的标准化有效范围是黄色的图。

直觉上来说,对于BatchNorm的计算方法,当Batch中序列长度差距过大时,均值和方差的波动也会很大

但这个问题对于LayerNorm来说并不存在,因为它是在每一个序列内部计算均值和方差的。

这样,我们可以直观地理解,为什么BatchNorm对于序列模型并不好用;为什么Transformer要采用LayerNorm

7. BN代码实现

我们翻一翻常见的backbone的结构。可以看到在官方Pytorch的resnet.pyclass BasicBlock中,forward时的基本结构是Conv+BN+Relu:

# 省略了一些地方
class BasicBlock(nn.Module):
    def __init__(self,...) -> None:
        ...
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        # 常见的Conv+BN+Relu
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        # 又是Conv+BN+relu
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)

        return out

resnet作为我们常见的万年青backbone不是没有理由的,效果好速度快方便部署。当然还有很多其他优秀的backbone,这些backbone的内部结构也多为Conv+BN+Relu或者Conv+BN的结构。

参考资料:BatchNorm and its variants - 知乎normalization 和 standardization 到底什么区别?_为什么batch normalization使用standardization而不是normaliz-CSDN博客不论是训练还是部署都会让你踩坑的Batch Normalization - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1238476.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PowerQuery领域的经典之作“猴子书“中文版来啦!

与数据打交道,还在纠结于Excel、SQL、VBA、Python?数据处理领域经典之作PowerQuery"猴子书"让你用更聪明的方法处理数据。学完这本书,你就掌握了Power Query的一切,想要学Power Query,只需要这一本就够啦&am…

提升企业人效,从精细化考勤管理开始

过去,许多企业提到考勤管理,只能关联到打卡、请假、算薪这些简单的事务性流程。随着越来越多企业希望通过数字化转型来提升运营效率,实现精细化人员管理。考勤数据的作用也不再仅限于算薪,而是成为了企业分析人效的关键因子。因此…

飞瓜数据B站丨B站UP主11月第3周榜单排行榜榜单(B站平台)发布!

飞瓜轻数发布2023年11月13日-11月19日飞瓜数据UP主排行榜(B站平台),通过充电数、涨粉数、成长指数、带货数据等维度来体现UP主账号成长的情况,为用户提供B站号综合价值的数据参考,根据UP主成长情况用户能够快速找到运营…

LangChain: 类似 Flask/FastAPI 之于 Django,LangServe 就是「LangChain 自己的 FastAPI」

原文:LangChain: 类似 Flask/FastAPI 之于 Django,LangServe 就是「LangChain 自己的 FastAPI」 - 知乎 说明:LangServe代替 langchainserver 成为新的langchain 部署工具 官网资料:🦜️🏓 LangServe | &…

智慧物流仓储仓库温湿度管理采集器钡铼技术远程终端RTU的使用

智慧物流仓储是当今物流行业的一个重要发展方向,它通过应用先进的技术和设备,实现对仓储环境的监控和管理。在智慧物流仓储中,温湿度管理是十分关键的一项工作。为了解决温湿度管理的问题,采集器钡铼技术远程终端RTU被广泛应用于仓…

未来制造业的新引擎:工业机器人控制解决方案

制造业正经历着一场革命性的变革 在这个变革的浪潮中,工业机器人成为推动制造业高效生产的关键力量。然而,要发挥机器人的最大潜力,一个强大而智能的控制系统是必不可少的。在这个领域,新一代的工业机器人控制解决方案正崭露头角&…

Linux:进度条(小程序)以及git三板斧

Linux小程序&#xff1a;进度条 在实现小程序前我们要弄清楚&#xff1a; 1.缓冲区&#xff1b; 2.回车与换行。 缓冲区&#xff1a; 分别用gcc来编译下面两个程序&#xff1a; 程序一&#xff1a; #include <stdio.h> int main() { printf("hello Makefil…

【云原生-Kurbernetes篇】 玩转K8S不得不会的HELM

Helm 一、Helm1.1 使用背景1.2 Helm简介1.3 Helm的几个概念1.4 helm2 和 helm3 的区别1.5 chart包的关键组成 二、Helm相关命令2.1 应用管理操作2.2 Helm repository仓库管理命令2.2 Helm chart包管理命令2.3 Helm release(实例) 管理命令2.4 Helm私有仓库管理命令 三、部署He…

代码混淆不再愁:一篇掌握核心技巧

​ 1. 概述 代码混淆是将计算机程序的代码转换成一种功能上等价&#xff0c;但是难以阅读和理解的形式。 对于软件开发者来说&#xff0c;代码混淆可以在一定程度上保护程序免被逆向。 对于逆向工程师来说&#xff0c;学习代码混淆可以帮助我们研究反混淆技术。 2. 常见混淆…

vue2使用el-tag自定义菜单导航标签

需求&#xff1a;使用el-tag写个菜单导航栏&#xff0c;点击路由的时候就添加 功能&#xff1a; 设置鼠标横向滚动并且不展示滚动条添加关闭其他、关闭左侧、关闭右侧、全部关闭标签功能单个标签删除功能添加&#xff0c;固定标签不可删除右键点击展开操作菜单栏设置个默认固定…

厦门某智慧社区的智慧排水监测系统实施落地

厦门某智慧社区的智慧排水监测系统实施落地 智慧社区的排水系统是一种高度智能化、高效且环保的排水解决方案&#xff0c;它结合了自动化控制系统、计算机网络技术、传感监测技术以及环保理念等多个领域的知识。其主要作用是确保社区的排水系统能够高效、稳定、环保地运行&…

Go并发编程学习-class1

class1. Mutex 解决资源并发访问 基础概念 临界区概念&#xff1a;一个被共享的资源&#xff0c;可以被并发访问。通过Mutex互斥锁&#xff0c;可以限定临界区只能由一个线程获取。 根据不同情况&#xff0c;不同适用场景 ●共享资源。并发地读写共享资源&#xff0c;会出现…

规划类3d全景线上云展馆帮助企业轻松拓展海外市场

科技3D线上云展馆作为一种基于VR虚拟现实和互联网技术的新一代展览平台。可以在线上虚拟空间中模拟真实的展馆&#xff0c;让观众无需亲自到场&#xff0c;即可获得沉浸式的参观体验。通过这个展馆&#xff0c;您可以充分、全面、立体展示您的产品、服务以及各种创意作品&#…

OpenAI董事会秒反悔!奥特曼被求重返CEO职位

明敏 丰色 发自 凹非寺 量子位 | 公众号 QbitAI 1天时间&#xff0c;OpenAI董事会大变脸。 最新消息&#xff0c;他们意在让奥特曼重返CEO职位。 多方消息显示&#xff0c;因为“投资人的怒火”&#xff0c;OpenAI董事会才在一天时间里来了个大反转。 微软CEO纳德拉被曝在得…

Imaris 卡退,是不是缓存盘没有设置好?

必须记录一下&#xff0c;从Imaris哔哩哔哩官方视频上学到的&#xff0c;如何设置缓存位置&#xff0c;尤其是做3D视频的时候。 但是隔一段时间就忘记&#xff0c;找不到当时的哔哩哔哩视频 这里记一下 如果是空间比较小的C盘&#xff0c;可以改成一个空间大一点的位置。 把缓…

用Stable Diffusion帮助进行卡通风格渲染

用Stable Diffusion帮助进行卡通风格渲染 正常风格渲染卡通风格贴图增加涅斐尔边缘高光效果 正常风格渲染 正常的动物写实模型 卡通风格贴图 用Stable Diffusion可以帮助我们将写实贴图转化为卡通风格&#xff08;具体参数可以自己调试&#xff0c;总体上是将提示词强度和图…

Python中控制台如何展示进度条——tqdm库使用

在 Python 中可以使用特定的库来创建控制台进度条&#xff0c;其中 tqdm 是一个常用的选择&#xff0c;它能够方便地显示进度条并跟踪迭代的进度。你可以通过 pip 安装 tqdm 库&#xff1a; pip install tqdm包装迭代器&#xff1a; 使用 tqdm 来包装你的迭代器&#xff0c;比…

外卖配送小程序商城的效果如何

线下餐饮店非常多&#xff0c;主要以同城生意为主&#xff0c;在线上电商和外卖平台的冲击下&#xff0c;传统商家仅通过传统方式经营很难宣传拓客及转化等&#xff0c;线上是必要的渠道&#xff0c;但入驻第三方平台又会有各种困扰&#xff0c;抽成/佣金/流量费/激烈竞争等。 …

C++ MiniZip实现目录压缩与解压

Zlib是一个开源的数据压缩库&#xff0c;提供了一种通用的数据压缩和解压缩算法。它最初由Jean-Loup Gailly和Mark Adler开发&#xff0c;旨在成为一个高效、轻量级的压缩库&#xff0c;其被广泛应用于许多领域&#xff0c;包括网络通信、文件压缩、数据库系统等。其压缩算法是…

pyqt5 窗口调用网页高德地图kpi,进行实时地图导航

作为主项目功能的一部分&#xff0c;这部分我想单独记录下来 一&#xff0c;注册高德kpi【进行实名认证】 高德开放平台 | 高德地图API (amap.com) 二&#xff0c;申请Key 三&#xff0c;进入路径规划-API文档-开发指南-Web服务 API|高德地图API (amap.com) 找到你需要的路径…