深度学习——批量归一化(笔记)

news2024/12/22 21:00:03

主流的卷积网络基本都设计批量归一化这个层

1.为什么要批量归一化?

 ①网络层很深,数据在底层,损失函数在最顶层。反向传播后,顶层的梯度大,所以顶层训练的较快。数据在底层,底层的梯度小,底层训练的慢。(学习率不变)

②数据在最底部,底部层的训练较慢。

Ⅰ底部层一变化,所有的层都要跟着变化

Ⅱ顶部的层需要重新学习很多次

Ⅲ导致收敛变慢

③批量归一化可以在学习底部层的时候避免变化顶部层。

2.批量归一化的思想:

①固定小批量里面的均值和方差(因为方差和均值的分布在不同的层有变化 )

 ②可学习参数:如果均值为0,方差为1的分布不是很适合的话,可以学习一个均值和方差使得对网络更好一些。

 

3.批量归一化层

①可学习的参数γ和β

②作用在:

Ⅰ全连接层和卷积层输出之后,(批量归一化层这里:线性变换,均值方差拉的比较好)激活函数之前

Ⅱ全连接层和卷积层的输入上

③对全连接层,作用在特征维 ( 二维的输入,每一行就是样本,每一列就是特征。全连接层对每一个特征计算标量的均值和方差。不一样的是每一个全连接层的输入和输出都做这件事,不只是做在数据上面。而是重新用学到的γ和β,对方差和均值做校验)

④对卷积层,作用在通道维(1*1卷积等价于全连接层,对于每一个像素有多通道。比如有一个像素对应通道是100,这个像素有100维的向量。向量就是这个像素的特征。所以在输入的时候,每一个像素就是一个样本。

卷积层输入:批量大小*高*宽*通道数  样本数:批量大小*高*宽。所以就是所有的像素当作样本,一个像素对应的所有通道当作特征)

 4.批量归一化的作用是什么

①最初的论文减少内部协变量的转移

②后续论文,通过每个小批量里加入噪音控制模型复杂度

 加入随机偏移和方差,然后通过学到的稳定的均值方差,使得变化不剧烈

③没必要和丢弃法一起

【总结】

①批量归一化固定小批量中的均值和方差,然后学习出适合偏移和缩放

②可以加速收敛速度,学习率可以变大,但一般不改变模型的精度

【代码实现】

import torch
from torch import nn
from d2l import torch as d2l


# X是输入 ,gamma-beta可学习参数,moving_mean,moving_var全局的均值和方差,是在预测的时候使用,eps避免分母出现0,momentum更新移动的平均和方差通常取0.9
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:  # 训练模式下
        assert len(X.shape) in (2, 4)  # 2是全连接层(两个维度,batch和全连接层的大小) 4是卷积层(batch,通道数,高,宽)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)  # 按行求出来 对同一列元素求均值
            var = ((X - mean) ** 2).mean(dim=0)  # 方差
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)  # 按通道数求均值
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)  # 按通道数求方差
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data


class BatchNorm(nn.Module):
    # num_features:完全连接层的输出数量或卷积层的输出通道数。
    # num_dims:2表示完全连接层,4表示卷积层
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.ones(shape)

    def forward(self, X):
        # 如果X不在内存上,将moving_mean和moving_var
        # 复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y


# 应用 LeNet模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16 * 4 * 4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

# 训练
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

简易实现

import torch
from torch import nn
from d2l import torch as d2l

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/72026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【开源项目】震惊JDBC查询比MyBatis查询慢

震惊JDBC查询比MyBatis查询快? 文章编写起始原因,在编写项目的时候碰到一个深坑,JDBC获取5000条数据,居然耗时261s,MyBatis同样的操作,耗时12s左右,震惊。看到这里下巴都快掉下来了。不是网上都…

Pyqt5 Key value动态创建 QTreeWidget

在自己的应用上,需要根据读取的 值来创建 目录与子页,并打开对应的界面 实现思路 1、定义数组 存放 {(Key value index ).....(Key_n value_n index_n )} 2、获取相关数据&#x…

【Java开发】 Spring 09 :Spring Data REST 实现并访问简单的超媒体服务

Spring Data REST 是提供一个灵活和可配置的机制来编写可以通过HTTP公开的简单服务,简单来说,而且可以省去大部分controller和services的逻辑,因为Spring Data REST 已经为你都做好了,目前支持JPA、MongoDB、Neo4j、Solr、Cassand…

Ribbon负载均衡

Ribbon负载均衡 Ribbon是微服务架构中,可以作为负载均衡的技术实现,如下图所示 Ribbon负载均衡 1、消费者发起请求2、被负载均衡拦截器拦截3、将请求信息交给RibbonLoadBanlancerClient4、获取url的服务id5、DynamicServerListLoadBalancer拿到id去eur…

java基础巩固-宇宙第一AiYWM:为了维持生计,架构知识+分布式微服务+高并发高可用高性能知识序幕就此拉开(二:网关balabala)~整起

上集,在架构知识分布式微服务高并发高可用高性能知识序幕就此拉开(一:总览篇)中,说到了 当用户请求过来时,这个请求或者说URL先到服务调用端【咱们之前的项目中的Controller其实就算是一个服务调用方&#…

VMware ESXi 8.0 SLIC Unlocker 集成网卡驱动和 NVMe 驱动 (集成驱动版)

