动手学深度学习—批量规范化(代码详解)

news2025/1/11 20:06:26

批量规范化

      • 1. 训练深层网络
      • 2. 批量规范化层
        • 2.1 全连接层
        • 2.2 卷积层
      • 3. 从零实现批量规范化层
      • 4. 使用批量规范化层的 LeNet

批量规范化(batch normalization),可持续加速深层网络的收敛速度。

1. 训练深层网络

  1. 数据预处理的方式通常会对最终结果产生巨大影响。
  2. 对于典型的多层感知机或卷积神经网络。训练时,中间层中的变量(例如,多层感知机中的仿射变换输出)可能具有更广的变化范围。
  3. 更深层的网络很复杂,容易过拟合。 这意味着正则化变得更加重要。

批量规范化:

  • 每次训练迭代中,首先规范输入,即减去均值并除以其标准差,其中两者均基于当前小批量处理。
  • 接下来,应用比例系数和偏移系数。
  • 因为是基于批量统计的标准化,才有了批量规范化的名称。
    在这里插入图片描述
    γ和β是需要与其他模型参数一起学习的参数
    均值和方差如下图公式所示
    在这里插入图片描述

2. 批量规范化层

全连接层和卷积层的批量规范化实现略有不同。

2.1 全连接层

对于全连接层,将批量规范化层置于全连接层中的仿射变换和激活函数之间。
在这里插入图片描述

2.2 卷积层

对于卷积层,在卷积层之后和非线性激活函数之前应用批量规范化。

假设我们的小批量包含m个样本,并且对于每个通道,卷积的输出具有高度p和宽度q。
那么对于卷积层,我们在每个输出通道的mpq个元素上同时执行每个批量规范化。

3. 从零实现批量规范化层

从头开始实现一个具有张量的批量规范化层。

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled方法来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # eps->方差估计值添加一个小的常量ε>0,以确保永远不会尝试除以0
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 卷积层:(batch_size, in_channels, height, weight)
            # 使用卷积层的情况,计算通道维上(axis=1)的均值和方差
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和标准差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta # 缩放和移位
    return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):
    # num_features:全连接层的输出数量或卷积层的输出通道数
    # num_dims:2表示全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸参数和偏移参数,其分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)
        
    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在的显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
            # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean, 
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

4. 使用批量规范化层的 LeNet

在LeNet模型上使用BatchNorm。

# 将其应用于LeNet模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

定义精度评估函数

