动手学深度学习(Pytorch版)代码实践 -卷积神经网络-28批量规范化

news2024/11/26 2:56:17

28批量规范化

"""可持续加速深层网络的收敛速度"""
import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    """实现一个具有张量的批量规范化层。"""
    # 如果启用了梯度计算,torch.is_grad_enabled() 返回 True;否则返回 False。
    if not torch.is_grad_enabled():
        # torch.no_grad() 是一个上下文管理器,用于临时禁用梯度计算
        # torch.enable_grad() 是一个上下文管理器,用于在禁用梯度计算的上下文中重新启用梯度计算。
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0) # 计算张量 X 沿着第 0 维的平均值
            # 维度 0 代表样本数量,即沿着每个特征计算所有样本的平均值。
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)

        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    # gamma 和 beta 的更新是通过反向传播和优化器自动完成的
    Y = gamma * X_hat + beta # 缩放和移位
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        
        # 非模型参数的变量初始化为0和1
        # 经过归一化处理后的数据均值接近于零。因此,将滑动均值初始化为0,是对数据初始均值的一种合理假设。
        self.moving_mean = torch.zeros(shape)
        # 方差表示数据的离散程度。将滑动方差初始化为1,意味着假设数据的初始方差为1,
        # 即数据分布接近标准正态分布。这样初始化可以避免初始阶段的数值不稳定。
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在GPU上                              
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        
         # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

#使用批量规范化层的 LeNet
net = nn.Sequential(
    nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2
    BatchNorm(num_features=6, num_dims=4),
    nn.ReLU(), # 激活函数
    nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2

    nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5
    BatchNorm(num_features=16, num_dims=4),
    nn.ReLU(), 
    nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2

    nn.Flatten(), # 展平层:将多维输入展平为1维
    nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120
    BatchNorm(num_features=120, num_dims=2),
    nn.ReLU(),
    nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84
    BatchNorm(num_features=84, num_dims=2),
    nn.ReLU(), 
    nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size)
# lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# plt.show()

# loss 0.200, train acc 0.925, test acc 0.812
# 34957.3 examples/sec on cuda:0

# loss 0.189, train acc 0.928, test acc 0.894
# 33471.2 examples/sec on cuda:0


#简明实现
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.ReLU(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.ReLU(),
    nn.Linear(84, 10)
)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()

# nn.Sigmoid()
# loss 0.263, train acc 0.902, test acc 0.833
# 46935.0 examples/sec on cuda:0

# nn.ReLU()
# loss 0.224, train acc 0.914, test acc 0.874
# 44479.2 examples/sec on cuda:0
"""
通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。
"""

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1852360.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ultralytics官方更新 | 添加YOLOv10到ultralytics

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡 专栏目录:《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有40篇内容,内含各种Head检测头、损失函数Loss、…

【idea-jdk1.8】使用Spring Initializr 创建 Spring Boot项目没有JDK8

信息差真可怕! 很久没创建springboot项目,今天使用idea的Spring Initializr 创建 Spring Boot项目时,发现java版本里,无法选择jdk1.8,只有17、21、22;前段时间也听说过,springboot将放弃java8&a…

【深度学习驱动流体力学】计算流体力学openfoam-paraview与python3交互

目的1:配置 ParaView 中的 Python Shell 和 Python 交互环境 ParaView 提供了强大的 Python 接口,允许用户通过 Python 脚本来控制和操作其可视化功能。在 ParaView 中,可以通过 View > Python Shell 菜单打开 Python Shell 窗口,用于执行 Python 代码。要确保正确配置 …

GitLab配置免密登录之后仍然需要Git登录的解决办法

GitLab配置免密登录之后仍然需要Git登录的解决办法 因为实习工作需要,要在本地拉取gitlab上的代码,设置了密钥之后连接的时候还需要登录的token,摸索之后有了下面的解决办法。 方法一: 根据报错的提示,去网站上设置个人…

加速鸿蒙生态共建,蚂蚁mPaaS助力鸿蒙原生应用开发创新

