MNIST内置手写数字数据集的实现

news2024/11/27 6:17:47

torchvision库

torchivision库是PyTorch中用来处理图像和视频的一个辅助库,接下来我们就会使用torchvision库加载内置的数据集进行分类模型的演示

为了统一数据加载和处理代码,PyTorch提供了两个类用于处理数据加载,他们分别是torch.utils.data.Dataset类和torch.utils.data.DataLoader类,通过这两个类可使数据集加载和预处理代码与模型训练代码脱钩,从而获得更好的代码模块化和代码可读性。torchvision加载的内置图片数据集均继承自torch.utils.data.Dataset类,因此可直接使用加载的内置数据集创建DataLoader.

加载内置图片数据集

PyTorch的内置图片数据集均在torchvision.datasets模块下,包含Caltech、CelebA、CIFAR、Cityscapes、COCO、Fashion-MNIST、ImageNet、MNIST等很多著名的数据集,其中MNIAT数据集是手写数字数据集,这是一个很适合入门者学习使用的小型计算机视觉数据集,它包含0到9的手写数字图片和每一张图片对应的标签。接下来我们就以此数据集为例子进行学习。

import torchvision  # 导入torchvision库
from torchvision.transforms import ToTensor  #做好准备工作,导入所需要的包
import torch
import matplotlib.pyplot as plt
import numpy as np

首先就是对我们所需要的库进行导入。

我对上述的代码进行一下解读,首先导入了torchvision库,从torchvision.transforms模块下导入ToTensor类。torchvision.transforms模块包含了转换函数,使用它可以很方便的对加载的图形进行各种变换,这里用到的ToTensor类,该类的主要作用有以下3点。

  1. 将输入转换为张量
  2. 将读取图片的格式规范为(channel,height,width),这里和我们经常遇到的图片格式有可能会有一些去呗,PyTorch中的图片格式一般是通道数(channel)在前,然后是高度(height)和宽度(width)
  3. 将图片像素的取值范围归一化,规范为0到1的范围内
train_ds=torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds=torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)

通过torchvision.datasets.MNIST方法加载MNIST数据集,方法中的第一个参数为data/表示下载数据集存放的位置,参数train表示是否是训练数据,若为True,则加载训练数据集,若为False,则加载测试数据集;
使用参数transform表示对加载数据的预处理,参数值为ToTensor();
最后一个参数download=True表示将下载此数据集,一旦下载完成后,下一次执行此代码是,将优先从本地文件夹直接加载,如果咱们的计算机不能连接互联网,也可以直接将文件复制到data文件夹中,这样就能从本地直接加载数据了。

现在我们得到了两个数据集,分别是训练数据集和测试数据集,PyTorch还提供了torch.utils.data.DataLoader类用以对数据集做进一步的处理,DataLoader接收数据集,并执行复杂的操作,如小批次处理、多线程、随机打乱等,以便从数据集中获取数据。它接收来自用户的Dataset实例,并使用采样器策略将数据采样为小批次。DataLoader的目的如下

1.使用shuffle参数对数据集做乱序的操作,一般情况下,需要对训练数据集进行乱序的操作,因为原始的数据在样本均衡的情况下肯呢个是按照某种顺序进行排列的,经过顺序打乱之后,数据的排列就会拥有一定的随机性,这样做可以避免出现模型反复依次序学习数据的特征或者学习到的只是数据的次数特征的情况。

2.将数据采样为小批次,可用batch_size参数指定批次大小。首先单个样本训练有一个很大的缺点,就是损失和梯度会受到单个样本的影响,如果样本分布不均匀,或者有错误标注样本,则会引起梯度的巨大震荡,从而导致模型训练效果很差。为了解决这个问题,我们可以考虑使用批量数据训练(也叫做批量梯度下降算法),通过遍历全部数据集算一次损失函数,然后计算损失对各个参数的梯度,并更新参数。这种训练方式没更新一次,参数都要把数据集里所有样本都看一遍,不仅计算开销大,而且计算速度慢。为了克服上述方法的缺点,一般采用的是一种折中手段进行损失函数计算:即把数据分为若干个小的批次,按批次来更新参数,这样,一个批次中的一组数据共同决定了本次梯度的方向,大大降低了参数更新时的梯度方差,下降起来更加稳定,减少了随机性,与单样本训练相比,小批次训练可利用矩阵操作进行有效的梯度计算,计算量也不是很大,对计算机内存的要求也不高。

3.可以充分利用多个子进程加速数据预处理。num_workers参数可以指定子进程的数量

4.可通过collate_fn参数传递批次数据的处理函数,实现在DataLoader中对批次数据做转换处理

train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=46)

上面代码中分别创建了训练数据和测试数据的DataLoader,并设置他们的批次大小为64,对训练数据设置了shuffle为True;对测试数据,由于仅仅作为测试,没必要做乱序。

DataLoader是可迭代对象,我们观察它返回的数据集的类型,给大家对对DataLoader和MNIST数据集有一个直观的印象

