神经网络数据的批量归一化—BN

news2024/9/9 4:47:55

文章目录

  • 1、简介
  • 2、批量归一化公式
  • 3、BN 层的接口
  • 4、代码示例
  • 5、小结

🍃作者介绍:双非本科大三网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发、数据结构和算法,初步涉猎人工智能和前端开发。
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

1、简介

在神经网络的搭建过程中,Batch Normalization (批量归一化)是经常使用一个网络层,其主要的作用是控制数据的分布,加快网络的收敛。
我们知道,神经网络的学习其实在学习数据的分布,随着网络的深度增加、网络复杂度增加,一般流经网络的数据都是一个 mini batch,每个 mini batch 之间的数据分布变化非常剧烈,这就使得网络参数频繁的进行大的调整以适应流经网络的不同分布的数据,给模型训练带来非常大的不稳定性,使得模型难以收敛。
如果我们对每一个 mini batch 的数据进行标准化之后,数据分布就变得稳定,参数的梯度变化也变得稳定,有助于加快模型的收敛。

2、批量归一化公式

f ( x ) = λ ⋅ x − E ( x ) Var ( x ) + ϵ + β f(x) = \lambda \cdot \frac{x - \mathbb{E}(x)}{\sqrt{\text{Var}(x) + \epsilon}} + \beta f(x)=λVar(x)+ϵ xE(x)+β

  1. λ 和 β 是可学习的参数,它相当于对标准化后的值做了一个线性变换,λ 为系数,β 为偏置;
  2. ϵ \epsilon ϵ 通常指为 1e-5,避免分母为 0;
  3. E(x) 表示变量的均值;
  4. Var(x) 表示变量的方差;

BN层是指“批量归一化层”(Batch Normalization Layer),它是在神经网络中用来进行批量归一化操作的层。
批量归一化层的主要目的是通过归一化每一层的输入,使得每一层的输入分布更加稳定,从而加速训练过程并提高模型性能。

数据在经过 BN 层之后,无论数据以前的分布是什么,都会被归一化成均值为 β,标准差为 γ 的分布。
注意:

  1. BN 层不会改变输入数据的维度,只改变输入数据的的分布
  2. 在实际使用过程中,BN 常常和卷积神经网络结合使用,卷积层的输出结果后接 BN 层。

3、BN 层的接口

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)
  1. 由于每次使用的 mini batch 的数据集,所以 BN 使用移动加权平均来近似计算均值和方差,而 momentum 参数则调节移动加权平均值的计算;
  2. affine = False 表示 γ=1,β=0,反之,则表示 γ 和 β 要进行学习;
  3. BatchNorm2d 适用于输入的数据为 4D,输入数据的形状 [N,C,H,W]

其中:N 表示批次,C 代表通道数,H 代表高度,W 代表宽度
由于每次输入到网络中的时小批量的样本,我们使用指数加权平均来近似表示整体的样本的均值和方法,其更新公式如下:

running_mean = momentum * running_mean + (1.0 – momentum) * batch_mean
running_var = momentum * running_var + (1.0 – momentum) * batch_var

上面的式子中,batch_mean 和 batch_var 表示当前批次的均值和方差。而 running_mean 和 running_var 是近似的整体的均值和方差的表示。当我们进行评估时,可以使用该均值和方差对输入数据进行归一化。‘

4、代码示例

(代码即注释)

# -*- coding: utf-8 -*-
# @Author: CSDN@逐梦苍穹
# @Time: 2024/7/29 17:11

import torch  # 导入PyTorch库
import torch.nn as nn  # 导入PyTorch中的神经网络模块
import torch.optim as optim  # 导入PyTorch中的优化器模块
import torch.nn.functional as F  # 导入PyTorch中的函数模块

# 设置随机种子以确保结果可重复
torch.manual_seed(0)


