【torch.utils.data】 Dataset和Dataloader的解读和使用

news2024/9/21 14:24:03

文章目录

  • torch.utils.data
    • 前言
    • Dataset
    • Dataloader
    • 实践
  • 参考

torch.utils.data

前言

Pytorch中的 torch.utils.data 提供了两个抽象类:DatasetDataloaderDataset 允许你自定义自己的数据集,用来存储样本及其对应的标签。而 Dataloader 则是在 Dataset 的基础上将其包装为一个可迭代对象,以便我们更方便地(小批量)访问数据集。

import torch
from torch.utils.data import Dataset, Dataloader
  • 一些必备概念:
    • Data Size:整个数据集的大小;
    • Batch Size :在训练过程中,我们不可能把所有样本一次性投喂给神经网络,只能分批次投喂。每个小批量的样本个数就是 Batch Size
    • Iteration :将一个 Batch 投喂给神经网络称为一次 Iteration;
    • Epoch :将所有的样本(即所有 Batch)投喂给神经网络后称为一个 Epoch。

在这里插入图片描述

一般来说PyTorch中深度学习训练的流程是这样的:

  1. 创建Dateset
  2. Dataset传递给DataLoader
  3. DataLoader迭代产生训练数据提供给模型
# 创建Dateset(可以自定义)     
dataset = face_dataset  # Dataset部分自定义过的face_dataset 
# Dataset传递给DataLoader     
dataloader = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle=False,num_workers=8)
# DataLoader迭代产生训练数据提供给模型    
for i in range(epoch):
        for index,(img,label) in enumerate(dataloader):
            pass

到这里应该就PyTorch的数据集和数据传递机制应该就比较清晰明了。Dataset负责建立索引到样本的映射,DataLoader负责以特定的方式从数据集中迭代的产生一个个batch的样本集合。在enumerate过程中实际上是dataloader按照其参数sampler规定的策略调用了其dataset__getitem__方法。

Dataset

可以看出,Dataset 是一个抽象类,我们自己编写的数据集类必须继承 Dataset,且需重新改写 __getitem____len__ 方法。

__getitem__ :传入指定的索引 index 后,该方法能够根据索引返回对应的单个样本及其对应的标签(以元组形式)。
__len__ :返回整个数据集的大小,即前面所说的 Data Size。
若我们自定义的类在继承 Dataset 时没有改写__getitem__ ,则程序会抛出 NotImplementedError 的异常。此外,因为 Dataset 类中提供了 add 方法,所以继承之后我们的数据集也会拥有此方法,从而合并数据集只需使用 + 运算即可。

一般而言,我们自定义的数据集的框架如下:

class MyDataset(Dataset):
    def __init__(self):
        # 初始化数据集的存储路径
        # 载入数据集(转化为tensor格式)
        # ...
    
    def __getitem__(self, index):
        # 返回单个样本及其标签
        pass
    
    def __len__(self):
        # 返回整个数据集的大小
        pass

Dataloader

绝大多数时候我们需要以 batch 的形式访问数据集。Dataloader这个接口提供了这样的功能,它能够基于我们自定义的数据集将其转换成一个可迭代对象以便我们批量访问。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

参数介绍:

  • dataset (Dataset) – 定义好的Map式或者Iterable式数据集。
  • batch_size (python:int, optional) – 一个batch含有多少样本 (default: 1)。
  • shuffle (bool, optional) – 每一个epoch的batch样本是相同还是随机 (default: False)。
  • sampler (Sampler, optional) – 决定数据集中采样的方法. 如果有,则shuffle参数必须为False。
  • batch_sampler (Sampler, optional) – 和 sampler 类似,但是一次返回的是一个batch内所有样本的index。和 batch_size, shuffle, sampler, and drop_last 三个参数互斥。
  • num_workers (python:int, optional) – 多少个子程序同时工作来获取数据,多线程。 (default: 0)
  • collate_fn (callable, optional) – 合并样本列表以形成小批量。
  • pin_memory (bool, optional) – 如果为True,数据加载器在返回前将张量复制到CUDA固定内存中。
  • drop_last (bool, optional) – 如果数据集大小不能被batch_size整除,设置为True可删除最后一个不完整的批处理。如果设为False并且数据集的大小不能被batch_size整除,则最后一个batch将更小。(default: False)
  • timeout (numeric, optional) – 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。 (default: 0)
  • worker_init_fn (callable, optional) – 每个worker初始化函数 (default: None)

实践

假设当前工作目录下有一个 data.txt,其内容如下:

1 -14 -15
1 -1 -15
1 -11 -14
1 0 -2
0 -4 2
1 7 -2
1 -7 -17
0 9 12
0 5 -14
1 -13 13

其中每一行都是一个样例。每行中的后两个数字为样本的特征,第一个数字为样本对应的标签。可以看出,我们一共有 10 1010 个样本,它们均位于二维欧式空间中,且问题为二分类问题。

于是数据集的框架可以这样写:

class MyDataset(Dataset):
    def __init__(self, path):
        self.data = np.loadtxt(path)
        self._X = torch.from_numpy(self.data[:, 1:])
        self._y = torch.from_numpy(self.data[:, 0])
    
    def __getitem__(self, index):
        return self._X[index], self._y[index]
    
    def __len__(self):
        return len(self._X)

使用时只需创建实例即可

path = './data.txt'
data = MyDataset(path)

我们可以调用各个方法来观察一下

len(data)
# 10

data[1]
# (tensor([ -1., -15.], dtype=torch.float64), tensor(1., dtype=torch.float64))

事实上,data是一个可迭代对象,我们可以直接使用 for 循环来输出整个数据集:

for feature, label in data:
    print(feature, label)
# tensor([-14., -15.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ -1., -15.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([-11., -14.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ 0., -2.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([-4.,  2.], dtype=torch.float64) tensor(0., dtype=torch.float64)
# tensor([ 7., -2.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ -7., -17.], dtype=torch.float64) tensor(1., dtype=torch.float64)
# tensor([ 9., 12.], dtype=torch.float64) tensor(0., dtype=torch.float64)
# tensor([  5., -14.], dtype=torch.float64) tensor(0., dtype=torch.float64)
# tensor([-13.,  13.], dtype=torch.float64) tensor(1., dtype=torch.float64)

创建dataloader :

dataloader = DataLoader(data, batch_size=3, shuffle=False, drop_last=False)

该代码将创建一个可迭代对象,我们将其列表化来观察一下:

list(dataloader)
# [[tensor([[-14., -15.],
#           [ -1., -15.],
#           [-11., -14.]], dtype=torch.float64),
#   tensor([1., 1., 1.], dtype=torch.float64)],
#  [tensor([[ 0., -2.],
#           [-4.,  2.],
#           [ 7., -2.]], dtype=torch.float64),
#   tensor([1., 0., 1.], dtype=torch.float64)],
#  [tensor([[ -7., -17.],
#           [  9.,  12.],
#           [  5., -14.]], dtype=torch.float64),
#   tensor([1., 0., 0.], dtype=torch.float64)],
#  [tensor([[-13.,  13.]], dtype=torch.float64),
#   tensor([1.], dtype=torch.float64)]]

可以看出,列表化后,每一个 batch 均以列表的形式存储。这说明我们可以通过 for 循环来遍历所有的 batch,具体做法如下:

for inputs, labels in dataloader:
    print(inputs, labels)
# tensor([[-14., -15.],
#         [ -1., -15.],
#         [-11., -14.]], dtype=torch.float64) tensor([1., 1., 1.], dtype=torch.float64)
# tensor([[ 0., -2.],
#         [-4.,  2.],
#         [ 7., -2.]], dtype=torch.float64) tensor([1., 0., 1.], dtype=torch.float64)
# tensor([[ -7., -17.],
#         [  9.,  12.],
#         [  5., -14.]], dtype=torch.float64) tensor([1., 0., 0.], dtype=torch.float64)
# tensor([[-13.,  13.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)

有些时候,我们需要记录每个 batch 的索引(即 iteration),则需要用到 enumerate函数(这里为了方便展示将 batch_size设为了1):

dataloader = DataLoader(data, batch_size=1, shuffle=True, drop_last=True)
for batch_idx, (inputs, labels) in enumerate(dataloader):
    print(batch_idx, end=' ')
    print(inputs, labels)
# 0 tensor([[-4.,  2.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 1 tensor([[ -1., -15.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 2 tensor([[ 0., -2.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 3 tensor([[ 7., -2.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 4 tensor([[ 9., 12.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 5 tensor([[  5., -14.]], dtype=torch.float64) tensor([0.], dtype=torch.float64)
# 6 tensor([[-11., -14.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 7 tensor([[-14., -15.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 8 tensor([[ -7., -17.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)
# 9 tensor([[-13.,  13.]], dtype=torch.float64) tensor([1.], dtype=torch.float64)

若要使用 Dataloader 进行神经网络训练,则需要将特征转化为 torch.float32
型,标签转化为 torch.int64型。

参考

PyTorch学习笔记(三)–Dataset和DataLoader_Lareges的博客-CSDN博客_dataset和dataloader

Pytorch之Dataset与DataLoader

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/61650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LTspice XVII > Transformer 变压器仿真

目录 第①步设置 第②步设置 第③步设置 第④步设置 输出结果 最近在看“无线电基础电路实作修订版 [(美)西尔弗 著] 2014年版”这本书,打算好好修炼下无线电方面的基础知识,让自己更加牛逼一些,工作中偶尔可以装…

指标与标签的区别?

概述 在公司数据建设过程中,经常会使用和提到指标和标签,但是很多小伙伴对于两者的区别确不能讲清楚。实际上标签与指标一样,是理解数据的两种方式,在赋能业务上,两者同样重要。接下来将结合自身的理解,从…

Java项目:SSM共享汽车租赁平台

作者主页:源码空间站2022 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文末获取源码 项目介绍 本项目分为前后台,前台为普通用户登录,后台为管理员登录; 管理员角色包含以下功能: 管理员登录…

ElementUI组件-日期时间控件设置禁用日期