发布 ESXi 8.0 集成驱动版,在个人电脑上运行企业级工作负载 请访问原文链接:VMware ESXi 8.0 SLIC & Unlocker 集成网卡驱动和 NVMe 驱动 (集成驱动版),查看最新版。原创作品,转载请保留出处。 作者主页:www.sysi…

【Pytorch】第 3 章 :进行数值估计的蒙特卡洛方法

🔎大家好,我是Sonhhxg_柒,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流🔎 📝个人主页-Sonhhxg_柒的博客_CSDN博客 📃 🎁欢迎各位→点赞…

Java ssh框架 mysql实现的进销存管理系统源码+运行教程+文档

今天给大家演示一下一款由sshmysql实现的进销存管理系统,其中struts版本是struts2,这个系统的功能非常完善,简直可以说是牛逼,到了可以用于企业直接商用的地步,此外该项目还带有完整的论文,是Java学习者及广…

Spark 初识

文章目录Spark 初识Spark是什么Apache Spark演变为什么使用Spark全快Spark组件Spark CoreSpark SQLSpark StreamingSpark MLlibSpark GraphXSparkRpySparkspark 在数仓的应用总结Spark 初识 从今天开始我们进入数据仓库的查询引擎篇了,前面我们已经写了大量的文章介…

三分钟了解LAP编程框架

针对Java开发者的灵魂拷问: 1、梳理的流程,关键逻辑是否有遗漏,理解一致吗? 2、设计时,如何更方便的与产品沟通?原有的设计是否有不合理的?绘制的流程图大家都能理解吗? 3、测试时&a…

316页11万字AI赋能智慧水利大数据信息化平台建设和运营解决方案

第一章 系统综述 1.1 项目背景 1.2 系统概述 1.3 需求分析 1.3.1 中心管控需求 1.3.2 前端监测需求 1.4 建设目标 1.5 设计原则 1.6 设计依据 第二章 系统总体设计 2.1 总体设计思路 2.2 架构设计 2.2.1 逻辑架构 2.2.2 系统架构 2.3 关键技术应用 2.4 系统特色…

代码随想录刷题|LeetCode 647. 回文子串 516.最长回文子序列

647. 回文子串 题目链接:https://leetcode.cn/problems/palindromic-substrings/ 思路 动态规划思路 1、确定dp数组 布尔类型的dp[i][j]:表示区间范围[i,j] (注意是左闭右闭)的子串是否是回文子串,如果是dp[i][j]为tr…

【真的?】用 ChatGPT 写一篇 Python 翻译库博客,可以打 9 分

今天来个大的实践项目,用 ChatGPT 写一篇博客,冲击一下热榜! 从零开始玩 ChatGPT⛳️ ChatGPT 亮点⛳️ 账号篇⛳️ 第一次使用⛳️ 用 Python 实现一个英汉互译的小程序⛳️ googletrans 库核心用法⛳️ 再补充一些知识点⛳️ googletrans 和…

功率放大电路和电压放大电路的区别是什么意思

功率放大电路和电压放大电路都属于模拟电路,是工程师日常经常用到的比较常见的模拟电路,很多小白工程师对于功率放大电路和电压放大电路的区别都很好奇,下面就来看看区别有哪些。 图:功率放大电路与电压放大电路对比 1、功能和基本…

docker之网络配置

目录一、网络模式1.bridge模式(默认模式)2.host模式3.初识网络模式二、bridge模式三、host模式四、自定义网络一、网络模式 Docker在创建容器时有四种网络模式:bridge/host/container/none,bridge为默认不需要用–net去指定,其他三种模式需要…

微服务框架 SpringCloud微服务架构 19 文档操作 19.2 修改文档

微服务框架 【SpringCloudRabbitMQDockerRedis搜索分布式,系统详解springcloud微服务技术栈课程|黑马程序员Java微服务】 SpringCloud微服务架构 文章目录微服务框架SpringCloud微服务架构19 文档操作19.2 修改文档19.2.1 修改文档19.2.2 总结19 文档操作 19.2 修…

推荐一款超级好用的工具:uTools详解使用

介绍 uTools 是什么?下载并安装uTools 能做什么?一切皆插件超级面板 uTools 是什么? uTools 是一个极简、插件化、跨平台的现代桌面软件。通过自由选配丰富的插件,打造你得心应手的工具集合。 通过快捷键(默认 alt…

红队隧道应用篇之CS正反向连接突破内网(二)

正向连接 环境拓扑图 操作步骤 在CS客户端新建一个TCP协议的监听, 监听端口为4444 创建无状态木马(Windows Executable(S)), 选择上述建立的TCP监听器, 随后将无状态木马放到不出网的内网主机中去运行, 运行后内网主机就会监听本机的4444端口 在web服务器的beacon命令行输入:…

EMQX安装与使用

EMQX文档:https://www.emqx.io/docs/zh/v5.0/ 1.安装 https://www.emqx.io/zh/downloads 找到自己合适的平台和版本安装 ①:后台启动 emqx start启动成功后可以使用 emqx ping 命令检测节点运行状态,返回 pong 则表示正常运行: …

pmp 证书到底有什么用处?

PMP 证书最重要的两个用处:一个是岗位招聘要求,一个是项目招标要求。 一、PMP证书的应用 1、PMP 证书的敲门砖作用 前面说的,PMP 作为项目管理领域的一个权威公认证书,很多行业要求项目管理岗位人才都会加一条"具备PMP 等证…