28 批量归一化【李沐动手学深度学习v2课程笔记】(备注:这一节讲的很迷惑,很乱)

news2025/1/10 22:25:56

目录

1.批量归一化

1.1训练神经网络时出现的挑战

1.2核心思想

1.3原理

2.批量规范化层

2.1 全连接层

2.2 卷积层

2.3 总结

3. 代码实现

4. 使用批量规范化层的LeNet

5. 简明实现


1.批量归一化

现在主流的卷积神经网络几乎都使用了批量归一化
批量归一化是一种流行且有效的技术,它可以持续加速深层网络的收敛速度

1.1训练神经网络时出现的挑战

1、数据预处理的方式通常会对最终结果产生巨大影响

使用真实数据时,第一步是标准化输入特征(使其均值为0,方差为1),这种标准化可以很好地与优化器配合使用(可以将参数的量级进行统一)


2、对于典型的多层感知机或卷积神经网络,在训练时中间层中的变量可能具有更广的变化范围

不论是沿着从输入到输出的层、跨同一层中的单元、或是随着时间的推移,模型参数的随着训练更新变化莫测
归一化假设变量分布中的不规则的偏移可能会阻碍网络的收敛


3、更深层的网络很复杂,容易过拟合

这就意味着正则化变得更加重要 作者:如果我是泡橘子 

当神经网络比较深的时候会发现:数据在下面,损失函数在上面,这样会出现什么问题?

正向传递的时候,数据是从下往上一步一步往上传递;反向传递的时候,数据是从上面往下传递,这时候就会出现问题:梯度在上面的时候比较大,越到下面就越容易变小(因为是n个很小的数进行相乘,越到后面结果就越小,也就是说越靠近数据的,层的梯度就越小)


上面的梯度比较大,那么每次更新的时候上面的层就会不断地更新;但是下面层因为梯度比较小,所以对权重地更新就比较少,这样的话就会导致上面的收敛比较快,而下面的收敛比较慢,这样就会导致底层靠近数据的内容(网络所尝试抽取的网络底层的特征:简单的局部边缘、纹理等信息)变化比较慢,上层靠近损失的内容(高层语义信息)收敛比较快,所以每一次底层发生变化,所有的层都得跟着变(底层的信息发生变化就导致上层的权重全部白学了),这样就会导致模型的收敛比较慢。


所以提出了假设:能不能在改变底部信息的时候,避免顶部不断的重新训练?(这也是批量归一化所考虑的问题) 

1.2核心思想

为什么会变?因为方差和均值整个分布会在不同层之间变化

所以假设将分布固定,假设每一层的输出、梯度都符合某一分布,相对来说就是比较稳定的(具体分布可以做细微的调整,但是整体保持基本一致,这样的话,在学习细微的变动时也比较容易)

批量归一化:将不同层的不同位置的小批量(mini-batch)输出的均值和方差固定,均值和方差的计算方法如下图所示

在这个基础上做额外的调整,如下图所示 

1.3原理

2.批量规范化层

回想一下,批量规范化和其他层之间的一个关键区别是,由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。 我们在下面讨论这两种情况:全连接层和卷积层,他们的批量规范化实现略有不同。

2.1 全连接层

2.2 卷积层

2.3 总结

3. 代码实现

下面,我们从头开始实现一个具有张量的批量规范化层

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

我们现在可以创建一个正确的BatchNorm层。 这个层将保持适当的参数:拉伸gamma和偏移beta,这两个参数将在训练过程中更新。 此外,我们的层将保存均值和方差的移动平均值,以便在模型预测期间随后使用。

撇开算法细节,注意我们实现层的基础设计模式。 通常情况下,我们用一个单独的函数定义其数学原理,比如说batch_norm。 然后,我们将此功能集成到一个自定义层中,其代码主要处理数据移动到训练设备(如GPU)、分配和初始化任何必需的变量、跟踪移动平均线(此处为均值和方差)等问题。 为了方便起见,我们并不担心在这里自动推断输入形状,因此我们需要指定整个特征的数量。 不用担心,深度学习框架中的批量规范化API将为我们解决上述问题,我们稍后将展示这一点。

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

4. 使用批量规范化层的LeNet

为了更好理解如何应用BatchNorm,下面我们将其应用于LeNet模型( 6.6节)。 回想一下,批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

