pytorch加载的cifar10数据集,到底有没有经过归一化

news2024/11/25 18:25:01

pytorch加载cifar10的归一化

  • pytorch怎么加载cifar10数据集
    • torchvision.datasets.CIFAR10
    • transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】
  • torchvision.datasets加载的数据集搭配Dataloader使用
  • model.train()和model.eval()

pytorch怎么加载cifar10数据集

torchvision.datasets.CIFAR10

pytorch里面的torchvision.datasets中提供了大多数计算机视觉领域相关任务的数据集,可以根据实际需要加载相关数据集——需要cifar10就用torchvision.datasets.CIFAR10(),需要SVHN就调用torchvision.datasets.SVHN()。

针对cifar10数据集而言,调用torchvision.datasets.CIFAR10(),其中root是下载数据集后保存的位置;train是一个bool变量,为true就是训练数据集,false就是测试数据集;download也是一个bool变量,表示是否下载;transform是对数据集中的"image"进行一些操作,比如归一化、随机裁剪、各种数据增强操作等;target_transform是针对数据集中的"label"进行一些操作。
在这里插入图片描述

示例代码如下:

# 加载训练数据集
train_data = datasets.CIFAR10(root='../_datasets', train=True, download=True,
                                  transform= transforms.Compose([  
                                                 transforms.ToTensor(),  
                                                 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化  
                                                 ])  )
# 加载测试数据集
test_data = datasets.CIFAR10(root='../_datasets', train=False,download=True, 
                             transform= transforms.Compose([  
                                               transforms.ToTensor(),  
                                               transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化  
                                               ])  )

transforms.Normalize()进行归一化到底在哪里起作用?【CIFAR10源码分析】

上面的代码中,我们用transforms.Compose([……])组合了一系列的对image的操作,其中trandforms.ToTensor()transforms.Normalize()都涉及到归一化操作:

  • 原始的cifar10数据集是numpy array的形式,其中数据范围是[0,255],pytorch加载时,并没有改变数据范围,依旧是[0,255],加载后的数据维度是(H, W, C),源码部分:
    在这里插入图片描述

  • __getitem__()函数中进行transforms操作,进行了归一化:实际上传入的transform在__getitem__()函数中被调用,其中transforms.Totensor()会将data(也就是image)的维度变成(C,H, W)的形式,并且归一化到[0.0,1.0]

在这里插入图片描述

  • transforms.Normalize()会根据z = (x-mean) / std 对数据进行归一化,上述代码中mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]是可以将3个通道单独进行归一化,3个通道可以设置不同的mean和std,最终数据范围变成[-0.5,+0.5] 。
    在这里插入图片描述

所以如果通过pytorch的cifar10加载数据集后,针对traindataset.data,依旧是没有进行归一化的;但是比如traindataset[index].data,其中[index]这样的按下标取元素的操作会直接调用的__getitem__()函数,此时的data就是经过了归一化的。
除traindataset[index]会隐式自动调用__getitem__()函数外,还有什么时候会调用这个函数呢?毕竟……只有调用了这个函数才会调用transforms中的归一化处理。——答案是与dataloader搭配使用!

torchvision.datasets加载的数据集搭配Dataloader使用

torchvision.datasets实际上是torch.utils.data.Dataset的子类,那么就能传入Dataloader中,迭代的按batch-size获取批量数据,用于训练或者测试。其中dataloader加载dataset中的数据时,就是用到了其__getitem__()函数,所以用dataloader加载数据集,得到的是经过归一化后的数据
在这里插入图片描述

model.train()和model.eval()

我发现上面的问题,是我用dataloader加载了训练数据集用于训练resnet18模型,训练过程中,我训练好并保存后,顺便测试了一下在测试数据集上的准确度。但是在测试的过程中,我没有用dataloader加载测试数据集,而是直接用的dataset.data来进行的测试。并且!由于是并没有将model设置成model.eval()【其实我设置了,但是我对自己很无语,我写的model.eval,忘记加括号了,无语呜呜】……也就是即便我的测试数据集没有经过归一化,由于模型还是在model.train()模式下,因此模型的BN层会自己调整,使得模型性能不受影响,因此在测试数据集上的accuracy达到了0.86,我就没有多想。
后来我用模型的时候,设置了model.eval()后,依旧是直接用的dataset.data(也就是没有归一化),不管是在测试数据集上还是在训练数据集上,accuracy都只有0.10+,我表示非常的迷茫疑惑啊!然后才发现是归一化的问题。

  • model.train()模式下进行预测时,PyTorch会默认启用一些训练相关的操作,例如Batch Normalization和Dropout,并且模型的参数是可变的,能够根据输入进行调整。这些操作在训练模式下可以帮助模型更好地适应训练数据,并产生较高的准确度。
  • model.eval()模式下进行预测时,PyTorch会将模型切换到评估模式,这会导致一些训练相关的操作行为发生变化。具体而言,Batch Normalization层会使用训练集上的统计信息进行归一化,而不是使用当前批次的统计信息。因此,如果输入数据没有进行归一化,模型在评估模式下的准确度可能会显著下降。

