Batch Normalization和Layer Normalization和Group normalization

news2025/1/18 9:51:43

文章目录

  • 前言
  • 一、Group normalization
  • 二、批量规范化(Batch Normalization)
  • 三、层规范化(Layer Normalization)


前言

  批量规范化和层规范化在神经网络中的每个批次或每个层上进行规范化,而GroupNorm将特征分成多个组,并在每个组内进行规范化。这种规范化技术使得每个组内的特征具有相同的均值和方差,从而减少了特征之间的相关性。通常,组的大小是一个超参数,可以手动设置或自动确定。
  相对于批量规范化,GroupNorm的一个优势是它对批次大小的依赖性较小。这使得GroupNorm在训练小批量样本或具有不同批次大小的情况下更加稳定。另外,GroupNorm还可以应用于一维、二维和三维的输入,适用于不同类型的神经网络架构。
  GroupNorm的一种变体是分组卷积(Group Convolution),它将输入通道分成多个组,并在每个组内进行卷积操作。这种结构可以减少计算量,并提高模型的效率。

在这里插入图片描述

  • BatchNorm:batch方向做归一化,算N* H*W的均值
  • LayerNorm:channel方向做归一化,算C* H* W的均值
  • InstanceNorm:一个channel内做归一化,算H*W的均值
  • GroupNorm:将channel方向分group,然后每个group内做归一化,算(C//G) * H * W的均值

一、Group normalization

  Group normalization(GroupNorm)是深度学习中用于规范化神经网络激活的一种技术。它是一种替代批量规范化(BatchNorm)和层规范化(LayerNorm)等其他规范化技术的方法。

import torch
import torch.nn as nn

class GroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5):
        super(GroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        
        self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
    
    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        
        # 将特征重塑成 (batch_size * num_groups, num_channels // num_groups, height, width)
        x = x.view(batch_size, self.num_groups, -1, height, width)
        
        # 计算每个组内的均值和方差
        mean = x.mean(dim=(2, 3, 4), keepdim=True)
        var = x.var(dim=(2, 3, 4), keepdim=True)
        
        # 规范化
        x = (x - mean) / torch.sqrt(var + self.eps)
        
        # 重塑特征
        x = x.view(batch_size, num_channels, height, width)
        
        # 应用缩放和平移
        x = x * self.weight + self.bias
        
        return x

# 使用示例
group_norm = GroupNorm(num_groups=4, num_channels=64)
inputs = torch.randn(32, 64, 32, 32)
outputs = group_norm(inputs)
print(outputs.shape)

二、批量规范化(Batch Normalization)

  BatchNorm的基本思想是对每个特征通道在一个小批次(即一个批次中的多个样本)的数据上进行规范化,使得其均值接近于0,方差接近于1。这种规范化可以有助于加速神经网络的训练,并提高模型的泛化能力。
  具体而言,对于给定的一个特征通道,BatchNorm的计算过程如下:

  1. 对于一个小批次中的输入数据,计算该特征通道上的均值和方差。
  2. 使用计算得到的均值和方差对该特征通道上的数据进行规范化,使得其均值为0,方差为1。
  3. 对规范化后的数据进行缩放和平移操作,使用可学习的参数进行调整,以恢复模型对数据的表示能力。

  通过在训练过程中对每个小批次的数据进行规范化,BatchNorm有助于解决梯度消失和梯度爆炸等问题,从而加速模型的收敛速度。此外,BatchNorm还具有一定的正则化效果,可以减少模型对输入数据的依赖性,增强模型的鲁棒性。

import torch
import torch.nn as nn

# 输入数据形状:(batch_size, num_features)
input_data = torch.randn(32, 64)

# 使用BatchNorm进行批量规范化
batch_norm = nn.BatchNorm1d(64)
output = batch_norm(input_data)

print(output.shape)

三、层规范化(Layer Normalization)

  与批量规范化相比,层规范化更适用于对序列数据或小批次样本进行规范化,例如自然语言处理任务中的文本序列。它在每个样本的特征维度上进行规范化,使得每个样本在特征维度上具有相似的分布。
层规范化的计算过程如下:
对于每个样本,计算该样本在特征维度上的均值和方差。

  1. 使用计算得到的均值和方差对该样本的特征进行规范化,使得其均值为0,方差为1。
  2. 对规范化后的特征进行缩放和平移操作,使用可学习的参数进行调整,以恢复模型对数据的表示能力。
import torch
import torch.nn as nn

# 输入数据形状:(batch_size, num_features)
input_data = torch.randn(32, 64)

# 使用LayerNorm进行层规范化
layer_norm = nn.LayerNorm(64)
output = layer_norm(input_data)

print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1492279.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【AGI】大模型 深度学习入门学习路径

【AGI】大模型 深度学习入门学习路径 1. 深度学习入门学习路径可以从以下几个方面入手:2. 深度学习中的基本概念和算法有哪些?3. Python在深度学习中的应用和库有哪些?4. PyTorch的基本计算结构和应用教程推荐?5. 如何通过实战项目加深对深度学习模型训练的理解?6. 参考资…

Linux运维:在线/离线安装Telnet客户端和Telnet服务

Linux运维:在线/离线安装Telnet客户端和Telnet服务 前言1.1 在线安装Telnet1.2 离线安装Telnet1.3 Telnet服务有关的命令 💖The Begin💖点点关注,收藏不迷路💖 前言 Telnet是一种用于远程登录到其他计算机的协议&…

springboot+jsp汽车配件管理系统idea maven 项目lw

springbootweb汽车配件销售业绩管理系统服务于汽车配件公司业务,实现了客户管理,主要负责对客户相关数据的增删改查方面、渠道管理,主要对渠道信息也就是设备的供应商渠道信息进行管理、项目管理,主要是一些项目信息的记录与整理、销售数据管…

深入探索Docker数据卷:实现容器持久化存储的完美方案(下)

🐇明明跟你说过:个人主页 🏅个人专栏:《Docker入门到精通》 《k8s入门到实战》🏅 🔖行路有良友,便是天堂🔖 目录 四、Docker数据卷的高级管理 1、数据卷的生命周期管理 2、数据…

基于巨控GRM561/562/563Y西门子1200PLC发邮件

巨控GRM560,GRM600系列同比之前的GRM530,除短信,微信,电话语音播报增加了邮件发送功能,简单介绍一下PLC发邮件。 1在博途中建立好DB块 2.打开GRMDEV6,新建工程,做好数据采集,这里以DB4.D0&#…

【笔记】React-Native React DevTools

/** * 官网文档:https://reactnative.dev/docs/next/react-devtools */ 1、本想在Demo项目中添加依赖(npx react-devtools),但其他项目就需要再操作一次,所以全局安装就好了 yarn global add react-devtools 或 npm install -g react-devto…

linux 将 api_key设置环境变量里

vi ~/.bashrc在最后添加api_key的环境变量 export GEMINI_API_KEYAIza**********WvpX7FwbdM刷新配置 source ~/.bashrc使用python 读取环境变量 import os gemini_api_key os.getenv(GEMINI_API_KEY) print(gemini_api_key)

实战解析:打造风控特征变量平台,赋能数据驱动决策

金融业务产品授信准入、交易营销等环节存在广泛的风控诉求,随着业务种类增多,传统的专家规则、评分卡模型难以应付日趋复杂的风控场景。 在传统风控以专家规则系统为主流应用的语境下,规则模型的入参习惯被称为“变量”。基于专家规则的风险…

每日好题3.5

前缀和 这个题目巨妙,打的时候没写出来,后面补题发现太牛了 思路:当前区间左端点 L L L ,当我们向右移动一次,就相当于,原式 - f ( L ) f ( L 1 e 18 ) f(L) f(L 1e18) f(L)f(L1e18),值就…

列车调度——典型的验证栈的出栈合不合法的问题,值得一看

题目描述 有n列火车按照1,2,3...n的顺序排列,现所有的火车需要掉头,所以需要火车先驶入一个调度站,再开出来。 由于只有一根铁轨,所以要么最前面的一辆火车进去调度站,要么调度栈内最上面一辆火车开出调度栈。 现给…

go并发模式之----工作池/协程池模式

常见模式之四:工作池/协程池模式 定义 顾名思义,就是有固定数量的工人(协程),去执行批量的任务 使用场景 适用于需要限制并发执行任务数量的情况 创建一个固定大小的 goroutine 池,将任务分发给池中的 g…

如何用VisualVM工具查看堆内存文件

1.找到安装JDK的bin目录,找到 jvisualvm.exe可执行文件运行即可; 2.然后导入堆内存文件 .hprof文件,看类; 3.分析是哪些对象占了多少内存。

七大 排序算法(一篇文章梳理)

一、引言 排序算法是计算机科学中不可或缺的一部分,它们在数据处理、数据库管理、搜索引擎、数据分析等多个领域都有广泛的应用。排序算法的主要任务是将一组数据元素按照某种特定的顺序(如升序或降序)进行排列。本文将对一些常见的排序算法…

Altium Designer 22 性能优化

目录 AD22 使用起来很卡,完全受不了,卡到我的快捷鼠标宏都无法使用,来试着优化一下它。 每点完一步,都需要点击应用,否则不下心关掉了会很难受 打开右上角齿轮进入设置,取消勾选这几个勾: 接…

不同用户同时编辑商品资料导致的db并发覆盖

背景 这个问题的背景来源于有用户反馈,他在商品系统中对商品打的标签不见了,影响到了前端页面上商品的资料显示 不同用户编辑同一商品导致的数据覆盖问题分析 查询操作日志发现用户B确实编辑过商品资料,并且日志显示确实打上了标签&#x…

【无标题】计算机主要应用于哪些领域

科学计算(或称为数值计算)、数据处理(信息管理)、辅助工程、生产自动化、人工智能。1、科学计算(或称为数值计算):早期的计算机主要用于科学计算。目前,科学计算仍然是计算机应用的一…

6_怎么看原理图之协议类接口之LCD笔记

首先想一想再前几篇文章讲的协议类的前提 1、双方约定好通信的协议 2、双方满足一定的时序要求 以上第二点又有一些要求: 1)弄清2440在这个通信协议中,能设置哪些时序的值,这些值的含义是什么——2440手册 2)弄清楚这…

【Leetcode每日一刷】贪心算法| 45.跳跃游戏 II

1、45.跳跃游戏 II 🦄解题思路: 这题还是比【55.跳跃游戏】难一些的。第一个版本只是说,求跳跃的范围,覆盖到了终点即可。这题则是,能保证覆盖范围到达终点,求的是最少跳几次,跳到终点。 这题…

针对conan install下载source失败问题解决

ps:下面操作是Linux系统,针对win操作系统也适合 问题现象 在运行conan install时,本地没有对应的库的缓存,conan会自动从conan center下载,可能会出现以下情况,重试多次,仍然是报错。 libssh2/1.11.0: C…

Spring基础——XML配置Bean的依赖注入

目录 什么是依赖注入依赖的解析 Spring提供的两种注入方式1. 基于构造器的赖注入1.1 通过类型注入1.2 通过索引注入1.3 通过参数名注入1.4 通过静态工厂方法参数注入 基于Setter的依赖注入 Spring对不同类型的注入方式1. 字面值(String,基本类型&#xf…