和以前一样,我们将在Fashion-MNIST数据集上训练网络。 这个代码与我们第一次训练LeNet( 6.6节)时几乎完全相同,主要区别在于学习率大得多。

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

让我们来看看从第一个批量规范化层中学到的拉伸参数gamma和偏移参数beta

net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([0.4863, 2.8573, 2.3190, 4.3188, 3.8588, 1.7942], device='cuda:0',
        grad_fn=<ReshapeAliasBackward0>),
 tensor([-0.0124,  1.4839, -1.7753,  2.3564, -3.8801, -2.1589], device='cuda:0',
        grad_fn=<ReshapeAliasBackward0>))

5. 简明实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1511838.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面向对象【static关键字】

文章目录 Java中的static关键字1. 静态变量2. 静态方法的特点3. 静态块4. 静态导入5. 单例模式中的应用 Java中的static关键字 在Java中&#xff0c;static是一个关键字&#xff0c;用于定义类级别的成员&#xff0c;这些成员与类的实例无关。static成员属于类而不是类的实例&…

怎么查看电脑是不是固态硬盘?简单几个步骤判断

随着科技的发展&#xff0c;固态硬盘&#xff08;Solid State Drive&#xff0c;简称SSD&#xff09;已成为现代电脑的标配。相较于传统的机械硬盘&#xff0c;固态硬盘在读写速度、稳定性和耐用性等方面都有显著优势。但是&#xff0c;对于不熟悉电脑硬件的用户来说&#xff0…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的水下目标检测系统(深度学习模型+UI界面+训练数据集)

摘要&#xff1a;本研究详述了一种采用深度学习技术的水下目标检测系统&#xff0c;该系统集成了最新的YOLOv8算法&#xff0c;并与YOLOv7、YOLOv6、YOLOv5等早期算法进行了性能评估对比。该系统能够在各种媒介——包括图像、视频文件、实时视频流及批量文件中——准确地识别水…

唯众物联网+地理科学交付云南师范大学地理学部教学实验室项目

近日&#xff0c;云南师范大学地理学部教学实验室建设项目顺利交付。该项目的成功落地&#xff0c;标志着物联网技术与地理科学教育的深度融合&#xff0c;为云南师范大学的地理教学提供了全新的教学平台与资源。该项目以物联网技术为核心&#xff0c;结合地理科学的特点&#…

新一代实时数据集成框架 Flink CDC 3.0 —— 核心技术架构解析

本文整理自阿里云开源大数据平台吕宴全关于新一代实时数据集成框架 Flink CDC 3.0 的核心技术架构解析&#xff0c;内容主要分为以下四部分&#xff1a; Flink CDC 演进历程 Flink CDC 3.0 的架构设计 Flink CDC 3.0 的核心实现 未来规划 一、Flink CDC 演进历程 Flink CD…

Web——HTML

一.HTML概述 超文本标记语言&#xff08;英语&#xff1a;HyperText Markup Language&#xff0c;简称&#xff1a;HTML&#xff09;是一种用于创建网页的标准标记语言。可以使用 HTML 来建立自己的 WEB 站点&#xff0c;HTML 运行在浏览器上&#xff0c;由浏览器来解析。 二.…

相机模型Omnidirectional Camera(全方位摄像机)

1. 背景 大多数商用相机都可以描述为针孔相机&#xff0c;通过透视投影进行建模。然而&#xff0c;有些投影系统的几何结构无法使用传统针孔模型来描述&#xff0c;因为成像设备引入了非常高的失真。其中一些系统就是全方位摄像机。 有几种方法可以制作全向相机。屈光照相机(D…

【Linux】yum及vim

目录 在Linux中安装软件&#xff08;以centos7为例&#xff09; Linux中的安装方式 安装的理解 什么是软件包&#xff1f; yum的一般使用流程 Linux开发工具 vi/vim 什么是vi/vim&#xff1f; vim的三大主要模式&#xff08;实际上还有很多模式&#xff0c;此篇不做介绍…

一款好用的AI工具——边界AICHAT(三)

目录 3.23、文档生成PPT演示3.24、AI文档翻译3.25、AI翻译3.26、论文模式3.27、文章批改3.28、文章纠正3.29、写作助手3.30、文言文翻译3.31、日报周报月报生成器3.32、OCR-DOC办公文档识别3.33、AI真人语音合成3.34、录音音频总结3.35、域方模型市场3.36、模型创建3.37、社区交…

