《动手学深度学习(PyTorch版)》笔记7.5

news2025/1/12 9:58:15

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过,同时对于书上部分章节也做了整合。

Chapter7 Modern Convolutional Neural Networks

7.5 Batch Normalization

批量规范化应用于单个可选层(也可以应用到所有层),其原理如下:
在每次训练迭代中,首先规范化输入,即通过减去其均值并除以其标准差(两者均基于当前小批量处理)。接下来应用比例系数和比例偏移。请注意,如果使用大小为1的小批量应用批量规范化,将无法学到任何东西,这是因为在减去均值之后,每个隐藏单元将为0,所以只有使用足够大的小批量,批量规范化才有效且稳定。用 x ∈ B \mathbf{x} \in \mathcal{B} xB表示一个来自小批量 B \mathcal{B} B的输入,批量规范化 B N \mathrm{BN} BN根据下式转换 x \mathbf{x} x

B N ( x ) = γ ⊙ x − μ ^ B σ ^ B + β . \mathrm{BN}(\mathbf{x}) = \boldsymbol{\gamma} \odot \frac{\mathbf{x} - \hat{\boldsymbol{\mu}}_\mathcal{B}}{\hat{\boldsymbol{\sigma}}_\mathcal{B}} + \boldsymbol{\beta}. BN(x)=γσ^Bxμ^B+β.

在上式中, μ ^ B \hat{\boldsymbol{\mu}}_\mathcal{B} μ^B是小批量 B \mathcal{B} B的样本均值, σ ^ B \hat{\boldsymbol{\sigma}}_\mathcal{B} σ^B是小批量 B \mathcal{B} B的样本标准差。应用标准化后,生成的小批量的平均值为0和方差为1。由于单位方差是一个主观选择的结果,因此通常包含拉伸参数(scale) γ \boldsymbol{\gamma} γ偏移参数(shift) β \boldsymbol{\beta} β,它们的形状与 x \mathbf{x} x相同,且是需要与其他模型参数一起学习的参数。

μ ^ B = 1 ∣ B ∣ ∑ x ∈ B x , σ ^ B 2 = 1 ∣ B ∣ ∑ x ∈ B ( x − μ ^ B ) 2 + ϵ . \begin{aligned} \hat{\boldsymbol{\mu}}_\mathcal{B} &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} \mathbf{x},\\ \hat{\boldsymbol{\sigma}}_\mathcal{B}^2 &= \frac{1}{|\mathcal{B}|} \sum_{\mathbf{x} \in \mathcal{B}} (\mathbf{x} - \hat{\boldsymbol{\mu}}_{\mathcal{B}})^2 + \epsilon.\end{aligned} μ^Bσ^B2=B1xBx,=B1xB(xμ^B)2+ϵ.

我们在方差估计值中添加一个小的常量 ϵ > 0 \epsilon > 0 ϵ>0,以确保永远不会除以零。此外,估计值 μ ^ B \hat{\boldsymbol{\mu}}_\mathcal{B} μ^B σ ^ B {\hat{\boldsymbol{\sigma}}_\mathcal{B}} σ^B还会通过使用关于平均值和方差的噪声(noise)来抵消缩放问题。乍看起来,这种噪声是一个问题,而事实上它是有益的。

7.5.1 Batch Normalize Layer

批量规范化层和其他层之间的一个关键区别是:由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。我们在下面讨论全连接层和卷积层的批量规范化实现。

7.5.1.1 Fully Connected Layer

我们通常将批量规范化层置于全连接层中的仿射变换和激活函数之间。设全连接层的输入为x,权重参数和偏置参数分别为 W \mathbf{W} W b \mathbf{b} b,激活函数为 ϕ \phi ϕ,批量规范化的运算符为 B N \mathrm{BN} BN,那么使用批量规范化的全连接层的输出的计算公式如下:
h = ϕ ( B N ( W x + b ) ) . \mathbf{h} = \phi(\mathrm{BN}(\mathbf{W}\mathbf{x} + \mathbf{b}) ). h=ϕ(BN(Wx+b)).

7.5.1.2 Convolutional Layer