imgs,labels=next(iter(train_dl))#创建生成器,并用next方法返回一个批次的数据
print(imgs.shape)
print(labels.shape)

我们使用iter方法将DataLoader对象创建为生成器,并使用next方法反悔了一个批次的图像(imgs)和对应的一个批次的标签(labels),image.shape为torch.Size([64,1,28,28]),这里的64是批次,我们可以认为这代表64张形状为(1,28,28)的图片,其中1为通道数,28和28分别为高和宽;既然这里有64张图片,那么就对应着有64个标签,也就是labels.shape所显示的torch.Size([64])

结果绘制

# 我们使用Matplotlib来绘制一下前10张的图片
plt.figure(figsize=(20,2))  # 创建一个(10,1)大小的画布
for i,img in enumerate(imgs[:20]):
    npimg=img.numpy()  # 将张量转换为ndarray
    npimg=np.squeeze(npimg)  # 图片形状由(1,28,28)转换为(28,28)
    plt.subplot(1,20,i+1)  # 初始化子图,3个参数表示1行10列的第i+1个子图
    plt.imshow(npimg)  #在子图中绘制单张图片
    plt.axis('off')  # 关闭显示子图坐标

plt.imshow() 是一个用于显示图像的函数,通常用于在 Python 中使用 Matplotlib 库绘制图像。它可以接受一个数组或图像数据,并将其显示为图像。这个函数通常用于可视化图像数据,比如热图、灰度图、彩色图等。plt.imshow() 可以接受一些参数,比如 cmap(颜色映射)、interpolation(插值方法)等,用来控制图像的显示效果。

接下来,我们打印对应的标签

print(labels[:20])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1318765.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Typescript中Omit数据类型的理解

在 TypeScript 中&#xff0c;Omit 是一个内置的工具类型&#xff0c;它用于从对象类型中排除指定的属性&#xff0c;并返回剩余的属性。 Omit 的语法如下所示&#xff1a; type Omit<T, K> Pick<T, Exclude<keyof T, K>>;其中&#xff0c;T 表示原始类型…

Leetcode刷题笔记题解(C++):224. 基本计算器

思路&#xff1a; step 1&#xff1a;使用栈辅助处理优先级&#xff0c;默认符号为加号。 step 2&#xff1a;遍历字符串&#xff0c;遇到数字&#xff0c;则将连续的数字字符部分转化为int型数字。 step 3&#xff1a;遇到左括号&#xff0c;则将括号后的部分送入递归&#x…

面向对象三大特征——继承

目录 1. 概述 2. 继承的限制 2.1 单继承 2.2 访问修饰符 2.3 . final 3. 重写 4. super 4.1super的作用 4.2访问父类的成员和被重写方法 4.3调用父类的构造器 1. 概述 多个类中存在相同属性和行为时&#xff0c;将这些内容抽取到单独一个类中&#xff0c;那么就无需在…

DeciLM-7B:突破极限,高效率、高精准度的70亿参数AI模型

引言 在人工智能领域&#xff0c;语言模型的发展速度令人瞩目。Deci团队最近推出了一款具有革命性意义的语言模型——DeciLM-7B。这款模型在速度和精确度上都实现了显著的突破&#xff0c;以其70亿参数的规模&#xff0c;在语言模型的竞争中脱颖而出。 Huggingface模型下载&am…

C# 基本桌面编程(二)

一、前言 本章为C# 基本桌面编程技术的第二节也是最后一节。前一节在下面这个链接 C# 基本桌面编程&#xff08;一&#xff09;https://blog.csdn.net/qq_71897293/article/details/135024535?spm1001.2014.3001.5502 二、控件布局 1 叠放顺序 在WPF当中布局&#xff0c;通…

我与Datawhale的故事之长篇

Datawhale成员 作者&#xff1a;Datawhale团队成员 前 言 上周五周年文章发出后大家反响比较热烈&#xff1a; 我们与Datawhale背后的故事&#xff01; 本期给大家带来三篇长篇回忆 胡锐峰 我与Datawhale的故事 题记&#xff1a;我和你的相遇就像春风拂面&#xff0c;就像夏雨…

[原创][R语言]股票分析实战[2]:周级别涨幅趋势的相关性

[简介] 常用网名: 猪头三 出生日期: 1981.XX.XX QQ联系: 643439947 个人网站: 80x86汇编小站 https://www.x86asm.org 编程生涯: 2001年~至今[共22年] 职业生涯: 20年 开发语言: C/C、80x86ASM、PHP、Perl、Objective-C、Object Pascal、C#、Python 开发工具: Visual Studio、D…

UE5 C++(三)— 基本用法(生命周期、日志、基础变量)

文章目录 生命周期日志打印Outlog打印屏幕打印 基础变量类型FString、FName 和 FText&#xff0c;三者之间的区别 基础数据类型打印 忘记说了每次在Vscode修改后C脚本后&#xff0c;需要编译一下脚本&#xff0c;为了方便我是点击这里编译脚本 生命周期 Actor 生命周期官方文档…

20--Set集合

