pytorch标准化与模型训练推理以及中间层注意点

news2025/1/10 20:58:12
1.图像归一化和通道转换操作
a = np.arange(3*3*3).reshape(3,3,3).astype(np.uint8)
print(a)
function = transforms.ToTensor()#注意只能转换3维度的ndarray或者PIL的Image类型
c = function(a)
print(c)

'''
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]]
tensor([[[0.0000, 0.0118, 0.0235],
         [0.0353, 0.0471, 0.0588],
         [0.0706, 0.0824, 0.0941]],

        [[0.0039, 0.0157, 0.0275],
         [0.0392, 0.0510, 0.0627],
         [0.0745, 0.0863, 0.0980]],

        [[0.0078, 0.0196, 0.0314],
         [0.0431, 0.0549, 0.0667],
         [0.0784, 0.0902, 0.1020]]])

进程已结束,退出代码为 0
'''

        注意: transforms.ToTensor()#注意只能转换3维度的ndarray或者PIL的Image类型的数据,源数据类型必须为uint8。

        transforms.ToTensor()的作用是把RGB拆分为三通道R,G,B,然后再把每个数除以 255 归一化。

2.Normalize

a = np.arange(3*3*3).reshape(3,3,3).astype(np.float32)
b = torch.from_numpy(a)
trans1_5 = transforms.Normalize(mean=1,std=0.5)
trans3_3 = transforms.Normalize(mean=(1,2,3),std=(1,2,1))
print(a)
c = trans1_5(b)
d = trans3_3(b)
print(c)
print(d)


''' 
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
[[[ 0.  1.  2.]
  [ 3.  4.  5.]
  [ 6.  7.  8.]]

 [[ 9. 10. 11.]
  [12. 13. 14.]
  [15. 16. 17.]]

 [[18. 19. 20.]
  [21. 22. 23.]
  [24. 25. 26.]]]
tensor([[[-2.,  0.,  2.],
         [ 4.,  6.,  8.],
         [10., 12., 14.]],

        [[16., 18., 20.],
         [22., 24., 26.],
         [28., 30., 32.]],

        [[34., 36., 38.],
         [40., 42., 44.],
         [46., 48., 50.]]])
tensor([[[-1.0000,  0.0000,  1.0000],
         [ 2.0000,  3.0000,  4.0000],
         [ 5.0000,  6.0000,  7.0000]],

        [[ 3.5000,  4.0000,  4.5000],
         [ 5.0000,  5.5000,  6.0000],
         [ 6.5000,  7.0000,  7.5000]],

        [[15.0000, 16.0000, 17.0000],
         [18.0000, 19.0000, 20.0000],
         [21.0000, 22.0000, 23.0000]]])

进程已结束,退出代码为 0

'''

         注意:transforms.Normalize() 转换的参数必须是float类型的tensor,当数据为3通道的时候输入的 MEAN和STD参数为列表或者三个数的元组时,是每个通道分别做上面公式的运算,从以上代码,以及打印数据可以看出来。

3.批标准化层

        transforms.Normalize()  transforms.Totensor()函数基本上都是在图像预处理阶段使用,他们都集成在 torchvision 这个库里面,而批标准化层一般在多层网络结构中都会使用它集成在 nn 模块里面,批标准化主要是为了解决梯度消失,加速收敛速度及稳定性的算法,经过批标准化,模型的形状是不变的,BatchNorm后是不改变输入的shape的;

        nn.BatchNorm1d: N * d --> N * d

        nn.BatchNorm2d: N * C * H * W  -- > N * C * H * W

        nn.BatchNorm3d: N * C * d * H * W --> N * C * d * H * W

我们常用的是 BatchNorm1d 和 BatchNorm2d,对于BatchNorm其主要参数为:

nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True,track_running_stats=True, device=None, dtype=None)

        主要参数介绍:

                num_features: 输入维度,也就是数据的特征维度;

                eps: 是在分母上加的一个值,是为了防止分母为0的情况,让其能正常计算;

                affine: 是仿射变化,将分别初始化为1和0;

BatchNorm1d具体计算方式:

方差类型,E(X)是母体方差

        torch.var()函数的参数 unbiased 为 True 则执行样本方差,unbiased 为 False 则执行母体方差(书本上教学的方差,正常理解的方差)。torch.std(),为计算标准差。

