动手学深度学习28 批量归一化

news2025/1/14 0:52:15

动手学深度学习28 批量归一化

  • 批量归一化
  • 代码
    • 从零实现
    • 调包简洁实现
  • QA

https://www.bilibili.com/video/BV1X44y1r77r/?spm_id_from=autoNext&vd_source=eb04c9a33e87ceba9c9a2e5f09752ef8

批量归一化

在这里插入图片描述
n个比较小的数相乘,值会越来越小。
批量归一化:在改变底部网络特征的时候,避免顶部网络不断的训练。
尝试把小批量在不同层输入数据的均值和方差固定住。
B:小批量的大小。
在这里插入图片描述
批量归一化带来的东西:

  1. 是个线性变换 让均值方差比较好,让变化不那么剧烈

  2. 把 批量大小* 高 * 宽=样本大小,那通道维就是特征维, 1*1的卷积也是这样做的。
    在这里插入图片描述
    认为从数据中计算的均值方差是噪音。是控制模型复杂度的一个方法。
    在这里插入图片描述

  3. 偏移和缩放拟合均值和方差。

  4. 加速收敛,允许用更大的学习率做训练。把每一层的输入均值方差都在一个层面,学习率可以用较大的。不会因为学习率过大导致梯度爆炸【上层网络】,学习率过小而无法收敛的问题【底层网络】。
    在这里插入图片描述

代码

从零实现

import torch
from torch import nn
from d2l import torch as d2l


# gamma, beta 可以学的
# moving_mean, moving_var 整个数据集的【全局】均值方差 推理时用
# eps 一个很小的值,避免除0的东西, 一般固定值尽量不要改动 常见设定 eps=1e-5
# momentum 用于更新moving_mean, moving_var的变量 0.9
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
  # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
  if not torch.is_grad_enabled():
    # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
    # 不用批量而是用全局的均值方差,是因为一个批量可能是一个样本没有均值方差。
    # 推理用全局的mean var
    X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
  else:
  	# 输入shape是二维或者是四维
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
      # 使用全连接层的情况,计算特征维上的均值和方差
      mean = X.mean(dim=0) # 按行求均值  每一列求均值  均值是一个1*n的行向量 
      var = ((X - mean) ** 2).mean(dim=0) # 按行求方差 方差是一个1*n的行向量
    else:
      # 二维卷积 输入shape 4维
      # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
      # 这里我们需要保持X的形状以便后面可以做广播运算
      # 按照通道数求均值 把每一个通道的所有批量和所有高宽像素都用来求均值 keepdim=True 输出是1*n*1*1四维向量
      mean = X.mean(dim=(0, 2, 3), keepdim=True)
      # 输出是1*n*1*1四维向量
      var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
      # 当前X求出的 mean  var
    # 训练模式下,用当前的均值和方差做标准化
    # 训练用小批量的均值和方差
    X_hat = (X - mean) / torch.sqrt(var + eps)
    # 训练中 更新移动平均的均值和方差
    # smooth计算真实均值方差的方法  无限逼近真实数据的均值方差
    moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
    moving_var = momentum * moving_var + (1.0 - momentum) * var
  Y = gamma * X_hat + beta  # 缩放和移位
  # .data 返回tensor
  return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
  # num_features:完全连接层的输出数量或卷积层的输出通道数。
  # num_dims:2表示完全连接层,4表示卷积层
  def __init__(self, num_features, num_dims):
    super().__init__()
    if num_dims == 2:
      shape = (1, num_features)
    else:
      shape = (1, num_features, 1, 1)
    # 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
    # 放在nn.Parameter里面 需要被迭代参数
    self.gamma = nn.Parameter(torch.ones(shape))  # 拟合的方差 不能为0 做乘法为0的话乘积为0训练不动了
    self.beta = nn.Parameter(torch.zeros(shape)) # 拟合的均值
    # 非模型参数的变量初始化为0和1  # 不用背迭代
    self.moving_mean = torch.zeros(shape)
    self.moving_var = torch.ones(shape)
  
  def forward(self, X):
    # 如果X不在内存上,将moving_mean和moving_var
    # 复制到X所在显存上
    if self.moving_mean.device != X.device:
      # 这两个参数没有放到nn.Parameter中 要手动挪下数据
      self.moving_mean = self.moving_mean.to(X.device)
      self.moving_var = self.moving_var.to(X.device)
    # 保存更新过的moving_mean和moving_var
    # 在不同框架下实现的时候,注意eps的设定
    Y, self.moving_mean, self.moving_var = batch_norm(
      X, self.gamma, self.beta, self.moving_mean,
      self.moving_var, eps=1e-5, momentum=0.9)
    return Y

# LeNet模型
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
    nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
    nn.Linear(84, 10))

# 训练网络
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