焦点调制网络

摘要 https://arxiv.org/pdf/2203.11926.pdf 我们提出了焦点调制网络&#xff08;简称FocalNets&#xff09;&#xff0c;其中自注意力&#xff08;SA&#xff09;被焦点调制模块完全取代&#xff0c;用于在视觉中建模令牌交互。焦点调制包含三个组件&#xff1a;&#xff08;…

使用html+css制作一个发光立方体特效

使用htmlcss制作一个发光立方体特效 <!DOCTYPE html> <html lang"zh-CN"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>Documen…

SpringController返回值和异常自动包装

今天遇到一个需求&#xff0c;在不改动原系统代码的情况下。将Controller的返回值和异常包装到一个统一的返回对象中去。 例如原系统的接口 public String myIp(ApiIgnore HttpServletRequest request);返回的只是一个IP字符串"0:0:0:0:0:0:0:1"&#xff0c;目前接口…

第五篇【传奇开心果系列】Python的自动化办公库技术点案例示例:深度解读Pandas在教育数据和研究数据处理领域的应用

传奇开心果博文系列 系列博文目录Python的自动化办公库技术点案例示例系列 博文目录前言一、Pandas 在教育和学术研究中的常见应用介绍二、数据清洗和预处理示例代码三、数据分析和统计示例代码四、数据可视化示例代码五、时间序列分析示例代码六、数据导入和导出示例代码七、数…

【数据挖掘】练习1:R入门

课后作业1&#xff1a;R入门 一&#xff1a;习题内容 1.要与R交互必须安装Rstudio&#xff0c;这种说法对不对&#xff1f; 不对。虽然RStudio是一个流行的R交互集成开发环境&#xff0c;但并不是与R交互的唯一方式。 与R交互可以采用以下几种方法&#xff1a; 使用R Conso…

Qt/C++音视频开发69-保存监控pcm音频数据到mp4文件/监控录像/录像存储和回放/264/265/aac/pcm等

一、前言 用ffmpeg做音视频保存到mp4文件&#xff0c;都会遇到一个问题&#xff0c;尤其是在视频监控行业&#xff0c;就是监控摄像头设置的音频是PCM/G711A/G711U&#xff0c;解码后对应的格式是pcm_s16be/pcm_alaw/pcm_mulaw&#xff0c;将这个原始的音频流保存到mp4文件是会…

【企业战略转型】某音响制造公司发展战略转型管理咨询项目纪实

案例&#xff1a;【客户评价】日本M汽车音响有限公司田总经理&#xff1a;受经济大环境的影响&#xff0c;我公司原有的依赖企业下订单的业务模式受到很大的影响&#xff0c;企业进入“不进则退”的重要转型阶段。当企业生存的关键因素&#xff0c;我们作为典型的OEM汽车音响代…

unity学习(57)——选择角色界面--删除角色2

1.客户端添加点击按钮所触发的事件&#xff0c;在selectMenu界面中增加myDelete函数&#xff0c;当点击“删除角色”按钮时触发该函数的内容。 public void myDelete() {string message nowPlayer.id;//string m Coding<StringDTO>.encode(message);NetWorkScript.get…

前端之用HTML做一个汇款单

例子 代码 里面注释是我我对运用到的知识的理解 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>工商银行电子汇款单</title> </head> <body><h3>工商银行电子汇款单</…

python疑难杂症(10)---Python函数def的定义分类,包括内置函数、外置函数、匿名函数、闭包函数、生成器函数等

本部分详细讲解Python函数的定义、常见的函数类型&#xff0c;尤其是特色函数包括内置函数、外置函数、匿名函数、闭包函数、生成器函数等以及用法。后续将对这类函数重点讲解使用方法。 函数定义&#xff1a; 函数是大多数编程语言使用的一个概念&#xff0c;函数是一段具有…

题目 2610: 第十二届省赛真题-杨辉三角形

题目描述: 下面的图形是著名的杨辉三角形&#xff1a; 如果我们按从上到下、从左到右的顺序把所有数排成一列&#xff0c;可以得到如下 数列&#xff1a; 1, 1, 1, 1, 2, 1, 1, 3, 3, 1, 1, 4, 6, 4, 1, ... 给定一个正整数 N&#xff0c;请你输出数列中第一次出现 N 是在第几…