深度学习——批标准化Batch Normalization

news2024/11/24 2:33:40

什么是批标准化?

批标准化(Batch Normalization)是深度学习中常用的一种技术,旨在加速神经网络的训练过程并提高模型的收敛速度。
批标准化通过在神经网络的每一层中对输入数据进行标准化来实现。具体而言,对于每个输入样本,在每一层的前向传播过程中,都会计算其均值和方差,并使用批量内的均值和方差对输入进行标准化。标准化后的数据会经过缩放和平移操作,使得网络可以学习到适合当前任务的特定数据分布。这样做的好处包括:
1.收敛速度更快:批标准化有助于避免梯度消失和梯度爆炸问题,使得神经网络在训练过程中更快地收敛。
2.允许更高的学习率:标准化输入可以使学习率的选择更加宽松,使得学习过程更加稳定。
3.正则化作用:批标准化在一定程度上具有正则化的效果,有助于防止过拟合。
4.不那么依赖初始化:由于标准化的存在,对网络的初始权重设置并不像传统网络那样敏感,这简化了网络的初始化过程。

对比使用批标准化和不使用批标准化

import torch
from torch import nn
from torch.nn import init
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np

# 用于可复现
# torch.manual_seed(1)    # reproducible
# np.random.seed(1)

# Hyper parameters
# 样本点
N_SAMPLES = 2000
# 批大小
BATCH_SIZE = 64
# 轮次
EPOCH = 12
# 学习率
LR = 0.03
# 隐藏层层数
N_HIDDEN = 8
# 激活函数
ACTIVATION = torch.tanh
B_INIT = -0.2   # use a bad bias constant initializer

# training data
# 生成-7到10之间的N_SAMPLES个值的等差数列,并将其转化为一个二维列向量
x = np.linspace(-7, 10, N_SAMPLES)[:, np.newaxis]
# 生成一个均值为0,标准差为2的和x相同形状的噪声数据
noise = np.random.normal(0, 2, x.shape)
# 生成x对应的y值
y = np.square(x) - 5 + noise

# test data
test_x = np.linspace(-7, 10, 200)[:, np.newaxis]
noise = np.random.normal(0, 2, test_x.shape)
test_y = np.square(test_x) - 5 + noise

train_x = torch.from_numpy(x).float()
train_y = torch.from_numpy(y).float()
test_x = torch.from_numpy(test_x).float()
test_y = torch.from_numpy(test_y).float()

train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)

# show data
# plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
# plt.scatter(test_x.numpy(), test_y.numpy(), c='blue', s=50, alpha=0.2, label='test')
# plt.legend(loc='best')



class Net(nn.Module):
    def __init__(self, batch_normalization=False):
        super(Net, self).__init__()
        # 是否进行批标准化
        self.do_bn = batch_normalization
        # 全连接层的列表
        self.fcs = []
        # 批标准化层的列表
        self.bns = []
        self.bn_input = nn.BatchNorm1d(1, momentum=0.5)   # for input data

        for i in range(N_HIDDEN):               # build hidden layers and BN layers
            # 如果是第一层,输入神经元个数为1,其余为10个
            input_size = 1 if i == 0 else 10
            # 全连接层
            fc = nn.Linear(input_size, 10)
            # 将全连接层重新命名然后设置为类属性
            setattr(self, 'fc%i' % i, fc)       # IMPORTANT set layer to the Module
            # 对全连接层的参数进行初始化
            self._set_init(fc)                  # parameters initialization
            # 添加到列表中
            self.fcs.append(fc)
            if self.do_bn:
                bn = nn.BatchNorm1d(10, momentum=0.5)
                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module
                self.bns.append(bn)

        self.predict = nn.Linear(10, 1)         # output layer
        self._set_init(self.predict)            # parameters initialization

    def _set_init(self, layer):
        init.normal_(layer.weight, mean=0., std=.1)
        init.constant_(layer.bias, B_INIT)

    # 前向传播
    def forward(self, x):
        pre_activation = [x]
        if self.do_bn:
            x = self.bn_input(x)     # input batch normalization
        layer_input = [x]
        for i in range(N_HIDDEN):
            x = self.fcs[i](x)
            pre_activation.append(x)
            if self.do_bn: x = self.bns[i](x)   # batch normalization
            x = ACTIVATION(x)
            layer_input.append(x)
        out = self.predict(x)
        # 返回预测值、每个隐藏层的输入、激活函数的输出
        return out, layer_input, pre_activation


