《动手学深度学习 Pytorch版》 7.5 批量规范化

news2025/1/12 12:09:54

7.5.1 训练深层网络

训练神经网络的实际问题:

  • 数据预处理的方式会对最终结果产生巨大影响。

  • 训练时,多层感知机的中间层变量可能具有更广的变化范围。

  • 更深层的网络很复杂容易过拟合。

批量规范化对小批量的大小有要求,只有批量大小足够大时批量规范化才是有效的。

x ∈ B \boldsymbol{x}\in B xB 表示一个来自小批量 B B B 的输入;$\hat{\boldsymbol{\mu}}_B $ 表示小批量 B B B 的样本均值; σ ^ B \hat{\boldsymbol{\sigma}}_B σ^B 表示小批量 B B B 的样本标准差;批量规范化 BN 根据以下表达式转换 x \boldsymbol{x} x

B N ( x ) = γ ⊙ x + μ ^ B σ ^ B + β BN(\boldsymbol{x})=\gamma\odot\frac{\boldsymbol{x}+\hat{\boldsymbol{\mu}}_B}{\hat{\boldsymbol{\sigma}}_B}+\beta BN(x)=γσ^Bx+μ^B+β

应用标准化后生成的小批量的均值为 0,单位方差为 1。此外,其中还包含与 x \boldsymbol{x} x 形状相同的拉伸参数 γ \gamma γ 和偏移参数 β \beta β。需要注意的是, γ \gamma γ β \beta β 是需要与其他模型一起参与学习的参数。

从形式上看,可以计算出上式中的 $\hat{\boldsymbol{\mu}}_B $ 和 σ ^ B \hat{\boldsymbol{\sigma}}_B σ^B

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x σ ^ B = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ \begin{align} \hat{\boldsymbol{\mu}}_B &= \frac{1}{\left|B\right|}\sum_{\boldsymbol{x}\in B}\boldsymbol{x}\\ \hat{\boldsymbol{\sigma}}_B &= \frac{1}{\left|B\right|}\sum_{\boldsymbol{x}\in B}(\boldsymbol{x}-\hat{\boldsymbol{\mu}}_B)^2+\epsilon \end{align} μ^Bσ^B=B1xBx=B1xB(xμ^B)2+ϵ

式中添加的大于零的常量 ϵ \epsilon ϵ 可以保证不会发生除数为零的错误。

7.5.2 批量规范化层

全连接层和卷积层需要两种略有不同的批量规范化策略:

  • 全连接层

    通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。 设全连接层的输入为 x x x,权重参数和偏置参数分别为 W \boldsymbol{W} W b b b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N BN BN。那么,使用批量规范化的全连接层的输出的计算详情如下:

    h = ϕ ( B N ( W x + b ) ) \boldsymbol{h}=\phi(BN(\boldsymbol{W}x+b)) h=ϕ(BN(Wx+b))

  • 卷积层

    对于卷积层,可以在卷积层之后和非线性激活函数之前应用批量规范化。而且需要对多个输出通道中的每个输出执行批量规范化,每个通道都有自己的标量参数:拉伸和偏移参数。

  • 预测过程中的批量规范化

    批量规范化在训练模式和预测模式下的行为通常不同。

7.5.3 从零实现

import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not torch.is_grad_enabled():  # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:  # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)  # 按行求均值
            var = ((X - mean) ** 2).mean(dim=0)  # 按行求方差
        else:  # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            mean = X.mean(dim=(0, 2, 3), keepdim=True)  # 保持X的形状(即第1维,输出通道数)以便后面可以做广播运算,结果的形状是1*n*1*1
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        X_hat = (X - mean) / torch.sqrt(var + eps)  # 训练模式下,用当前的均值和方差做标准化
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

7.5.4 使用批量规范化层的 LeNet

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

学习率拉的好大。

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.262, train acc 0.902, test acc 0.879
20495.7 examples/sec on cuda:0

在这里插入图片描述

7.5.5 简明实现

