01_pytorch中的DataSet

news2025/1/12 8:46:14

在pytorch 中,
Dataset: 用于数据集的创建;
DataLoader: 用于在训练过程中,传递获取一个batch的数据;

这里先介绍 pytorch 中的 Dataset 这个类,
torch.utils.data. dataset.py 是一个表示数据集的抽象类。任何自定义的数据集都需要继承这个类并覆写相关方法。

数据集,其实就是一个负责处理索引(index)到样本(sample)映射的一个类(class)。

在这里插入图片描述
在torch.utils.data. dataset.py 中可知,
pytorch 提供两种数据集:

  • Map 式数据集, 上图中MapDataPipe()
  • Iterable 式数据集, IterDataPipe()

1. Map 式数据集

即上图中MapDataPipe(),

1.1 需要重写的方法

一个Map式的数据集必须要重写__getitem__(self, index),
len(self) 两个内建方法,用来表示从索引到样本的映射(Map).

这样一个数据集dataset的作用如下,

  • 当使用dataset[idx]命令时,可以在你的硬盘中读取你的数据集中第idx张图片以及其标签(如果有的话);
  • len(dataset)则会返回这个数据集的容量。

1.2 使用方法

例子-1: 自己实验中写的一个例子:这里我们的图片文件储存在“./data/faces/”文件夹下,图片的名字并不是从1开始,而是从final_train_tag_dict.txt这个文件保存的字典中读取,label信息也是用这个文件中读取。大家可以照着注释阅读这段代码。

from torch.utils import data
import numpy as np
from PIL import Image


class face_dataset(data.Dataset):
	def __init__(self):
		self.file_path = './data/faces/'
		f=open("final_train_tag_dict.txt","r")
		self.label_dict=eval(f.read())
		f.close()

	def __getitem__(self,index):
		label = list(self.label_dict.values())[index-1]
		img_id = list(self.label_dict.keys())[index-1]
		img_path = self.file_path+str(img_id)+".jpg"
		img = np.array(Image.open(img_path))
		return img,label

	def __len__(self):
		return len(self.label_dict)

2. Iterable 式数据集

一个Iterable(迭代)式数据集是抽象类data.IterableDataset的子类,并且覆写了__iter__方法成为一个迭代器。

这种数据集主要用于数据大小未知,或者以流的形式的输入,本地文件不固定的情况,需要以迭代的方式来获取样本索引。

一个 Iterable 式的数据集必须要重写__iter__,

class IterableDataset(Dataset[T_co]):
    r"""An iterable Dataset.

    All datasets that represent an iterable of data samples should subclass it.
    Such form of datasets is particularly useful when data come from a stream.

    All subclasses should overwrite :meth:`__iter__`, which would return an
    iterator of samples in this dataset.

    When a subclass is used with :class:`~torch.utils.data.DataLoader`, each
    item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader`
    iterator. When :attr:`num_workers > 0`, each worker process will have a
    different copy of the dataset object, so it is often desired to configure
    each copy independently to avoid having duplicate data returned from the
    workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker
    process, returns information about the worker. It can be used in either the
    dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's
    :attr:`worker_init_fn` option to modify each copy's behavior.

    """
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])

所有表示数据样本可迭代的数据集都应该对其进行子类化。
这种形式的数据集在数据来自流的时候特别有用。

所有的子类都应该覆盖 :meth:__iter__,它将返回这个数据集中样本的迭代器。

  • 当子类与 :class:~torch.utils.data.DataLoader一起使用时,数据集中的每个项目将从 :class:~torch.utils.data.DataLoader迭代器中产生。
  • 当 :attr:num_workers > 0时,每个工作进程将有一个不同的数据集对象的副本,所以通常需要独立配置每个副本,以避免从工作进程返回重复的数据。
  • :func:~torch.utils.data.get_worker_info,当在工作进程中调用时,返回关于工作者的信息。它可以被用于数据集的 :meth:__iter__ 方法或 :class:~torch.utils.data.DataLoader 的 j-
  • :attr:worker_init_fn 选项来修改每个副本的行为

2.1  迭代器和生成器之间的关系

在这里插入图片描述

2.2 python 中的迭代器

顾名思义,迭代器就是用于迭代操作(for 循环)的对象,它像列表一样可以迭代获取其中的每一个元素,任何实现了 __next__ 方法 (python2 是 next)的对象都可以称为迭代器。

它与列表的区别在于,构建迭代器的时候,不像列表把所有元素一次性加载到内存,而是以一种延迟计算(lazy evaluation)方式返回元素,这正是它的优点。

