【深度学习】图像分类数据集

news2025/4/8 7:07:39

图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。
我们将使用类似但更复杂的Fashion-MNIST数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()#设置图表大小,具体实现过程及其底层逻辑见微积分一节

读取数据集

我们可以[通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中]。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,

# 并除以255使得所有像素的数值均在0~1之间

trans = transforms.ToTensor()

mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
    
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

这段代码的主要目的是从 torchvision 库中下载并加载 Fashion - MNIST 数据集,同时对数据进行预处理,将图像转换为 PyTorch 张量。
代码主要分为三个部分:定义图像预处理操作、加载训练集数据、加载测试集数据。下面逐行进行详细解释。

1. 定义图像预处理操作

trans = transforms.ToTensor()

  • 功能:创建一个图像预处理的转换对象 transtransforms.ToTensor()torchvision.transforms 模块里的一个类,专门用于将 PIL(Python Imaging Library)图像或者 NumPy 数组(一般是 uint8 类型)转换为 torch.FloatTensor 类型的张量。
  • 转换细节
    - 在转换过程中,会把图像的像素值归一化到 [0.0, 1.0] 范围。例如,原始图像像素值范围是 [0, 255],经过该转换后,像素值会除以 255,变成 [0.0, 1.0] 之间的浮点数。
    - 同时,转换后张量的维度也会发生变化。对于单通道的灰度图像,会从 (H, W)(高度和宽度)变为 (1, H, W);对于三通道的彩色图像,会从 (H, W, C) 变为 (C, H, W),这里 C 代表通道数。

2. 加载训练集数据

mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_train,用于加载 Fashion - MNIST 数据集的训练集部分。
  • 参数解释
    - root="../data":指定数据集的存储路径。若该路径下没有数据集,下载的数据会存于此;若已存在,则直接从该路径加载数据。
    - train=True:表明要加载的是训练集数据。Fashion - MNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,通过此参数区分加载的是训练集还是测试集。
    - transform=trans:指定对图像数据进行的预处理操作。这里使用之前创建的 trans 对象,即对每个图像应用 ToTensor() 变换,将其转换为张量
    - download=True:如果指定路径下未找到数据集,会自动从网络下载 Fashion - MNIST 数据集。

3. 加载测试集数据

mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_test,用于加载 Fashion - MNIST 数据集的测试集部分。
  • 参数解释:与加载训练集的代码基本相同,唯一区别在于 train=False,表示加载的是测试集数据。

Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像
测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

len(mnist_train), len(mnist_test)

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。
数据集由灰度图像组成,其通道数为1。
为了简洁起见,将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w或( h h h, w w w)。

mnist_train[0][0].shape

在这里插入图片描述
[两个可视化数据集的函数]

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

列表推导式
[expression for item in iterable]

  • expression:对每个 item 进行操作后得到的结果,它将成为新列表中的一个元素。
  • item:从 iterable 中取出的单个元素。
  • iterable:一个可迭代对象,如列表、元组、字符串等。

示例代码

text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
               'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
labels = [0, 2, 4]
result = [text_labels[int(i)] for i in labels]
print(result)  # 输出: ['t-shirt', 'pullover', 'coat']

我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

子图坐标轴对象
在 matplotlib 中,一个图形(Figure)可以包含多个子图(Axes),每个子图就是一个独立的绘图区域,子图坐标轴对象(Axes 对象)就代表了这些独立的绘图区域。它可以被看作是一个 “画布”,你可以在这个 “画布” 上进行各种绘图操作,比如绘制线条、散点、柱状图等,还可以设置坐标轴的范围、标签、标题等。