nets = [Net(batch_normalization=False), Net(batch_normalization=True)]

# print(*nets)    # print net architecture

# 优化器
opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]
# MSE作为损失函数
loss_func = torch.nn.MSELoss()


def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
    for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
        [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
        if i == 0:
            p_range = (-7, 10);the_range = (-7, 10)
        else:
            p_range = (-4, 4);the_range = (-1, 1)
        ax_pa.set_title('L' + str(i))
        ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
        ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
        for a in [ax_pa, ax, ax_pa_bn, ax_bn]: a.set_yticks(());a.set_xticks(())
        ax_pa_bn.set_xticks(p_range);ax_bn.set_xticks(the_range)
        axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')
    plt.pause(0.01)


if __name__ == "__main__":
    f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))
    # 开启动态绘制
    plt.ion()  # something about plotting
    plt.show()

    # training
    losses = [[], []]  # recode loss for two networks

    for epoch in range(EPOCH):
        print('Epoch: ', epoch)
        layer_inputs, pre_acts = [], []
        # 训练两个网络
        for net, l in zip(nets, losses):
            net.eval()              # set eval mode to fix moving_mean and moving_var
            pred, layer_input, pre_act = net(test_x)
            l.append(loss_func(pred, test_y).data.item())
            layer_inputs.append(layer_input)
            pre_acts.append(pre_act)
            net.train()             # free moving_mean and moving_var
        plot_histogram(*layer_inputs, *pre_acts)     # plot histogram

        for step, (b_x, b_y) in enumerate(train_loader):
            for net, opt in zip(nets, opts):     # train for each network
                # 获取到预测值
                pred, _, _ = net(b_x)
                # 计算loss
                loss = loss_func(pred, b_y)
                # 梯度清零
                opt.zero_grad()
                # 误差反向传播
                loss.backward()
                # 逐步优化网络参数
                opt.step()    # it will also learns the parameters in Batch Normalization

    # 关闭动态绘制
    plt.ioff()

    # plot training loss
    # 绘制loss图
    plt.figure(2)
    plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
    plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
    plt.xlabel('step')
    plt.ylabel('test loss')
    plt.ylim((0, 2000))
    plt.legend(loc='best')

    # evaluation
    # set net to eval mode to freeze the parameters in batch normalization layers
    [net.eval() for net in nets]    # set eval mode to fix moving_mean and moving_var
    preds = [net(test_x)[0] for net in nets]
    plt.figure(3)
    # 测试拟合效果
    plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')
    plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')
    plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')
    plt.legend(loc='best')
    plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/784282.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

我在VScode学Python(Python函数,Python模块导入)

我的个人博客主页:如果’真能转义1️⃣说1️⃣的博客主页 (1)关于Python基本语法学习---->可以参考我的这篇博客《我在VScode学Python》 (2)pip是必须的在我们学习python这门语言的过程中Python ---->&a…

fl studio 20如何设置中文汉化汇总及flstudio21水果language选项中文设置方法

fl studio这是一个编曲软件,它有中文和英文两种语言供大家选择,对我们来说,中文版肯定更方便。fl studio如何设置中文?事实上,只需在设置中切换中文即可。 我们一起 fl studio 20如何设置中文一些方法 一、fl studio手…

Angular:动态依赖注入和静态依赖注入

问题描述: 自己写的服务依赖注入到组件时候是直接在构造器内初始化的。 直到看见代码中某大哥写的 private injector: Injector 动态依赖注入和静态依赖注入 在 Angular 中,使用构造函数注入的方式将服务注入到组件中是一种静态依赖注入的方式。这种方…

docker中搭建lnmp

目录 一:项目环境 1、主机ip需求 2、 任务需求 二:多级构建Dockerfile实验部署 lnmp 1、先部署一个有所有依赖包的镜像 2、搭建nginx 3、搭建mysql 4、搭建php 三:一级构建安装lnmp 1、构建自定义docker网络 2、构建nginx容器&#x…

办公室安全升级,如何保障人身财产安全?

视频监控,一种常见的安全措施,以监视和记录办公室内的活动。这项技术为企业提供了许多优势,包括保障员工和财产安全、帮助调查犯罪事件、提高业务管理效率以及应对突发事件。 因此,在合理范围内应用视频监控,将为企业提…

LINUX中的myaql(一)安装