6月21日-23日,2024华为开发者大会(HDC 2024)如期举行。在22日的【鸿蒙生态伙伴SDK】分论坛中,正式发布了【鸿蒙生态伙伴SDK市场】,其中蚂蚁数科旗下移动开发平台mPaaS(以下简称:蚂蚁mPaaS&#…

How to use ModelSim

How to use ModelSim These are all written by a robot Remember, you can only simulate tb files.

SD卡无法读取?原因分析与数据恢复策略

一、SD卡无法读取的困境 SD卡作为便携式的存储介质,广泛应用于手机、相机、平板等多种电子设备中。然而,在使用过程中,我们可能会遭遇SD卡无法读取的困扰。当我们将SD卡插入设备时,设备无法识别SD卡,或者虽然识别了SD…

学习使用js和jquery修改css路径,实现html页面主题切换功能

学习使用js和jquery修改css路径&#xff0c;实现html页面主题切换功能 效果图html代码jquery切换css关键代码js切换css关键代码 效果图 html代码 <!DOCTYPE html> <html> <head><meta charset"utf-8"><title>修改css路径</title&g…

通俗解释resultType和resultMap的区别

【 1 对于单表而言&#xff1a; 注&#xff1a;以下都是摘抄过来的&#xff0c;做了让自己更能理解的版本 如果数据库返回结果的列名和要封装的实体的属性名完全一致的话用 resultType 属性 如果数据库返回结果的列名&#xff08;起了别名&#xff09;和要封装的实体的属性名…

PHP题目

一.编写函数change($str)实现字符串转换功能&#xff0c;例如“str_replace”转换成“str%replace”、“arr_var”转换成“arr%var”。 <?php function change($str){$astr_replace(_,%,$str);return $a; } echo change(str_replace); ?> 运行结果&#xff1a; 二.通…

spring-依赖注入DI

Setter注入&#xff1a; 1、引用类型&#xff1a;在bean中定义引用类型属性并提供可访问的set方法&#xff0c;配置中使用property标签ref属性注入引用类型对象&#xff1b; 2、简单类型&#xff1a;在bean中定义引用类型属性并提供可访问的set方法&#xff0c;在配置中使用pr…

NavicatforMySQL11.0软件下载-NavicatMySQL11最新版下载附件详细安装步骤

我们必须承认Navicat for MySQL 支援 Unicode&#xff0c;以及本地或远程 MySQL 服务器多连线&#xff0c;使用者可浏览数据库、建立和删除数据库、编辑数据、建立或执行 SQL queries、管理使用者权限&#xff08;安全设定&#xff09;、将数据库备份/复原、汇入/汇出数据&…

md5在ida中的识别

ida中 识别md5 ,先右键转为hex 或者按h _DWORD *__fastcall MD5Init(_DWORD *result) {*result 0;result[1] 0;result[2] 1732584193;result[3] -271733879;result[4] -1732584194;result[5] 271733878;return result; }在ida中当然也可以使用搜索 search imdate-value …

分布式系统的演化(单机架构/应用符合和存储服务分离架构/应用服务集群架构/主从分离架构/冷热分离架构)

文章目录 单机架构应用服务和存储服务分离应用服务集群架构读写分离/主从分离架构冷热分离架构--引入缓存分库分表 单机架构 单机架构只有一台服务器&#xff0c;使用一台服务器负责所有的工作 举个例子&#xff1a;假设有以下电商网站&#xff0c;商品、用户、交易等功能服务…

实验六:三维图形修改器的综合应用

如果文章有写的不准确或需要改进的地方&#xff0c;还请各位大佬不吝赐教&#x1f49e;&#x1f49e;&#x1f49e;。朱七在此先感谢大家了。&#x1f618;&#x1f618;&#x1f618; &#x1f3e0;个人主页&#xff1a;语雀个人知识库 &#x1f9d1;个人简介&#xff1a;大家…

【STM32-启动文件 startup_stm32f103xe.s】

STM32-启动文件 startup_stm32f103xe.s ■ STM32-启动文件■ STM32-启动文件主要做了以下工作&#xff1a;■ STM32-启动文件指令■ STM32-启动文件代码详解■ 栈空间的开辟■ 栈空间大小 Stack_Size■ .map 文件的详细介绍■ 打开map文件 ■ 堆空间■ PRESERVE8 和 THUMB 指令…

OCC显示渲染结构剖析

1.Display显示 2.Drawer 3.Graphics 4.InteractiveContext 5.Render 6.Selection 7.View

探索计算机视觉(人工智能重要分支)的发展与应用

引言 在当今快速发展的科技时代&#xff0c;计算机视觉作为人工智能领域的重要分支&#xff0c;正日益成为各行各业不可或缺的关键技术。从简单的图像处理到复杂的智能系统&#xff0c;计算机视觉的发展不仅改变了我们看待世界的方式&#xff0c;也深刻影响着工业、医疗、交通等…

数据结构与算法引入(Python)

华子目录 引入第一次尝试第二次尝试 算法的概念算法的五大特性 算法效率衡量执行时间单靠时间值绝对可信吗&#xff1f; 时间复杂度与 "大O记法"如何理解 “大O记法” 最坏时间复杂度时间复杂度的几条基本计算规则 算法分析常见的时间复杂度常见时间复杂度之间的关系…