pytorch基础实践-数据与预处理

news2025/1/11 8:18:59

文章目录

  • 数据集
    • Fashion-MNIST 数据集
  • 数据预处理
    • 包的导入
    • 在Pytorch中进行 ETL
      • 利用torchvison包获取和处理数据集(E+T)
    • 访问数据集
      • 访问和查看 train_set 中的单个数据
      • 利用 DataLoader 成批访问数据

数据集

Fashion-MNIST 数据集

  • MNIST
    MNIST,Modified National Institute of Standards and Technology database,前面加了“modified ”是因为这个数据集已经是在原始的 NIST 数据集上修改过的版本。

    简单来说 MNIST 就是一个包含了 0-9 十个数字(十个类别)的手写图片数据集,都是灰度图片,每张图片 28x28 像素,每个类别 7000 张图片,一共 70000 张。并且划分了 60000 张图片作为训练集,10000 张图片作为测试集。
    在这里插入图片描述
    MNIST 在图像分类领域非常流行,主要有两个原因:一是这个数据集特别简单,适合新手上手;二是学术圈为了比较各自的算法优劣,会在相同的数据集上训练算法,就是 MNIST。

    MNIST 也有它的问题就是太简单了(图像分类领域的“hello world”),所以有一帮人就开发了 Fashion-MNIST 想要来取代 MNIST。

  • Fashion-MNIST
    Fashion-MNIST 是一个德国的时装公司 Zalando 下面的研究院 Zalando Research 开发的,它用10类服装的图片取代了十类手写数字图片。十个类别分别是:
    在这里插入图片描述
    Fashion-MNIST 的设计理念就是作为 MNIST 的直接取代(direct dropin replacement),就是说以前使用 MNIST 的模型,除了数据集的链接(URL),其他什么都不用改。但替换之后的图像分类有了更高的难度。所以 Fashion-MNIST 和 MNIST 一样,都是灰度图片,28x28 像素,每类 7000 张,一共 70000 张,其中训练集 60000 张,测试集 10000 张。数据集链接
    Fashion-MNIST 是直接从 Zalando 网站上的商品图片提取出来制作的,包括以下7步:① 转换为PNG;② 裁剪;③ 长边缩放为28像素;④ 锐化;⑤ 补足空白;⑥ 取负片;⑦ 取灰度。在这里插入图片描述

数据预处理

通过 PyTorch 的 torchvision 包获取 Fashion-MNIST 数据集。

一般而言,对一个数据集的预处理流程为 ETL,即包含 extract、transform、load 三个步骤。

1.Extract - Get the Fashion-MNIST image data from the data source.Transform - 2.Transform image data into a desirable PyTorch tensor format.
3.Load - Put data into a suitable structure to make it easily accessible.

完成 ETL 流程之后,就可以开始构建和训练深度学习模型。

包的导入

需要把所有需要的PyTorch包导入:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

对各个包的描述如下:

  • torch - The top-level PyTorch package and tensor library.
  • torch.nn - A subpackage that contains modules and extensible classes for building neural networks.
  • torch.optim - A subpackage that contains standard optimization operations like SGD and Adam.
  • torch.nn.functional - A functional interface that contains typical operations used for building neural networks like loss functions and convolutions.
  • torchvision - A package that provides access to popular datasets, model architectures, and image transformations for computer vision.
  • torchvision.transforms - An interface that contains common transforms for image processing.

在Pytorch中进行 ETL

对于 ETL 流程,PyTorch 提供了两个类(class):
在这里插入图片描述
使用 PyTorch 创建自定义的数据集,我们通过创建子类并继承 Dataset 中的函数,来实现 Dataset 的扩展,然后就可以传递给 DataLoader 对象。
在这里插入图片描述
len() 和 getitm() 是其中两个必要的函数,前者的功能是计算数据集的长度,后者的功能是在数据集中按指定的索引编号将数据取出。

利用torchvison包获取和处理数据集(E+T)

利用 torchvision 获取并创建 Fashion-MNIST 数据集的一份实例(instance),这个过程中同时完成的数据集的获取(E)和转化(T),代码如下:

train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

在这里插入图片描述
参数解释如下:
在这里插入图片描述
因为希望将图片数据集转换为张量,所以在 transform 中使用了 transforms.ToTensor();将此数据集命名为train_set,是因为我们希望将其作为训练数据;另外数据集仅会被下载一次,程序下载之前会检查本地有没有。

