Pytorch进行自定义Dataset 和 Dataloader 原理

news2024/9/22 19:25:00

1、自定义加载数据

在pytorch中,数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现相应的方法。

在学习Pytorch的教程时,加载数据许多时候都是直接调用torchvision.datasets里面集成的数据集,直接在线下载,然后使用torch.utils.data.DataLoader进行加载。
那么,我们怎么使用我们自己的数据集,然后用DataLoader进行加载呢?

常见的两种形式的导入:

1.1、一种是整个数据集都在一个文件下,内部再另附一个label文件,说明每个文件的状态。这种存放数据的方式可能更时候在非分类问题上得到应用。下面就是我们经常使用的数据存放方式。

1.2、一种则是更适合在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。这样,我们可以从文件夹或文件名得到label。使用torchvision.datasets.imageFolder函数生成数据集。这种方式没有用过,暂时不介绍了

2、重写 Dataset 类

2.1、Pytorch自定义Dataset的步骤:

官方:torch.utils.data.Dataset 是一个抽象类,

def __getitem__(self, index):
	raise NotImplementedError

def __len__(self):
	raise NotImplementedError

用户想要加载自定义的数据只需要继承这个类(torch.util.data.Dataset),并且覆写__len__ 和 __getitem__两个方法, 不覆写这两个方法会直接返回错误因此步骤如下:

  1. 继承torch.util.data.Dataset
  2. __init__:改写__init__函数时,需要添加对父类的初始化,该方法主要就是一些参数初始化工作,定义一些路径或者变量什么的
  3. __getitem__: 该方法是加载数据用的,用于读取每一条数据,他会有一个参数idx,就是对应的索引,可以用来获取一些索引的数据,使dataset [i] 返回数据集中第i个样本。
  4. __len__:实现len(dataset),返回整个数据集的大小

建立的自定义类如下:

# 加载数据集,自己重写DataSet类
class dataset(Dataset):
    # image_dir为数据目录,label_file,为标签文件
    def __init__(self, image_dir, label_file, transform=None):
        super(dataset, self).__init__()    # 添加对父类的初始化
        self.image_dir = image_dir         # 图像文件所在路径
        self.labels = read(label_file)     # 图像对应的标签文件, read label_file之后的结果
        self.transform = transform         # 数据转换操作
        self.images = os.listdir(self.image_dir )#目录里的所有img文件
    
    # 加载每一项数据
    def __getitem__(self, idx):
        image_index = self.images[index]    #根据索引index获取该图片
        img_path = os.path.join(self.image_dir, image_index) #获取索引为index的图片的路径名    
        labels = self.labels[index]   # 对应标签

        image = Image.open(img_name)
        if self.transform:
            image = self.transform(image)
        # 返回一张照片,一个标签
        return image, labels
    
    # 数据集大小
    def __len__(self):
        return (len(self.images))

设置好数据类之后,我们就可以将其用torch.utils.data.DataLoader加载,并访问它。