# gamma学出来的是什么样子的
print(net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,)))  
# 当网络学的很深的时候 可以对比最底层网络和最上层网络均值方差有什么区别
loss 0.263, train acc 0.903, test acc 0.834
17653.4 examples/sec on cuda:0
tensor([4.9547, 2.1918, 2.7570, 2.9478, 1.5544, 3.2417], device='cuda:0',
       grad_fn=<ViewBackward0>) tensor([-2.5122,  1.4937,  0.9189, -2.8232, -2.1574, -1.6293], device='cuda:0',
       grad_fn=<ViewBackward0>)

在这里插入图片描述

用BatchNorm收敛更快。
一样的数据,加不加BN可能效果一样。
当网络学的很深的时候 可以对比最底层网络和最上层网络均值方差有什么区别

调包简洁实现

import torch
from torch import nn
from d2l import torch as d2l

net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5), 
    nn.BatchNorm2d(6), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5),
    nn.BatchNorm2d(16), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
    nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
    nn.Linear(84, 10))

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.277, train acc 0.898, test acc 0.874
27529.3 examples/sec on cuda:0

测试精度震荡的原因:因为学习率大,把学习率调低就行了,再加上lenet是小模型,容易震荡
在这里插入图片描述
在这里插入图片描述

QA

1 xavier 的归一化 和BN层本质上没有区别。 选取好的初始化,在训练初始的时候比较稳定。BN在训练中数值都比较稳定
2 权重衰退更新时把每个权重除以一个小值,使得权重变得比较小,BN不对权重做太多处理
4 BN可以用在MLP中,主要用在比较深的网络中。
7 BN框架实现 默认第二维度是通道数维度,把其他所有维度都拉平计算均值方差
8 BN也是线性变换 和加一个线性层没什么区别。线性层不一定学到自己的想要的东西,不做的的话可能数值不稳定,可能不能训练到比较好的值域里面。
9 加BN后收敛加快, BN使得梯度可以变大一点 + 可以使用更大的lr = 收敛加快
10 重要程度:epoch数【可以选大 收敛结束可以中途停掉】 batchSize【根据内存调 要选的合适】 学习率【调完batchsize调】 相互相关 框架【看个人习惯】
11 XXnormalization 特别多。BN 是feature维度对样本做normalization。 layerNormalization在每一个样本里面的特征做normalization,通常用于比较大的网络。
12 BN一般不用在激活函数之后,因为是输入和输出的线性相关,可以尝试看看效果。
14 增加batchsize 看看每秒处理的样本数,增到一定程度。。。。看看效果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1797053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React+TS前台项目实战(一)-- 项目初始化配置及开此系列的初衷

文章目录 前言一、初始化项目二、基础配置1. 项目目录及说明如下2. TS版本使用Craco需注意 总结 前言 前面 后台管理系统实战 系列教程暂时告一段落了&#xff0c;想了解全局各种配置的可自行查看。本次教程将重点介绍React前台项目的实操&#xff0c;关于具体的配置&#xff…

企业微信hook接口协议,ipad协议http,发送CDN语音消息

发送CDN语音消息 参数名必选类型说明uuid是String每个实例的唯一标识&#xff0c;根据uuid操作具体企业微信send_userid是long要发送的人或群idisRoom是bool是否是群消息 请求示例 {"uuid":"1753cdff-0501-42fe-bb5a-2a4b9629f7fb","send_userid&q…

耐用好用充电宝有哪些?畅销排行榜前四款充电宝推荐

在日常生活中&#xff0c;一款耐用且好用的充电宝是我们出行必备的利器&#xff0c;它可以为我们的手机、平板等设备提供持续的电力支持。然而&#xff0c;在市面上琳琅满目的充电宝品牌中&#xff0c;究竟哪些才是真正耐用又好用的选择&#xff1f;为了帮助大家更好地了解市场…

MYSQL基础_01_数据库概述

第01章_数据库概述 1. 为什么要使用数据库 持久化(persistence)&#xff1a;把数据保存到可掉电式存储设备中以供之后使用。大多数情况下&#xff0c;特别是企业级应用&#xff0c;数据持久化意味着将内存中的数据保存到硬盘上加以”固化”&#xff0c;而持久化的实现过程大多…

国产低功率立体声音频编解码器CJC8988Pin to Pin替代WM8988

CJC8988是一个低功率&#xff0c;高质量的立体声编解码器&#xff0c;和WM8988外围电路一致&#xff0c;管脚兼容&#xff1b;可以直接Pin to Pin替代WM8988&#xff0c;CJC8988是参考WM8988设计的&#xff0c;芯片可以直接替换&#xff1b;IIC格式一致&#xff0c;寄存器定义一…

笔记-Python pip配置国内源

众所周知&#xff0c;Python使用pip方法安装第三方包时&#xff0c;需要从 https://pypi.org/ 资源库中下载&#xff0c;但是会面临下载速度慢&#xff0c;甚至无法下载的尴尬&#xff0c;这时&#xff0c;你就需要知道配置一个国内源有多么重要了&#xff0c;通过一番摸索和尝…

Django里的Form组件

Form组件提供 自动生成HTML标签和半自动读取关联数据 (“半自动”是指还得需要自己手写输入数据进来)表单验证和错误提示 要想创建并使用该组件&#xff0c;操作步骤如下&#xff1a; 在 views.py 里创建类 # 在 views.py 文件里from django import formsclass AssetForm(fo…