"""
    定义精度评估函数:
    1、将数据集复制到显存中
    2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save
    # 判断net是否属于torch.nn.Module类
    if isinstance(net, nn.Module):
        net.eval()
        
        # 如果不在参数选定的设备,将其传输到设备中
        if not device:
            device = next(iter(net.parameters())).device
    
    # Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。
    metric = d2l.Accumulator(2)
    with torch.no_grad():
        for X, y in data_iter:
            # 将X, y复制到设备中
            if isinstance(X, list):
                # BERT微调所需的(之后将介绍)
                X = [x.to(device) for x in X]
            else:
                X = X.to(device)
            y = y.to(device)
            
            # 计算正确预测的数量,总预测的数量,并存储到metric中
            metric.add(d2l.accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

定义GPU训练函数

"""
    定义GPU训练函数:
    1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;
    2、使用Xavier随机初始化模型参数;
    3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练模型(在第六章定义)"""
    # 定义初始化参数,对线性层和卷积层生效
    def init_weights(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    
    # 在设备device上进行训练
    print('training on', device)
    net.to(device)
    
    # 优化器:随机梯度下降
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    
    # 损失函数:交叉熵损失函数
    loss = nn.CrossEntropyLoss()
    
    # Animator为绘图函数
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
                            legend=['train loss', 'train acc', 'test acc'])
    
    # 调用Timer函数统计时间
    timer, num_batches = d2l.Timer(), len(train_iter)
    
    for epoch in range(num_epochs):
        
        # Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量
        metric = d2l.Accumulator(3)
        net.train()
        
        # enumerate() 函数用于将一个可遍历的数据对象
        for i, (X, y) in enumerate(train_iter):
            timer.start() # 进行计时
            optimizer.zero_grad() # 梯度清零
            X, y = X.to(device), y.to(device) # 将特征和标签转移到device
            y_hat = net(X)
            l = loss(y_hat, y) # 交叉熵损失
            l.backward() # 进行梯度传递返回
            optimizer.step()
            with torch.no_grad():
                # 统计损失、预测正确数和样本数
                metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
            timer.stop() # 计时结束
            train_l = metric[0] / metric[2] # 计算损失
            train_acc = metric[1] / metric[2] # 计算精度
            
            # 进行绘图
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (train_l, train_acc, None))
                
        # 测试精度
        test_acc = evaluate_accuracy_gpu(net, test_iter) 
        animator.add(epoch + 1, (None, None, test_acc))
        
    # 输出损失值、训练精度、测试精度
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'
          f'test acc {test_acc:.3f}')
    
    # 设备的计算能力
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'
          f'on {str(device)}')

在这里插入图片描述

训练LeNet模型

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述
在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1132458.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Stable Diffusion AI绘图

提示词: masterpiece, best quality, 1girl, (anime), (manga), (2D), half body, perfect eyes, both eyes are the same, Global illumination, soft light, dream light, digital painting, extremely detailed CGI anime, hd, 2k, 4k background 反向提示词&…

微机原理:汇编指令集——调用传送指令、算术运算指令、转移类指令(详解)

文章目录 一、通用传送类指令1、数据传送指令2、堆栈操作指令 二、算术运算指令1、总图2、加减运算指令2.1 例子2.2 INC/DEC指令 3、比较指令 三、转移类指令1、无条件转移2、有条件转移2.1 无符号数条件转移指令2.2 有符号数条件转移指令2.3 例题一2.4 循环控制指令&#xff0…

【golang】Go中的切片slice和操作笔记,垃圾回收机制,重组 reslice ,复制和追加,内存结构

切片 文章目录 切片将切片传递给函数make() 创建一个切片new() 和 make()的区别多维切片bytes包for-range切片重组 reslice切片的复制和追加 字符串、数组和切片的应用获取字符串的某一部分字符串和切片的内存结构修改字符串中的某个字符字节数组对比函数搜索及排序切片和数组a…

非侵入式负荷检测与分解:电力数据挖掘新视角

电力数据挖掘 概述案例背景分析目标分析过程数据准备数据探索缺失值处理 属性构造设备数据周波数据模型训练 性能度量推荐阅读 主页传送门:📀 传送 概述 摘要:本案例将根据已收集到的电力数据,深度挖掘各电力设备的电流、电压和功…

全网最全面最深入 剖析华为“五看三定”战略神器中的“五看”(即市场洞察)(长文干货,建议收藏)

添加图片注释,不超过 140 字(可选) (本文摘自谢宁专著《华为战略管理法:DSTE实战体系》,欢迎购买) 兵法有云:胜兵先胜而后求战,败兵先战而后求胜,所谓胜兵先…

对被测软件来说,需要多少测试就足够了?

相信每位测试人员或者测试团队都曾遇到这样的问题“需要多少测试才能确保软件成功发布”。这个答案很难回答,在很大程度上,这取决于被测软件的类型、用途和目标受众。所有的测试人员都希望用一种比测试手电筒的应用程序更严格的方法来测试其他软件。然而…

JavaScript异步编程:提升性能与用户体验

目录 什么是异步编程? 回调函数 Promise Async/Await 总结 在Web开发中,处理耗时操作是一项重要的任务。如果我们在执行这些操作时阻塞了主线程,会导致页面失去响应,用户体验下降。JavaScript异步编程则可以解决这个问题&…

睿趣科技:抖音开网店多久回本

随着互联网的发展,越来越多的人选择在抖音上开设网店。然而,开店容易,经营难。许多人关心的问题是:抖音开网店多久能回本? 首先,我们需要明确一点,抖音开网店的回本时间并不是固定的,它受到许多…

经典卷积神经网络 - NIN

网络中的网络,NIN。 AlexNet和VGG都是先由卷积层构成的模块充分抽取空间特征,再由全连接层构成的模块来输出分类结果。但是其中的全连接层的参数量过于巨大,因此NiN提出用1*1卷积代替全连接层,串联多个由卷积层和“全连接”层构成…

C语言:杨氏矩阵、杨氏三角、单身狗1与单身狗2

下面介绍四道题目和解法 1.杨氏矩阵 算法:右上角计算 题目:有一个数字矩阵,矩阵的每行从左到右是递增的,矩阵从上到下是递增的,请编写程序在这样的矩阵中查找某个数字是否存在。 要求:时间复杂度小于O(N…

react笔记基础部分(组件生命周期路由)

注意点&#xff1a; class是一个关键字&#xff0c; 类。 所以react 写class, 用classname &#xff0c;会自动编译替换class 点击方法&#xff1a; <button onClick {this.sendData}>给父元素传值</button>常用的插件&#xff1a; 需要引入才能使用的&#xf…

ubuntu执行普通用户或root用户执行apt-get update时报错Couldn‘t create temporary file /tmp/...

apt-get update无法更新&#xff0c;报错&#xff1a; Couldnt create temporary file /tmp/apt.conf.GSzv74 for passing config to&#xff0c;&#xff0c;&#xff0c; 这是由于/tmp目录没有权限导致的&#xff0c;解决办法&#xff1a; chmod 777 /tmp

额定电压输出电流:电源性能测试指标之一

额定电压和额定电流是电源设计生产时需要考虑的两个重要参数&#xff0c;额定电压是电源输出的电压标准&#xff0c;额定电流是电源能够提供的最大电流容量。这两个参数是评估电源性能的重要指标之一&#xff0c;指导着电气设备的正常工作运行。 额定电压输出电流测试方法 额定…

上门家政维修多城市代理多商户师傅入驻小程序开源版开发

上门家政维修多城市代理多商户师傅入驻小程序开源版开发 用户登录/注册&#xff1a;用户可以使用手机号或第三方账号登录或注册小程序。 服务分类&#xff1a;在主页上显示不同的服务分类&#xff0c;例如电器维修、家具拆装、管道疏通、清洁保洁等。 城市选择&#xff1a;用…

C++反转链表递归

文章目录 题目描述解题思路代码复杂度分析 题目描述 LCR 024. 反转链表 - 力扣&#xff08;LeetCode&#xff09; 给定单链表的头节点 head &#xff0c;请反转链表&#xff0c;并返回反转后的链表的头节点。 解题思路 这里我们采用递归的思路来解决首先我们分为两个视角来查看…

竞赛选题 深度学习卫星遥感图像检测与识别 -opencv python 目标检测

文章目录 0 前言1 课题背景2 实现效果3 Yolov5算法4 数据处理和训练5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; **深度学习卫星遥感图像检测与识别 ** 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学长非常推荐…

超好用的数据可视化工具推荐,小白也适用!

Excel、Tableau……可以做数据可视化的工具不少&#xff0c;但简单、好用又高效&#xff0c;甚至连无SQL基础的小白也能轻松使用的就真没几个。奥威BI数据可视化工具是少有的操作难度低、成本支出低、灵活自助分析能力强的BI工具。 1、操作难度低 奥威BI数据可视化工具的操作…

图片放大镜效果

安装&#xff1a; vueuse 插件 npm i vueuse/core 搜索&#xff1a; useMouseInElement 方法 <template><div ref"target"><h1>Hello world</h1></div> </template><script> import { ref } from vue import { useM…

图纸管理制度《三》

一、目的和使用范围 为了更好的规范设备及设计图纸的保管、发放和使用&#xff0c;根据业主仅提供四套图纸的实际情况&#xff0c;本着施工图纸服务施工的第一原则&#xff0c;合理利用有限的图纸资源&#xff0c;将《管理制度汇编》中的图纸管理制度进行细化&#xff0c;制定本…

视频与png图片批量分类技巧:轻松管理文件

在我们的日常工作中&#xff0c;经常会遇到需要处理大量文件的情况&#xff0c;其中就包括视频和png图片。这些文件数量繁多&#xff0c;如果一个个手动分类&#xff0c;不仅耗时而且容易出错。因此&#xff0c;掌握批量分类技巧成为了高效管理文件的关键。本文将为您运用云炫文…