以下是对 show_images 函数的详细解释:

  • def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    • 定义了一个名为 show_images 的函数,用于将一组图像以网格形式展示出来。
    • imgs:是一个包含图像的列表,这些图像可以是 PyTorch 张量,也可以是 PIL(Python Imaging Library)图像对象。
    • num_rows:指定了要展示的图像网格的行数。
    • num_cols:指定了要展示的图像网格的列数。
    • titles:是一个可选参数,类型为列表,用于为每个图像设置对应的标题。如果不提供该参数,则默认不显示标题。
    • scale:同样是可选参数,是一个浮点数,用于调整图像显示的缩放比例,默认值为 1.5。
  • figsize = (num_cols * scale, num_rows * scale):
    • 这行代码根据 num_cols(列数)、num_rows(行数)和 scale(缩放比例)计算出整个图像展示窗口的大小。
    • figsize 是一个元组,第一个元素是窗口的宽度,由列数乘以缩放比例得到;第二个元素是窗口的高度,由行数乘以缩放比例得到。
  • _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    • num_rowsnum_cols 分别指定了子图的行数和列数,也就是图像网格的布局。
    • figsize=figsize 表示使用之前计算好的窗口大小。
    • subplots 函数返回两个值,第一个是 Figure 对象,这里用 _ 占位表示我们不关心这个返回值;第二个是一个包含所有子图坐标轴对象的数组,赋值给 axes
  • axes = axes.flatten()
    • axes 原本是一个二维数组,因为它对应着 num_rows 行和 num_cols 列的子图布局。
    • flatten 方法将这个二维数组转换为一维数组,这样在后续遍历图像和子图时会更加方便。
  • for i, (ax, img) in enumerate(zip(axes, imgs))
    • zip(axes, imgs)axes 数组(包含所有子图坐标轴对象)和 imgs 列表(包含所有要展示的图像)中的元素一一对应地组合起来。
    • enumerate 函数用于为组合后的元素添加索引,i 就是当前元素的索引。
    • 在每次循环中,ax 代表当前子图的坐标轴对象,img 代表当前要展示的图像。
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
  • torch.is_tensor(img) 用于判断当前的 img 是否为 PyTorch 张量。
  • 如果是张量,使用 img.numpy() 将其转换为 NumPy 数组,因为 matplotlibimshow 函数更适合处理 NumPy 数组。然后使用 ax.imshow 函数在当前子图上显示图像。
  • 如果不是张量,说明 img 可能是 PIL 图像对象,直接使用 ax.imshow 函数显示该图像。
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
  • ax.axes.get_xaxis() 获取当前子图的 x 轴对象,set_visible(False) 方法将 x 轴设置为不可见。
  • 同理,ax.axes.get_yaxis() 获取当前子图的 y 轴对象,set_visible(False) 方法将 y 轴设置为不可见。这样可以使图像显示更加简洁,只专注于图像内容。
        if titles:
            ax.set_title(titles[i])
  • if titles: 检查是否提供了 titles 列表。
  • 如果提供了,使用 ax.set_title 方法为当前子图设置对应的标题,标题从 titles 列表中根据当前索引 i 取出。
    return axes
  • 最后,函数返回 axes 数组,这个数组包含了所有子图的坐标轴对象。返回它的目的是方便在调用该函数后,对图形进行进一步的操作,例如修改坐标轴属性等。

以下是训练数据集中前[几个样本的图像及其相应的标签]。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

在这里插入图片描述

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。
回顾一下,在每次迭代中,数据加载器每次都会[读取一小批量数据,大小为batch_size]。
通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4
#shuffle表示在每个训练周期开始时,对数据集进行随机打乱
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

我们看一下读取训练数据所需的时间。

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

整合所有组件

现在我们[定义load_data_fashion_mnist函数],用于获取和读取Fashion-MNIST数据集。
这个函数返回训练集和验证集的数据迭代器。
此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    #trans初始化为一个包含transforms.ToTensor()的列表
    
    if resize:
        trans.insert(0, transforms.Resize(resize))
        #在 trans 列表的开头插入 transforms.Resize(resize) 操作
        
    trans = transforms.Compose(trans)
    #将 trans 列表中的所有变换操作组合成一个完整的变换序列 trans
    
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
        
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)#X.shape表示张量 X 的形状,X.dtype表示张量 X 中元素的数据类型
    break

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2286497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【四川乡镇界面】图层shp格式arcgis数据乡镇名称和编码2020年wgs84无偏移内容测评

本文将详细解析标题和描述中提到的IT知识点,主要涉及GIS(Geographic Information System,地理信息系统)技术,以及与之相关的文件格式和坐标系统。 我们要了解的是"shp"格式,这是一种广泛用于存储…

ubuntu解决普通用户无法进入root

项目场景: 在RK3566上移植Ubuntu20.04之后普通用户无法进入管理员模式 问题描述 在普通用户使用sudo su试图进入管理员模式的时候报错 解决方案: 1.使用 cat /etc/passwd 查看所有用户.最后一行是 若无用户,则使用 sudo useradd -r -m -s /…

第3章 基于三电平空间矢量的中点电位平衡策略

0 前言 在NPC型三电平逆变器的直流侧串联有两组参数规格完全一致的电解电容,由于三电平特殊的中点钳位结构,在进行SVPWM控制时,在一个完整开关周期内,直流侧电容C1、C2充放电不均匀,各自存储的总电荷不同,电容电压便不均等,存在一定的偏差。在不进行控制的情况下,系统无…

网络工程师 (8)存储管理

一、页式存储基本原理 (一)内存划分 页式存储首先将内存物理空间划分成大小相等的存储块,这些块通常被称为“页帧”或“物理页”。每个页帧的大小是固定的,例如常见的页帧大小有4KB、8KB等,这个大小由操作系统决定。同…

实验一---典型环节及其阶跃响应---自动控制原理实验课

一 实验目的 1.掌握典型环节阶跃响应分析的基本原理和一般方法。 2. 掌握MATLAB编程分析阶跃响应方法。 二 实验仪器 1. 计算机 2. MATLAB软件 三 实验内容及步骤 利用MATLAB中Simulink模块构建下述典型一阶系统的模拟电路并测量其在阶跃响应。 1.比例环节的模拟电路 提…

【BQ3568HM开发板】如何在OpenHarmony上通过校园网的上网认证