net1 = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))
d2l.train_ch6(net1, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.263, train acc 0.903, test acc 0.870
36208.4 examples/sec on cuda:0

在这里插入图片描述

7.5.6 争议

这个东西就是玄学,有效但是不知大为什么有效。作者给出的解释是“减少内部协变量偏移”,但是也是处于直觉而不是证明。

练习

(1)在使用批量规范化之前,我们是否可以从全连接层或者卷积层中删除偏置函数?为什么?

我认为可以,偏置会在减去均值时消去,此外,BN 中也是带偏移参数的。


(2)比较 LeNet 在使用和不使用批量规范化情况下的学习率。

a. 绘制训练和测试精准度的提高。

b. 学习率有多高?

学习率相同的话,使用批量规范化的收敛速度会非常快。


(3)我们是否需要在每个层中进行批量规范化?尝试一下?

net2 = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net2, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.349, train acc 0.871, test acc 0.856
37741.0 examples/sec on cuda:0

在这里插入图片描述

去掉后面两个之后曲线稳多了。


(4)可以通过批量规范化来替换暂退法吗?行为会如何改变?

看来还是批量规范化好些

net3 = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.Sigmoid(),
    nn.Dropout(p=0.1),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Dropout(p=0.1),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net3, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.541, train acc 0.790, test acc 0.748
40642.6 examples/sec on cuda:0

在这里插入图片描述


(5) 确定参数 gamma 和 beta,并观察和分析结果。

net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([3.1800, 1.6709, 4.0375, 3.4801, 2.6182, 2.3103], device='cuda:0',
        grad_fn=<ReshapeAliasBackward0>),
 tensor([ 3.5415,  1.6295,  1.8926, -1.5510, -2.4556,  1.1020], device='cuda:0',
        grad_fn=<ReshapeAliasBackward0>))

(6)查看高级 API 中关于 BatchNorm 的在线文档,以了解其他批量规范化的应用。

略。


(7)研究思路:可以应用的其他“规范化”变换有哪些,可以应用概率积分变换吗,全秩协方差估计呢?

略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1032312.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CTF--攻防世界--杂项基础

多做几道基础题就会发现这东西真的跟智障题一样&#xff0c;开始嘲笑当初的自己了。 就当是学习笔记了【根据题库给的顺序&#xff0c;随便写几道】 用那个隐写工具一步完事 这一题我是真的理解了杂项的概念&#xff0c;这玩意是真杂啊&#xff0c;死活都没想到居然还有这种题…

网络编程day05(IO多路复用)

今日任务&#xff1a; TCP多路复用的客户端、服务端&#xff1a; 服务端代码&#xff1a; #include <stdio.h> #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #include <netinet/in.h> #include <unistd.h> …

若依使用及源码解析(前后端分离版)

部署环境 JDK > 1.8 MYSQL > 5.7 Maven > 3.0 Node > 12 Redis > 3 运行若依项目 下载若依源码 若依官网 若依项目源码(前后端分离) 运行后端项目 ruoyi-ui就是vue项目&#xff08;这里使用vscode打开&#xff09; 整体用idea打开 1.配置数据库(sq…

千兆以太网传输层 UDP 协议原理与 FPGA 实现

文章目录 前言心得体会一、UDP 协议介绍二、UDP 数据报格式三、UDP 数据发送测试四、Verilog实现UDP 数据发送1、IP 头部检验 IPchecksun 的计算2、以太网报文的校验字段 FCS 的计算3、以太网报文发送模块实现五、以太网数据发送测试六、仿真代码七、仿真波形展示八、上板测试九…

婚礼策划展示小程序制作全程解析

随着互联网的发展&#xff0c;小程序已成为各行各业所钟爱的一种数字化工具。对于婚礼策划师来说&#xff0c;一款专为自己业务打造的小程序能够更好地展示婚礼策划方案&#xff0c;提升服务质量&#xff0c;加强与客户的沟通。以下就是制作婚礼策划展示小程序的全程解析。 一、…

【Linux网络编程】gdb调试技巧

这篇博客主要要记录一下自己在Linux操作系统Ubuntu下使用gbd调试程序的一些指令&#xff0c;以及使用过程中的一些心得。 使用方法 可以使用如下代码 gcc -g test.c -o test 或者 gcc test.c -o test ​ -g的选项最好添加&#xff0c;如果不添加&#xff0c;l指令无法被识别 …

机试(2017 cs se)

2017计算机系夏令营 题解参考&#xff1a; 2017 华东师范计算机系暑期夏令营机考 A. 不等式 Problem #3304 - ECNU Online Judge 有点像贪心算法 选一个刚刚好在条件范围里的b[i]作为候选&#xff0c;【这个“刚刚好”是指选一个符合这个条件的最极限的值】 代码 #in…

力扣669 补9.16

最近大三上四天有早八&#xff0c;真的是受不了了啊&#xff0c;欧嗨呦&#xff0c;早上困如狗&#xff0c;然后&#xff0c;下午困如狗&#xff0c;然后晚上困如狗&#xff0c;尤其我最近在晚上7点到10点这个时间段看力扣&#xff0c;看得我昏昏欲睡&#xff0c;不自觉就睡了1…

关于Python安装Scrapy库的常见报错解决

