PyTorch翻译官网教程3-DATASETS DATALOADERS

news2024/11/20 14:37:37

官网链接

Datasets & DataLoaders — PyTorch Tutorials 2.0.1+cu117 documentation

数据集和数据加载器

处理样本数据的代码可能会变得混乱并且难以维护。理想情况下,我们希望我们的数据集代码与模型训练代码解耦,以获得更好的可读性和模块化。PyTorch提供了两个数据源:torch.utils.data.DataLoader和torch.utils.data.Dataset,它们允许你使用预加载的数据集和你自己的数据集。Dataset存储样本及其相应的标签,DataLoader在Dataset之上包装一个可迭代对象,以便于访问样本。

加载数据集

下面是一个如何从TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST是Zalando文章图像的数据集,由60,000个训练样例和10,000个测试样例组成。每个示例都包含一个28×28灰度图像和来自10个类之一的关联标签。

我们用以下参数加载FashionMNIST数据集:

  • root 是存储训练/测试数据的路径
  • train 指定训练或者测试数据集
  • download=True 如果在root目录下不可用,是否从互联网上下载数据
  • transform 和 target_transform 指定特征和标签的转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

输出

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 65536/26421880 [00:00<01:12, 363415.33it/s]
  1%|          | 229376/26421880 [00:00<00:38, 679946.01it/s]
  3%|2         | 753664/26421880 [00:00<00:12, 2072129.51it/s]
  7%|6         | 1802240/26421880 [00:00<00:06, 3878939.55it/s]
 16%|#6        | 4358144/26421880 [00:00<00:02, 9473273.38it/s]
 25%|##4       | 6553600/26421880 [00:00<00:01, 10918007.48it/s]
 34%|###4      | 9011200/26421880 [00:01<00:01, 14051286.10it/s]
 43%|####3     | 11370496/26421880 [00:01<00:01, 14100224.63it/s]
 52%|#####2    | 13762560/26421880 [00:01<00:00, 16167133.39it/s]
 61%|######1   | 16187392/26421880 [00:01<00:00, 15640376.47it/s]
 70%|#######   | 18612224/26421880 [00:01<00:00, 17384518.47it/s]
 80%|#######9  | 21069824/26421880 [00:01<00:00, 16443689.83it/s]
 89%|########8 | 23429120/26421880 [00:01<00:00, 17854523.89it/s]
 98%|#########8| 25952256/26421880 [00:01<00:00, 16957283.54it/s]
100%|##########| 26421880/26421880 [00:02<00:00, 13185733.62it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 327080.02it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|1         | 65536/4422102 [00:00<00:12, 361139.75it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 678952.39it/s]
 19%|#9        | 851968/4422102 [00:00<00:01, 2356375.68it/s]
 44%|####3     | 1933312/4422102 [00:00<00:00, 4134961.37it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6052787.05it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 43184553.98it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw


迭代和可视化数据集

我们可以像列表一样手动索引数据集:training_data[index]。我们可以使用matplotlib来可视化训练数据中的一些样本。

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

 

通过文件创建自定义数据集

自定义Dataset类必须实现三个函数:__init__, __len__和__getitem__。FashionMNIST图像存储在img_dir目录中,它们的标签单独存储在CSV文件annotations_file中。

在接下来的部分中,我们将分解这些函数中发生的事情。

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


 

__init__

__init__函数在实例化Dataset对象时运行一次。我们初始化的目录包含图像、注释文件和两个transforms(下一节将详细介绍)

label .csv文件看起来像这样:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

__len__函数返回数据集中的样本数。

示例:

def __len__(self):
    return len(self.img_labels)

__getitem__

__getitem__函数使用给定的索引idx,从数据集中加载并返回一个样本。基于索引,它识别图像在磁盘上的位置,使用read_image函数将其转换为张量,从self.img_labels中的CSV数据中检索相应的标签。调用它们的transform函数(如果可用)。并在元组中返回张量图像和相应的标签。

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

使用DataLoaders准备训练数据集

Dataset每次检索我们数据集中一个样本的特征和标签。在训练模型时,我们通常希望以“小批量”的方式传递样本,在每个epoch中重新洗数据以减少模型过拟合,并使用Python的多进程来加速数据检索。

DataLoader是一个可迭代对象,它用一个简单的API为我们抽象了复杂性。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

遍历DataLoader

