[PyTorch][chapter 35][Batch Normalize]

news2024/11/16 12:05:45

前言:

          Batch Norm 是深度学习里面常用的技术之一,主要作用是

把指定维度参数约束到x \sim (r,\sigma^2)范围内,有效的解决了梯度弥散

现象 ,有助于加速模型的训练速度。

          


  1. 问题解释

  2. 特征缩放 Feature Scaling

  3. Batch Normalization

  4. Torch API 


 一  问题解释

    

     如上图,输入范围如下

     x_1 \in (1,\sigma^2)

     x_2 \sim (100,\sigma^2)

  所以   沿w_1 方向 搜索Loss等高线变化慢,

            沿w_2 方向搜索 Loss 等高线变化快

 1.1 左图 :

    没有做Batch Normalize,进行训练时,如果初始点为A点,搜素比较曲折,收敛的速度很慢

1.2 右图:

             做完 Batch Normlize后, 

            x_1 \sim(1,\sigma^2),x_2\sim (1,\sigma^2),无论沿着A点或者B点,

收敛速度都很快


二  特征缩放 Feature Scaling

    在卷积神经网络常用的方案如下:

     2.1 Image Normalization

       比如一张图片有R,G,B 三个通道,我们可以通过

       

 normalize = transform .Normalize(mean = [0.485, 0.456, 0.406], 
                         std = [0.229, 0.224, 0.225]),

             x_r = \frac{x_r-0.485}{0.229}
             x_g = \frac{x_r-0.456}{0.224}

             x_b=\frac{x_b-0.406}{0.225}

2.2  Batch Normalization

    有四种,主要区别是在不同的维度上对输入进行Normalization 

    输入 [N,C,W,H] 图片张数,维度, 图片宽,图片高

   以输入[6,3,28,28]为例

    Batch Norm:

                  把宽高相乘得到 [6,3,784]

                 分别对 channel 上的r,g,b通道 6张图片,求均值,方差

然后Normalization ,得到[3]个实例,r,g,b 通道上新的标准化值

   Layer Norm:

            把宽高相乘得到 [6,3,784]

            对6张图片分别求均值,方差,然后Normalization,

           得到[6]个实例的标准化值

   Instance Norm

              对 [6,3] 单独图片,指定的r,g,b维度分别求均值方差,

             得到[6,3] 个实例的标准化值


三 Batch Normalization

      Batch Normalization(BN) 技术是2015年由Sergey Ioffe 以及 Christian Szegedy 团队提出,通过在每一层神经网络中加入Batch Normalization层,是输入到改成的小批次的数据在训练前进行标准化,作用:
     有助于加速模型的训练速度,
     降低模型的训练过程收受初始权重影响程度,
     模型更稳定,更加有效的收敛,
     提高模型的泛化能力

以一个图片数据集为例 [6,3,28,28],在rgb通道上做BN

如上图为对一个mini-batch做BN层处理的流程。一共分为四步:

1. 计算当前mini-batch所有样本的均值;[6,784]

2. 计算当前mini-batch所有样本的方差;[6,784]

3. 对当前mini-batch内每个样本用前面的均值和方差做归一化;

     使得r,g,b channel 上的分布服从 x \sim N(0,1)

4. 对归一化后的样本,乘以一个*缩放*系数,再做一次*平移*;

       使得r,g,b channel 上的分布 服从x \sim N(\beta,\gamma^2)

      这两个参数需要通过训练时候学习

前面的三步,都是直接为了稳定x的分布,缓解ICS而做的归一化处理。

参考代码


四  PyTorch API 函数

4.1 nn.BatchNorm1d

   N * d --> N * d

BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, 
track_running_stats=True, device=None, dtype=None)

参数

 

含义
 num_features:也就是数据的特征维度
eps分母上加的一个值,是为了防止分母为0的情况
affine

 是仿射变化,将,分别初始化为1和0;

# -*- coding: utf-8 -*-
"""
Created on Tue May 23 16:10:52 2023

@author: chengxf2
"""

import torch
import torch.nn as nn


def BN():
    
     x = torch.rand(2,3,4)
     
     # num_features=3: 输入维度,也就是数据的特征维度;
     layer = nn.BatchNorm1d(3)
     
     out = layer(x)
     
     print("\n batch 均值",layer.running_mean)
     print("\n batch 方差",layer.running_var)
     
     print("\n input ",x)
     print("\n out",out)
     
BN()

4.2   BatchNorm2d

BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, 
track_running_stats=True, device=None, dtype=None)

        主要作用在特征上,比如输入维度为B*C*H*W, B代表batchsize大小,C代表channel,H代表图片的高度维度,W代表图片的宽度维度;

      nn.BatchNorm2d是对channel做归一化处理,也就是对批次内的特征进行归一化;

