BatchNorm1d的复现以及对参数num_features的理解

news2024/11/17 1:38:29

0. Intro

  1. 以pytorch为例,BatchNorm1d的参数num_features涉及了对什么数据进行处理,但是我总是记不住,写个blog帮助自己理解QAQ

1. 复现nn.BatchNorm1d(num_features=1)

  1. 假设有一个input tensor:
input = torch.tensor([[[1.,2.,3.,4.]],[[0.,0.,0.,0.]]])
print(input.shape)
# torch.Size([2, 1, 4])
  1. nn.BatchNorm1d(num_features=1)函数介绍
  • 这个函数长这个样子:
    torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
  • 使用起来是这样的:
BN1 = nn.BatchNorm1d(num_features=1,affine=False,eps=0)   
# input只有1个feature(只有1个channel),每个features的长度=4,第一个batch
print("---BN1---")
print(torch.squeeze(BN1(input)))
  • 注意1:函数参数eps=0是为了让下图这个batchnorm的公式的这个等于0(起保护作用),eps默认为1e-5
    在这里插入图片描述
  • 注意2:上式里的 γ \gamma γ β \beta β分别默认值是1和0,因此只要设置affine=False就可以使用了,注意affine默认为True
  • input shape符合BatchNorm1d要求的[B,C,L]的格式,这里num_features=1C对应
    上面函数的输出为:
---BN1---
tensor([[-0.1690,  0.5071,  1.1832,  1.8593],
        [-0.8452, -0.8452, -0.8452, -0.8452]])
  1. nn.BatchNorm1d(num_features=1)复现结果:
ans = (input-torch.mean(torch.flatten(input)))/torch.sqrt(torch.var(torch.flatten(input),unbiased=False))
print(torch.squeeze(ans))
  • 注意1:torch.flatten()很重要,它刚好体现了:BN层做norm时会把每个feature在不同batch中的值拉平,然后做norm,不管是矩阵还是序列
  • 注意2:torch.var的参数unbiased=False表示求方差时分母是n,也就是不需要求无偏的方差
    它的输出为:
tensor([[-0.1690,  0.5071,  1.1832,  1.8593],
        [-0.8452, -0.8452, -0.8452, -0.8452]])
  • 一模一样

2. 复现nn.BatchNorm1d(num_features=4)

  1. 依然假设有一个input tensor,和上面一样,复制过来
input = torch.tensor([[[1.,2.,3.,4.]],[[0.,0.,0.,0.]]])
print(input.shape)
# torch.Size([2, 1, 4])
  1. nn.BatchNorm1d(num_features=4) 函数介绍
  • 首先这个函数使用起来是这样的:
BN2 = nn.BatchNorm1d(num_features=4,affine=False,eps=0)
print("---BN2---")
print(BN2(torch.squeeze(input)))
  • 注意点1:torch.squeeze是必须的,使用之后tensor的shape会从torch.Size([2, 1, 4])变为torch.Size([2, 4]),符合BatchNorm1d要求的[B,C]的格式,这里num_features=4C对应
  • 上面的函数输出为
---BN2---
tensor([[ 1.,  1.,  1.,  1.],
        [-1., -1., -1., -1.]])
  1. 复现
  • 重点来了,我们理解一下num_features=4,对于现在的input data(经过squeeze之后shape为[B,C] = [2,4]),input data的每个feature现在是一个single value值(不是序列或者矩阵),因此这里可以对某个feature手动计算一下:

    • 以最后一个feature为例:[4,0],可以计算得mean=2,sqrt(var)=2,因此([4,0]-mean)/sqrt(var)=[1,-1]
    • 同理可以计算其他3个feature
  • 一模一样


上面的代码

input = torch.tensor([[[1.,2.,3.,4.]],[[0.,0.,0.,0.]]])
print(input.shape)