ElementUI组件-日期时间控件禁用指定日期 主要属性 查看官网,可以看到有个叫做picker-options的组件属性,没错,就是借助他来完成禁用指定日期的操作,如下 该属性值传入的是一个对象,对于时间选择器、日期选择器、日…

[阶段4 企业开发进阶] 3. 消息队列--RabbitMQ

文章目录1 消息队列1.1 MQ的概念基本介绍使用原因MQ分类如何选择1.2 RabbitMQRabbitMQ核心工作原理安装教程1 消息队列 1.1 MQ的概念 基本介绍 MQ本质是个队列,FIFO 先入先出,只不过队列中存放的内容是 message 而已是一种跨进程的通信机制&#xff0…

[附源码]计算机毕业设计校刊投稿系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

Py之removebg:removebg的简介、安装、使用方法之详细攻略

Py之removebg:removebg的简介、安装、使用方法之详细攻略 目录 removebg的简介 1、官网注册获取APIKey removebg的安装 removebg的使用方法 1、直接调用并实现抠图 2、更多案例 removebg的简介 Remove Image Background,是一款不用PS就完成抠图的强…

每日挠头算法题(十五)螺旋矩阵II

“强大方能侠义” ------持续更新Blue Bridge杯入门系列算法实例-------- 如果你也喜欢Java和算法,欢迎订阅专栏共同学习交流! 你的点赞、关注、评论、是我创作的动力! -------希望我的文章对你有所帮助-------- 前言:最近可能…

【Python自学笔记】报错No module Named Wandb

【Python自学笔记】已经装了wandb,还报错No module Named Wandb 方法1.重启cmd和jupyter notebook 直接把窗口和cmd页面全关了,重新打开,再次运行安装和启动代码: !pip install wandbimport wandb wandb.init(project"你自…

【Matlab】一、解常微分方程ODE

文章目录求解常微分方程 ODE(1)求解解析解(2)求解数值解求解常微分方程 ODE ​ 在matlab中,我们可以求解常微分方程的解析解,和数值解,一般使用dsolve来求解常微分方程的解析解,使用…

jsp 上传文件及实体信息,ajax post 请求(formdata)报错400<======>前后端代码示例

Content-Type最常见的几种类型: 通常,没有声明,默认application/x-www-form-urlencoded application/x-www-form-urlencoded form表单默认的数据格式,提交的数据形式 key1val1&key2val2(参数少) mu…

[附源码]计算机毕业设计线上社区管理系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

〖全域运营实战白宝书 - 高转化文案速成篇③〗- 高打开率标题型文案的10大黄金法则

大家好,我是 哈士奇 ,一位工作了十年的"技术混子", 致力于为开发者赋能的UP主, 目前正在运营着 TFS_CLUB社区。 💬 人生格言:优于别人,并不高贵,真正的高贵应该是优于过去的自己。💬 &#x1f4e…

第一期 | 整洁,从桌面开始

文章目录前言一、主要内容介绍二、文件分类,整理你的桌面1.网格对齐图标,取消自动排列2.保持工作状态,提取近期文件3.用好排序,让文件一目了然4.分类整理,让文件听你的话5.按照实际情况作调整三、合理归档,…

[附源码]JAVA毕业设计框架的企业机械设备智能管理系统的设计与实现(系统+LW)

[附源码]JAVA毕业设计框架的企业机械设备智能管理系统的设计与实现(系统LW) 目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支…

文献阅读-VQAR-基于计算机视觉和自然语言处理的信息检索技术综述

VQAR: Review on Information Retrieval Techniques based on Computer Vision and Natural Language Processing 标题:VQAR-基于计算机视觉和自然语言处理的信息检索技术综述 Authors:Shivangi ModiDhatri Pandya Journal:2019 3rd Inter…

在Docker中运行Dubbo应用,详细教程,一学就会

Dubbo概述 Dubbo是阿里开源的一个分布式服务框架,在国内粉丝很多。官网上的介绍是: DUBBO是一个分布式服务框架,致力于提供高性能和透明化的RPC远程服务调用方案,是阿里巴巴SOA服务化治理方案的核心框架,每天为2,000…

Spring_第2章_注解开发+整合Mybatis+Junit

Spring_第2章_注解开发整合MybatisJunit 文章目录Spring_第2章_注解开发整合MybatisJunit一、第三方资源配置管理1 管理DataSource连接池对象问题导入1.1 管理Druid连接池【重点】1.2 管理c3p0连接池2 加载properties属性文件【重点】问题导入2.1 基本用法2.2 配置不加载系统属…

浅谈Android输入法(IME)架构

简介: 输入法 (IME) 是一种可让用户输入文本的用户控件。Android 提供了一种可扩展的输入法框架。借助该框架,应用可以为用户提供备选输入法,例如屏幕键盘,甚至语音输入。安装所需的 IME 后,用户可以从系统设置中选择要…

每日一题:斐波那契数列

每日一题:斐波那契数列 我们先来看一下斐波那契数列的定义: 斐波那契数列(Fibonacci sequence),又称黄金分割数列,因数学家莱昂纳多斐波那契(Leonardo Fibonacci)以兔子繁殖为例子而…