# -*- coding: utf-8 -*-
"""
Created on Tue May 23 16:10:52 2023

@author: chengxf2
"""

import torch
import torch.nn as nn


def BN():
    
     x = torch.rand(1,16,7,7)
     
     layer = nn.BatchNorm2d(16)
     
     out = layer(x)
     
     print("\n batch 均值",layer.weight)
     print("\n batch 方差",layer.bias)
     
     print("\n input ",vars(layer))
  
     
BN()

上面的

weight : 对应 \gamma

bias: 对应 \beta

在测试的时候,用的是全局的均值,方差,所以要用到eval( 不启动 Batch Normalization 和 Dropout,并且不会保存中间变量、计算图)

 BatchNorm 即批规范化,是为了将每个batch的数据规范化为统一的分布,帮助网络训练, 对输入数据做规范化,称为Covariate shift;

        数据经过一层层网络计算后,数据的分布也在发生着变化,因为每一次参数迭代更新后,上一层网络输出数据,经过这一层网络参数的计算,数据的分布会发生变化,这就为下一层网络的学习带来困难 -- 也就是在每一层都进行批规范化(Internal Covariate shift),方便网络训练,因为神经网络本身就是要学习数据的分布;
 

参考:

高等数学学习笔记——第七十讲——方向导数与梯度_方向导数的几何意义_预见未来to50的博客-CSDN博客

方向导数与梯度_Young__Fan的博客-CSDN博客

Transformer中的归一化(一):什么是归一化&为什么要归一化 - 知乎

Transformer中的归一化(二):机器学习中的特征归一化方法 - 知乎

Transformer中的归一化(三):特征归一化在深度神经网络的作用 - 知乎

Transformer中的归一化(四):BatchNormalization的原理、作用和实现 - 知乎

Batch-Normalization层原理与分析 - 知乎

科学网—Pytorch中nn.Conv1d、Conv2D与BatchNorm1d、BatchNorm2d函数 - 张伟的博文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/564051.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《开箱元宇宙》爱心熊通过 The Sandbox 与粉丝建立更紧密的联系

你们有没有想过 The Sandbox 如何融入世界上最具标志性的品牌和名人的战略?在本期《开箱元宇宙》系列中,我们与 Cloudco Entertainment 的数字内容顾问 Derek Roberto 聊天,了解为什么爱心熊决定在 The Sandbox 中试验 web3,以及他…

Grpc 整合 Nacos SpringBoot 日常使用(Java版本)包括 Jwt 认证

前言 最近感到有点子迷茫,天天写业务代码有点麻木,趁着有点空闲时间去了解了下 Grpc 这个框架,一方面是听说他很火,支持多种语言。另一方面也是为了将来可能需要用到他,未雨绸缪一下,当然了本文只是基于使用…

Python数据可视化入门教程

什么是数据可视化? 数据可视化是为了使得数据更高效地反应数据情况,便于让读者更高效阅读,通过数据可视化突出数据背后的规律,以此突出数据中的重要因素,如果使用Python做数据可视化,建议学好如下这四个Pyt…

数据可视化是什么?怎么做?看这篇文章就够了

