批量规范化

news2025/1/12 23:00:11

  ✨✨✨

感谢优秀的你打开了小白的文章

希望在看文章的你今天又进步了一点点,生活更加美好!🌈🌈🌈

 目录

1.批量规范化基本原理 

2.批量规范化的使用

2.1对于全连接层 

2.2对于卷积层

3.代码实现

3.1方式一 

3.2方式二


1.批量规范化基本原理 

        仅仅对原始输入数据进行标准化是不充分的,因为虽然这种做法可以保证原始输入数据的质量,但它却无法保证隐藏层输入数据的质量。浅层参数的微弱变化经过多层线性变换与激活函数后被放大,改变了每一层的输入分布,造成深层的网络需要不断调整以适应这些分布变化,最终导致模型难以训练收敛。 

        神经网络的发展正朝着越来越大,越来越深的地方发展,更深层的网络很复杂,容易过拟合。 这意味着正则化变得更加重要。批量规范化应用于单个可选层(也可以应用到所有层)批量规范化是一种流行且有效的技术,可持续加速深层网络的收敛速度,但一般不改变模型精度。批量规范化应用于单个可选层(也可以应用到所有层),其原理如下:在每次训练迭代中,我们首先规范化输入,即通过减去其均值并除以其标准差,其中两者均基于当前小批量处理。

        如果我们尝试使用大小为1的小批量应用批量规范化,我们将无法学到任何东西。 这是因为在减去均值之后,每个隐藏单元将为0。 所以,只有使用足够大的小批量,批量规范化这种方法才是有效且稳定的。在应用批量规范化时,批量大小的选择可能比没有批量规范化时更重要。

2.批量规范化的使用

        批量规范化和其他层之间的一个关键区别是,由于批量规范化在完整的小批量上运行,因此我们不能像以前在引入其他层时那样忽略批量大小。 一般会比较考虑与全连接层和卷积层的关系,它们的批量规范化实现略有不同。

2.1对于全连接层 

2.2对于卷积层

        同样,对于卷积层,我们可以在卷积层之后和非线性激活函数之前应用批量规范化。 当卷积有多个输出通道时,我们需要对这些通道的“每个”输出执行批量规范化,每个通道都有自己的拉伸(scale)和偏移(shift)参数,这两个参数都是标量。 假设我们的小批量包含m个样本,并且对于每个通道,卷积的输出具有高度p和宽度q。 那么对于卷积层,我们在每个输出通道的m*p*q个元素上同时执行每个批量规范化。 因此,在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。 

3.代码实现

3.1方式一 

实现一个具有张量的批量规范化层 

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

拉伸gamma和偏移beta,这两个参数将在训练过程中更新。 将保存均值和方差的移动平均值,以便在模型预测期间随后使用。 

接下来将自定义层:

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

将其使用在LeNet中:

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

进行训练,并得到训练结果 

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

3.2方式二

也可以直接使用深度学习框架中定义的BatchNorm 

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/720546.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

26488-24-4,Cyclo(D-Phe-L-Pro),具有良好的生物相容性

资料编辑|陕西新研博美生物科技有限公司小编MISSwu​ 【产品描述】 Cyclo(D-Phe-L-Pro)环(D-苯丙氨酸-L-脯氨酸),环二肽是由两个氨基酸通过肽键环合形成,是自然界中小的环状肽。由于其存在两个酰胺键即四个可以形成氢键的位点,环二肽可以在氢…

商业海外社交媒体营销10步指南 [2023]

如今,社交媒体是任何成功商业战略的重要组成部分。这不是奢侈品,而是必需品。全球有超过 36 亿人使用社交媒体,它是企业展示其产品和服务、建立品牌知名度以及与客户联系的数字游乐场。 但这不仅仅是娱乐和游戏。要在社交媒体上取得成功&…

Golang每日一练(leetDay0114) 矩阵中的最长递增路径、按要求补齐数组

目录 329. 矩阵中的最长递增路径 Longest Increasing Path In A Matrix 🌟🌟 330. 按要求补齐数组 Patching Array 🌟🌟 🌟 每日一练刷题专栏 🌟 Rust每日一练 专栏 Golang每日一练 专栏 Python每日…

数据结构--二叉树的性质

数据结构–二叉树的性质 二叉树常考性质 常见考点1: 设非空二叉树中度为0、1和2的结点个数分别为 n 0 、 n 1 和 n 2 ,则 n 0 n 2 1 n_0、n_1和n_2,则n_0 n_2 1 n0​、n1​和n2​,则n0​n2​1 n 0 n 2 1 \color{red}n_0 n_2 1 n0​n2​…

图层中大型数据集的分块处理思路

图层中大型数据集的分块处理思路 为改善要素叠加工具(如联合和相交)的性能和可伸缩性,软件采用了称为自适应细分处理的运算逻辑。当可用的物理内存不足以对数据进行处理时,就会触发系统使用此逻辑。由于保持在物理内存的可用范围…