对于卷积层,可以在卷积操作之后和非线性激活函数之前应用批量规范化。当卷积有多个输出通道时,我们需要对这些通道的每个输出执行批量规范化。每个通道都有自己的拉伸(scale)和偏移(shift)参数,两者都是标量。假设小批量包含 m m m个样本,并且对于每个通道,卷积的输出具有高度 p p p和宽度 q q q,那么对于卷积层,在每个输出通道的 m ⋅ p ⋅ q m \cdot p \cdot q mpq个元素上同时执行每个批量规范化。因此在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。

7.5.2 Batch Normalization During Prediction

批量规范化在训练模式和预测模式下的行为和结果通常不同。在训练过程中,我们无法得知使用整个数据集来估计平均值和方差,所以只能根据每个小批次的平均值和方差不断训练模型,一种常用的方法是通过移动平均(具体方法在batch_norm()函数中定义)估算整个训练数据集的样本均值和方差,并在预测时使用它们得到确定的输出。而在将训练好的模型用于预测时,我们可以根据整个数据集精确计算批量规范化所需的平均值和方差,不再需要样本均值中的噪声以及在微批次上估计每个小批次产生的样本方差了。

import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        #参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        #这些参数是nn.Parameter类型,因此会被PyTorch的优化器所管理,梯度更新是自动进行的
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,就将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y
    
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))

#简洁实现
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
plt.show()

训练结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1435714.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电商开放API商品采集接口、关键字搜索接口,获取商品ID、商品主图接口

API是application programming interface(应用程序接口)的简称,是一些预先定义的函数,目的是提供应用程序与开发人员基于某软件或硬件的以访问一组例程的能力,而又无需访问源码,或理解内部工作机制的细节。…

2024年【A特种设备相关管理(电梯)】报名考试及A特种设备相关管理(电梯)免费试题

题库来源:安全生产模拟考试一点通公众号小程序 A特种设备相关管理(电梯)报名考试是安全生产模拟考试一点通总题库中生成的一套A特种设备相关管理(电梯)免费试题,安全生产模拟考试一点通上A特种设备相关管理…

版本管理git及其命令介绍-附带详细操作

前言 在版本管理时代之前,人们写软件的方式如下图1所示 图1 无版本管理的代码 其坏处就是软件版本随着时间越来越多,每个版本修改了什么内容,修改了哪些文件,如果没有详细记录也不知道。这样久会导致如果我们想回退到某个版本内…

C语言——联合体类型

📝前言: 在前面两篇文章:C语言——结构体类型(一)和C语言——结构体(二)中,我们讲述了C语言中重要的数据类型之一:结构体类型,今天我们来介绍一下C语言中的另…

BVH动画绑骨蒙皮并在Unity上展示

文章目录 Blender绑定骨骼Blender蒙皮Blender中导入bvh文件将FBX导入Unity Blender绑定骨骼 先左上角红框进入model模式,选中要绑定的模型,然后进入Edit模式把骨骼和关节对齐。 (选中骨骼,G移动,R旋转) 为…

苹果手机如何录屏?这里告诉你答案!

苹果公司的iPhone以其卓越的性能和用户体验受到了全球消费者的喜爱,而录屏功能作为手机的一项重要功能,能够帮助我们记录手机屏幕上的操作,分享游戏技巧、制作教程视频等。本文将为您介绍苹果手机如何录屏,帮助您更好地掌握录屏技…

零售新业态,让老牧区焕发新生命

敦煌老马一声魔性“浇给”勾起了无数人对羊肉的食欲,而当大家集体涌入餐厅或者在网上下单,都想要尝一尝网红同款的时候,可能并没有想过这样一个问题——为什么在今天,即便是远离牧区的现代大城市,草原羊肉却一样能触手…

双向链表的插入、删除、按位置增删改查、栈和队列区别、什么是内存泄漏

2024年2月4日 1.请编程实现双向链表的头插&#xff0c;头删、尾插、尾删 头文件&#xff1a; #ifndef __HEAD_H__ #define __HEAD_H__ #include<stdio.h> #include<stdlib.h> #include<string.h> typedef int datatype; enum{FALSE-1,SUCCSE}; typedef str…

【宝藏系列】嵌入式入门概念大全

【宝藏系列】嵌入式入门概念大全 0️⃣1️⃣操作系统&#xff08;Operating System&#xff0c;OS&#xff09; 是管理计算机硬件与软件资源的系统软件&#xff0c;同时也是计算机系统的内核与基石。操作系统需要处理管理与配置内存、决定系统资源供需的优先次序、控制输入与输…