数据可视化是什么 数据可视化主要旨在借助于图形化手段,清晰有效地传达与沟通信息。也就是说可视化的存在是为了帮助我们更好的去传递信息。 我们需要对我们现有的数据进行分析,得出自己的结论,明确要表达的信息和主题(即你通过…

https 建立连接过程

从真实的抓包开始 根据抓包结果可以看到 从客户端发起https 请求开始,主要经过以下几个过程: 1、tcp 三次握手 2、浏览器发送 Client Hello 到服务器 3、服务器对Hello 进行响应 4、服务器发送Server Hello 、证书、证书状态、服务器密钥,到…

【Linux服务】web基础与HTTP协议

web基础与HTTP协议 一、域名概述1.1域名空间结构1.2域名注册 二、网页的概念三、HTML概述3.1HTML超文本标记语言 四、Web概述4.1Web1.0与Web2.04.2静态网页4.3动态网页 五、HTTP协议概述5.1HTTP协议版本5.2http请求方法5.3GET 和 POST 比较5.4HTTP状态码5.5HTTP请求流程 一、域…

无代码开发:让程序员更高效,让非编程人员也能参与

说起无代码开发,可能大多数人的第一反应就是:“我不知道!” 作为一种能快速实现复杂系统的软件开发模式,无代码开发目前还处于推广阶段。但在我们看来,无代码开发是一个很好的尝试,它能让程序员更高效&…

《汇编语言》- 读书笔记 - 第4章-第一个程序

《汇编语言》- 读书笔记 - 第4章-第一个程序 4.1 一个源程序从写出到执行的过程4.2 源程序程序 4.11. 伪指令1.1 segment ends 声明段1.2 end 结束标记1.3 assume 关联 2. 源程序中的“程序”3. 标号4. 程序的结构5. 程序返回6. 语法错误和逻辑错误 4.3 编辑源程序4.4 编译4.5 …

Electron 我与你,今天不谈技术谈感情!

目录 前言一、无知二、初见三、再见四、相遇五、行动总结 前言 今天不谈技术,谈谈我和 Electron 的缘分。可能有人觉得,或许有些人认为,和一个框架谈感情这不是疯了吗?但是,我相信每个开发者都会有同样的经历&#xf…

数字化浪潮下,运维绕不开的需求升级

伴随企业数据中心规模化、复杂度、设备多样性的发展,运维也迎来史无前例的巨大挑战,运维的重要性被推向高点,对运维平台而言无疑是最好的时代,充分利用大数据和人工智能技术融合来解决实际问题,建立数据要素全周期管理…

XSS基础环境及实验演示教程(适合新手)

目录 前言 环境说明: 1、轻量级 Web 服务器 PHP 2、易受XSS攻击的PHP程序 3、非持久性 XSS 攻击 4、窃取会话cookie 5 注入表单窃取密码 前言 花了一点时间,做了一个“XSS基础环境及实验演示教程”,当然教程很简单,适合刚接触和安…

Electron 如何创建模态窗口?

目录 前言一、模态窗口1.Web页面模态框2.Electron中的模态窗口3.区分父子窗口与模态窗口 二、实际案例使用总结 前言 模态框是一种常用的交互元素,无论是在 Web 网站、桌面应用还是移动 APP 中,都有其应用场景。模态框指的是一种弹出窗口,它…

leetcode 1383. Maximum Performance of a Team(团队的最大performance)

n个工程师,长度为n的speed数组和efficiency数组。 每次最多选k个工程师,取出k个对应的speed和efficiency数字。 performancesum(k个speed) ✖ min(k个efficiency) 可以理解为k个人一起干,效率按最慢的人算(一个环节干不完其他人都…

Linux——IO之系统接口+文件描述符详解

IO 文件再次理解系统接口文件操作理解文件描述符 fd 文件再次理解 文件 文件内容 文件属性 其中文件属性也是数据–>即便你创建一个空文件,其也是要占据磁盘攻坚的。 文件操作 文件内容的操作 文件属性的操作 有可能在操作文件的过程中即改变文件的内容&…

Linux---echo命令、反引号`、tail命令、重定向符

1. echo命令 可以使用echo命令在命令行内输出指定内容 语法:echo 输出的内容 无需选项,只有一个参数,表示要输出的内容,复杂内容可以用 ”” 包围 带有空格或 \ 等特殊符号,建议使用双引号包围。 如果不使用双引号…

华为OD机试真题 Java 实现【统计匹配的二元组个数】【2023Q2 200分】

一、题目描述 给定两个数组A和B,若数组A的某个元素A[i]与数组B中的某个元素B[j]满足 A[i] B[j],则寻找到一个值匹配的二元组(i, j)。 请统计在这两个数组A和B中,一共存在多少个这样的二元组。 二、输入描述 第一行输入数组A的长度M&…

复习之[ 查询帮助 ] 和 [ 输入输出管理 ]

1.查询命令用途--whatis # whatis 命令 : 查询命令的用法 -如果结果出现nothing , 有两种情况: (1)查询数据库没有更新,此时输入命令 mandb更新数据库即可。 (2)查询的命令不存在。 2.获得命令的简要帮…

想学渗透,如何入门?

首先 渗透是计算机技术应用的一种,脱离不了基础,您需要学会一门编程语言,任何与计算机相关的都是从学习编程语言开始的,让你对计算机有个初步的认识,将您认识的数字转化为用0和1表示的编码。这个阶段推荐学习Python&a…

​LeetCode解法汇总LCP 33. 蓄水

目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣 描述: 给定 N 个无限容量且初始均空的水缸,每个水缸配有一个水桶用来打水&…

华芯微特SWM34-IO速度优化

对比测试了一下IO翻转速度在各种函数调用的情况下的差异 CPU运行速度150Mhz,SDRAM开 直接调用翻转函数 while(1) {GPIO_InvBit(GPIOA, PIN0); }速度大约5Mhz,主要是因为函数调用开销和函数内部的移位和异或操作,增加了指令的运行数量。 vo…