【从零开始学习深度学习】6.使用torchvision下载与查看图像分类数据集Fashion-MNIST

news2025/1/15 6:48:24

目录

    • 1.1 获取Fashion-MNIST数据集
    • 2.2 读取小批量
    • 小结

图像分类数据集中最常用的是手写数字识别数据集MNIST。但大部分模型在MNIST上的分类精度都超过了95%。为了更直观地观察算法之间的差异,我们将使用一个图像内容更加复杂的数据集Fashion-MNIST。

本节我们将使用torchvision包,它是服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。torchvision主要由以下几部分构成:

  1. torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
  3. torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
  4. torchvision.utils: 其他的一些有用的方法。

1.1 获取Fashion-MNIST数据集

首先导入本节需要的包或模块。

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys

下面,我们通过torchvision的torchvision.datasets来下载这个数据集。第一次调用时会自动从网上获取数据。我们通过参数train来指定获取训练数据集或测试数据集(testing data set)。测试数据集也叫测试集(testing set),只用来评价模型的表现,并不用来训练模型。

另外我们还指定了参数transform = transforms.ToTensor()使所有数据转换为Tensor,如果不进行转换则返回的是PIL图片。transforms.ToTensor()将尺寸为 (H x W x C) 且数据位于[0, 255]的PIL图片或者数据类型为np.uint8的NumPy数组转换为尺寸为(C x H x W)且数据类型为torch.float32且位于[0.0, 1.0]的Tensor

注意: 由于像素值为0到255的整数,所以刚好是uint8所能表示的范围,包括transforms.ToTensor()在内的一些关于图片的函数就默认输入的是uint8型,若不是,可能不会报错但可能得不到想要的结果。所以,如果用像素值(0-255整数)表示图片数据,那么一律将其类型设置成uint8,避免不必要的bug。

mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

上面的mnist_trainmnist_test都是torch.utils.data.Dataset的子类,所以我们可以用len()来获取该数据集的大小,还可以用下标来获取具体的一个样本。训练集中和测试集中的每个类别的图像数分别为6,000和1,000。因为有10个类别,所以训练集和测试集的样本数分别为60,000和10,000。

print(type(mnist_train))
print(len(mnist_train), len(mnist_test))

输出:

<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000

我们可以通过下标来访问任意一个样本:

feature, label = mnist_train[0]
print(feature.shape, label)  # Channel x Height x Width

输出:

torch.Size([1, 28, 28]) tensor(9)

变量feature对应高和宽均为28像素的图像。由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0.0, 1.0]的32位浮点数。需要注意的是,feature的尺寸是 (C x H x W) 的,而不是 (H x W x C)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1。后面两维分别是图像的高和宽。

Fashion-MNIST中一共包括了10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。

def get_fashion_mnist_labels(labels):
    # 将数值标签转成相应的文本标签
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

下面定义一个可以在一行里画出多张图像和对应标签的函数。