以下是我没有用dataloader加载数据集,进行预测的代码:

def correctness(model,data,target, device):
    batchsize = 1000
    batch_num = int(len(data) / batchsize)   
    # 对原始的数据进行操作 从H.W.C变成C.H.W 
    data = torch.tensor(data).permute(0,3,1,2).type(torch.FloatTensor).to(device)
    # 手动归一化
    data = data/255
    data = (data - 0.5) / 0.5 
    # 求一个batch的correctness
    def _batch_correctness(i):
        images, labels = data[i*batchsize : (i+1)*batchsize], target[i*batchsize : (i+1)*batchsize]
        predict = model(images).detach().cpu()    
        correctness = np.array(torch.argmax(predict, dim = 1).numpy() == np.array(labels) , dtype= np.float32)
        return correctness
    result = np.array([_batch_correctness(i) for i in range(batch_num)])

    return result.flatten().sum()/data.shape[0]

我后面用上面的代码测试了四种情况:

  • model.eval() + 没有归一化:train_accuracy = 0.10,test_accuracy = 0.10;
  • model.eval() + 手动归一化:train_accuracy = 0.95,test_accuracy = 0.84;
  • model.train() + 没有归一化:train_accuracy = 0.95,test_accuracy = 0.83;
  • model.train() + 手动归一化:train_accuracy = 0.94,test_accuracy = 0.84;

由此可见,在model.eval()模式下,数据归一化对最终的测试结果有很大影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1179384.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Paste v4.1.2(Mac剪切板)

Paste for Mac是一款运行在Mac OS平台上的剪切板小工具,拥有华丽的界面效果,剪切板每一条记录可显示(预览)文本,图片等记录的完整内容,可以记录最近指定条数的剪切板信息,方便用户随时调用&…

输电线路AR可视化巡检降低作业风险

随着现代工业的快速发展,各行业的一线技术工人要处理的问题越来越复杂,一些工作中棘手的问题迫切需要远端专家的协同处理。但远端专家赶来现场往往面临着专家差旅成本高、设备停机损失大、专业支持滞后、突发故障无法立即解决等痛点。传统的远程协助似乎…

Emacs之高亮显示超过80个字符部分(一百三十)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

Linux入门知识

​ 文章目录 一、Linux发展历程1.1、Linux前身-Unix1.2、Linux 诞生 二、Linux系统特点三、Linux 分支四、Linux系统架构4.1、系统调用4.2、Linux shell4.3、Linux文件系统4.4、Linux内核 团队博客: 汽车电子社区 一、Linux发展历程 1.1、Linux前身-Unix 1968年Multics 项目…

【Excel】如何画不同时序交叉的百分比堆积柱状图

