深度学习基础知识 BatchNorm、LayerNorm、GroupNorm的用法解析

news2024/9/20 7:54:29

深度学习基础知识 BatchNorm、LayerNorm、GroupNorm的用法解析

  • 1、BatchNorm
  • 2、LayerNorm
  • 3、GroupNorm
    • 用法:

BatchNorm、LayerNorm 和 GroupNorm 都是深度学习中常用的归一化方式。
它们通过将输入归一化到均值为 0 和方差为 1 的分布中,来防止梯度消失和爆炸,并提高模型的泛化能力

1、BatchNorm

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

import numpy as np
import torch.nn as nn
import torch
 
def bn_process(feature, mean, var):
    feature_shape = feature.shape
    for i in range(feature_shape[1]):
        # [batch, channel, height, width]
        feature_t = feature[:, i, :, :] # 得到每一个channel的height和width
        mean_t = feature_t.mean()
        # 总体标准差
        std_t1 = feature_t.std()
        # 样本标准差
        std_t2 = feature_t.std(ddof=1)
 
        # bn process
        # 这里记得加上eps和pytorch保持一致
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2 + 1e-5)
        # update calculating mean and var
        mean[i] = mean[i] * 0.9 + mean_t * 0.1
        var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
    print(feature)
 
 
# 随机生成一个batch为2,channel为2,height=width=2的特征向量
# [batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
# 初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
# print(feature1.numpy())
 
# 注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)
 
bn = nn.BatchNorm2d(2, eps=1e-5)
output = bn(feature1)
print(output)

显示结果如下:
在这里插入图片描述

在这里插入图片描述

代码:

import torch
import torch.nn as nn
import numpy as np

featuer_array=(np.random.rand(2,4,2,2)).astype(np.float32)
print(featuer_array.dtype)

featuer_tensor=torch.tensor(featuer_array,dtype=torch.float32)
bn_out=nn.BatchNorm2d( num_features=featuer_array.shape[1],eps=1e-5)(featuer_tensor)
print(bn_out)

print("-----")

for i in range(featuer_array.shape[1]):
    channel=featuer_array[:,i,:,:]
    mean=channel.mean()
    var=channel.var()
    print(f"mean---{mean},var---{var}")

    featuer_array[:,i,:,:]=(channel-mean) / np.sqrt(var + 1e-5)
print(featuer_array)

打印结果:
在这里插入图片描述

2、LayerNorm

Transformer block 中会使用到 LayerNorm , 一般输入尺寸形为 :(batch_size, token_num, dim),会在最后一个维度做 归一化,其中dim维度为token的特征向量: nn.LayerNorm(dim)

在这里插入图片描述

import torch
import torch.nn as nn
import numpy as np


feature_array=(np.random.rand(2,3,2,2).astype(np.float32))

# 需要将其转化为[batch,token_num,dim]的形式
feature_array=feature_array.reshape((2,3,-1)).transpose(0,2,1)
print(feature_array.shape)   # (2, 4, 3)

feature_tensor=torch.tensor(feature_array.copy(),dtype=torch.float32)

layer_norm=nn.LayerNorm(normalized_shape=feature_array.shape[2])(feature_tensor)
print(layer_norm)

print("\n","*"*50,"\n")
batch,token_num,dim=feature_array.shape

feature_array=feature_array.reshape((-1,dim))
for i in range(batch * token_num):
    mean=feature_array[i,:].mean()
    var=feature_array[i,:].var()
    print(f"mean----{mean},var----{var}")

    feature_array[i,:]=(feature_array[i,:]-mean) / np.sqrt(var + 1e-5)
print(feature_array.reshape(batch,token_num,dim))

打印效果如下所示:
在这里插入图片描述

3、GroupNorm

在这里插入图片描述

用法:

torch.nn.GroupNorm:将channel切分成许多组进行归一化
torch.nn.GroupNorm(num_groups,num_channels)
num_groups:组数
num_channels:通道数量
在这里插入图片描述
代码:

import torch
import torch.nn as nn
import numpy as np

feature_array=(np.random.rand(2,4,2,2)).astype(np.float32)
print(feature_array.dtype)