杭州威雅学校2024-25学年招生简章

杭州市新学年招生政策现已公布&#xff0c;杭州威雅学校小学部、初中部招生工作全面开启。 招生简章与报名流程如下&#xff1a; Part.1 幼升小 01 招生对象 2017年9月1日至2018年8月31日出生满六周岁且符合杭州市萧山区小学招生条件的适龄儿童。 02 招生人数 新一年级…

CAPL如何发送一条UDP报文

UDP作为传输层协议,本身并不具有可靠性传输特点,所以不需要建立连接通道,可以直接发送数据。当然,前提是需要知道对方的通信端点,也就是IP地址和端口号。 端口号是传输层协议中最显著的特征,传输层根据它来确定上层绑定的应用程序,以达到把数据交给上层应用处理的目的。…

vsCode双击文件才能打开文件,单击文件只能预览?

解决&#xff1a; 1、打开设置 2、搜索workbench.editor.enablePreview 3、更改为不勾选状态 4、关闭设置 效果&#xff1a; 现在单击一个文件时&#xff0c;将会在编辑器中直接打开&#xff0c;而非是预览状态。

Flink SQL实践

环境准备 方式1&#xff1a;基于Standalone Flink集群的SQL Client 启动hadoop集群 [hadoopnode2 ~]$ start-cluster.sh [hadoopnode2 ~]$ sql-client.sh 使用Yarn Session启动Flink集群 [hadoopnode2 ~]$ start-cluster.sh [hadoopnode2 ~]$ sql-client.sh ... 省略若干…

2.Rust自动生成文件解析

目录 一、生成目录解析二、生成文件解析2.1 Cargo.toml2.2 main函数解析 一、生成目录解析 先使用cargo clean命令删除所有生成的文件&#xff0c;下图显示了目录结构和 main.rs文件 使用cargo new testrust时自动创建出名为testrust的Rust项目。内部主要包含一个src的源码文…

Unity开发——编辑器打包、3种方式加载AssetBundle资源

一、创建ab资源 &#xff08;一&#xff09;Unity资源设置ab格式 1、选中要打包成assetbundle的资源&#xff1b; 可以是图片&#xff0c;材质球&#xff0c;预制体等&#xff0c;这里方便展示用预制体打包设置展示&#xff1b; 2、AssetBundle面板说明 &#xff08;1&…

Android Uri转File path路径,Kotlin

Android Uri转File path路径&#xff0c;Kotlin /*** URI转化为file path路径*/private fun getFilePathFromURI(context: Context, contentURI: Uri): String? {val result: String?var cursor: Cursor? nulltry {cursor context.contentResolver.query(contentURI, null…

【Python机器学习】预处理对监督学习的作用

还是用cancer数据集&#xff0c;观察使用MinMaxScaler对学习SVC的作用。 首先&#xff0c;在原始数据上拟合SVC&#xff1a; cancerload_breast_cancer() X_train,X_test,y_train,y_testtrain_test_split(cancer.data,cancer.target,random_state0 ) svmSVC(C100) svm.fit(X_t…

【LeetCode】39.组合总和

组合总和 题目描述&#xff1a; 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 任意顺序 返回这些组合。 candidates 中的 同一个…

专业文章 | AIGC绘制:基于Stable Diffusion制作端午海报

AIGC全称为AI Generated Content&#xff08;人工智能生产内容&#xff09;&#xff0c;即基于生成对抗网络GAN、大型预训练模型等人工智能技术&#xff0c;通过寻找已有数据规律与适当泛化能力生成相关技术内容。简单来说&#xff0c;任何AI技术生成的内容都可以视为AIGC。 2…

Aurora 8b/10b协议(高速收发器十五)

点击进入高速收发器系列文章导航界面 前面几篇文章通过自定义PHY协议去实现高速收发器收发数据&#xff0c;一帧数据包括帧头、数据、帧尾等信息&#xff0c;在空闲的时候发送FLSR伪随机序列降低电磁干扰&#xff0c;并且每隔固定空闲时间发送一次逗号&#xff0c;用于接收端字…

(文章复现)基于主从博弈的售电商多元零售套餐设计与多级市场购电策略

参考文献&#xff1a; [1]潘虹锦,高红均,杨艳红,等.基于主从博弈的售电商多元零售套餐设计与多级市场购电策略[J].中国电机工程学报,2022,42(13):4785-4800. 1.摘要 随着电力市场改革的发展&#xff0c;如何制定吸引用户选择的多类型零售套餐成为提升售电商利润的研究重点。为…

大模型备案重点步骤详细说明

随着人工智能技术的发展&#xff0c;大模型在语音识别、图像处理、自然语言处理等领域应用日益广泛&#xff0c;为进一步保障和监管大模型技术应用&#xff0c;我国出台了《生成式人工智能服务管理暂行办法》&#xff0c;为大模型的合规提供了明确的法律框架。2024年4月2日&…