BN1 = nn.BatchNorm1d(num_features=1,affine=False,eps=0)   # 每个features的长度=4,第一个batch
print("---BN1---")
print(torch.squeeze(BN1(input)))
print("---BN1 Repeat---")
ans = (input-torch.mean(torch.flatten(input)))/torch.sqrt(torch.var(torch.flatten(input),unbiased=False) )
print(torch.squeeze(ans))

BN2 = nn.BatchNorm1d(num_features=4,affine=False,eps=0)
print("---BN2---")
print(BN2(torch.squeeze(input)))
# BN2就手动算一下啦

3. 对于BatchNorm2d是类似的

  1. 注意点其实只有2点
    • 找准feature是什么
    • BN层做norm时会把每个feature在不同batch中的值拉平,然后做norm,不管是矩阵还是序列

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/399144.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Plsql使用

登录登录system用户,初始有两个用户sys和system,密码是自己安装oracle数据库时写的,数据库选择orcl创建用户点击user,右键新增填写权限关于3个基本去权限介绍: connect : 基本操作表的权限,比如增删改查、视图创建等 r…

Netty channelHandler注意事項——super.channelRead(ctx, msg)

通过nioSocketChannel.pipeline()的addLast添加入站处理器,如果有多个必须显示的唤醒下一个入站处理器,否则执行链中间会断掉。 protected void initChannel(NioSocketChannel nioSocketChannel) throws Exception {log.debug(nioSocketChannel.toStrin…

前端优化,webpack打包删除无用文件,并附上批量删除文件脚本!非常好用

前言 大家可能在webpack打包项目过程中,常遇见一些无用的图片,js文件,怎样能够自动检测哪些是无用的文件呢?本文中介绍使用插件useless-files-webpack-plugin查找无用文件,在terminal中删除,附加bat批量删…

Ngnix安装教程(2023.3.8)

Nginx安装教程(2023.3.8)引言1、Nginx简介2、Nginx安装2.1 下载Nginx安装包2.2 免安装启动Nginx(切记解压后将nginx-1.23.3文件夹需要放在英文路径下,实测中文路径不识别且启动不成功)2.3 熟悉Nginx文件夹目录结构2.4 …

平安银行LAMBDA实验室负责人崔孝林:提早拿到下一个计算时代入场券

量子前哨重磅推出独家专题《“量子”百人科学家》,我们将遍访全球探索赋能“量子”场景应用的百位优秀科学专家,从商业视角了解当下各行业领域的“量子”最新研究成果,多角度、多维度、多层面讲述该领域的探索历程,为读者解析商业…

Python - Pandas - 数据分析(2)

Pandas数据分析2前言常用的21种统计方法describe():numeric_only:偏度skewness:功能:含义:计算公式:演示:峰度值:用途:数值:计算公式:演示&#x…

[Java·算法·中等]LeetCode34. 在排序数组中查找元素的第一个和最后一个位置

每天一题,防止痴呆题目示例分析思路1题解1👉️ 力扣原文 题目 给你一个按照非递减顺序排列的整数数组 nums,和一个目标值 target。请你找出给定目标值在数组中的开始位置和结束位置。 如果数组中不存在目标值 target,返回 [-1,…

Windows 安装 MongoDB 并内网穿透远程连接

本文目录1.前言2.MongoDB数据库的安装2.1 MongoDB下载安装2.2 MongoDB连接测试2.3 cpolar下载安装3.Cpolar端口设置3.1 Cpolar云端设置3.2.Cpolar本地设置4.公网访问测试5.结语1.前言 现代电子技术日新月异,并且快速应用到我们的生活中,与之相应的&…

SAP BTEs的简介及实现

一、认识BTE BTE(Business Transaction Event)也称之为“业务交易事件”,一般的增强(Tcode:SMOD|CMOD)依旧使用ABAP进行二次开发,然而BTE则提供了RFC调用其它产品的可能(Tcode:FIBF)。BTE的设计思路更加简单,和BADI有点类似。在标准程序中留有…

ssm框架之spring:浅聊IOC

IOC 前面体验了spring,不过其运用了IOC,至于IOC( Inverse Of Controll—控制反转 ) 看一下百度百科解释: 控制反转(Inversion of Control,缩写为IoC),是面向对象编程中的一种设计原则&#x…

训练自己的GPT2-Chinese模型

文章目录效果抢先看准备工作环境搭建创建虚拟环境训练&预测项目结构模型预测续写训练模型遇到的问题及解决办法显存不足生成的内容一样文末效果抢先看 准备工作 从GitHub上拉去项目到本地,准备已训练好的模型百度网盘:提取码【9dvu】。 gpt2对联训…

又一个开源第一!飞桨联合百舸,Stable Diffusion推理速度遥遥领先

AIGC(AI Generated Content),即通过人工智能方法生成内容,是当前深度学习最热门的方向之一。其在绘画、写作等场景的应用也一直层出不穷,其中,AI绘画是大家关注和体验较多的方向。 Diffusion系列文生图模型可以实现AI绘画应用&…

八股总结(一)C++语言特性、基础语法、类与模板、内存管理、拷贝控制、STL及C++11新特性

layout: post title: 八股总结(一)C语言特性、基础语法、类与模板、内存管理、拷贝控制、STL及C11新特性 description: 八股总结(一)C语言特性、基础语法、类与模板、内存管理、拷贝控制、STL及C11新特性 tag: 八股总结 总结的大部…

使用python求PLS-DA的方差贡献率

以鸢尾花数据集为例,实现PLS-DA降维,画出降维后数据的散点图并求其方差贡献率。 效果图 完整代码 # 导入所需库 import numpy as np from sklearn.cross_decomposition import PLSRegression from sklearn.datasets import load_iris from sklearn.pre…

synchronized原理mointor

Monitor对象头 在java中普通对象的对象头信息 Mark Word记录分代年龄、加锁的状态;Klass Word指向类对象的指针; 其中Mark Word结构 monitor执行原理 我们在加了重量级锁synchronize后,对象头的mark word会指向一个monitor,mon…

pandas库中的read_csv函数读取数据时候的路径问题详解(ValueError: embedded null character)

read_csv()函数不仅是R语言中的一个读取csv文件的函数,也是pandas库中的一个函数。pandas是一个用于数据分析和处理的python库。它的read_csv函数可以读取csv文件里的数据,并将其转化为pandas里面的DataFrame对象。它由很多参数可以设置,例如…

Express的详细教程

Express 文章目录Express初识ExpressExpress简介Express的基本使用安装创建基本的web服务器监听GET请求监听POST请求把内容响应给客户端获取URL中携带的查询参数获取URL中的动态参数托管静态资源express.static()托管多个静态资源挂载路径前缀nodemon为什么要使用nodemon安装no…

【专项训练】动态规划-1

动态规划 以上,并没有什么本质的不一样,很多时候,就是一些小的细节问题! 要循环,要递归,就是有重复性! 动态规划:动态递推 分治 + 最优子结构 会定义状态,把状态定义对 斐波那契数列 递归、记忆化搜索,比较符合人脑思维 递推:直接开始写for循环,开始递推 这里…

mysql无法启动服务及其他问题总结

文章目录1.安装后关于配置的问题显示【发生系统错误,拒绝访问】命令行Command Line Client闪退2.显示【MySQL服务无法启动】问题检查端口被占用删除data文件并初始化配置my.ini/.conf文件重新安装MySQL1.安装后关于配置的问题 显示【发生系统错误,拒绝访…

Apache Dubbo 存在反序列化漏洞(CVE-2023-23638)

漏洞描述 Apache Dubbo 是一款轻量级 Java RPC 框架 该项目受影响版本存在反序列化漏洞,由于Dubbo在序列化时检查不够全面,当攻击者可访问到dubbo服务时,可通过构造恶意请求绕过检查触发反序列化,执行恶意代码 漏洞名称Apache …