之后将获取的 train_set 打包给 DataLoader,使得数据集可以通过 DataLoader 方便的访问和加载(L):

train_loader = torch.utils.data.DataLoader(train_set)

至此已经完成了数据集的 Extract(利用url从网页下载)和 Transform(上面的transforms.ToTensor()),并且已经打包给了 DataLoader,可以通过 DataLoader 来实现 Load,比如设置 batch_size 和 shuffle:

train_loader = torch.utils.data.DataLoader(train_set
    ,batch_size=1000
    ,shuffle=True
)

访问数据集

首先可以查看数据集中有多少个图片,使用 Python 的 len() 函数:

 >len(train_set)
 60000

查看所有图片的标签,只需要访问 train_set.targets 属性:

> train_set.targets
tensor([9, 0, 0, ..., 3, 0, 5])

如果希望查看数据集中每一个类别有多少个标签(即多少个图片,适用于图片全部有标记的情况),可以用 PyTorch 的 bincount() 函数:

>train_set.targets.bincount()
tensor([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000])

Fashion-MNIST 数据集中每一类都有 6000 个图片和标签对,这种每一类的样本数量相等的数据集称作 balanced dataset,反之类别之间样本数量不一致的数据集称为 unbalanced dataset。

访问和查看 train_set 中的单个数据

一次只查看单张图片,首先将 train_set 这个对象传递给 Python 内建的 iter() 函数,它会返回一个可以在其上迭代的代表数据流(stream of data)的对象,使我们可以沿数据流访问数据。

接下来再使用 Python 的内建函数 next() 来获取数据流中的下一个数据元素,如此就可以获取数据集中的一个单独数据(因此下面命名变量都是单数形式):

> sample = next(iter(train_set))

> len(sample)
2

> type(sample)
tuple

获取的一个单独数据长度为 2,这是因为数据集是由图片-标签对的形式组成的,每一个 data element 中都包含两个东西,一个是存储图片数据的张量,另一个是其对应的标签。

sample 的数据类型是 tuple,tuple 是Python中的一种 sequence types,是一个可以迭代的顺序不可变的数据序列。
可以用 sequence unpacking 来将其中的图像和标签分别提取出来:

> image, label = sample

和下面这种写法是等效的:

> image = sample[0]
> label = sample[1]

查看数据类型和shape:

> type(image)
torch.Tensor

> type(label)
int

> image.shape
torch.Size([1, 28, 28]) 

> torch.tensor(label).shape
torch.Size([])

Fashion-MNIST 数据集是单通道的灰度图,所一张图片的 tensor shape 就是 1x28x28。把没有用的颜色通道 squeeze 掉:

> image.squeeze().shape
torch.Size([28, 28])

显示出图片和标签:

> plt.imshow(image.squeeze(), cmap="gray")
> torch.tensor(label)
tensor(9)

在这里插入图片描述
标签是“9”,代表靴子,与图片是相符的。

利用 DataLoader 成批访问数据

> batch = next(iter(train_loader))

> len(batch)
2

> type(batch)
list

list 也是一种 Python sequence types,与 tuple 的不同在于 list 是可变序列。

一次访问 10 张图片,则需要给 DataLoader 指定 batch_size:

> display_loader = torch.utils.data.DataLoader(
    train_set, batch_size=10
)

关于 DataLoader 中的“shuffle=True”:如果“shuffle=True”,则每次调用 next() 返回的 batch 都会不同,训练集中的第一组样本将在第一次调用 next() 时返回,这个功能默认是 False。

可以像上面一样对 display_loader 使用 iter() 和 next() 来每次查看 10 张图片:

> batch = next(iter(display_loader))
> print('len:', len(batch))
len: 2

进行 sequence unpacking:

> images, labels = batch

> print('types:', type(images), type(labels))
> print('shapes:', images.shape, labels.shape)
types: <class 'torch.Tensor'> <class 'torch.Tensor'>
shapes: torch.Size([10, 1, 28, 28]) torch.Size([10])

此时返回的图像张量是 [10, 1, 28, 28] 的四阶张量,标签是一个长度为 10 的一阶张量。可以单独查看其中每一个图片和标签:

> images[0].shape
torch.Size([1, 28, 28])

> labels[0]
tensor(9)

一次绘制一批图像,可以使用 torchvision.utils.make_grid() 函数创建一个可以按网格绘制图片的 grid:

> grid = torchvision.utils.make_grid(images, nrow=10)    # nrow指定每行多少列图片

> plt.figure(figsize=(15,15))        # 缩放图像显示大小?
> plt.imshow(grid.permute(1,2,0))    # 这一步让grid符合imshow的要求,不清楚细节

> print('labels:', labels)
labels: tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])\

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/912671.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day1 链表专题 牛客TOP100 BM 1-10

文章目录 链表BM1 反转链表BM2 链表内指定区间反转BM3 链表中的节点每k个一组翻转BM4 合并两个排序的链表BM5 合并k个已排序的链表BM6 判断链表中是否有环BM7 链表中环的入口结点BM8 链表中倒数最后k个结点BM9 删除链表的倒数第n个节点BM10 两个链表的第一个公共结点 链表 BM1…

ssm+vue绿色农产品推广应用网站源码和论文PPT

ssmvue绿色农产品推广应用网站041 开发工具&#xff1a;idea 数据库mysql5.7 数据库链接工具&#xff1a;navcat,小海豚等 技术&#xff1a;ssm 摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高…

浅谈限流式保护器在电气线路火灾中的应用

安科瑞 华楠 电气线路起火的主要原因 1.线路短路 所谓短路就是交流电路的两根导线互相触碰&#xff0c;电流不经过线路中的用电设备&#xff0c;而直接形成回路。由于电线本身的电阻比较小&#xff0c;若仅是通过电线这个回路&#xff0c;电流就会急剧变大&#xff0c;比正常情…

HAProxy的配置与搭建

Haproxy概念 HAProxy是可提供高可用性、负载均衡以及基于TCP和HTTP应用的代理&#xff0c;是免费、快速并且可靠的一种解决方案。HAProxy非常适用于并发大&#xff08;并发量达1w以上&#xff09;web站点&#xff0c;这些站点通常又需要会话保持或七层处理。HAProxy的运行模式…

MySQL创建表报错

CREATE TABLE IF NOT EXISTS nhooo_b1 (nhooo_id INT UNSIGNED AUTO_INCREMENT,nhooo_title VARCHAR(100) NOT NULL,nhooo_author VARCHAR(40) NOT NULL,submission_date DATE,PRIMARY KEY (nhooo_id) ) ENGINEINNODB DEFAULT CHARSETutf8;创建表始终报以下错误&#xff1a; 这…

Linux操作系统调度基本准则和实现

今天分享一篇处理器调度相关的理论介绍文章。 1&#xff0c;基本概念 在多道程序系统中&#xff0c;进程的数量往往多于处理机的个数&#xff0c;进程争用处理机的情况就在所难免。处理机调度是对处理机进行分配&#xff0c;就是从就绪队列中&#xff0c;按照一定的算法&…

k8s-ingress-context deadline exceeded

报错&#xff1a; rancher-rke-01:~/rke # helm install rancher rancher-latest/rancher --namespace cattle-system --set hostnamewww.rancher.local Error: INSTALLATION FAILED: Internal error occurred: failed calling webhook "validate.nginx.ingress.kube…

【使用Node.js搭建自己的HTTP服务器】

文章目录 前言1.安装Node.js环境2.创建node.js服务3. 访问node.js 服务4.内网穿透4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5.固定公网地址 前言 Node.js 是能够在服务器端运行 JavaScript 的开放源代码、跨平台运行环境。Node.js 由 OpenJS Foundation&#xff0…

二叉搜索树的(查找、插入、删除)

一、二叉搜索树的概念 二叉搜索树又称二叉排序树&#xff0c;它或者是一棵空树&#xff0c;或者是具有以下性质的二叉树: 1、若它的左子树不为空&#xff0c;则左子树上所有节点的值都小于根节点的值&#xff1b; 2、若它的右子树不为空&#xff0c;则右子树上所有节点的值都…

使用rook搭建Ceph集群

宿主机&#xff1a; MacBook Pro&#xff08;Apple M2 Max&#xff09; VMware Fusion Player 版本 13.0.2 VM软硬件&#xff1a; ubuntu 22.04.2 4核 CPU&#xff0c;5G 内存&#xff0c;40G硬盘 *每台机器分配硬件资源很重要&#xff0c;可以适当超过宿主机的资源量&am…