我们已经将该数据集加载到DataLoader中,并且可以根据需要迭代该数据集。每次迭代都返回一批train_features和train_labels(分别包含batch_size=64个特征和标签)。因为我们指定了shuffle=True,所以在遍历所有批次之后,将对数据进行清洗(要对数据加载顺序进行更细粒度的控制,请查看样例)

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

输出

 

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/680626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

轻松了解工作与学习必备的版本控制+Git,全程舒适~

目录 一、版本控制 二、版本控制器 三、Git 四、项目实操 第一步 在github上创建一个新的远程仓库 第二步 克隆到本地文件夹 第三步 IDEA&#xff08;PyCharm为例&#xff09;集成Git 一、版本控制 概念&#xff1a;版本控制是指对软件开发过程中各种程序代码、配置文件…

【spring cloud学习】4、创建服务提供者

注册中心Eureka Server创建并启动之后&#xff0c;接下来介绍如何创建一个Provider并且注册到Eureka Server中&#xff0c;再提供一个REST接口给其他服务调用。 首先一个Provider至少需要两个组件包依赖&#xff1a;Spring Boot Web服务组件和Eureka Client组件。如下所示&…

ADRC自抗扰控制(CODESYS平台完整源代码)

博途PLC ADRC完整源代码请参考下面文章链接: 博途PLC ADRC自抗扰控制完整SCL源代码_adrc控制算法代码_RXXW_Dor的博客-CSDN博客关于自抗扰控制框图可以参看专栏的其它文章,这里不再讲解具体算法过程,详细了解也可以参看韩京清研究员写的 《ADRC自抗扰》一书。_adrc控制算法…

基于混合策略的改进哈里斯鹰优化算法-附代码

基于混合策略的改进哈里斯鹰优化算法 文章目录 基于混合策略的改进哈里斯鹰优化算法1.哈里斯鹰优化算法2.改进哈里斯鹰优化算法2.1 Sobol 序列初始化种群2.2 limit 阈值执行全局搜索阶段2.4 动态反向学习 3.实验结果4.参考文献5.Matlab代码6.python代码 摘要&#xff1a;针对原…

ElasticSearch-Kibana的安装

Kibana的安装 什么是ELK? ELK是Elasticsearch,Logstash,Kibana三大开源框架首字母大写简称,ELK属于大数据,是拆箱即用的,上手比较快 什么是Kibana? Kibana是一个针对ES的开源分析以及可视化平台,用来搜索,查看交互存储在ES索引中的数据,使用Kibana可以通过各类图标进行高级…

Flink(1)-概述

1.1 Apache Flink是什么&#xff1f; 在当前数据量激增的时代&#xff0c;各种业务场景都有大量的业务数据产生&#xff0c;对于这些不断产生的数据应该如何进行有效的处理&#xff0c;成为当下大多数公司所面临的问题。目前比较流行的大数据处理引擎Apache Spark&#xff0c;…

SpringBoot第14讲:SpringBoot 如何统一异常处理

SpringBoot第14讲&#xff1a;SpringBoot 如何统一异常处理 本文是SpringBoot第14讲&#xff0c;SpringBoot接口如何对异常进行统一封装&#xff0c;并统一返回呢&#xff1f;以上文的参数校验为例&#xff0c;如何优雅的将参数校验的错误信息统一处理并封装返回呢 文章目录 Sp…

诊断测试工具CANoe.DiVa从入门到精通系列——开门见山

我是穿拖鞋的汉子,魔都中坚持长期主义的工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 人们会在生活中不断攻击你。他们的主要武器是向你灌输对自己的怀疑:你的价值、你的能力、你的潜力。他们往往会将此伪装成客观意见,但无一例外的是,他们想…

网络安全就业前景如何?是否还能入行?

网络安全专业是2015年新设立的专业&#xff0c;作为新兴专业吸引了很多人准备入行&#xff0c;那么它的就业前景怎么样&#xff1f;大致可以分为3个版块来介绍。 1.就业领域前景广阔 目前互联网、通信、新能源、房地产、金融证券、电子技术等行业迫切需要网络安全人才&#x…

22. 算法之图的最短路径

前言 关于图的最短路径问题&#xff0c;是图这种数据结构中的经典问题。也是与我们的生活息息相关的&#xff0c;比如上海四通八达的地铁线路&#xff0c;从一个地铁站&#xff0c;到另一个地铁站&#xff0c;可能有很多种不同的路线。那么&#xff0c;我们选哪种路线&#xf…