引言 前面已经对BQ3568HM开发板进行了初步测试,后面我要实现MQTT的工作,但是遇到一个问题,就是开发板无法通过校园网的认证操作。未认证的话会,学校使用的深澜软件系统会屏蔽所有除了认证用的流量。好在我们学校使用的认证系统和…

PythonFlask框架

文章目录 处理 Get 请求处理 POST 请求应用 app.route(/tpost, methods[POST]) def testp():json_data request.get_json()if json_data:username json_data.get(username)age json_data.get(age)return jsonify({username: username测试,age: age})从 flask 中导入了 Flask…

【电工基础】1.电能来源,触电伤害,触电预防,触电急救

一。电能来源 1.电能来源 发电-》输电-》变电-》配电 2.分配电 一类负荷 如果供电中断会造成生命危险,造成国民经济的重大损失,损坏生产的重要设备以致使生产长期不能恢复或产生大量废品,破坏复杂的工艺过程,以及破坏大…

大数据学习之Kafka消息队列、Spark分布式计算框架一

Kafka消息队列 章节一.kafka入门 4.kafka入门_消息队列两种模式 5.kafka入门_架构相关名词 Kafka 入门 _ 架构相关名词 事件 记录了世界或您的业务中 “ 发生了某事 ” 的事实。在文档中 也称为记录或消息。当您向 Kafka 读取或写入数据时,您以事件的 形式执行…

SQL Server查询计划操作符(7.3)——查询计划相关操作符(5)

7.3. 查询计划相关操作符 38)Flow Distinct:该操作符扫描其输入并对其去重。该操作符从其输入得到每行数据时即将其返回(除非其为重复数据行,此时,该数据行会被抛弃),而Distinct操作符在产生任何输出前将消费所有输入。该操作符为逻辑操作符。该操作符具体如图7.2-38中…

单片机基础模块学习——NE555芯片

一、NE555电路图 NE555也称555定时器,本文主要利用NE555产生方波发生电路。整个电路相当于频率可调的方波发生器。 通过调整电位器的阻值,方波的频率也随之改变。 RB3在开发板的位置如下图 测量方波信号的引脚为SIGHAL,由上面的电路图可知,NE555已经构成完整的方波发生电…

ts 进阶

吴悠讲编程 : 20分钟TypeScript进阶!无废话快速提升水平 前端速看 https://www.bilibili.com/video/BV1q64y1j7aH

【C++】STL介绍 + string类使用介绍 + 模拟实现string类

目录 前言 一、STL简介 二、string类 1.为什么学习string类 2.标准库中的string类 3.auto和范围for 4.迭代器 5.string类的常用接口说明 三、模拟实现 string类 前言 本文带大家入坑STL,学习第一个容器string。 一、STL简介 在学习C数据结构和算法前,我…

【Redis】List 类型的介绍和常用命令

1. 介绍 Redis 中的 list 相当于顺序表,并且内部更接近于“双端队列”,所以也支持头插和尾插的操作,可以当做队列或者栈来使用,同时也存在下标的概念,不过和 Java 中的下标不同,Redis 支持负数下标&#x…

【愚公系列】《循序渐进Vue.js 3.x前端开发实践》033-响应式编程的原理及在Vue中的应用

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主&…

PETSc源码分析: Optimization Solvers

本文结合PETSc源代码,分析PETSc中的优化求解器。 注1:限于研究水平,分析难免不当,欢迎批评指正。 注2:文章内容会不定期更新。 参考文献 Balay S. PETSc/TAO Users Manual, Revision 3.22. Argonne National Labora…

面向对象设计(大三上)--往年试卷题+答案

目录 1. UML以及相关概念 1.1 动态图&静态图 1.2 交互图 1.3 序列图 1.4 类图以及关联关系 1.4.1类图 1.4.2 关系类型 (1) 用例图中的包含、扩展关系(include & extend) (2) 类图中的聚合、组合关系(aggragation & composition) 1.5 图对象以及职责划…

芯片AI深度实战:进阶篇之vim内verilog实时自定义检视

本文基于Editor Integration | ast-grep,以及coc.nvim,并基于以下verilog parser(my-language.so,文末下载链接), 可以在vim中实时显示自定义的verilog 匹配。效果图如下: 需要的配置如下: 系列文章: 芯片…

几种K8s运维管理平台对比说明

目录 深入体验**结论**对比分析表格**1. 功能对比****2. 用户界面****3. 多租户支持****4. DevOps支持** 细对比分析1. **Kuboard**2. **xkube**3. **KubeSphere**4. **Dashboard****对比总结** 深入体验 KuboardxkubeKubeSphereDashboard 结论 如果您需要一个功能全面且适合…

TikTok 推出了一款 IDE,用于快速构建 AI 应用

字节跳动(TikTok 的母公司)刚刚推出了一款名为 Trae 的新集成开发环境(IDE)。 Trae 基于 Visual Studio Code(VS Code)构建,继承了这个熟悉的平台,并加入了 AI 工具,帮助开发者更快、更轻松地构建应用——有时甚至无需编写任何代码。 如果你之前使用过 Cursor AI,T…