目录 1、关于pip3命令的报错 2、执行scrapy报错&#xff08;Python3下的OpenSSL模块出错&#xff09; 3、卸载pyopenssl时报错 由于Scrapy该库在Windows下会存在兼容问题&#xff0c;下面介绍的是在Linux系统进行安装。 1、关于pip3命令的报错 报错代码&#xff1a; error…

redis漏洞修复:(CNVD-2019-21763)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、漏洞内容二、镜像准备1.确认镜像版本2.下载镜像 三、配置文件准备1.获取配置文件2.修改配置文件 四、启动redis容器五、修改iptables文件总结 前言 漏扫发…

Java JVM分析利器JProfiler 结合IDEA使用详细教程

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、JProfiler是什么&#xff1f;二、我的环境三、安装步骤1.Idea安装JProfiler插件1.下载程序的安装包 四、启动 前言 对于我们Java程序员而言&#xff0c;肯…

博通强迫三星签不平等长约,被韩处罚1亿元 | 百能云芯

近日&#xff0c;博通&#xff08;Broadcom&#xff09;这家国际知名的半导体公司因其市场主导地位的滥用&#xff0c;遭到了韩国公平贸易委员会&#xff08;FTC&#xff09;的严厉制裁&#xff0c;罚款高达191亿韩元&#xff0c;约合人民币1.04亿元。这一惩罚背后的故事揭示了…

用AVR128单片机的音乐门铃

一、系统方案 1、使用按键控制蜂鸣器模拟发出“叮咚”的门铃声。 2、“叮”声对应声音频率714Hz&#xff0c;“咚”对应声音频率500Hz,这两种频率由ATmega128的定时器生成&#xff0c;定时器使用的工作模式自定&#xff0c;处理器使用内部4M时钟。“叮”声持续时间300ms&#x…

Linux新手教程||Linux vi/vim

所有的 Unix Like 系统都会内建 vi 文书编辑器&#xff0c;其他的文书编辑器则不一定会存在。 但是目前我们使用比较多的是 vim 编辑器。 vim 具有程序编辑的能力&#xff0c;可以主动的以字体颜色辨别语法的正确性&#xff0c;方便程序设计。 什么是 vim&#xff1f; Vim是…

链表oj3(Leetcode)——相交链表;环形链表

一&#xff0c;相交链表 相交链表&#xff08;Leetcode&#xff09; 1.1分析 看到这个我们首先想到的就是一个一个比较他们的值有相等的就是交点&#xff0c;但是如果a1和b2的值就相等呢&#xff1f;所以这个思路不行&#xff0c;第二种就是依次比较链表&#xff0c;但是这…

xdebug3开启profile和trace

【xdebug开启profiler】 https://xdebug.org/docs/profiler http://www.xdebug.org.cn/docs/profiler 1、php.ini添加下面配置然后重启php容器&#xff1a; xdebug.modeprofile ;这个目录保存profile和trace文件 xdebug.output_dir /var/tmp/xdebugPHP日志提示报错&#xff1a…

ultraEdit正则匹配多行(xml用)

在ultraEdit中&#xff0c;我想选取<channel到</channel>之间的多行&#xff08;进行删除&#xff09;。在perl模式下&#xff0c;命令为“<channel[\s\S]?</channel>”。下面是xml文件&#xff1a; <!--This XML file does not appear to have any sty…

版本控制系统git:一文了解git,以及它在生活中的应用,网站维护git代码,图导,自动化部署代码

目录 1.Git是什么 2.git在生活中的应用 2.1git自动化部署代码 3.网站维护git代码 3.1如何在Git代码托管平台等上创建一个仓库 3.2相关文章 4.ruby实现基础git 4.1.Git add 4.2 Git commit 4.3 Git log 1.Git是什么 Git是一个版本控制系统&#xff0c;它可以追踪文件的…

C++11的一些新特性|线程库|包装器|lambda表达式

文章目录 目录 文章目录 一、可变参数模板 1.可变参数模板 2.STL容器中emplace相关函数接口: 二、lambda表达式 1.c98中的一个例子 2.lambda表达式 三、包装器 1.fuction包装器 四、线程库 1.thread类简单介绍 2.并发和并行的区别 3.线程函数参数 4.原子性操作库…

关于Java多线程的那些事

多线程 多线程1. 关于多线程的理解1.1 进程和线程1.2 并行和并发1.3 线程调度 2. 创建多线程的方式创建线程有哪几种方式&#xff1f;2.1 通过继承Thread类来创建并启动线程的步骤如下&#xff1a;2.2 通过实现Runnable接口来创建并启动线程的步骤如下&#xff1a;2.3 通过实现…