【标准化方法】(3) Group Normalization 原理解析、代码复现,附Pytorch代码

news2024/9/23 5:20:27

今天和各位分享一下深度学习中常用的标准化方法,Group Normalization 数据分组归一化,向大家介绍一下数学原理,并用 Pytorch 复现。

Group Normalization 论文地址:https://arxiv.org/pdf/1803.08494.pdf


1. 原理介绍

在目标检测,视频分类等大型计算机视觉应用中,受到计算机内存的限制,必须设置较小的样本数量,但是样本量小势必会导致批归一化的性能有所影响。

分组归一化(Group Normalization,GN)是针对批归一化算法对批次大小依赖性强这一弱点而提出的改进算法。因为 BN 层统计信息的计算与批次的大小有关因此当批次变小时,很明显统计均值和方差的计算会越不准确和稳定,最终会有小批次高错误率的这一现象发生

分组归一化 GN 介于层归一化 LN 和实例归一化 IN 之间,对于输入大小为 [N,C,H,W] 的图像,N 代表批次的大小,C 表示输入通道数,H,W 表示输入图片高度和宽度。

分组归一化首先将输入通道 C 分为 G 个小组,然后分别对每一小组做归一化操作,也就是先把输入的特征维度由 [N,C,H,W] 变成 [N,G,\frac{C}{G},H,W]归一化的维度[\frac{C}{G},H,W]

事实上,当 G 等于 1 时,即所有的输入通道为 1 组GN 与 LN 的计算方式相同,而当 G 等于 C 时1 个输入通道为 1 组GN 与 IN 的计算方式相同

上图是批归一化算法 BN、层归一化算法 LN、实例归一化 IN 和分组归一化 GN 的简单图示。图中的立方体是三维,蓝色的方块是各个算法计算均值和方差的区域

其中 C 代表通道数,N 是批量大小,H,W 是高度和宽度,第三个维度的大小是 H*W,这样输入就可以用三维图形来表示。从上图中可以看出只有 BN  的计算与批次大小 N 有关LN、IN 和 GN 的计算都在单个样本上进行,  LN、IN 和 GN 三者可相互转换。

通常来说,归一化的方式如下所示:

\mu_i=\frac{1}{m}\sum_{k\in S_i}x_k

\sigma_i=\sqrt{\frac{1}{m}\sum_{k\in S_i}\left(x_k-\mu_i\right)^2+\epsilon}

S_i 是均值和方差的计算区域,在 BN 中有:

S_i=\left\{k|k_C=i_C\right\}

在 LN 中:

S_i=\left\{k|k_N=i_N\right\}

在 GN 中:

S_i=\{k\mid k_N=i_N,floor(\frac{k_C}{C/G})=floor(\frac{i_C}{C/G})\}

优点:不依赖批量大小。

缺点:当批量大小较大时,性能不如BN。


2. 代码展示

import torch 
from torch import nn

class GN(nn.Module):
    # 初始化
    def __init__(self, groups:int, channels:int, 
                 eps:float=1e-5, affine:bool=True):
        super(GN, self).__init__()
        # 通道数要整除组数
        assert channels % groups == 0, 'channels should be evenly divisible by groups'
        self.groups = groups  # 把通道分成多少组
        self.channels = channels  # 通道数
        self.eps = eps  # 防止分母为0
        self.affine = affine  # 是否使用可学习的线性变化参数
        if self.affine:
            self.scale = nn.Parameter(torch.ones(channels))  # 缩放因子
            self.shift = nn.Parameter(torch.zeros(channels))  # 偏置
    # 前向传播
    def forward(self, x: torch.Tensor):
        x_shape = x.shape  # 输入特征的维度 [b,c,w,h]
        batch_size = x_shape[0]  # 样本量
        assert self.channels == x.shape[1]  # 预设通道数和输入特征的通道数要保持一致
        # [b,c,w,h]-->[b,g,w*h*c/g]
        x = x.view(batch_size, self.groups, -1)
        # 在最后一个维度上做标准化
        mean = x.mean(dim=[-1], keepdim=True)  # [b,g,1]
        mean_x2 = (x**2).mean(dim=[-1], keepdim=True)  # [b,g,1]
        var = mean_x2 - mean**2
        x_norm = (x-mean) / torch.sqrt(var+self.eps)  # [b,g,w*h*c/g]
        # 线性变化
        if self.affine:
            x_norm = x_norm.view(batch_size, self.channels, -1)  # [b,c,w*h]
            x_norm = self.scale.view(1,-1,1)* x_norm + self.shift.view(1,-1,1)  # [1,c,1]*[b,c,w*h]+[1,c,1]
        # [b,c,w*h]-->[b,c,w,h]
        return x_norm.view(x_shape)