JavaFX第五篇 Image图片加载处理

JavaFX第五篇 Image图片加载处理 1. 代码2. 讲解3. 代码仓 图片已经成为每个网站的必备了&#xff0c;不仅可以提升个人网站的标识度而且还可以美化网站&#xff0c; 所以这里需要讲解一下如何加载图片&#xff0c;展示到前台给用户查看。 本次只是简单的讲解如何展示使用&…

【算法证明 七】深入理解深度优先搜索

深度优先搜索包含一个递归&#xff0c;对其进行分析要复杂一些。与上一篇文章一样&#xff0c;还是给节点定义几个状态&#xff0c;然后详细分析深度优先搜索算法有哪些性质。 算法描述 定义状态 v . c o l o r &#xff1a;初始状态为白色&#xff0c;被发现时改为灰色&…

Mysql的SQL性能分析【借助EXPLAIN分析】

性能分析 要说sql有问题&#xff0c;需要拿出证据&#xff0c;因此需要性能分析 Mysql查询优化器&#xff08;Mysql Query Optimizer&#xff09; Mysql中有专门负责优化SELECT语句的优化器模块&#xff0c;主要功能&#xff1a;通过计算分析系统中收集到的统计信息&#xf…

Xline v0.4.1: 一个用于元数据管理的分布式KV存储

Xline是什么&#xff1f;我们为什么要做Xline&#xff1f; Xline是一个基于Curp协议的&#xff0c;用于管理元数据的分布式KV存储。现有的分布式KV存储大多采用Raft共识协议&#xff0c;需要两次RTT才能完成一次请求。当部署在单个数据中心时&#xff0c;节点之间的延迟较低&a…

python机器学习——分类模型评估 分类算法(k近邻,朴素贝叶斯,决策树,随机森林,逻辑回归,svm)

目录 分类模型的评估模型优化与选择1.交叉验证2.网格搜索 【分类】K近邻算法【分类】朴素贝叶斯——文本分类实例&#xff1a;新闻数据分类 【分类】决策树和随机森林1.决策树2.决策树的算法3.代码实现实例&#xff1a;泰坦尼克号预测生死 【集成学习】随机森林1.集成学习2.随机…

LOMO:在受限资源上全参数微调

LOMO&#xff1a;Full Parameter Fine-Tuning for large language models with limited resources IntroductionMethodRethink the functionality of optimizerUsing SGD LOMO&#xff1a; LOw-Memory Optimization 实验参考 Introduction 在这篇文章中&#xff0c;作者的目的…

Go 语言进阶 - 工程进阶

前言&#xff1a; \textcolor{Green}{前言&#xff1a;} 前言&#xff1a; &#x1f49e;这个专栏就专门来记录一下寒假参加的第五期字节跳动训练营 &#x1f49e;从这个专栏里面可以迅速获得Go的知识 今天的内容包括以下两个内容。关于实践的内容我会在后续发布出来。 01.语言…

新零售破局丨2023年探索全新电商运维模式——永倍达模式深度解析

新零售破局丨2023年探索全新电商运维模式——永倍达模式深度解析 大家好&#xff01;我是微三云胡佳东&#xff0c;一家专业的电商软件开发公司的负责人。 近年来&#xff0c;随着电商的高速发展&#xff0c;不少电商平台成为了市场经济的优质榜样&#xff0c;互联网市场竞争也…

设计模型学习-UML图

1&#xff0c;简介 UML图有很多种类型&#xff0c;但掌握其中的类图、用例图和时序图就可以完成大部分的工作。其中最重要的便是「类图」&#xff0c;它是面向对象建模中最常用和最重要的图&#xff0c;是定义其他图的基础。 类图主要是用来显示系统中的类、接口以及它们之间的…

Ubuntu环境下读取罗技G29方向盘信息

本篇博客最早发布于实验室公共博客&#xff0c;但已无人维护&#xff0c;现迁移至个人博客 引言 实验室有这么酷的驾驶设备&#xff0c;来了一年还没有实际操作过&#xff0c;早就蠢蠢欲试了&#xff0c;哈哈哈不过之前负责的师兄还在就一直没敢用&#xff0c;现在他毕业了就可…