1、Set集合 1.1 Set集合概述 java.util.Set接口和java.util.List接口一样&#xff0c;同样继承自Collection接口&#xff0c;它与Collection接口中的方法基本一致&#xff0c;并没有对Collection接口进行功能上的扩充&#xff0c;只是比Collection接口更加严格了。与List接口…

wordpress安装之正式开始安装wordpress

1、拉取wordpress镜像 docker pull wordpress 2、启动容器 启动容器&#xff0c;设置容器名为wordpress2并把80端口映射到宿主机的9988端口 docker run -it --name wordpress2 -p 9988:80 -d wordpress 3、查看容器状态 docker ps 4、安装wordpress博客程序 因为我们前面启…

SLAM算法与工程实践——相机篇:传统相机使用(3)

SLAM算法与工程实践系列文章 下面是SLAM算法与工程实践系列文章的总链接&#xff0c;本人发表这个系列的文章链接均收录于此 SLAM算法与工程实践系列文章链接 下面是专栏地址&#xff1a; SLAM算法与工程实践系列专栏 文章目录 SLAM算法与工程实践系列文章SLAM算法与工程实践…

关于找不到XINPUT1_3.dll,无法继续执行代码问题的5种不同解决方法

一、xinput1_3.dll的作用 xinput1_3.dll是Windows操作系统中的一款动态链接库文件&#xff0c;主要用于支持游戏手柄和游戏输入设备。这款文件属于Microsoft Xbox 360兼容性库&#xff0c;它包含了与游戏手柄和其他输入设备相关的功能。在游戏中&#xff0c;xinput1_3.dll负责…

计算机操作系统-第十八天

目录 进程调度时机 补充知识 进程调度的方式 非剥夺调度方式 剥夺调度方式 进程的切换与过程 本节思维导图 进程调度时机 进程调度&#xff08;低级调度&#xff09;&#xff0c;即按照某种算法从就绪队列中选择一个进程为其分配处理机。 共有两种需要进行进程调度与…

CCNP课程实验-OSPF-CFG

目录 实验条件网络拓朴需求 配置实现基础配置1. 配置所有设备的IP地址 实现目标1. 要求按照下列标准配置一个OSPF网络。 路由协议采用OSPF&#xff0c;进程ID为89 &#xff0c;RID为loopback0地址。3. R4/R5/R6相连的三个站点链路OSPF网络类型配置成广播型&#xff0c;其中R5路…

PMP项目管理 - 资源管理

系列文章目录 PMP项目管理 - 质量管理 PMP项目管理 - 采购管理 PMP项目管理 - 资源管理 PMP项目管理 - 风险管理 现在的一切都是为将来的梦想编织翅膀&#xff0c;让梦想在现实中展翅高飞。 Now everything is for the future of dream weaving wings, let the dream fly in…

DISC-MedLLM—中文医疗健康助手

文章目录 DISC-MedLLM 项目介绍数据集构建重构AI医患对话知识图谱生成问答对医学图谱构建图谱生成QA对 人类偏好引导的对话样例其他数据MedMCQA通用数据 模型微调评估评估方式评估结果 总结 DISC-MedLLM 项目介绍 DISC-MedLLM 是一个专门针对医疗健康对话式场景而设计的医疗领…

「斗破年番」小医仙黑皇城遭调戏,五品丹换药材,获取菩提涎消息

Hello,小伙伴们&#xff0c;我是拾荒君。 《斗破苍穹年番》的第75集已经更新了&#xff0c;喜欢这部国漫的小伙伴应该都去观看了吧&#xff0c;拾荒君也是看了看这一集。在这一集中&#xff0c;萧炎成功地帮助吴昊等人摆脱了鹰爪老人的围困&#xff0c;然后便前往了黑皇城。 黑…

openGauss学习笔记-163 openGauss 数据库运维-备份与恢复-导入数据-使用COPY FROM STDIN导入数据-简介

文章目录 openGauss学习笔记-163 openGauss 数据库运维-备份与恢复-导入数据-使用COPY FROM STDIN导入数据-简介163.1 关于COPY FROM STDIN导入数据163.2 CopyManager类简介163.2.1 CopyManager的继承关系163.2.2 构造方法163.2.3 常用方法 openGauss学习笔记-163 openGauss 数…

torch中张量与数据类型的介绍

PyTorch张量的定义介绍 PyTorch最基本的操作对象是张量&#xff0c;它表示一个多维数组&#xff0c;类似NumPy的数组&#xff0c;但是前者可以在GPU上加速计算 初始化张量 ttorch.tensor([1,2]) # 创建一个张量 print(t) t.dtype #打印t的数据类型为torch.int…

Vue 指定class区域增加水印显示(人员姓名+时间)

效果 代码&#xff0c;存放位置 /utils/waterMark.js //waterMark.js文件let waterMark {}let setWaterMark (str,str1) > {let id 1.23452384164.123412416;if (document.getElementById(id) ! null) {//ui-table是table上的一个样式&#xff0c;一般水印显示在这个tab…