(1)BatchNorm1d计算方式


t = torch.arange(6,dtype=torch.float32).view(2,3)
print("t",t)
print('----------------------------------')
me = t.mean(0)
print("t.mean",me)
print('----------------------------------')
print("t.var",t.var(0,False))#执行正常理解的方差
print('----------------------------------')
sq = torch.sqrt(t.var(0,False)+1e-5)
print("t.sqrt",sq)
print('----------------------------------')
x = (t-me)
print("t-me",x )
print('----------------------------------')
y = x/sq
print("x/sq",y )
print('----------------------------------')
bn1d = torch.nn.BatchNorm1d(3,1e-5,0.1,True,
                            True,None,None)
print("weight:",bn1d.weight,"\r\n","bias:",bn1d.bias)
print('----------------------------------')
bnout = bn1d(t)
print("bnout:",bnout)
print('----------------------------------')

'''
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
t tensor([[0., 1., 2.],
        [3., 4., 5.]])
----------------------------------
t.mean tensor([1.5000, 2.5000, 3.5000])
----------------------------------
t.var tensor([2.2500, 2.2500, 2.2500])
----------------------------------
t.sqrt tensor([1.5000, 1.5000, 1.5000])
----------------------------------
t-me tensor([[-1.5000, -1.5000, -1.5000],
        [ 1.5000,  1.5000,  1.5000]])
----------------------------------
x/sq tensor([[-1.0000, -1.0000, -1.0000],
        [ 1.0000,  1.0000,  1.0000]])
----------------------------------
weight: Parameter containing:
tensor([1., 1., 1.], requires_grad=True) 
 bias: Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
----------------------------------
bnout: tensor([[-1.0000, -1.0000, -1.0000],
        [ 1.0000,  1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>)
----------------------------------

进程已结束,退出代码为 0
'''

 注意: BatchNorm1d在模型训练模式的时候至少需要提供两组参数,在模型预测模式可以提供一组参数,在模型训练模式的时候至少需要提供两组参数组成2维的参数,因为训练的时候要根据两组参数来计算。BatchNorm1d平均值mean是按照列求的,目的是求这个批次的均值。 

        官方说明翻译:

将批量归一化应用于4D输入(具有额外通道维度的小批量2D输入),如论文《批量归一化:通过减少内部协变量偏移加速深度网络训练》中所述。
y = (x − E[x])/(√(Var[x] + ϵ))*γ + β
平均值和标准偏差是在小批量的每个维度上计算的,γ和β是大小为C的可学习参数向量(其中C是输入大小)。默认情况下,γ的元素设置为1,而β的元素则设置为0。在前向通道的训练时间,通过有偏估计器计算标准偏差,等效于torc.var(输入,无偏=False)。然而,存储在标准偏差移动平均值中的值是通过无偏估计器计算的,相当于torc.var(输入,无偏=True)。
同样在默认情况下,在训练过程中,该层保持对其计算的平均值和方差的估计,然后在评估过程中用于归一化。运行估计值保持0.1的默认动量。
如果track_running_stats设置为False,则该层不保持运行估计,而是在评估期间使用批统计信息。
笔记
这个动量论点不同于优化器类中使用的动量论点和动量的传统概念。从数学上讲,这里运行统计数据的更新规则是新的 = 1. −  推进力 × x + 推进力 × xt,其中x是估计的统计量,xt是新的观测值。
由于批处理规范化是在C维度上完成的,计算(N,H,W)切片的统计信息,因此称之为“空间批处理规范”是一个常见的术语。

(1)BatchNorm2d计算方式

t = torch.arange(2*3*3*3,dtype=torch.float32).view(2,3,3,3)#生成tensor数据
print("t",t)
bn1d = torch.nn.BatchNorm2d(3,1e-5,0.1,True,
                            True,None,None)
print("weight:",bn1d.weight,"\r\n","bias:",bn1d.bias)
print('----------------------------------')
bnout = bn1d(t)
print("bnout:",bnout)
print('----------------------------------')

#取出来第二通道数据
x = [9., 10., 11.,12., 13., 14.,15., 16., 17.,36., 37., 38.,39., 40., 41.,42., 43., 44.]
x = np.array(x)
mean = x.mean()
print("mean",mean)
out = (x-mean)/np.sqrt(x.var()+1e-5)
print("out",out.reshape(2,3,3))
# x.var()   这个是求方差
# x.std()   这个是求标准差
# print(x.std(),x.var())