比如列表含有中一千万个整数,需要占超过400M的内存,而迭代器只需要几十个字节的空间。
因为它并没有把所有元素装载到内存中,而是等到调用 next 方法时候才返回该元素
(按需调用 call by need 的方式,本质上 for 循环就是不断地调用迭代器的__next__方法)。

以斐波那契数列为例来实现一个迭代器:

class Fib:
    def __init__(self, n):
        self.prev = 0
        self.cur = 1
        self.n = n

    def __iter__(self):
        return self

    def __next__(self):
        if self.n > 0:
            value = self.cur
            self.cur = self.cur + self.prev
            self.prev = value
            self.n -= 1
            return value
        else:
            raise StopIteration()
    # 兼容python2
    def next(self):
        return self.__next__()

f = Fib(10)
print([i for i in f])
#[1, 1, 2, 3, 5, 8, 13, 21, 34, 5

2.3  python 中的生成器

知道迭代器之后,就可以正式进入生成器的话题了。

2.3.1 为什么需要生成器

通过列表生成式,我们可以直接创建一个列表,但是,受到内存限制,列表容量肯定是有限的,而且创建一个包含100万个元素的列表,不仅占用很大的存储空间,如果我们仅仅需要访问前面几个元素,那后面绝大多数元素占用的空间都白白浪费了。

所以,如果列表元素可以按照某种算法推算出来,那我们是否可以在循环的过程中不断推算出后续的元素呢?这样就不必创建完整的list,从而节省大量的空间,在Python中,这种一边循环一边计算的机制,称为生成器:generator

生成器是一个特殊的程序,可以被用作控制循环的迭代行为,python中生成器是迭代器的一种,使用yield返回值函数,每次调用yield会暂停,而可以使用next()函数和send()函数恢复生成器。

生成器类似于返回值为数组的一个函数,这个函数可以接受参数,可以被调用,但是,不同于一般的函数会一次性返回包括了所有数值的数组,生成器一次只能产生一个值,这样消耗的内存数量将大大减小,而且允许调用函数可以很快的处理前几个返回值,因此生成器看起来像是一个函数,但是表现得却像是迭代器

  • 生成器函数:也是用def定义的,利用关键字yield一次性返回一个结果,阻塞,重新开始

  • 生成器表达式:返回一个对象,这个对象只有在需要的时候才产生结果

普通函数用 return 返回一个值,和 Java 等其他语言是一样的,然而在 Python 中还有一种函数,用关键字 yield 来返回值,这种函数叫生成器函数,函数被调用时会返回一个生成器对象,生成器本质上还是一个迭代器,也是用在迭代操作中,因此它有和迭代器一样的特性,唯一的区别在于实现方式上不一样, 生成器更加简洁。

2.3.2 生成器函数

最简单的生成器函数:

>>> def func(n):
...     yield n*2
...
>>> func
<function func at 0x00000000029F6EB8>
>>> g = func(5)
>>> g
<generator object func at 0x0000000002908630>
>>>

func 就是一个生成器函数,调用该函数时返回对象就是生成器 g ,这个生成器对象的行为和迭代器是非常相似的,可以用在 for 循环等场景中。注意 yield 对应的值在函数被调用时不会立刻返回,而是调用next方法时(本质上 for 循环也是调用 next 方法)才返回。

>>> g = func(5)
>>> next(g)
10

>>> g = func(5)
>>> for i in g:
...     print(i)
...
10

那为什么要用生成器呢?用生成器它没有那么多冗长代码了,而且性能上一样的高效。

不足之处,便是需要多理解一下。

来看看用生成器实现斐波那契数列有多简单。

def fib(n):
    prev, curr = 0, 1
    while n > 0:
        n -= 1
        yield curr
        prev, curr = curr, curr + prev

print([i for i in fib(10)])
#[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]

2.3.2 生成器表达式

  • 生成器表达式
    器表达式与列表推导式长的非常像,但是它俩返回的对象不一样,前者返回生成器对象,后者返回列表对象。
>>> g = (x*2 for x in range(10))
>>> type(g)
<type 'generator'>
>>> l = [x*2 for x in range(10)]
>>> type(l)
<type 'list'>

生成器的优势,就是迭代海量数据时,生成器会更加内存。

2.3  生成器在 DataLoader中应用

深度学习框架PyTorch中的DataLoader模块的实现就使用了生成器的机制来生成一次训练用的batch。

DataLoader的详解在博主的另一篇文章Pytorch之Dataloader中,这里只讲其中运用到生成器机制的Sampler类模块。