助力企业完成等保2.0的重要工具

在当今数字化时代,企业面临着越来越多的网络安全威胁和数据泄露风险。为了保护敏感信息和维护业务的连续性,许多企业正在积极采取措施来实施等保2.0标准。在这一过程中,EventLog Analyzer作为一种全面的安全信息与事件管理解决方案&#xff0…

swagger2word使用(将swagger2转化为word)

开源项目地址 https://github.com/JMCuixy/swagger2word 项目使用 1、项目拉下来以后先修改application.xml配置文件红框内容,将其修改为要换为自己swagger文档的地址 2、运行项目后输入地址http://127.0.0.1:8080/toWord 即可下载word文档

结构体和数据结构--共用体

共用体,也称联合(Union),是将不同类型的数据组织在一起共同占用同一段内存的一种构造数据类型。共用体与结构体的类型声明方法类似,只是关键字变为了Union。 例题:演示共用体所占内存字节数的计算方法 #i…

如何用手机制作3D人物模型素材

3D人物模型素材是现代3D游戏和电影制作中必不可少的一部分。它们是数字艺术家和设计师们用来创造逼真世界的关键。3D人物模型素材是用计算机程序制作的虚拟人物,可以被用于电影、电视、游戏和虚拟现实应用中。它们可以被用来代替实际演员,也可以被用来创…

小程序蓝牙通信

蓝牙通信能力封装 一开始是根据uniapp提供的蓝牙api写的蓝牙方法,之后发现复用性,以及一些状态的监听存在缺陷,之后整理成了类。这样复用性以及状态监听的问题就解决了。 蓝牙组件 创建蓝牙组件的类 单例模式是为了保证蓝牙长连接&#xff0…

前端(一)——前端开发遇到的普遍问题以及解决策略

😄博主:小猫娃来啦 😄文章核心:前端开发遇到的普遍问题以及解决策略 前端十万个为什么? 有人说vue框架是基于mvvm实现的?这种说法对吗? mvc和mvvm的区别是什么? mvvm是否是mvc的升…

内容文本生成二维码用excel表格导出(java)

内容文本生成二维码用excel表格导出(java) //若有问题可留言 效果如下: import java.io.ByteArrayOutputStream; import java.io.FileOutputStream; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map;import org.apache.po…

大厂股权就是这么“坑”,150万股票到账前被优化,损失惨重

某网友发文称:“自己还有47天就可以解锁股权,到时就有150万的股票到账,结果接到公司裁员通知,实在是淌血,我能反抗吗?” 对这我只能说,公司卡的就是这个点。所以大家在找工作的时候,…

SparkJDBC性能优化指南

前言 本文以Mysql为例。Spark作为一种强大且广泛应用于大数据处理的分布式计算框架,有着出色的性能和可伸缩性。在使用Spark处理大规模数据时,往往需要与关系型数据库MySQL进行交互。然而,由于MySQL和Spark本身的特性之间存在一些差异,直接使用Spark读写MySQL的默认配置可…

SQL 查找重复的电子邮箱

SQL 182 查找重复的电子邮箱 SQL架构 表: Person -------------------- | Column Name | Type | -------------------- | id | int | | email | varchar | -------------------- id 是该表的主键列。 此表的每一行都包含一封电子邮件。电子邮件不包含大写字母。 编写一个 SQ…

线性DP-入门篇

目录 数字三角形: 最长上升子序列: 魔族密码: 编辑距离: 线性动态规划的主要特点是状态转移的推导是按照问题规模 从小到大依次推导,较大规模的问题的解依赖较小规模的问题的解。 数字三角形: [USA…

大模型是什么

在计算机领域,大模型’是一个近年来备受关注的词汇。这篇文章旨在带你遨游大模型的世界,了解它们的特点、优缺点,以及需如何有效地利用它们。我们还会探讨一些具体的大模型实例,并分析其对人类社会的影响。 首先,我们…

Android Studio实现内容丰富的安卓博客发布平台

如需源码可以添加q-------3290510686,也有演示视频演示具体功能,源码不免费,尊重创作,尊重劳动。 项目编号078 1.开发环境 android stuido jdk1.8 eclipse mysql tomcat 2.功能介绍 安卓端: 1.注册登录 2.查看博客列表…

@项目经理:写好简历其实只要2步,保证你offer拿到手软!

早上好,我是老原。 混职场,最重要的是什么?还是能赚到钱。 有人说,重要的是开心。这么说吧,我身边那些赚得多的,没几个不开心的。 很多人赚不到钱,归结为自己能力差,不够努力。 …

年度好用的8款AI绘画工具,第1款一定要看

本文总结了8款2023年年度好用的AI绘画工具,它们结合了最新的技术和创新的设计理念,能帮助设计师将创意变为创作,一起来看看吧! 1.即时AI灵感 即时AI灵感作为一款国产的AI绘图工具,采用了先进的自然语言处理和图像生成…