''' 
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
t tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]]],


        [[[27., 28., 29.],
          [30., 31., 32.],
          [33., 34., 35.]],

         [[36., 37., 38.],
          [39., 40., 41.],
          [42., 43., 44.]],

         [[45., 46., 47.],
          [48., 49., 50.],
          [51., 52., 53.]]]])
weight: Parameter containing:
tensor([1., 1., 1.], requires_grad=True) 
 bias: Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
----------------------------------
bnout: tensor([[[[-1.2732, -1.2005, -1.1277],
          [-1.0550, -0.9822, -0.9094],
          [-0.8367, -0.7639, -0.6912]],

         [[-1.2732, -1.2005, -1.1277],
          [-1.0550, -0.9822, -0.9094],
          [-0.8367, -0.7639, -0.6912]],

         [[-1.2732, -1.2005, -1.1277],
          [-1.0550, -0.9822, -0.9094],
          [-0.8367, -0.7639, -0.6912]]],


        [[[ 0.6912,  0.7639,  0.8367],
          [ 0.9094,  0.9822,  1.0550],
          [ 1.1277,  1.2005,  1.2732]],

         [[ 0.6912,  0.7639,  0.8367],
          [ 0.9094,  0.9822,  1.0550],
          [ 1.1277,  1.2005,  1.2732]],

         [[ 0.6912,  0.7639,  0.8367],
          [ 0.9094,  0.9822,  1.0550],
          [ 1.1277,  1.2005,  1.2732]]]], grad_fn=<NativeBatchNormBackward0>)
----------------------------------
mean 26.5
out [[[-1.27321838 -1.20046305 -1.12770771]
  [-1.05495237 -0.98219704 -0.9094417 ]
  [-0.83668637 -0.76393103 -0.69117569]]

 [[ 0.69117569  0.76393103  0.83668637]
  [ 0.9094417   0.98219704  1.05495237]
  [ 1.12770771  1.20046305  1.27321838]]]

进程已结束,退出代码为 0
'''

以上代码的通道2,bn计算后的值和np按照公式计算出来的值一样。

注意:BatchNorm2d的均值是按照通道内所有数据求和除以个数得来的,比如(b,c,w,h),求通道一的均值,把 b 批次内所有通道1数据全部加起来除以个数求mean均值,其他通道类似,每个通道互不干扰。

4.Dropout

        当一个复杂的前馈神经网络被训练在小的数据集时,容易造成过拟合。为了防止过拟合,可以通过阻止特征检测器的共同作用来提高神经网络的性能。Dropout会随机选取一些张量置为0,阻止梯度向后传播,是为了减少张量之间的依赖关系,让模型学到更多特征。

t = torch.arange(6,dtype=torch.float32).view(1,-1,3)
print(t)
dp = torch.nn.Dropout1d(0.5)
dpt = dp(t)
print(dpt)

'''
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
tensor([[[0., 1., 2.],
         [3., 4., 5.]]])
tensor([[[0., 2., 4.],
         [0., 0., 0.]]])

进程已结束,退出代码为 0
'''

        Dropout1d支持2维和3维输入 ,随机置0输入,不置0的输入变为,输入/Dropout1d概率 ,置0输入可以让梯度不往后面再进行传输

t = torch.arange(27,dtype=torch.float32).view(1,3,3,3)
print(t)
dp = torch.nn.Dropout2d(0.5)
dpt = dp(t)
print(dpt)

'''
D:\anaconda3\python.exe E:\test\pythonProject\test.py 
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.],
          [24., 25., 26.]]]])
tensor([[[[ 0.,  2.,  4.],
          [ 6.,  8., 10.],
          [12., 14., 16.]],

         [[ 0.,  0.,  0.],
          [ 0.,  0.,  0.],
          [ 0.,  0.,  0.]],

         [[36., 38., 40.],
          [42., 44., 46.],
          [48., 50., 52.]]]])

进程已结束,退出代码为 0
'''

        Dropout2d支持4维的输入(bach,c,w,h),随机置 0 某个通道 ,不置0的输入变为,输入/Dropout1d概率 ,置0输入可以让梯度不往后面再进行传输

        注意Dropout 是在模型训练的时候才起作用,在模型预测的时候不起作用,BatchNorm在训练的时候会优化自己的参数,在预测模式时候保持权重和偏置参数不变,会使 输出数值 = 输入*weight + bias ,所以在不同使用场景要设置模型为不同的模式。