feature_tensor=torch.tensor(feature_array.copy(),dtype=torch.float32)
group_result=nn.GroupNorm(num_groups=2,num_channels=feature_array.shape[1])(feature_tensor)
print(group_result)

feature_array = feature_array.reshape((2, 2, 2, 2, 2)).reshape((4, 2, 2, 2))

for i in range(feature_array.shape[0]):
    channel = feature_array[i, :, :, :]
    mean = feature_array[i, :, :, :].mean()
    var = feature_array[i, :, :, :].var()
    print(mean)
    print(var)


    feature_array[i, :, :, :] = (feature_array[i, :, :, :] - mean) / np.sqrt(var + 1e-5)
feature_array = feature_array.reshape((2, 2, 2, 2, 2)).reshape((2, 4, 2, 2))
print(feature_array)

打印结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1080978.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

学生用的台灯哪种比较好?分享专家推荐的学生台灯

对于学生来说,台灯是必不可少的一盏学习照明灯具,它能提供室内不足的光线、亮度,基本每个学生在宿舍、家里都会备着一台。不过台灯也并不是随便挑选一台使用就可以的,很多学生就是因为使用了一些价格低廉、质量安全没有保障的台灯…

【大数据】hadoop安装部署(学习笔记)

一、集群组成概述 Hadoop集群包括两个集群:HDFS集群、YARN集群 两个集群逻辑上分离、通常物理上在一起 两个集群都是标准的主从架构集群 HDFS集群(分布式存储): 主角色:NameNode从角色:DataNode主角色…

代码随想录算法训练营第五十八天 |583. 两个字符串的删除操作、72. 编辑距离、编辑距离总结篇

一、583. 两个字符串的删除操作 题目链接/文章讲解/视频讲解:代码随想录 思考: 1.确定dp数组(dp table)以及下标的含义 dp[i][j]:以i-1为结尾的字符串word1,和以j-1位结尾的字符串word2,想要达…

postman 密码rsa加密登录-2加密密码

上一篇讲了获取公钥,将环境准备好之后,在登录接口的Pre-request Scrip 里,使用公钥进行加密后在正常登录。本文采用的方案是使用第三方模块forge.js来实现加密。 1、环境准备好,系统git 和node都OK。下载forge.js git clone htt…

Java-Atomic原子操作类详解及源码分析,Java原子操作类进阶,LongAdder源码分析

文章目录 一、Java原子操作类概述1、什么是原子操作类2、为什么要用原子操作类3、CAS入门 二、基本类型原子类1、概述2、代码实例 三、数组类型原子类1、概述2、代码实例 四、引用类型原子类1、概述2、AtomicReference3、ABA问题与AtomicStampedReference4、一次性修改&#xf…

SpringBoot (1)