首先,是 RandomSampler, iter(randomSampler) 会返回一个可迭代对象,这个可迭代对象 每次 next 都会输出当前要采样的 index,SequentialSampler也是一样,只不过产生的 index 是顺序的

class RandomSampler(Sampler):

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(torch.randperm(len(self.data_source)).long())

    def __len__(self):
        return len(self.data_source)

BatchSampler 是一个普通 Sampler 的 wrapper, 普通Sampler 一次仅产生一个 index, 而 BatchSampler 一次产生一个 batch 的 indices。

class BatchSampler(Sampler):
    """Wraps another sampler to yield a mini-batch of indices. Args: sampler (Sampler): Base sampler. batch_size (int): Size of mini-batch. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size`` Example: >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integeral value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

reference:

https://chenllliang.github.io/2020/02/04/dataloader/
https://chenllliang.github.io/2020/02/06/PyIter/
https://www.zhihu.com/question/20829330/answer/213544776
https://www.cnblogs.com/wj-1314/p/8490822.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/601999.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SharpContour论文精读

SharpContour: A Contour-based Boundary Refinement Approach for Efficient and Accurate Instance Segmentation 论文链接&#xff1a;[2203.13312] SharpContour: A Contour-based Boundary Refinement Approach for Efficient and Accurate Instance Segmentation (arxiv…

[SpringBoot]Knife4j框架

Knife4j框架 Knife4j框架是一款国人开发的、基于Swagger 2的在线API文档框架。 Knife4j框架的一些主要作用和特点&#xff1a; 自动生成API文档&#xff1a;Knife4j可以根据代码中的注解和配置信息&#xff0c;自动生成API接口文档。开发者只需要在代码中添加相关注解&#…

数据治理服务解决方案word

本资料是ppt格式&#xff0c;适用于方案规划、项目实施、工作汇报。本资料来源公开网络&#xff0c;仅供个人学习&#xff0c;请勿商用&#xff0c;如有侵权请联系删除。篇幅有限&#xff0c;无法完全展示&#xff0c;喜欢资料可转发评论&#xff0c;私信“方案”了解更多信息。…

亚马逊、沃尔玛、eBay、wish的测评风险:源头控制与有效规避

测评补单已逐渐成为跨境电商卖家的一种重要推广方式。然而&#xff0c;近期&#xff0c;一些卖家反映&#xff0c;由于平台规则日益严格&#xff0c;测评变得更为棘手。若违反评论政策并被捕获&#xff0c;卖家可能会面临一系列的处罚&#xff0c;如删除店铺所有产品的评论&…

GRPC CPP 开发单向Stream服务器

上周提到我们要给llama.cpp增加一个grpc入口&#xff0c;这是最终成果仓库&#xff0c;等待进一步测试后提交合并。 今天讲讲GRPC CPP开发的麻烦事情。 参考文档 Quick start | C | gRPC&#xff0c;参考文档就是官方的这篇文档了&#xff0c;安装grpc可以参考我上一篇文章&…

Pycharm:通过git拉取仓库代码并创建项目环境

一、使用pycharm打开空的文件夹 使用菜单栏&#xff1a;在 PyCharm 的菜单栏中&#xff0c;选择 "File"&#xff08;文件&#xff09;菜单&#xff0c;然后选择 "Open"&#xff08;打开&#xff09;或 "Open Folder"&#xff08;打开文件夹&…

Hive3.1.3

文章目录 1、Hive入门1.1 Hive简介1.2 Hive本质1.3 Hive架构原理 2、Hive安装2.1 Hive安装地址2.2 Hive安装部署2.2.1 安装Hive(最小化)2.2.2 启动并使用Hive 2.3 MySQL安装2.3.1 安装MySQL2.3.2 配置MySQL 2.4 配置Hive元数据存储到MySQL2.4.1 配置元数据到MySQL2.4.2 验证元数…

校验表格中的多个表单

要实现的效果是: 点击保存回校验当前页面的所有输入框 首先 分成两个上下两个子组件, 上面的子组件是一个表单包括规则名称和区域 下面的子组件是一个表格,表格可以是多行的,需要校验每一行的输入框 父组件调用两个子组件的校验方法, 第一个子组件可以直接校验,第二个子组件在…

深度学习笔记之循环神经网络(十)基于循环神经网络模型的简单示例

深度学习笔记之循环神经网络——基于循环神经网络模型的简单示例 引言文本表征&#xff1a; One-hot \text{One-hot} One-hot向量简单示例:文本序列的预测任务数据预处理过程生成文本数据遍历数据集&#xff0c;构建字典抓取数据&#xff0c;创建训练样本、标签字符特征与数字特…

Uni-app学习从0到1开发一个app——(2)windowns环境搭配

文章目录 0 引入1、使用HBuilderX构建工程2、使用vscode2.1 官方推荐的使用2.2 如何使用 3、总结 0 引入 工欲善其事必先利其器介绍两种开发小程序的方法&#xff0c;个人倾向于第一种&#xff0c;后续演示的的工程也是基于前者&#xff0c;毕竟官方的更有说服力。 1、使用HBu…

基于yolov5开发构建枪支刀具等危险物品检测识别系统

安全始终是重如泰山的事情&#xff0c;安全事件如果能够做到早发现早制止可能结果就会完全不一样了&#xff0c;本文的核心目的很简单&#xff0c;就是想基于目标检测模型来尝试构建枪支刀具等危险物品检测识别系统&#xff0c;希望基于人工智能手段来打击犯罪行为&#xff0c;…

【JavaSE】Java基础语法(四十三):反射

文章目录 概述&#xff1a;1. java.lang.Class1.1 获取 Class 对象1.2 通过反射创建对象1.3 通过反射获取类的属性、方法和注解等1.3.1 反射获取构造方法1.3.2 反射通过构造器创建对象1.3.3 反射获取成员方法1.3.4 反射获取属性 2. 工具类操作3. 反射是如何破坏单例模式的4. 反…

linux0.12-12-2-buffer

基本上看完赵老师中文解释&#xff0c;都可以自己写这部分的代码。 [622页] 12-2 buffer.c程序 从本节起&#xff0c;我们对fs/目录下的程序逐一进行说明和注释。按照本章第2节中的描述&#xff0c; 本章的程序可以被分成4个部分&#xff1a; 高速缓冲管理&#xff1b; 文件…

基于ATC89C51单片机的超市临时储物柜密码锁设计

点击链接获取Keil源码与Project Backups仿真图: https://download.csdn.net/download/qq_64505944/87855870?spm=1001.2014.3001.5503 源码获取 摘 要 随着微机测量和控制技术的迅速发展与广泛应用,以单片机为核心的电子密码锁的设计研发与应用在很大程度上改善了人们的…

windows 部署多个tomcat

去官网下载tomcat&#xff0c;地址&#xff1a;Apache Tomcat - Apache Tomcat 8 Software Downloads 选择对应的版本下载&#xff0c;下载完成后&#xff0c;直接解压文件&#xff0c; 修改第二个解压的tomcat的catalina.bat 和 startup.bat和service.bat文件的配置&#x…

iptables 基础

iptables防火墙 主要实现数据包的过滤、封包重定向和网络地址转换&#xff08;NAT&#xff09;等功能 iptables&#xff1a;用户空间的命令行工具&#xff0c;用于管理配置netfilter&#xff1a;真正实现功能的是netfilter运行在内核空间 iptables的4表5链 链&#xff1a;通过…

想管好数据资源,不妨了解大数据分析开源框架

在如今快节奏的时代中&#xff0c;办公自动化早已成为各行各业的发展趋势和方向。随着业务量的激增&#xff0c;数据资源也不断增多&#xff0c;如果没有一套完善的大数据分析开源框架&#xff0c;那这么多的数据资源就不能很好地利用和发挥其价值&#xff0c;如果采用专业的大…

基于AT89C52单片机的交通灯设计

点击链接获取Keil源码与Project Backups仿真图&#xff1a; https://download.csdn.net/download/qq_64505944/87855439?spm1001.2014.3001.5503 源码获取 一、实验目的 掌握单片机的综合应用设计。加强对单片机和汇编语言的认识&#xff0c;充分掌握和理解设计各部分的工作…

华为防火墙双机热备外线vrrp地址和接口地址非同网段

主防火墙FW1: HRP_Mdis current-configuration 2023-06-02 15:51:48.270 08:00 !Software Version V500R005C10SPC300 sysname USG6000V1 l2tp domain suffix-separator undo info-center enable ipsec sha2 compatible enable undo telnet server enable undo telnet ipv6 se…

Office Visio 2007安装教程

哈喽&#xff0c;大家好。今天一起学习的是Visio 2007的安装&#xff0c;这是一个绘制流程图的软件&#xff0c;用有效的绘图表达信息&#xff0c;比任何文字都更加形象和直观。Office Visio 是office软件系列中负责绘制流程图和示意图的软件&#xff0c;便于IT和商务人员就复杂…