# ---------------------------------- #
# 验证
# ---------------------------------- #

if __name__ == '__main__':
    # 构造输入层
    x = torch.linspace(0, 47, 48, dtype=torch.float32)  # 构造输入层
    x = x.reshape([2,6,2,2])  # [b,c,w,h]
    # 实例化
    gn = GN(groups=3, channels=6)
    # 前向传播
    x = gn(x)
    print(x.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/505912.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Javascript - Cookie的获取和保存应用

在之前的博客介绍了如何利用 Selenium去搭建 cookie池,进行自动化登录、获取信息等。那什么是cookie呢?它的作用又是什么呢? 这里,再重复简单介绍一下。 cookie 是浏览器储存在用户电脑上的一小段文本文件。该文件里存了加密后的用…

LeetCode之回溯算法

文章目录 思想&框架1.组合/子集和排列问题2.组合应用问题 组合/子集问题1. lc77 组合2. lc216 组合总和III3. lc39 组合总和4. lc40 组合总和II5. lc78 子集6. lc90 子集II 排列1. 全排列I2. 全排列II 组合问题的应用1.lc17 电话号码的字母组合2.lc131 分割回文串3. lc19 复…

集约式智能自动化办公,实在智能门户开启政企数字化转型新范式

导语: 随着数字化和智能化的快速发展,数字技术已经深入到各个行业和领域。实在智能基于数字员工在行业的深厚理解和丰富的实践经验,打造一站式的智能化统一平台——智能门户,打破了技术壁垒和系统数据之间的割裂感,实现…

软考A计划-重点考点-专题五(计算机网络知识)

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 👉关于作者 专注于Android/Unity和各种游戏开发技巧,以及各种资源分享&am…

Apache Sentry

官方 说明 Sentry是一种用于在Hadoop集群中控制和管理访问权限的工具。因此,CDH的Sentry指的是Cloudera Distribution for Hadoop中集成的Sentry组件,用于管理Hadoop集群中的访问控制和权限管理。 作用 Sentry是一个用于管理Hadoop集群中的访问权限的…

基于C++实现旅行线路设计

访问【WRITE-BUG数字空间】_[内附完整源码和文档] 系统根据风险评估,为旅客设计一条符合旅行策略的旅行线路并输出,系统能查询当前时刻旅客所处的地点和状态(停留城市/所在交通工具)。 实验内容和实验环境描述 1.1 实验内容 城…

【吐槽贴】项目经理如何进行高效沟通?

“项目最大的风险就是都觉得没有风险。” 这还是跟同行聊天时开玩笑的一句话,最近我却深有体会。一直以为一切正常的项目,最近却接连出了问题,复盘才发现几个关键性问题都出在沟通方面,还一直认为沟通能力是自己的优势。这次主要踩…

使用java-timeseries库,使用arima算法预测时间序列(

项目地址&#xff1a; GitHub - signaflo/java-timeseries: Time series analysis in Java maven&#xff1a; <dependency><groupId>com.github.signaflo</groupId><artifactId>timeseries</artifactId><version>0.4</version> &…

【剖析STL】String

1.什么是STL&#xff1f; 标准模板库&#xff08;Standard Template Library&#xff0c;STL&#xff09;是惠普实验室开发的一系列软件的统称。它是由Alexander Stepanov、Meng Lee和David R Musser在惠普实验室工作时所开发出来的。虽说它主要出现到C中&#xff0c;但在被引…

Dockerfile部署java项目

一、dockerfile展示 将DockerFile 配置文件放到 maven项目目录内&#xff0c;和pom.xml同级。 # Download code FROM bitnami/git:2 AS git RUN mkdir -p /home/app/src RUN git -c http.sslVerifyfalse -C /home/app/src clone -b local https://github.com/test.git# # Bui…

2023年宜昌市中等职业学校技能大赛 “网络搭建与应用”竞赛题-1

2023年宜昌市中等职业学校技能大赛 “网络搭建与应用”竞赛题 一、竞赛内容分布 “网络搭建及应用”竞赛共分二个部分&#xff0c;其中&#xff1a; 第一部分&#xff1a;企业网络搭建部署项目&#xff0c;占总分的比例为50%&#xff1b; 第二部分&#xff1a;企业网络服…

打工人使用ChatGPT的一天!

众所周知&#xff0c;ChatGPT 自去年OpenAI 推出以来&#xff0c;这款 AI 聊天机器人可以说迅速成为了 AI 界的「当红炸子鸡」 作为一名资深的打工人&#x1f477;&#x1f3fb;‍♂️&#xff0c;我们应该怎样利用ChatGPT提高工作效率呢&#xff1f;今天给大家介绍下打工人使…

c++ cuda加速学习笔记

1. 环境配置 (1)显卡驱动下载官网&#xff0c;需要知道自己电脑的显卡类型 搜索链接&#xff1a;https://www.nvidia.com/Download/index.aspx?langzh-cn(2)怎么知道自己的的显卡类型 https://jingyan.baidu.com/article/2a13832888b2a7464a134fef.html 此电脑->管理->…

通过Modbus实现TTS语音全彩声光告警-博灵语音通知终端-网络语音报警灯

背景 目前PLC在工业领域应用广泛&#xff0c;在运行过程中可能会涉及到各种告警。 为了简单快速的实现语音声光告警&#xff0c;本文以大连英仕博科技出品的博灵语音通知终端为例&#xff0c;演示如何通过Modbus TCP协议实现声光告警推送。 播报效果演示 Modbus-博灵语音通知…

Google Play 政策更新重点回顾 (上) | 2023 年 4 月

Google Play 始终如一地为大家打造值得信赖的安全平台&#xff0c;支持大家走向成功。为了让您更及时更清晰地掌握 Google Play 最新政策&#xff0c;我们准备了两篇文章为您详细说明 2023 年 4 月的政策更新内容&#xff0c;以及深度解析。本文是第一篇内容&#xff0c;我们将…

日撸 Java 三百行day46

文章目录 说明day46 快速排序1.基本思路2. 代码 说明 闵老师的文章链接&#xff1a; 日撸 Java 三百行&#xff08;总述&#xff09;_minfanphd的博客-CSDN博客 自己也把手敲的代码放在了github上维护&#xff1a;https://github.com/fulisha-ok/sampledata day46 快速排序 …

计算机网络安全--期末

计算机网络安全绪论 计算机网络实体是什么 计算机网络中的关键设备&#xff0c;包括各类计算机、网络和通讯设备、存储数据的媒体、传输路线…等 典型的安全威胁有哪些 ★ ⋆ \bigstar\star ★⋆ 窃听(敏感信息被窃听)重传(被获取在传过来)伪造(伪造信息发送&#xff09;篡…

kubeadm 部署Kubernetes 集群一主多从集群并完成pod部署

目录 搭建环境准备三台虚拟机&#xff1a; 环境条件限制&#xff1a; 一&#xff0c;关闭交换分区 二&#xff0c;禁用selinux 三&#xff0c;防火墙关闭 四&#xff0c;docker安装 五&#xff0c;设置IPv4 流量传递到 iptables 六&#xff0c;配置k8s的yum源 七&#…

vue3学习三 computed

vue3中的computed 也被封装成了一个组合api , 所以我们使用的时候&#xff0c; 要 import {computed} from “vue” 和vue2样&#xff0c; computed 是有两种书写方式 简写方式和全写方式 <template><div> firstname:<input v-model"firstname" /&g…

【专题连载】基于5G+机器视觉的芯片检测解决方案

基于5G机器视觉的芯片检测解决方案 背景 机器视觉的价值体现在它能为工业生产带来产量的增加和产品质量的提升&#xff0c;并同时降低生产成本&#xff0c;推动了工业生产的快速发展&#xff0c;使工业生产企业真正从中受益。为了进一步压缩生产成本&#xff0c;工业控制的产…