模型训练模式和预测模式调用以下函数改变状态:

model.train()#模型处于训练模式
model.eval()#模型处于预测模式

        注意:模型训练模式和推理模式设置只影响 dropout层和BatchNorm层,对其他层没有影响,在训练模式的时候dropout层有效,在推理模式dropout层无效。BatchNorm层在训练模式按照公式计算,在推理模式直接用 输入*weight+bias = 输出 (具体参数计算上面代码和运行输出有相关示例),我们在做推理的时候需要设置 with torch.no_grad():  不计算梯度 和  model.eval()  模型推理模式。图像预处理的相关转换都在 torchvision.transforms 这个模块里面。dropout和BatchNorm在torch.nn这个模块里面。


                

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1499316.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

智慧灯杆-智慧城市照明现状分析(2)

作为城市照明的主体,城市道路照明伴随着我国城市建设的高速发展,获得了快速的增长。国家统计局数据显示,从2004年至2014年,我国城市道路照明灯数量由1053.15万盏增加到3000万盏以上,年均复合增长率超过11%,城市道路照明行业保持持续快速发展的趋势。 近几年,随着中国路灯…

二维码门楼牌管理系统应用场景:数据管理的智慧新选择

文章目录 前言一、数据管理部门的智慧工具二、助力决策制定与优质服务提供三、二维码门楼牌管理系统的优势四、展望未来 前言 随着科技的飞速发展&#xff0c;二维码门楼牌管理系统正逐渐成为城市管理的智慧新选择。该系统不仅提升了数据管理效率&#xff0c;还为政府和企业提…

黑马点评-分布式锁业务

分布式锁原理和实现 分布式系统部署了多个tomcat&#xff0c;每个tomcat都有一个属于自己的jvm&#xff0c;那么假设在服务器A的tomcat内部&#xff0c;有两个线程&#xff0c;这两个线程由于使用的是同一份代码&#xff0c;那么他们的锁对象是同一个&#xff0c;是可以实现互…

【Proteus仿真】【STM32单片机】井盖安全检测装置设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真STM32单片机控制器&#xff0c;使用LCD1602液晶显示模块、WIFI模块、蜂鸣器、LED按键、ADC PCF8591、角度/可燃气检测传感器等。 主要功能&#xff1a; 系统运行后&#xff0c;LC…

鸿蒙ArkTS语言快速入门-TS(一)

ArkTS与TS的学习 ArkTS与TS的关系简述TypeScript&#xff08;TS&#xff09;简述基础类型1&#xff0c;let2&#xff0c;const3&#xff0c;布尔类型4&#xff0c;数字number5&#xff0c;字符串string6&#xff0c;数组Array7&#xff0c;元组 Tuple8&#xff0c;枚举 enum9&a…

CentOS7.9基于Apache2.4+Php7.4+Mysql8.0架构部署Zabbix6.0LTS 亲测验证完美通过方案

前言: Zabbix 由 Alexei Vladishev 创建,目前由 Zabbix SIA 主导开发和支持。 Zabbix 是一个企业级的开源分布式监控解决方案。 Zabbix 是一款监控网络的众多参数以及服务器、虚拟机、应用程序、服务、数据库、网站、云等的健康和完整性的软件。 Zabbix 使用灵活的通知机制,…

Vue中项目使用debugger,浏览器无效!

现象&#xff1a;下载了别的项目&#xff0c;启动之后&#xff0c;打了debugger&#xff0c;结果浏览器居然忽视&#xff0c;直接过去。打一堆日志&#xff0c;太麻烦了。 解决方案 第一步 F12打开浏览器调试器&#xff0c;找到设置 第二步 如果是英文的&#xff0c;找这…

自定义协议清理后,浏览器还一直弹出匹配提示用户新应用打开问题

问题 这段时间出现了自定义协议清理异常的问题。在一台电脑上&#xff0c;用chrome&#xff0c;一直出现问题&#xff0c;自定义协议可能存在了缓存或者其他内容。导致一直重复的弹出ms-store打开新应用的奇怪问题。 后来 第一步&#xff1a; 清理注册表&#xff0c;把注册…