目录 1 入门案例 1.1 环境准备 1.2 编写pom.xml 1.3 编写入口程序 1.4 编写接口 1.5 编写配置 1.6 快速部署 1.6.1 打jar包 1.6.2 部署 1.7 访问接口 2 全注解开发 2.1 常用注解 2.2 属性绑定注解 2.2.1 注册组件 2.2.2 ConfigurationProperties(prefix"te…

AlphaPose Pytorch 代码详解(一):predict

前言 代码地址:AlphaPose-Pytorch版 本文以图像 1.jpg(854x480)为例对整个预测过程的各个细节进行解读并记录 python demo.py --indir examples/demo --outdir examples/res --save_img1. YOLO 1.1 图像预处理 cv2读取BGR图像 img [480,…

哈希的应用--位图和布隆过滤器

哈希的应用--位图和布隆过滤器 位图1. 位图概念2. 位图在实际中的应用3. 位图相似应用给定100亿个整数,如何找到只出现一次的整数?1个文件100亿int,1G内存,如何找到不超过2次的所有整数 布隆过滤器1. 布隆过滤器的提出2. 布隆过滤…

HarmonyOS学习 -- ArkTS开发语言入门

文章目录 一、编程语言介绍二、TypeScript基础类型1. 布尔值2. 数字3. 字符串4. 数组5. 元组6. 枚举7. unknown8. void9. null 和 undefined10. 联合类型 三、TypeScript基础知识条件语句if语句switch语句 函数定义有名函数和匿名函数可选参数剩余参数箭头函数 类1. 类的定义2.…

华为认证 | 这门HCIA认证正式发布!

华为认证云计算工程师HCIA-Cloud Computing V5.5(中文版)自2023年9月28日起,正式在中国区发布。 01 发布概述 基于“平台生态”战略,围绕“云-管-端”协同的新ICT技术架构,华为公司打造了覆盖ICT领域的认证体系&#…

机器人制作开源方案 | 齿轮传动轴偏心轮摇杆简易四足

1. 功能描述 齿轮传动轴偏心轮摇杆简易四足机器人是一种基于齿轮传动和偏心轮摇杆原理的简易四足机器人。它的设计原理通常如下: ① 齿轮传动:通过不同大小的齿轮传动,实现机器人四条腿的运动。通常采用轮式齿轮传动或者行星齿轮传动&#xf…

git多分支、git远程仓库、ssh方式连接远程仓库、协同开发(避免冲突)、解决协同冲突(多人在同一分支开发、 合并分支)

1 git多分支 2 git远程仓库 2.1 普通开发者,使用流程 3 ssh方式连接远程仓库 4 协同开发 4.1 避免冲突 4.2 协同开发 5 解决协同冲突 5.1 多人在同一分支开发 5.2 合并分支 1 git多分支 ## 命令操作分支-1 创建分支git branch dev-2 查看分支git branch-3 分…

bash一行输入,多行回显demo脚本

效果图: 脚本: #!/bin/bash # 定义一个变量,用来存储输入的内容 input"" # 定义一个变量,用来存储输入的字符 char""# 为了让read能读到空格键 IFS_store$IFS IFS# 提示内容,在while循环中也有&a…

SMOS数据处理,投影变换,‘EPSG:6933‘转为‘EPSG:4326‘

在处理SMOS数据时,遇到了读取nc数据并存为tif后,影像投影无法改变,因此全球数据无法重叠。源数据的投影为EPSG:6933,希望转为EPSG:4326。 解决代码。 python import os import netCDF4 as nc import numpy as np from osgeo impo…

阿里云ModelScope 是一个“模型即服务”(MaaS)平台

简介 项目地址:https://github.com/modelscope/modelscope/tree/master ModelScope 是一个“模型即服务”(MaaS)平台,旨在汇集来自AI社区的最先进的机器学习模型,并简化在实际应用中使用AI模型的流程。ModelScope库使开发人员能够通过丰富的…

sap 一次性供应商 供应商账户组 临时供应商 <转载>

原文链接:https://blog.csdn.net/xianshengsun/article/details/132620593 sap中有一次性供应商这个名词,一次性供应商和非一次性供应商又有什么区别呢? 有如何区分一次性供应商和非一次性供应商呢? 1 区分一次性供应商和非一次性…

狄拉克函数及其性质

狄拉克函数及其性质 狄拉克函数 近似处理 逼近近似 积分近似 狄拉克函数的性质 狄拉克函数的Hermite展开

【C++】【自用】STL六大组件:算法

文章目录 🔺sortstable_sort🔺reverse🔺swap🔺find🔺max/min🔺next_permutation/prev_permutation 全排列binary_searchlower_bound/upper_bound 求下界和上界set_union/set_intersection/set_difference 求…

结构体课程自我理解

目录 1. 结构体类型的声明 1.2特殊的声明方法 1.3结构体的自引用 1.4 typedef 对结构体命名 2. 结构体变量的创建和初始化 3. 结构成员访问操作符 4. 结构体内存对⻬ 4.1下面给大伙4个练习题,有自我解析 4.2为什么存在内存对齐 4.3修改默认对齐数 5. …

Centos中利用自带的定时器Crontab_实现mysql数据库自动备份_linux中mysql自动备份脚本---Linux运维工作笔记056

这个经常需要,怕出问题因而需要经常备份数据库,可以利用centos自带的定时器,配合脚本实现自动备份. 1.首先查看一下,这个crontab服务有没有打开: 执行:ntsysv 可以看到已经开机自启动了. 注意这个操作界面,用鼠标不行,需要用,tab按键,直接tab到确定,或取消,然后按回车回到命…