目录 前言 一、概述 二、数据库类型 三、数据库模型 四、MYSQL的安装 (一)yum安装MYSQL (二)rpm安装MYSQL 五、MYSQL本地登录 rpm安装MYSQL本地登录 六、重置密码 总结 前言 MySQL是一种常用的开源关系型数据库管理系统&#xff…

泛微OA客戶管理融合呼叫中心系统功能

泛微OA,全程数字化运营平台。用户因业务管理需要,crm客户管理流程中,需要通过语音与客户沟通,对话务沟通过程实现流程和过程化管理。 泛微OA中,无需过多编程代码。通过呼叫中心接口快速开发,实现点击拨打&…

【计算机网络】第 3 课 - 计算机网络体系结构

欢迎来到博主 Apeiron 的博客,祝您旅程愉快 ! 时止则止,时行则行。动静不失其时,其道光明。 目录 1、常见的计算机网络体系结构 2、计算机网络体系结构分层的必要性 2.1、物理层 2.2、数据链路层 2.3、网路层 2.4、运输层 2…

融合正余弦和柯西变异的麻雀搜索算法优化CNN-BiLSTM,实现多输入单输出预测,MATLAB代码...

上期作者推出的融合正余弦和柯西变异的麻雀优化算法,效果着实不错,今天就用它来优化一下CNN-BiLSTM。CNN-BiLSTM的流程:将训练集数据输入CNN模型中,通过CNN的卷积层和池化 层的构建,用来特征提取,再经过BiL…

(20)操纵杆或游戏手柄

文章目录 前言 20.1 你将需要什么 20.2 校准 20.3 用任务规划器进行设置 20.4 飞行前测试控制装置 20.5 测试失控保护 20.6 减少控制的滞后性 前言 本文解释了如何用操纵杆或游戏手柄控制你的飞行器,使用任务计划器向飞行器发送"RC Override"消息…

【C++基础(六)】类和对象(中) --构造,析构函数

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:C初阶之路⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习C   🔝🔝 类和对象-中 1. 前言2. 构造函数3. 构造函数的特性4…

Opencv 细节补充

1.分辨率的解释 •像素:像素是分辨率的单位。像素是构成位图图像最基本的单元,每个像素都有自己的颜色。 •分辨率(解析度): a) 图像分辨率就是单位英寸内的像素点数。单位为PPI(Pixels Per Inch) b) PPI表示的是每英…

从零开始 Spring Cloud 7:Gateway

从零开始 Spring Cloud 7:Gateway 图源:laiketui.com Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2.0 和 Project Reactor 等响应式编程和事件流技术开发的网关,它旨…

系统集成|第三章(笔记)

目录 第三章 系统集成专业技术3.1 信息系统建设3.1.1 信息系统3.1.2 信息系统集成 3.2 信息系统设计3.3 软件工程3.4 面向对象系统分析与设计3.5 软件架构3.5.1 软件的架构模式3.5.2 软件中间件 3.6 典型应用集成技术3.6.1 数据库与数据仓库技术3.6.2 Web Services 技术3.6.3 J…

区间预测 | MATLAB实现基于QRF随机森林分位数回归时间序列区间预测模型

区间预测 | MATLAB实现基于QRF随机森林分位数回归时间序列区间预测模型 目录 区间预测 | MATLAB实现基于QRF随机森林分位数回归时间序列区间预测模型效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现基于QRF随机森林分位数回归时间序列区间预测模型&#xff1…

SpringCloud学习路线(10)——分布式搜索ElasticSeach基础

一、初识ES (一)概念: ES是一款开源搜索引擎,结合数据可视化【Kibana】、数据抓取【Logstash、Beats】共同集成为ELK(Elastic Stack),ELK被广泛应用于日志数据分析和实时监控等领域&#xff0…

java——继承

🎄🎄🎄继承 🍆🍆继承的概念: 继承(inheritance)机制:是面向对象程序设计使代码可以复用的最重要的手段,它允许程序员在保持原有类(父类)特性的基础上进行扩展…

挑战css基础面试题

挑战css基础面试题一,看看你能做出来吗 文章目录 前言一、盒模型二、如何实现一个最大的正方形三、文本一行水平居中,多行居左四、画一个三角形五、BFC理解六、两栏布局,左边固定,右边自适应,左右不重叠最后 前言 本片…

springboot快速整合腾讯云COS对象存储

1、导入相关依赖 <!--腾讯云COS--><dependency><groupId>com.tencentcloudapi</groupId><artifactId>tencentcloud-sdk-java</artifactId><version>3.0.1</version></dependency><dependency><groupId>com…