nvm安装node后,npm无效

类似报这种问题&#xff0c;是因为去github下载npm时下载失败&#xff0c; Please visit https://github.com/npm/cli/releases/tag/v6.14.17 to download npm. 第一种方法&#xff1a;需要复制这里面的地址爬梯子去下载&#xff08;github有时不用梯子能直接下载&#xff0c;有…

百面嵌入式专栏(技能篇)嵌入式技能树详解

沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇我们将介绍嵌入式重点知识。 一、C语言 C语言这一块的高频考点有预处理、关键字、数据类型、指针与内存管理。 预处理有文件包含、宏定义、条件编译,其中最重要的是宏定义,通常考核宏定义的语法、宏替换与函数的区…

在每个地方都应该添加 memo 吗?

文章概叙 本文主要讲的是React中memo的使用&#xff0c;以及考虑是否使用memo的判断依据 memo介绍 memo 允许你的组件在 props 没有改变的情况下跳过重新渲染。 在使用memo将组件包装起来之后&#xff0c;我们可以‍获得该组件的一个 记忆化 版本。通常情况下&#xff0c;只要…

14.0 Zookeeper环球锁实现原理

全局锁是控制全局系统之间同步访问共享资源的一种方式。 下面介绍zookeeper如何实现全民锁&#xff0c;讲解他锁和共享锁两类全民锁。 排他锁 排他锁&#xff08;Exclusive Locks&#xff09;&#xff0c;又被称为写锁或独占锁&#xff0c;如果事务T1对数据对象O1加上排他锁…

可解释性AI(XAI)的主要实现方法和研究方向

文章目录 每日一句正能量前言主要实现方法可解释模型模型可解释技术 未来研究方向后记 每日一句正能量 当你还不能对自己说今天学到了什么东西时&#xff0c;你就不要去睡觉。 前言 随着人工智能的迅速发展&#xff0c;越来越多的决策和任务交给了AI系统来完成。然而&#xff…

问题:下列关于海关统计项目的表述,正确的有:A.进出境货物的统计重量和数量应以报关单位申报的重量和数 #笔记#职场发展#媒体

问题&#xff1a;下列关于海关统计项目的表述&#xff0c;正确的有&#xff1a;A&#xff0e;进出境货物的统计重量和数量应以报关单位申报的重量和数 下列关于海关统计项目的表述&#xff0c;正确的有&#xff1a; A&#xff0e;进出境货物的统计重量和数量应以报关单位申报的…

【初识爬虫+requests模块】

爬虫又称网络蜘蛛、网络机器人。本质就是程序模拟人使用浏览器访问网站&#xff0c;并将需要的数据抓取下来。爬虫不仅能够使用在搜索引擎领域&#xff0c;在数据分析、商业领域都得到了大规模的应用。 URL 每一个URL指向一个资源&#xff0c;可以是一个html页面&#xff0c;一…

Pandas教程11:关于pd.DataFrame.shift(1)数据下移的示例用法

---------------pandas数据分析集合--------------- Python教程71&#xff1a;学习Pandas中一维数组Series Python教程74&#xff1a;Pandas中DataFrame数据创建方法及缺失值与重复值处理 Pandas数据化分析&#xff0c;DataFrame行列索引数据的选取&#xff0c;增加&#xff0c…

Oracle Vagrant Box 扩展根文件系统

需求 默认的Oracle Database 19c Vagrant Box的磁盘为34GB。 最近在做数据库升级实验&#xff0c;加之导入AWR dump数据&#xff0c;导致空间不够。 因此需要对磁盘进行扩容。 扩容方法1&#xff1a;预先扩容 此方法参考文档Vagrant, how to specify the disk size?。 指…

Postman发送带登录信息的请求

环境&#xff1a;win10Postman10.17.7 假设有个请求是这样的&#xff1a; RequiresPermissions("tool:add") PostMapping(value"/predict") ResponseBody /** * xxx * param seqOrderJson json格式的参数 * return */ public String predictSampleIds(Req…

蓝桥杯省赛无忧 课件125 线段树-标记永久化

前置知识 线段树 01 什么是标记永久化 02 如何标记永久化 03 标记永久化的优势