深度学习归一化原理及代码实现(BatchNorm2d,LayerNorm,InstanceNorm,GroupNorm)

news2025/1/22 23:37:13

文章目录

    • 概述
    • 形式
    • 原理理解
    • 源代码实现
      • 1.BatchNorm2d
      • 2.LayerNorm
      • 3.InstanceNorm
      • 4.GroupNorm

概述

本文记录总结pytorch中四种归一化方式的原理以及实现方式。方便后续理解和使用。
本文原理理解参考自

https://zhuanlan.zhihu.com/p/395855181

形式

四种归一化的公式都是相同的,即
在这里插入图片描述
其实就是普通的归一化公式,

((x-均值)/标准差)*γ +β

γ和β是可学习参数,代表着对整体归一化值的缩放(scale) γ和偏移(shift) β。

四种不同形式的归一化归根结底还是归一化维度的不同。

形式原始维度均值/方差的维度
BatchNorm2dNCHW1C11
LayerNormNCHWN111
InstanceNormNCHWNC11
GroupNormNCHWNG11 (G=1,LN,G=C,IN)

原理理解

在这里插入图片描述

  1. BatchNorm2d 从维度上分析,就是在NHW维度上分别进行归一化,保留特征图的通道尺寸大小进行的归一化。
    由上图理解,蓝色位置代表一个归一化的值,BN层的目的就是将每个batch的hw都归一化,而保持通道数不变。抽象的理解就是结合不同batch的通道特征。因此这种方式比较适合用于分类,检测等模型,因为他需要对多个不同的图像有着相同的理解。
  2. LayerNorm 从维度上分析,就是在CHW上对对象的归一化,该归一化的目的可以保留每个batch的自有特征。抽象上来理解,就是通过layernorm让每个batch都有不同的值,有不同的特征,因此适用于图像生成或RNN之类的工作
  3. InstanceNorm从维度上来分析,就是将HW归一化为一个值,保留在通道上C和batch上的特征N。相当于对每个batch每个通道做了归一化。可以保留原始图像的信号而不混杂,因此常用于风格迁移等工作。
  4. GroupNorm从维度上来分析,近似于IN和LN,但是就是在通道上可以分成若干组(G),当G代表权重通道时就变成了LN,当G代表单通道就变成了IN,我也不清楚为什么用这个,但是G通常好像设置为32.

源代码实现

结合以上理解,就可以从原理上实现pytorch中封装的四个归一化函数。如下所示。

1.BatchNorm2d

import torch
import torch.nn as nn

class CustomBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,scale=1,shift=0):
        super(CustomBatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 可训练参数
        self.scale = scale
        self.shift = shift
        
        # 不可训练的运行时统计信息
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)
    
    def forward(self, x):
        # 计算输入张量的均值和方差
        mean = x.mean(dim=(0, 2, 3), keepdim=True)
        print("mean.shape",mean.shape)
        var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
        print("var.shape",var.shape)
        
        # 更新运行时统计信息 (Batch Normalization在训练和推理模式下的行为不同)
        if self.training:
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.squeeze()
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        
        # 应用 scale 和 shift 参数
        scaled_x = self.scale.view(1, -1, 1, 1) * x_normalized + self.shift.view(1, -1, 1, 1)
        
        return scaled_x
if __name__ =="__main__":
    # 创建示例输入张量
    x = torch.randn(16, 3, 32, 32)  # 示例输入数据

    scale = nn.Parameter(torch.randn(x.size(1)))
    shift = nn.Parameter(torch.randn(x.size(1)))

    # 创建自定义批量归一化层
    custom_batchnorm = CustomBatchNorm2d(num_features=3,scale=scale,shift=shift)

    # 调用自定义批量归一化层
    normalized_x_custom = custom_batchnorm(x)

    # 创建官方的批量归一化层
    official_batchnorm = nn.BatchNorm2d(num_features=3)
    official_batchnorm.weight=scale
    official_batchnorm.bias=shift

    # 调用官方批量归一化层
    normalized_x_official = official_batchnorm(x)

    # 检查自定义层和官方层的输出是否一致
    are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
    print("自定义批量归一化和官方批量归一化是否一致:", are_equal)


2.LayerNorm

import torch
import torch.nn as nn