张驰咨询:有效导入精益生产咨询,企业提升竞争力的关键

精益生产是一种源于日本的先进生产管理理念&#xff0c;旨在通过消除生产过程中的浪费&#xff0c;提高生产效率和质量&#xff0c;降低成本&#xff0c;从而提升企业的竞争力。在我国&#xff0c;越来越多的企业开始尝试导入精益生产咨询&#xff0c;但效果并不尽如人意。为了…

关于slot-scope已经废弃的问题

说起来啊&#xff0c;这个问题啊&#xff0c;我之前一直没关注&#xff0c;还是webstorm给我的警告。 因为使用了element-ui的组件库&#xff0c;所以在使用组件的时候往往就cv大法了&#xff0c;直到今天用webstorm写代码是&#xff0c;提示了如下的错误 我这一看&#xff0c…

伦敦金短线好还是长线好

在伦敦金投之中&#xff0c;长期有一个争论很久的问题&#xff0c;那就是伦敦金投资究竟是长线好还是短线好&#xff1f;不同的投资者对这个问题有不同的看法&#xff0c;一般认为&#xff0c;伦敦金投资比较适合短线交易。笔者也将讨论这个问题&#xff0c;看看伦敦金投资是不…

《网络是怎样连接的》(四)

本文主要取材于 《网络是怎样连接的》 第四章。 目录 4.1 互联网的基本结构 4.2光纤接入网&#xff08;FTTH&#xff09; 4.3 接入网中使用的PPP和隧道 4.4 网络运营商的内部 4.5 跨越运营商的网络包 简述&#xff1a;本文主要内容是解释 网络包是如何通过互联网接入路由…

svg mask和stroke冲突问题

目录 先说结论各种样例首先是水平、垂直的线然后是斜线如果是图形加stroke呢用《g》标签包起来呢 总结 先说结论 实际上svg里&#xff0c;mask对svg内元素起作用的并非元素本身&#xff0c;而是元素几何形状的外包矩形&#xff0c;特别是和stroke有冲突&#xff0c;会产生奇怪…

opencv 进阶16-基于FAST特征和BRIEF描述符的ORB(图像匹配)

在计算机视觉领域&#xff0c;从图像中提取和匹配特征的能力对于对象识别、图像拼接和相机定位等任务至关重要。实现这一目标的一种流行方法是 ORB&#xff08;Oriented FAST and Rotated Brief&#xff09;特征检测器和描述符。ORB 由 Ethan Rublee 等人开发&#xff0c;结合了…

工作7年的测试员,明白了如何正确的“卷“

背景 近两年&#xff0c;出台和落地的反垄断法&#xff0c;明确指出要防止资本无序扩张。 这也就导致现在的各大互联网公司&#xff0c;不能再去染指其他已有的传统行业&#xff0c;只能专注自己目前存量的这些业务。或者通过技术创新&#xff0c;开辟出新的行业。 但创新这种…

vmware 虚拟机开机自启动脚本

1、建立一个txt文件 D:\VMware\VMware Workstation\vmrun.exe -T ws start "I:\Documents\Virtual Machines\centos\centos.vmx" nogui 注意&#xff1a;如果路径中有中文需要先转换txt文件编码格式ANSI 2、设置bat开机自启动 winr shell:startup 复制文本文件到…

【uniapp】微信小程序 , 海报轮播图弹窗,点击海报保存到本地,长按海报图片分享,收藏或保存

uivew 2.0 uniapp 海报画板 DCloud 插件市场 第一步&#xff0c;下载插件并导入HbuilderX 第二步&#xff0c;文件内 引入 海报组件 <template><painter ref"haibaorefs"></painter> <template> <script>import painter from /comp…

Docker关于下载,镜像配置,容器启动,停止,查看等基础操作

系列文章目录 文章目录 系列文章目录前言一、安装Docker并配置镜像加速器二、下载系统镜像&#xff08;Ubuntu、 centos&#xff09;三、基于下载的镜像创建两个容器 &#xff08;容器名一个为自己名字全拼&#xff0c;一个为首名字字母&#xff09;四、容器的启动、 停止及重启…