if __name__=='__main__':
    data = AnimalData(img_dir_path, label_file, transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    for i_batch,batch_data in enumerate(dataloader):
        print(i_batch)#打印batch编号
        print(batch_data['image'].size())#打印该batch里面图片的大小
        print(batch_data['label'])#打印该batch里面图片的标签

其实Dataset类不局限于这么写,它可以实现多种数据读取方法,只需要把读取数据以及数据处理逻辑写在__getitem__方法中即可,然后将处理好后的数据以及标签返回即可。

3、Dataloader

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

参数解释:

  1. dataset(Dataset): 传入的数据集
  2. batch_size(int, optional): 每个batch有多少个样本
  3. shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
  4. sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么shuffle必须为False
  5. batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
  6. num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
  7. collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
  8. pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
  9. drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
  10. 如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
  11. timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是大于等于0。默认为0
  12. worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。

因为dataloader是有batch_size参数的,我们可以通过自定义collate_fn=myfunction来设计数据收集的方式,意思是已经通过上面的Dataset类中的__getitem__函数采样了batch_size数据,以一个包的形式传递给collate_fn所指定的函数。

dataloader 对于数据的读取延迟主要取决于num_workerspin_memory这两个参数。首先,我先介绍一下比较简单的 pin_memory 参数。

3.1、什么是 pin_memory

所谓的 pin_memory 就是锁页内存的意思。

计算机为了运行进程会先将进程和数据读到内存里。一般来说,计算机的内存都是比较小的,很难存的下太多的数据。但是,某个进程在某个时间段所需的进程和数据往往是比较少的,也就是说在某个时间点我们不需要将一个进程所需要的所有资源都放在内存里。我们可以将这些暂时用不到的数据或进程存放在硬盘一个被称为虚拟内存的地方。在进程运行的时候,我们可以不断交换内存和虚拟内存的数据以减少内存所需存储的数据。而且这些交换往往是通过某些规律预测下个时刻进程会用到的数据和代码并提前交换至内存的,这些规律的使用以及预测的准确性将会影响到进程的速度。

所谓的锁页内存就是说,我们不允许系统将某些内存里的数据交换至虚拟内存,毋庸置疑这将会提升进程的运行速度。但是也会是内存的存储占用消耗很多。

pin_memory 为 true 的时候速度的提升会有多大

3.2、Dataloader 的多进程读数据细节

Dataloader 多进程读取数据的参数是通过num_workers指定的,num_workers 为 0 的话就用主进程去读取数据,num_workers 为 N 的话就会多开 N 个进程去读取数据。这里的多进程是通过 python 的 multiprocessing module 实现的(其实 pytorch 在 multiprocessing 又加了一个 wraper 以实现shared memory)。

关于 num_workers的工作原理:

  1. 开启num_workers个子进程(worker)。
  2. 每个worker通过主进程获得自己需要采集的ids。
    ids的顺序由采样器(sampler)或shuffle得到。然后每个worker开始采集一个batch的数据。(因此增大num_workers的数量,内存占用也会增加。因为每个worker都需要缓存一个batch的数据)
  3. 在第一个worker数据采集完成后,会卡在这里,等着主进程把该batch取走,然后采集下一个batch。
  4. 主进程运算完成,从第二个worker里采集第二个batch,以此类推。
  5. 主进程采集完最后一个worker的batch。此时需要回去采集第一个worker产生的第二个batch。如果第一个worker此时没有采集完,主线程会卡在这里等。(这也是为什么在数据加载比较耗时的情况下,每隔num_workers个batch,主进程都会在这里卡一下。)

 所以:

  • 如果内存有限,过大的num_workers会很容易导致内存溢出。
  • 可以通过观察是否每隔num_workers个batch后出现长时间等待来判断是否需要继续增大num_workers。如果没有明显延时,说明读取速度已经饱和,不需要继续增大。反之,可以通过增大num_workers来缓解。
  • 如果性能瓶颈是在io上,那么num_workers超过(cpu核数*2)是有加速作用的。但如果性能瓶颈在cpu计算上,继续增大num_workers反而会降低性能。(因为现在cpu大多数是每个核可以硬件级别支持2个线程。超过后,每个进程都是操作系统调度的,所用时间更长)

Dataloader 读数据的整个流程:

  1. 首先每个 worker 的进程会拥有一个 index_queue,dataloader 初始化的时候,每个 worker 的 index_queue 会放入两个batch 的 index。index 的放入是根据 worker 的 id顺序放入的。
  2. 每个 worker 的进程会不断检查自己的 index_queue 里有没有值,没有的话就继续检查。有的话,就去读一个 batch(这个读的过程是通过调用 dataset 的get_item()实现的,并通过函数将数据合并为一个 batch)。放入所有 worker 共享的 data_queue(如果指定了 pin_memory,这个新加的 batch 是会被放入 pin_memory 的)
  3. Dataloader 会返回一个迭代器,每迭代一次,首先进程会检查这次要 load 的 idx 数据是不是之前已经 load 过了(已经从共享的 data_queue 里取出来了),并事先放在一个字典里存起来了(为什么会 load 过,下面会解释),如果是的话,就直接拿来用。 如果没有 load 过,就从 data_queue 获取下一个 batch 和相应的 idx,但是这里从 data_queue 获得的 batch 可能不是按顺序的,因为有的 worker 可能比较快提前将它的数据读好放到 data_queue 里了。这时候我们将这个提前来的 batch 先保存到 self.reorder_dict 这个字典里面,这就解释了上面为什么会出现 load 过的问题。如果一直等不到我们就会一直将提前来的 batch 放入 self.reorder_dict 暂存,直至我们等到那个按顺序来的 batch。
  4. 在每次迭代成功的时候,dataloader 会放入一个新的 batch_index 到特定 worker 的 index_queue 里面

可以看出,dataloader 只会在每次迭代成功的时候才会放入新的 index 到 index_queue 里面。因为上面写了在初始化 dataloader 的时候,我们一共放了 2 x self.num_workers 个 batch 的 index 到 index_queue。读了一个 batch 才会放新的 batch,所以这所有的 worker 进程最多缓存的 batch 数量就是 2 x self.num_workers 个。

以上流程的如果想看代码可以参考:Pytorch Dataloader 学习笔记 · 大专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/133651.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

GO第 4 章:运算符

第 4 章 运算符 4.1 运算符的基本介绍 运算符是一种特殊的符号&#xff0c;用以表示数据的运算、赋值和比较等 运算符是一种特殊的符号&#xff0c;用以表示数据的运算、赋值和比较等 算术运算符 赋值运算符 比较运算符/关系运算符 逻辑运算符 位运算符 其它运算 4.2 …

Java开发环境安装

总步骤 第一步&#xff1a;安装JDK&#xff08;Java Development Kit&#xff0c;Java软件开发工具包&#xff09; 第二步&#xff1a;安装IDEA&#xff08;是Java语言的集成开发环境&#xff09; 一、安装JDK Windows下最简单的Java环境安装指南 - 大博哥VV6 - 博客园 (cnblo…

微信小程序框架

框架 小程序开发框架的目标是通过尽可能简单、高效的方式让开发者可以在微信中开发具有原生 APP 体验的服务。 整个小程序框架系统分为两部分&#xff1a;逻辑层&#xff08;App Service&#xff09;和 视图层&#xff08;View&#xff09;。小程序提供了自己的视图层描述语言…

【Linux】进程创建、进程终止和进程等待

​&#x1f320; 作者&#xff1a;阿亮joy. &#x1f386;专栏&#xff1a;《学会Linux》 &#x1f387; 座右铭&#xff1a;每个优秀的人都有一段沉默的时光&#xff0c;那段时光是付出了很多努力却得不到结果的日子&#xff0c;我们把它叫做扎根 目录&#x1f449;进程创建&…

力扣刷题记录——231. 2 的幂、228. 汇总区间、242. 有效的字母异位词

本专栏主要记录力扣的刷题记录&#xff0c;备战蓝桥杯&#xff0c;供复盘和优化算法使用&#xff0c;也希望给大家带来帮助&#xff0c;博主是算法小白&#xff0c;希望各位大佬不要见笑&#xff0c;今天要分享的是——《231. 2 的幂、228. 汇总区间、242. 有效的字母异位词》。…

【王道操作系统】2.2.4 作业进程调度算法(FCFS先来先服务、SJF短作业优先、HRRN高响应比优先)

作业进程调度算法(FCFS先来先服务、SJF短作业优先、HRRN高响应比优先) 文章目录作业进程调度算法(FCFS先来先服务、SJF短作业优先、HRRN高响应比优先)1.先来先服务(FCFS)2.短作业优先(SJF)3.高响应比优先(HRRN)4.三种算法的对比和总结1.先来先服务(FCFS) 先来先服务调度算法(F…

区间选点 and 最大不相交区间

区间选点 题目描述 给定 N 个闭区间 [ai,bi]&#xff0c;请你在数轴上选择尽量少的点&#xff0c;使得每个区间内至少包含一个选出的点。 输出选择的点的最小数量。 位于区间端点上的点也算作区间内。 输入输出及样例 最大不相交区间 题目描述 给定 N 个闭区间 [ai,bi]&…

ArcGIS基础实验操作100例--实验32计算栅格行列号

本实验专栏参考自汤国安教授《地理信息系统基础实验操作100例》一书 实验平台&#xff1a;ArcGIS 10.6 实验数据&#xff1a;请访问实验1&#xff08;传送门&#xff09; 高级编辑篇--实验32 计算栅格行列号 目录 一、实验背景 二、实验数据 三、实验步骤 &#xff08;1&am…

GPU存储器架构-- 全局内存 本地内存 寄存器堆 共享内存 常量内存 纹理内存

上表表述了各种存储器的各种特性。作用范围栏定义了程序的哪个部分能使用该存储器。而生存期定义了该存储器中的数据对程序可见的时间。除此之外&#xff0c;Ll和L2缓存也可以用于GPU程序以便更快地访问存储器。 总之&#xff0c;所有线程都有一个寄存器堆&#xff0c;它是最快…

【PDPTW】python调用guribo求解PDPTW问题(Li Lim‘s benchmark)之二

原文连接&#xff1a;知乎《使用Python调用Gurobi求解PDPTW问题&#xff08;Li & Lim’s benchmark&#xff09;》 分析文章&#xff1a;文章目录修改utlis.pytest.py运行DataPath"lc101.txt"修改 以及修改公示约束&#xff08;8&#xff09;与代码不符合的问题…

【QT开发笔记-基础篇】| 第五章 绘图QPainter | 5.13 抗锯齿

本节对应的视频讲解&#xff1a;B_站_视_频 https://www.bilibili.com/video/BV1YP4y1B7Ex 本节讲解抗锯齿效果 前面实现的效果中&#xff0c;仔细观看能看到明显的锯齿的效果&#xff0c;如下&#xff1a; 此时&#xff0c;可以增加抗锯齿的效果。 1. 关联信号槽 首先&…

22年12月日常实习总结

12月结束了&#xff0c;8月末开始准备的日常实习也算是告一段落了 准备了2个多月&#xff0c;面试了一个月&#xff0c;也拿了一些offer 算是小有感触&#xff0c;遂写下此文&#xff0c;供还在准备或者要准备日常实习的同学参考。 个人背景及投递的日常实录在这篇文章里 24…

RegNet——颠覆常规神经网络认知的卷积神经网络(网络结构详解+详细注释代码+核心思想讲解)——pytorch实现

RegNet的博客的准备我可谓是话费了很多的时间&#xff0c;参考了诸多大佬的资料&#xff0c;主要是网上对于这个网络的讲解有点少&#xff0c;毕竟这个网络很新。网上可以参考的资料太少&#xff0c;耗费了相当多的时间&#xff0c;不过一切都是值得的&#xff0c;毕竟学完之后…

第二证券|下周解禁市值超980亿元,多家机构参与解禁股评级

宁德年代迎来431.8亿元解禁。 下周A股解禁市值超980亿元 证券时报数据宝统计&#xff0c;1月3日至6日&#xff0c;A股商场将有53家上市公司迎来限售股解禁。以个股最新价计算&#xff0c;53股解禁市值合计981.68亿元。 从解禁规模来看&#xff0c;宁德年代和中国移动居前&…

4.搭建配置中心-使用SpringCloud Alibaba-Nacos

naocs除了做服务注册、发现&#xff0c;还可以做为配置中心&#xff0c;使用分以下几步 1.pom引入nacos-config依赖 <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-starter-alibaba-nacos-config</artifactId> &…

python中的多态和抽象类接口

目录 一.多态 抽象类&#xff08;接口&#xff09; 小结 一.多态 多态&#xff0c;指的是:多种状态&#xff0c;即完成某个行为时&#xff0c;使用不同的对象会得到不同的状态。 同样的行为&#xff08;函数&#xff09;&#xff0c;传入不同的对象得到不同的状态 演示 cl…

降维算法-sklearn

1.概述 维度&#xff1a; 对于数组和series&#xff0c;维度就是功能shape返回的结果&#xff0c;shape中返回了几个数字&#xff0c;就是几个维度。降维算法中的”降维“&#xff0c;指的是降低特征矩阵中特征的数量。降维的目的是为了让算法运算更快&#xff0c;效果更好&am…

LabVIEW​​开关模块与万用表DMM扫描模式

LabVIEW​​开关模块与万用表DMM扫描模式 在同步扫描模式下(Synchronous scanning)&#xff0c;扫描列表里面的每一条目都会在开关模块收到一个来自多功能数字万用表(DMM)的数字脉冲(触发输入)后执行.而DMM被编程设置为以一个固定的时间间隔去测量以及在每次测量完产生一个数字…

机器学习--数据清理、数据变换、特征工程

目录 一、数据清理 二、数据变换 三、特征工程 四、总结 一、数据清理 数据清理是提升数据的质量的一种方式。 数据不干净&#xff08;噪声多&#xff09;&#xff1f; 需要做数据的清理&#xff0c;将错误的信息纠正过来&#xff1b; 数据比较干净&#xff08;数据不是…

STM32 TIM PWM初阶操作:非互补PWM输出

STM32 TIM PWM初阶操作详解&#xff1a;非互补PWM输出 STM32 TIM可以输出管脚PWM信号适合多种场景使用&#xff0c;功能包括单线/非互补PWM输出&#xff0c;双线/互补PWM输出&#xff0c;以及死区时间和刹车控制等。 实际上&#xff0c;因为早期IP Core的缺陷&#xff0c;早期…