这里写自定义目录标题 1 将两表交叉合并为一个表1.1 步骤一:在两独立表的工作天数和工资列下面按1-n顺次标号。1.2 步骤二:选中两表需要合并的部分,调出自定义排序1.3 步骤三:选项 ——> 按行排序 (选完后点确定&am…

LV.12 D17 中断控制器 学习笔记

一、中断控制器 在处理IRQ的时候,会将CPSR写入IRQ_SPSR,然后将CPU切换为IRQ模式,把状态改成ARM状态,把I位写成1禁止全部的IRQ,所以中断这样是我们不想要的。4412是一个四核的CPU,在发送中断前要确定发送给哪…

【UDS基础】简单介绍“统一诊断服务“

1. 前言 我们将在这个实用教程中介绍UDS的基础知识,重点关注在CAN总线上的UDS(UDSonCAN)和CAN诊断(DoCAN)。此外,我们还会介绍ISO-TP协议,并解释UDS、OBD2、WWH-OBD和OBDonUDS之间的差异。 最后,我们将解释如何请求、记录和解码UDS消息,并提供一些实际示例,例如记录…

libevent

libevent 库概念和特点 开源。精简。跨平台(Windows、Linux、maxos、unix)。专注于网络通信(不一定非用在网络当中,比如下面的读写管道)。 libevent特性:基于"事件",面向“文件描述符…

C++算法:第N位数的原理、源码及测试用例

本文涉及知识点 简单的数学知识。 本博文对应源码,审核比较慢,请耐心等待:https://download.csdn.net/download/he_zhidan/88504919 本博文在CSDN 学院有对应课程。 题目 给你一个整数 n ,请你在无限的整数序列 [1, 2, 3, 4, 5…

编译原理(1)----LL(1)文法(首符号集,后跟符号集,选择符号集)

一.首符号集(First()) 技巧:找最左边可能出现的终结符 例: 1.First(E) E->T,最左边为T,又因为T->F,最左边为F,F->(E)|i,则最左边为{(,i } 2.First(T):只需要看符号串最左…

Mac VsCode g++编译报错:不支持C++11语法解决

编译运行时报错: [Running] cd “/Users/yiran/Documents/vs_projects/c/” && g 1116.cpp -o 1116 && "/Users/yiran/Documents/vs_projects/c/"1116 1116.cpp:28:22: warning: range-based for loop is a C11 extension [-Wc11-extensi…

pda条码二维码扫描数据采集安卓手持终端扫码热敏标签打印一体机

HT800新一代移动物联终端是深圳联强优创信息科技有限公司自主研发的基于Android11操作系统的高性能、高可靠的工业级手持数据终端,能与其它设备进行无线通讯,提供良好的操作界面,支持条码扫描、RFID读写(NFC)、GPS定位…

ch579串口编程笔记

“CH579SFR.h”库文件,关于串口中断部分 /* UART interrupt identification values for IIR bits 3:0 */ #define UART_II_SLV_ADDR 0x0E // RO, UART0 slave address match #define UART_II_LINE_STAT 0x06 // R…

云服务器哪家便宜靠谱 | 简单了解亚马逊云科技发展史

云服务器哪家便宜又靠谱呢?为什么说亚马逊云科技在这道题答案的第一行,一篇故事告诉你。 1994年,杰夫贝索斯在西雅图创建了亚马逊,最初只是一个在线书店。 1997年,亚马逊在纳斯达克交易所上市,成为一家公…

基于 Odoo + Python 的网站快速开发指南

基于 Odoo Python 的网站快速开发指南 下载根据本指南开发的主题模块源码 Odoo 网站生成器是一个灵活的工具,可以轻松构建与 Odoo 应用完全集成的网站。使用其提供的主题选项 (options) 和构建块 (blocks) 很容易定制网站。然而,你还可以更进一步深度定…

JavaEE进阶3

传递数组: 当我们请求中,同一个参数有多个时,浏览器就会帮我们封装成一个数组 用逗号进行分割也是可以的(有的浏览器不能直接使用逗号,需要我们去转码) 传递集合: HTTP 状态码(不是后端自定义的) 2XX:成功 3XX:重定向 4XX:客户端错误 5XX:服务器错误 业务状态码:HTTP响应…

第十章 Python 自定义模块及导入方法

系列文章目录 第一章 Python 基础知识 第二章 python 字符串处理 第三章 python 数据类型 第四章 python 运算符与流程控制 第五章 python 文件操作 第六章 python 函数 第七章 python 常用内建函数 第八章 python 类(面向对象编程) 第九章 python 异常处理 第十章 python 自定…

Cookie、Session、Token、JWT 一篇就够了

什么是认证(Authentication) 通俗地讲就是验证当前用户的身份,证明“你是你自己”(比如:你每天上下班打卡,都需要通过指纹打卡,当你的指纹和系统里录入的指纹相匹配时,就打卡成功&a…

git工作流(待续)

一、git git是分布式管理系统和集中式的区别在于每个人具有的本地份数不同,集中式只有一份 分布式 主要是 协同工作 gitlab github 等是git仓库的一个托管平台 二、git安装、初始化 基础配置 第一次需要对身份进行说明 git config --global user.name "xx&…

阿里云安全恶意程序检测(速通二)

阿里云安全恶意程序检测 高阶数据探索变量分析连续数值变量与连续数值变量单个类别变量与连续数值变量两个类别变量与连续数值变量两个变量线性关系探索查看多个双变量关系的技巧 高阶数据探索多变量交叉探索 高阶数据探索 变量分析 连续数值变量与连续数值变量 分析连续数值…