Spring Boot异常处理和单元测试

1.SpringBoot异常处理 1.1.自定义错误页面 SpringBoot默认的处理异常的机制&#xff1a;SpringBoot 默认的已经提供了一套处理异常的机制。一旦程序中出现了异常 SpringBoot 会向/error 的 url 发送请求。在 springBoot 中提供了一个叫 BasicErrorController 来处理/error 请…

natfrp和FRP配置SSL的基本步骤和bug排查

获取免费/付费SSL 我直接买了一年的ssl证书 设置 主要参考&#xff1a;https://doc.natfrp.com/frpc/ssl.html 遇到的Bug root域名解析是ALIAS&#xff0c;不是CNAME

详细分析Python字典合并的五种方法(附Demo)

目录 前言1. 字典拼接2. {**dict1, **dict2}3. dict.update()4. collections.ChainMap5. collections.defaultdict6. 彩蛋&#xff08;不覆盖合并&#xff09; 前言 从项目中了解到这个函数&#xff1a;res {**res, **tmp}&#xff0c;也知道是字典的合并&#xff0c;且遇到相…

WordPress建站入门教程:如何上传安装WordPress主题?

我们成功搭建WordPress网站后&#xff0c;默认使用的是自带的最新主题&#xff0c;但是这个是国外主题&#xff0c;可能会引用一些国外的资源文件&#xff0c;所以为了让我们的WordPress网站访问速度更快&#xff0c;强烈建议大家使用国产优秀的WordPress主题。 今天boke112百…

msfconsole数据库连接不了的问题【已解决】

msfconsole数据库连接 1.msf数据库端口 msf使用的是postgresql&#xff0c;这个数据库默认端口是5432 单个模块的使用可以不需要数据库&#xff0c;但是模块与模块之间需要沟通的时候就会用到数据库。 2.查看msf数据库连接状态 db_status #msf内部查看systemctl status p…

Windows系统安装MongoDB并结合内网穿透实现公网访问本地数据库

文章目录 前言1. 安装数据库2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射2.3 测试随机公网地址远程连接 3. 配置固定TCP端口地址3.1 保留一个固定的公网TCP端口地址3.2 配置固定公网TCP端口地址3.3 测试固定地址公网远程访问 前言 MongoDB是一个基于分布式文件存储的数…

24 Linux PWM 驱动

一、PWM 驱动简介 其实在 stm32 中我们就学过了 PWM&#xff0c;这里就是再复习一下。PWM&#xff08;Pulse Width Modulation&#xff09;&#xff0c;称为脉宽调制&#xff0c;PWM 信号图如下&#xff1a; PWM 最关键的两个参数&#xff1a;频率和占空比。 频率是指单位时间内…

【易飞】易飞ERP自动审核程序功能

易飞ERP自动审核程序功能 一、 使用场景二、 操作说明三、 安装方式 一、 使用场景 OA系统集成 与第三方OA系统软件集成&#xff0c;在OA软件审核完成后&#xff0c;直接将ERP中的单据审核。MES系统集成 MES系统生成单据写入到易飞ERP中&#xff0c;并需要自动审核单据&#x…

java SSM旅游景点与公交线路查询系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM旅游景点与公交线路查询系统是一套完善的web设计系统&#xff08;系统采用SSM框架进行设计开发&#xff0c;springspringMVCmybatis&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系…

微服务---Eureka注册中心

目录 一、服务中的提供者与消费者 二、Eureka工作流程 三、搭建Eureka服务 四、服务拉取 五、总结 1.搭建EurekaServer 2.服务注册 3.服务发现 一、服务中的提供者与消费者 服务提供者&#xff1a;一次业务中&#xff0c;被其他微服务调用的服务。即提供接口给其他微服务。…

Leetcode HOT150

55. 跳跃游戏 给你一个非负整数数组 nums &#xff0c;你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标&#xff0c;如果可以&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1 …

(十五)【Jmeter】取样器(Sampler)之HTTP请求

简述 操作路径如下: HTTP请求 (HTTP Sampler): 作用:模拟发送HTTP请求并获取响应。配置:设置URL、请求方法、请求参数等参数。使用场景:测试Web应用程序的HTTP接口性能。优点:支持多种HTTP方法和请求参数,适用于大多数Web应用程序测试。缺点:功能较为基础,对于复杂…