# 定义一个简单的卷积神经网络模型,使用批量归一化
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # TODO 定义卷积层1,输入通道数为3,输出通道数为16,卷积核大小为3x3,步长为1,填充为1
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  
        self.bn1 = nn.BatchNorm2d(16)  # 对卷积层1进行批量归一化
        # TODO 定义卷积层2,输入通道数为16,输出通道数为32,卷积核大小为3x3,步长为1,填充为1
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  
        self.bn2 = nn.BatchNorm2d(32)  # 对卷积层2进行批量归一化
        self.fc1 = nn.Linear(32 * 8 * 8, 10)  # 定义全连接层,输入尺寸为32*8*8,输出尺寸为10

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # 卷积层1后进行批量归一化,然后通过ReLU激活函数
        x = F.max_pool2d(x, 2)  # 最大池化层,池化核大小为2x2
        x = F.relu(self.bn2(self.conv2(x)))  # 卷积层2后进行批量归一化,然后通过ReLU激活函数
        x = F.max_pool2d(x, 2)  # 最大池化层,池化核大小为2x2
        x = x.view(x.size(0), -1)  # 将x展平为一维
        x = self.fc1(x)  # 通过全连接层
        return x


# 模拟训练数据
batch_size = 16  # 批量大小为16
num_channels = 3  # 通道数为3
height, width = 32, 32  # 图像高度和宽度为32
num_classes = 10  # 类别数为10

# 创建一个简单的数据集
x = torch.randn(batch_size, num_channels, height, width)  # 生成随机输入张量,形状为(batch_size, num_channels, height, width)
y = torch.randint(0, num_classes, (batch_size,))  # 生成随机标签,形状为(batch_size,)

# 初始化模型、损失函数和优化器
model = SimpleCNN()  # 实例化模型
criterion = nn.CrossEntropyLoss()  # 定义交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 定义优化器,使用随机梯度下降法,学习率为0.01,动量为0.9

# 前向传播
outputs = model(x)  # 计算模型输出
loss = criterion(outputs, y)  # 计算损失

# 反向传播和优化
optimizer.zero_grad()  # 清零梯度
loss.backward()  # 反向传播计算梯度
optimizer.step()  # 更新模型参数

# 打印损失和批量归一化层的运行均值和方差
print(f'Loss: {loss.item()}')  # 打印损失值
print(f'Running mean of first BN layer: {model.bn1.running_mean}')  # 打印第一层批量归一化的运行均值
print(f'Running var of first BN layer: {model.bn1.running_var}')  # 打印第一层批量归一化的运行方差
print(f'Running mean of second BN layer: {model.bn2.running_mean}')  # 打印第二层批量归一化的运行均值
print(f'Running var of second BN layer: {model.bn2.running_var}')  # 打印第二层批量归一化的运行方差

运行结果:

"C:\Program Files\Python39\python.exe" D:\Python\AI\神经网络\10-批量归一化.py 
Loss: 2.573316812515259
Running mean of first BN layer: tensor([-0.0056,  0.0143, -0.0037, -0.0068,  0.0059,  0.0054,  0.0151,  0.0005,
        -0.0021,  0.0096, -0.0178, -0.0070, -0.0003, -0.0143,  0.0096, -0.0145])
Running var of first BN layer: tensor([0.9206, 0.9367, 0.9336, 0.9363, 0.9488, 0.9425, 0.9182, 0.9328, 0.9261,
        0.9360, 0.9302, 0.9286, 0.9404, 0.9280, 0.9289, 0.9302])
Running mean of second BN layer: tensor([ 0.0248,  0.0364, -0.0138,  0.0471, -0.0746,  0.0464,  0.0405,  0.0621,
         0.0327, -0.0391, -0.0227,  0.0368, -0.0107,  0.0006,  0.0374, -0.0079,
        -0.0669,  0.1178, -0.0184,  0.1392, -0.0039, -0.0067,  0.1190, -0.0364,
         0.0417, -0.0583, -0.0099,  0.0407, -0.0307,  0.1127, -0.0287, -0.0231])
Running var of second BN layer: tensor([0.9157, 0.9151, 0.9154, 0.9156, 0.9185, 0.9150, 0.9183, 0.9199, 0.9129,
        0.9151, 0.9160, 0.9133, 0.9173, 0.9184, 0.9150, 0.9123, 0.9164, 0.9212,
        0.9128, 0.9205, 0.9133, 0.9138, 0.9194, 0.9156, 0.9149, 0.9158, 0.9163,
        0.9151, 0.9144, 0.9187, 0.9223, 0.9124])