class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5,scale=1,shift=0):
        super(CustomLayerNorm, self).__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps
        
        # 可训练参数
        # self.scale = nn.Parameter(torch.ones(normalized_shape))
        # self.shift = nn.Parameter(torch.zeros(normalized_shape))
        self.scale = scale
        self.shift = shift
    
    def forward(self, x):
        # 计算输入张量 x 的均值和方差
        mean = x.mean(dim=(1,2,3), keepdim=True)
        variance = x.var(dim=(1,2,3), unbiased=False, keepdim=True)
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # # 应用 scale 和 shift 参数
        
        scaled_x = self.scale * x_normalized + self.shift
        # 应用 scale 和 shift 参数
        #scaled_x = self.scale.view(-1, 1, 1, 1) * x_normalized + self.shift.view(-1, 1, 1, 1)
        
        return scaled_x

# 创建示例输入张量
x = torch.randn(16, 3, 32, 32)  # 示例输入数据

scale = nn.Parameter(torch.randn(3,32,32))
shift = nn.Parameter(torch.randn(3,32,32))

# 创建自定义 Layer Normalization 层
#custom_layernorm = CustomLayerNorm(normalized_shape=16)
custom_layernorm = CustomLayerNorm(normalized_shape=(3,32,32),scale=scale,shift=shift)

# 调用自定义 Layer Normalization 层
normalized_x_custom = custom_layernorm(x)

# 创建官方的 Layer Normalization 层
#official_layernorm = nn.LayerNorm(normalized_shape=3)
official_layernorm = nn.LayerNorm(normalized_shape=(3,32,32))
official_layernorm.weight=scale
official_layernorm.bias=shift
#official_layernorm = nn.LayerNorm(normalized_shape=(0,2,3))

# 调用官方 Layer Normalization 层
normalized_x_official = official_layernorm(x)
#print(normalized_x_official.shape)

# # 检查自定义层和官方层的输出是否一致
are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
print("自定义 Layer Normalization 和官方 Layer Normalization 是否一致:", are_equal)

3.InstanceNorm

import torch
import torch.nn as nn

class CustomInstanceNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5,scale=1,shift=0):
        super(CustomInstanceNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        
        # 不可训练参数
        # self.scale = nn.Parameter(torch.ones(num_features))
        # self.shift = nn.Parameter(torch.zeros(num_features))
        self.scale = scale
        self.shift = shift
    
    def forward(self, x):
        # 计算输入张量 x 的均值和方差
        mean = x.mean(dim=(2, 3), keepdim=True)
        variance = x.var(dim=(2, 3), unbiased=False, keepdim=True)
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # 应用 scale 和 shift 参数
        scaled_x = self.scale.view(1, -1, 1, 1) * x_normalized + self.shift.view(1, -1, 1, 1)
        
        return scaled_x

# 创建示例输入张量
x = torch.randn(16, 3, 32, 32)  # 示例输入数据

# 创建自定义 Instance Normalization 层
scale = nn.Parameter(torch.randn(3))
shift = nn.Parameter(torch.randn(3))
custom_instancenorm = CustomInstanceNorm(num_features=3,scale=scale,shift=shift)

# 调用自定义 Instance Normalization 层
normalized_x_custom = custom_instancenorm(x)

# 创建官方的 Instance Normalization 层
official_instancenorm = nn.InstanceNorm2d(num_features=3)
official_instancenorm.weight=scale
official_instancenorm.bias=shift

# 调用官方 Instance Normalization 层
normalized_x_official = official_instancenorm(x)

# # 检查自定义层和官方层的输出是否一致
are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
print("自定义 Layer Normalization 和官方 Layer Normalization 是否一致:", are_equal)



4.GroupNorm

import torch
import torch.nn as nn

class CustomGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, eps=1e-5,scale=1,shift=0):
        super(CustomGroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        
        # 不可训练参数
        self.scale = scale
        self.shift = shift
    
    def forward(self, x):
        # 将输入张量 x 分成 num_groups 个组
        # 注意:这里假定 num_channels 可以被 num_groups 整除
        group_size = self.num_channels // self.num_groups
        x = x.view(-1, self.num_groups, group_size, x.size(2), x.size(3))
        
        # 计算每个组的均值和方差
        mean = x.mean(dim=(2, 3, 4), keepdim=True)
        variance = x.var(dim=(2, 3, 4), unbiased=False, keepdim=True)
        
        # 归一化输入张量
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # 将组合并并应用 scale 和 shift 参数
        x_normalized = x_normalized.view(-1, self.num_channels, x.size(3), x.size(4))
        scaled_x = self.scale.view(1, -1, 1, 1) * x_normalized + self.shift.view(1, -1, 1, 1)
        
        return scaled_x

# 创建示例输入张量
x = torch.randn(16, 6, 32, 32)  # 示例输入数据,有6个通道

# 创建自定义 Group Normalization 层
scale = nn.Parameter(torch.randn(6))
shift = nn.Parameter(torch.randn(6))
custom_groupnorm = CustomGroupNorm(num_groups=3, num_channels=6,scale=scale,shift=shift)

# 调用自定义 Group Normalization 层
normalized_x_custom = custom_groupnorm(x)

# 创建官方的 Group Normalization 层
official_groupnorm = nn.GroupNorm(num_groups=3, num_channels=6)
official_groupnorm.weight = scale
official_groupnorm.bias = shift

# 调用官方 Group Normalization 层
normalized_x_official = official_groupnorm(x)

# # 检查自定义层和官方层的输出是否一致
are_equal = torch.allclose(normalized_x_custom, normalized_x_official, atol=1e-5)
print("自定义 Layer Normalization 和官方 Layer Normalization 是否一致:", are_equal)

如果有用帮忙点个赞哦

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020988.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

油猴Safari浏览器辅助插件:Tampermonkey for Mac中文版

油猴脚本Tampermonkey是一款油猴Safari浏览器辅助插件,是一款适用于Safari用户的脚本管理,能够方便管理不同的脚本。虽然有些受支持的浏览器拥有原生的用户脚本支持,但tampermonkey油猴插件将在您的用户脚本管理方面提供更多的便利&#xff0…

使用Python进行时间序列分析的8种图

时间序列数据 时间序列数据是按时间顺序以固定的时间间隔排列的观测值的集合。每个观察对应于特定的时间点,并且数据可以以各种频率(例如,每日、每月、每年)。这种类型的数据在许多领域都非常重要,包括金融&#xff0…

linux+c语言杂记(二)

一、在 Ubuntu 20.04 上安装 GCC 默认的 Ubuntu 软件源包含了一个软件包组,名称为 “build-essential”,它包含了 GNU 编辑器集合,GNU 调试器,和其他编译软件所必需的开发库和工具。 想要安装开发工具软件包,以 拥有 sudo 权限用…

新一代最强开源UI自动化测试神器Playwright(Java版)(对话框处理)

🎭Playwright让网页对话框🌐💬处理变得更加快捷!网页对话框是在网页上出现的常见弹窗,包括Alert、Confirm和Prompt等。这些对话框通常需要用户输入信息或进行某些选择,但是在自动化测试中处理它们可能会很棘…

双周赛113(枚举、分类讨论 + 二分查找、枚举值域两数之和、换根DP)

文章目录 双周赛113[2855. 使数组成为递增数组的最少右移次数](https://leetcode.cn/problems/minimum-right-shifts-to-sort-the-array/)暴力枚举贪心 O(n) [2856. 删除数对后的最小数组长度](https://leetcode.cn/problems/minimum-array-length-after-pair-removals/)分类讨…

[MAUI]实现动态拖拽排序网格

文章目录 创建页面元素创建可绑定对象创建绑定服务类拖拽(Drag)拖拽悬停,经过(DragOver)释放(Drop) 限流(Throttle)和防抖(Debounce)项目地址 上一章我们使用拖放(drag-drop)手势识别实现了可拖…

TCP/IP协议栈各层涉及到的协议

21/tcp FTP 文件传输协议 22/tcp SSH 安全登录、文件传送(SCP)和端口重定向 23/tcp Telnet 远程连接 80/tcp HTTP 443/tcp HTTPS 计算机各层网络协议 五层:应用层: (典型设备:应用程序,如FTP,SMTP ,HTTP) DHCP(Dynamic Host…

Pikachu Burte Force(暴力破解)

一、Burte Force(暴力破解)概述 ​ “暴力破解”是一攻击具手段,在web攻击中,一般会使用这种手段对应用系统的认证信息进行获取。 其过程就是使用大量的认证信息在认证接口进行尝试登录,直到得到正确的结果。 为了提高…

RFID与人工智能怎么融合,RFID与人工智能融合的应用

随着物联网技术的不断发展,现实世界与数字世界的桥梁已经被打通。物联网通过各种传感器,将现实世界中的光、电、热等信号转化为有价值的数据。这些数据可以通过RFID技术进行自动收集和传输,然后经由人工智能算法进行分析、建模和预测&#xf…

uniapp cli创建 vue3 + typeScript项目 配置eslint prettier husky

1 命令创建项目 npx degit dcloudio/uni-preset-vue#vite-ts my-vue3-project2 下载依赖 npm install3 填写appid 4 运行项目并且微信开发工具打开 npm run dev:mp-weixin5 安装 vscode 插件 安装 **Vue Language Features (Volar)** :Vue3 语法提示插件 安装 *…

伦敦银一手是多少?

伦敦银是以国际现货白银价格为跟踪对象的电子合约交易,无论投资者通过什么地方的平台进入市场,执行的都是统一国际的标准,一手标准的合约所代表的就是5000盎司的白银,如果以国内投资者比较熟悉的单位计算,那约相当于15…

http客户端Feign使用

一、RestTemplate方式调用存在的问题 先来看我们以前利用RestTemplate发起远程调用的代码: String url "http://userservice/user/" order.getUserId(); User user restTemplate.getForObject(url, User.class);存在下面的问题: 代码可读…

Mosh Java课程自学(一)

目录 一、前言 二、全局介绍 三、Types 一、前言 首先推荐一下B站上转载的Mosh讲Java课程,当然,建议有一定基础并且英文水平尚可的同学学习,否则你可能会被搞得很累并逐渐失去对编程的兴趣。 Mosh 【JAVA终极教程】中英文字幕 高清完整版…

口袋参谋:如何高效一键下载真实买家秀?

​在淘宝天猫上,即使卖一支笔都有上万个宝贝竞争,所有卖家拼的就是权重带来的曝光度,能展示给买家多少,自己收获多少流量。 如何用自己的优势将流量访客转化为顾客,提升店铺的转化率。而买家秀,就是为此而生…

Java常用类之 String、StringBuffer、StringBuilder

Java常用类 文章目录 一、字符串相关的类1.1、String的 不可变性1.2、String不同实例化方式的对比1.3、String不同拼接操作的对比1.4、String的常用方法1.5、String类与其他结构之间的转换1.5.1、String 与基本数据类型、包装类之间的转换1.5.2、String 与char[]的转换1.5.3、…

ipad可以使用其他品牌的手写笔吗?开学平价电容笔推荐

新学期已经来临,相信不少同学已经开始着手筹备新学期的该准备什么了,毕竟原装的苹果Pencil,功能强大,但价格昂贵,一般人根本买不起。那么,有没有像苹果原装那样的电容笔呢?当然是有的。国产的平…

长安链上线可视化敏捷测试工具v1.0版本

开发者对区块链底层平台进行初步的了解后,一项经常会涉及到的工作是对平台进行测试以考量其性能及稳定性是否符合自身使用需求。长安链推出了可视化UI操作界面的区块链敏捷测试工具v1.0版本,当前版本可对内置合约进行压测并生成网络拓扑图以验证组网方式…

免费开箱即用微鳄售后工单管理系统

编者按:本文介绍基于天翎MyApps低代码平台开发的微鳄售后工单管理系统, 引入低代码平台可以帮助企业快速搭建和部署售后工单管理系统, 以工作流作为支撑,在线完成各环节数据审批,解决售后 工单 服务的全生命周期过程管…

《2023中国氢能源行业分析报告》丨附下载_三叠云

✦ ✦✦ ✦✦ ✦✦ ✦ 1. 国内氢能政策梳理 直接涉及氢能政策:1)21年以来,发布国家级10个、省级83个、 市县级252个;2)涉及发展规划占比45%、财政支持占比 20%、项目支持占比17%、管理办法占比16%、 氢能安全和标准占…

公私钥非对称加密 生成和验证JSON Web Token (JWT)

前言 这是我在这个网站整理的笔记,关注我,接下来还会持续更新。 作者:神的孩子都在歌唱 公私钥非对称加密 生成和验证JSON Web Token 什么是JSON Web Token (JWT)Java程序中生成和验证JWT代码解析 什么是JSON Web Token (JWT) JSON Web Tok…