def show_fashion_mnist(images, labels):
    d2l.use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _, figs = plt.subplots(1, len(images), figsize=(12, 12))
    for f, img, lbl in zip(figs, images, labels):
        f.imshow(img.view((28, 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()

现在,我们看一下训练数据集中前10个样本的图像内容和文本标签。

X, y = [], []
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))

在这里插入图片描述

2.2 读取小批量

我们将在训练数据集上训练模型,并将训练好的模型在测试数据集上评价模型的表现。前面说过,mnist_traintorch.utils.data.Dataset的子类,所以我们可以将其传入torch.utils.data.DataLoader来创建一个读取小批量数据样本的DataLoader实例。

在实践中,数据读取经常是训练的性能瓶颈,特别当模型较简单或者计算硬件性能较高时。PyTorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置4个进程读取数据。

batch_size = 256
if sys.platform.startswith('win'):
    num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
    num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

查看读取一遍训练数据需要的时间。

start = time.time()
for X, y in train_iter:
    continue
print('%.2f sec' % (time.time() - start))

输出:

3.36 sec

小结

  • Fashion-MNIST是一个10类服饰分类数据集,之后章节里将使用它来检验不同算法的表现。
  • 我们将高和宽分别为 h h h w w w像素的图像的形状记为 h × w h \times w h×w(h,w)

如果内容对你有帮助,感谢点赞+关注哦!

关注下方GZH,可获取更多干货内容~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/62563.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

博图Modbus组态及参数设定源码

1、组态选择 协议为Modbus,可在程序里设置通讯方式 2、参数初始化设置 3、选择伺服Modbus 地址 4、写入负值&#xff0c;两个字都必须是负值 5、接线方式 伺服端&#xff1a;驱动器通过通讯连接器与计算机相连&#xff0c;使用者可利用 MODBUS 通讯结合汇编语言来操作驱动器&…

运放失调电压失调电流,计算输入电压信号大小,设计反向放大器

笔者电子信息专业硕士毕业&#xff0c;获得过多次电子设计大赛、大学生智能车、数学建模国奖&#xff0c;现就职于南京某半导体芯片公司&#xff0c;从事硬件研发&#xff0c;电路设计研究。对于学电子的小伙伴&#xff0c;深知入门的不易&#xff0c;特开次博客交流分享经验&a…

销售抓住客户心理的话术

由于销售会与各种各样的人打交道&#xff0c;因此销售也是最容易洞察客户心理的职业&#xff0c;优秀的销售懂得如何抓住客户的心理&#xff0c;从而挖掘客户的需求和痛点&#xff0c;进一步促进成交。 前言 由于销售会与各种各样的人打交道&#xff0c;因此销售也是最容易洞察…

【NC65】主子表单据按照单表结构展现 节点客开

需求描述: 需要将【采购入库】按照【采购订单关闭】节点的形式展现数据。 客开思路: 功能注册增加功能节点,(40080603)勾选启用。菜单注册增加 菜单 并关联 功能节点(40080603)。初始化 单据模板,查询模板采购入库单主子表汇总VO是PurchaseInViewVO ,系统里其他单据节点 …

MyBatis-Plus之ActiveRecord[基础增删改查操作]

系列文章目录 Mybatis-Plus知识点[MyBatisMyBatis-Plus的基础运用]_心态还需努力呀的博客-CSDN博客 Mybatis-PlusSpringBoot结合运用_心态还需努力呀的博客-CSDN博客MyBaits-Plus中TableField和TableId用法_心态还需努力呀的博客-CSDN博客 MyBatis-Plus中的更新操作&#xf…

应用层之HTTP和HTTPS协议(必备知识)

文章目录1、什么是HTTP协议2、HTTP协议格式<1>HTTP请求方法<2>HTTP的状态码3、HTTP是不保存状态的协议<1>使用Cookie的状态管理3、HTTPS<1>加密方式<2>理解HTTPS加密过程1、什么是HTTP协议 HTTP协议常被称为超文本传输协议&#xff0c;HTTP协议…

国产蓝牙耳机什么牌子好?2022蓝牙耳机品牌排行

随着蓝牙耳机市场的快速发展&#xff0c;国产蓝牙耳机品牌也越来越多。那么在众多的国产蓝牙耳机当中&#xff0c;什么牌子的比较好呢&#xff1f;下面&#xff0c;一起来看看2022蓝牙耳机品牌排行吧。 一、南卡小音舱蓝牙耳机 售价&#xff1a;299 蓝牙&#xff1a;5.3 发…

ADSP-21489的开发详解:VDSP+自己编程写代码开发(4-按键控制 LED 灯)

以上全部都 OK 之后&#xff0c;我们就可以开始跑程序了。&#xff08;抱歉上面几项写的很罗嗦&#xff0c;都是我这近 15 年来开发 ADI DSP 实际项目里碰到问题的经验之谈&#xff0c;希望能够对用户有帮助&#xff09; 跑程序就涉及到了 Visual DSP软件的操作&#xff0c;我…

高校教材征订系统(Java+Web+MySQL)

目 录 ABSTRACT 2 1 概述 5 1.1开发背景 5 1.2 项目提出的意义 5 1.3 系统的开发方法 5 1.4 系统开发工具 6 1.4.1 JSP简介 6 1.4.2 JDK配置 7 1.4.3 数据库简介 8 1&#xff0e;4&#xff0e;4 tomcat配置 9 2 需求分析 11 2.1可行性分析 11 2.2 系统设计的要求 11 2.3 系统功…

01、RabbitMQ入门

目录 1.、什么是MQ 2、应用场景 3、主流MQ框架 4、Docker安装部署RabbitMQ 5、RabbitMQ管理平台 6、MQ的核心概念 单一生产者和单一消费者 7、springboot整合rabbitmq 执行测试方法testRabbitmq&#xff0c;控制台输出&#xff1a;receive msg : test rabbitmq messag…

[附源码]计算机毕业设计时间管理软件appSpringboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Vector-常用CAN工具 - CANoe入门到精通_01

之前有写过相关的CANoe工程配置&#xff0c;不过没有进行系统的介绍&#xff0c;通过各位热心人士的反馈&#xff0c;有些内容无法看懂&#xff0c;因此后续的内容会做好排版&#xff0c;让大家从入门到精通&#xff0c;一次性掌握所有的相关内容。今天就主要来介绍下VN1640/VN…

读vue源码搞懂响应式原理

vue2响应式原理 Vue2是借助Object.defineProperty()实现的&#xff0c;而Vue3是借助Proxy实现的想深入学习 Vue2 的响应式原理, 需要先学习 Object.defineProperty() 方法为了称呼方便, 后续说 Vue 响应式原理统一指 Vue2 的响应式原理 1.Object.defineProperty 方法 - 简介 定…

go实战(1)-hello,world与包基础(1)-模块基础

目录程序结构程序代码使用和建立包建立目录模块参考中的身份验证计算哈希值go.sum文件Checksum database程序结构 声明一个main包(包是对函数进行分组的一种方法&#xff0c;它由同一目录中的所有文件组成)。 导入流行的fmt包&#xff0c;它包含格式化文本的函数&#xff0c;…

WIFI码挪车码创建生成CPS聚合流量主小程序开发

WIFI码挪车码创建生成CPS聚合流量主小程序开发 系统特点// 这不是一套普通的给别人开SAAS账号的CPS推广返利系统&#xff0c;而是一套服务商版的CPS推广返利系统&#xff01;所谓服务商版&#xff0c;就是所有CPS推广走你的渠道接口&#xff0c;除了可以给你的下级客户开账号外…

MySQL下载安装运行

方式1、MySQL 官方网站&#xff1a;http://www.mysql.com 拉到最下面&#xff1a; 方式2、Windows版 MySQL 的官方下载地址&#xff1a;https://dev.mysql.com/downloads/mysql/ 配置环境变量&#xff1a;在Path中添加至“\bin”&#xff08;系统盘C盘&#xff09;形式 使用管…

【OpenCV 例程 300篇】249. 特征描述之视网膜算法(FREAK)

『youcans 的 OpenCV 例程300篇 - 总目录』 【youcans 的 OpenCV 例程 300篇】249. 特征检测之视网膜算法&#xff08;FREAK&#xff09; 1. FREAK 算法简介 快速视网膜算法&#xff08;FREAK&#xff09;算法是 Alexandre Alahi 在 ICCV 2012 的论文 FREAK: Fast Retina Keyp…

1.32 Cubemx_STM32F429串口中断+空闲中断

1、简介 有时候串口接收数据时,没有帧头与帧尾,单纯使用单字节中断接收数据,不太好断帧。如果单纯使用空闲中断接收数据,当帧内数据不连续或者黏包,使用空闲中断接收就会出现接收的数据小于或者大于帧长度,比较难断帧。解决办法 方法1、单字节中断接收+空闲中断 发送命…

Spring Cache组件

《Spring Cache组件》 提示: 本材料只做个人学习参考,不作为系统的学习流程,请注意识别!!! 《Spring Cache组件》《Spring Cache组件》1. Spring Cache组件概述2. ConcurrentHashMap缓存管理3. Cacheable详解4. Caffeine缓存管理5. 缓存更新策略6. 缓存清除策略7. 多级缓存策略…

[附源码]计算机毕业设计基于Springboot景区直通车服务系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…