Process finished with exit code 0

5、小结

批量归一化层,该层的作用主要是用来控制每层数据的流动时的均值和方差,防止训练过程出现剧烈的波动,模型难以收敛,或者收敛较慢。
批量归一化层在计算机视觉领域使用较多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961313.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java】字符串String类(011)

目录 ♦️API和API帮助文档 ♦️创建String 🎏直接赋值类 🎏new类 🐡空参类 构造方法: 举例代码: 🐡有参类 构造方法: 举例代码: 🐡字符数组类 构造方法&…

如何借助逻辑数据编织平台实现“数据优先堆栈( DFS )”

一、什么是面向“数据优先”的数据研发平台? 企业在数字化转型的浪潮中,愈发认知到数据作为核心战略资产的重要性。然而,要充分利用数据的价值并非易事。一方面,企业需要投入大量资源来建设和维护复杂的数据基础设施;另…

ref函数

Vue2 中的ref 首先我们回顾一下 Vue2 中的 ref。 ref 被用来给元素或子组件注册引用信息。引用信息将会注册在父组件的 $refs 对象上。如果在普通的 DOM 元素上使用,引用指向的就是 DOM 元素;如果用在子组件上,引用就指向组件实例&#xff1…

计算机基础(day1)

1.什么是内存泄漏?什么是内存溢出?二者有什么区别? 2.了解的操作系统有哪些? Windows,Unix,Linux,Mac 3. 什么是局域网,广域网? 4.10M 兆宽带是什么意思?理论…

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 整数数组按个位数字排序(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆Coding ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 👏 感谢大家的订阅➕ 和 喜欢💗 🍿 最新华为OD机试D卷目录,全、新、准,题目覆盖率达 95% 以上,支持题目在线评测,专栏文章质量平均 93 分 最新华为OD机试目录…

使用大型语言模型进行文档解析

动机 多年来,正则表达式一直是我解析文档的首选工具,我相信对于许多技术人员和行业也是如此。尽管正则表达式在某些情况下非常强大,但它们常常在面对真实世界文档的复杂性和多样性时缺少灵活性。 另一方面,大型语言模型提供了一…

Mysql输出今年1月至当前月份日期序列

#今日2024-07-29SELECTDATE_FORMAT( DATE_ADD( NOW(), INTERVAL -(CAST( help_topic_id AS SIGNED INTEGER )) MONTH ), %Y-%m ) monthsFROMmysql.help_topicWHEREhelp_topic_id < TIMESTAMPDIFF(MONTH, CONCAT(DATE_FORMAT(CURDATE(), "%Y-01-01")),CONCAT(STR_…

《动手做科研 》| 03. 如何阅读人工智能研究论文

地址链接:《动手做科研》03. 如何阅读人工智能研究论文 导读: 在刚迈入科研时&#xff0c;人人都说读论文很重要&#xff0c;但是很少有人能完整地教你应该如何读论文。论文不仅揭示了行业的最新进展和趋势&#xff0c;而且为我们提供了改进技术和解决复杂问题的思路。然而&…

你知道缓存的这个问题到底把多少程序员坑惨了吗?

在现代系统中&#xff0c;缓存可以极大地提升性能&#xff0c;减少数据库的压力。 然而&#xff0c;一旦缓存和数据库的数据不一致&#xff0c;就会引发各种诡异的问题。 我们来看看几种常见的解决缓存与数据库不一致的方案&#xff0c;每种方案都有各自的优缺点 先更新缓存&…

探索NSL-KDD数据集:入侵检测的起点

引言 在信息安全的世界里&#xff0c;数据集是我们最宝贵的资源。就像厨师离不开食材&#xff0c;数据科学家也离不开数据集。对于入侵检测系统&#xff08;IDS&#xff09;而言&#xff0c;NSL-KDD数据集无疑是一个经典的选择。今天&#xff0c;我们将深入探讨这个数据集&…

Python数据分析案例56——灰色预测、指数平滑预测人口数量,死亡率,出生率等

案例背景 时间序列的预测现在都是用神经网络&#xff0c;但是对于100条以内的小数据集&#xff0c;神经网络&#xff0c;机器学习这种方法效果表现不太好。 所以还是需要用上一些传统的统计学方法来进行预测&#xff0c;本次就使用灰色预测&#xff0c;指数平滑两大方法来分别…

MySQL学习(16):视图

视图是一种虚拟临时表&#xff0c;并不真正存储数据&#xff0c;它的作用就是方便用户查看实际表的内容或者部分内容 1.视图的使用语法 &#xff08;1&#xff09;创建 create view 视图名称 as select语句; #视图形成的虚拟表就来自于select语句所查询的实际表&#xff0c;…

突破•指针四

听说这是目录哦 函数指针数组&#x1fae7;用途&#xff1a;转移表 回调函数&#x1fae7;能量站&#x1f61a; 函数指针数组&#x1fae7; 函数指针数组是存放函数地址的数组&#xff0c;例如int (*parr[5])()中parr先和[]结合&#xff0c;说明parr是可以存放5个函数地址【元…

IT运维必备神器!PsShutdown,定时关机重启一键搞定!

嘿&#xff0c;各位技术小能手们&#xff0c;小江湖今天要给大家安利一个宝贝——PsShutdown&#xff01;这可不是一般的关机小工具哦&#xff1b;当你坐在电脑前&#xff0c;手指轻轻敲几下键盘&#xff0c;就能实现定时任务&#xff0c;无论是关机、重启&#xff0c;还是注销…

Python 爬虫入门(四):使用 pandas 处理和分析数据 「详细介绍」

Python 爬虫入门&#xff08;四&#xff09;&#xff1a;使用 pandas 处理和分析数据 「详细介绍」 前言1. pandas简介1.1 什么是pandas?1.2 为什么要使用pandas?1.3 安装 Pandas 2. pandas的核心概念2.1 Series2.2 DataFrame2.3 索引 3. 数据导入和导出3.1 从CSV文件读取数据…

uniapp app跳小程序详细配置

应用场景 app跳微信小程序&#xff0c;支付等 前提配置 1.1微信开放平台申请移动应用 1.2关键&#xff1a;开放平台的移动应用的app的包名和签名必须和uniapp app的包名一致 1.3查看unaipp app的包的签名 下载工具&#xff1a;GenSignature&#xff0c;模拟器安装工具 ht…

iframe嵌套项目后,接口跳出登入页面(会出现画中画的场景)

iframe嵌套项目后&#xff0c;接口跳出登入页面&#xff08;会出现画中画的场景&#xff09; JavaScript 跳出iframe框架 window.top top 属性返回最顶层的先辈窗口。该属性返回对一个顶级窗口的只读引用。如果窗口本身就是一个顶级窗口&#xff0c;top 属性存放对窗口自身的…

使用DTW算法简单实现曲线的相似度计算

相对接近产品交付形态的实现&#xff1a;基于DTW距离的KNN算法实现股票高相似筛选案例-CSDN博客 一、问题背景和思路 问题背景&#xff1a;如果你有历史股票的K线图&#xff0c;怎么从众多股票K线图中提取出TopN相似的几支股票&#xff0c;用来提供给投资者或专家做分析、决策…

任意空间平面点云旋转至与水平面平行(python)

1、背景介绍 将三维空间中位于任意平面上的点云数据&#xff0c;通过一系列的坐标变换&#xff08;平移旋转&#xff09;&#xff0c;使其投影到与XOY平面平行&#xff0c;同时点云形状保持不变。具体效果如下&#xff0c;对于原始点集&#xff08;蓝色点集&#xff09;&#x…

关于 AGGLIGATOR(猛禽)网络宽频聚合器

AGGLIGATOR 是一个用于多个链路UDP/IP带宽聚合的工具软件&#xff0c;类似MTCP的作用&#xff0c;不过它是针对UDP/IP宽频聚合的。 举个例子&#xff1a; 中国大陆有三台公网服务器&#xff0c;中国香港有一台大带宽服务器。 那么&#xff1a; AGGLIGATOR 允许中国大陆的客户…