一文弄懂Pytorch的DataLoader,Dataset,Sampler之间的关系

news2025/1/12 13:29:32

很多文章都是从 D a t a s e t Dataset Dataset等对象自下网上进行介绍的,但是对于初学者而言,其实这并不好理解,因为有时候,会不自觉的陷入到一些细枝末节中去,而不能把握重点,所以本文将自上而下的对 P y t o r c h Pytorch Pytorch数据读取方法进行介绍。

自上而下理解三者关系

首先,我们看一下 D a t a L o a d e r . n e x t DataLoader.next DataLoader.next的源代码长什么样,为方便理解,我只选取了num_works为0的情况,(num_works)简单理解都是能够并行化读取数据

 def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

在阅读上面代码时候,我们可以假设,我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取的数据就只需要对应index即可,即上面代码中的 i n d i c e s indices indices,而选取index的方式有多种:有按顺序的,也有乱序的,所以这个工作需要 S a m p l e r Sampler Sampler来完成,现在你不需要具体的细节,后面会介绍,只需要了解 D a t a L o a d e r DataLoader DataLoader S a m p l e r Sampler Sampler在这里产生关系.
那么 D a t a s e t Dataset Dataset D a t a L o a d e r DataLoader DataLoader在什么时候产生关系呢?没错就是下面一行,我们已经拿到了 i n d i c e s indices indices,那么下一步,我们只需要根据 i n d i c e s indices indices对数据进行读取即可.

在下面 i f if if语句的作用都是,如果 p i n m e m o r y = T r u e , pin_memory=True, pinmemory=True,,那么 P y t o r c h Pytorch Pytorch会采用一系列操作把数据拷贝到GPU中,总之为了加速.

综上,可以了解DataLoader Sampler和Dataset三者关系如下:
在这里插入图片描述
在阅读后文中,始终需要将上面的关系记在心里,这样能帮助你更好的理解

Sampler

参数传递

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

要更加细致的理解 S a m p l e r Sampler Sampler原理,我们需要先阅读以下 D a t a L o a d e r DataLoader DataLoader的源代码 如下:
可以看到初始化参数有两种 S a m p l e r Sampler Sampler : Sampler和batch_sampler
都默认为None,前者作用是生成一系列 i n d e x index index,而batch_sampler则是将sampler生成indices打包分组,得到一个又一个batch的index,例如,下面所示示例:
Batchsampler将 S e q u e n t i a l S a m p l e r SequentialSampler SequentialSampler,生成的index按照指定的batchsize分组.

>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

pyTorch已经实现的sampler有以下几种

  • SequentialSampler

  • RandomSampler

  • WeightedSampler

  • SubsetRandomSampler
    需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读理解源码更深刻的理解,这里只做总结:

  • 源码

    • 如果自定义batch_sampler,那么这些参数都必须使用默认值:batch_size Shuffle sampler drop_last.
    • 如果自定义了sampler :那么shuffle需要设置为false
    • 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
      • 若shuffle = True时,则sampler=RandomSampler(dataset)
      • 若shuffle = False时,则sampler=SequentialSampler(dataset)

如果定义sampler和BatchSampler

仔细查看源代码可以发现,所有采样器其实都继承同一个父类,即 S a m p l e r Sampler Sampler,其代码定义如下:

class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
 
    def __init__(self, data_source):
        pass
 
    def __iter__(self):
        raise NotImplementedError
		
    def __len__(self):
        return len(self.data_source)

所以,你要做好的都是定义好__iter__(self) 函数,不过要注意的是该函数的返回值需要是可迭代的,例如 S e q u e n t i a l S a m p l e r SequentialSampler SequentialSampler返回的是:
iter(range(len(self.data_source)))
另外 B a t c h S a m p l e r BatchSampler BatchSampler与其他 S a m p l e r Sampler Sampler的主要区别是其需要将 S a m p l e r Sampler Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表,也就是说后面读取数据的过程中都是 b a t c h s a m p l e r batch sampler batchsampler.

Dataset

定义如下

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

上面三个方法最基本的,其中__getitem__是最主要的方法,其规定了如何读取数据,但是其又不同于一般的方法,因为它是 p y t h o n b u i l t − i n python built-in pythonbuiltin方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问,加入你定义好一个dataset,那么可以直接通过dataset[0]来访问第一个数据,在之前,我一值没弄清__getitem__是什么作用,所以一值不知道该怎么进入这个函数进行调试,现在如果你想对__getitem__方法进行调试,可以写一个for循环遍历dataset来进行调试,而不用构建dataloader等一大堆东西啦,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:

class DataLoader(object): 
    ... 
     
    def __next__(self): 
        if self.num_workers == 0:   
            indices = next(self.sample_iter)  
            batch = self.collate_fn([self.dataset[i] for i in indices]) # this line 
            if self.pin_memory: 
                batch = _utils.pin_memory.pin_memory_batch(batch) 
            return batch

我们仔细可以发现,前面有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前,我们需要知道每个参数的含义:

  • indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
  • self.dataset[i] 这里对第i个数据进行读取操作.
    一般来说:self.dataset[i]=(img, label)

我们不难猜出,collate_fn的作用就是将一个batch的数据进行合并的操作,默认的是collate_fn是将img和label分别合并成 i m g s imgs imgs l a b e l s labels labels,所以,如果你的__getitem__方法只是返回img,label.那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。

自己理解在这里插入图片描述

DataLoader Dataset和Sampler之间的关系

  • Sampler产生对数据进行采样
  • Dataset:产生数据
  • DataLoader将数据迭代产生batch_size数据格式.

总结

会自己看源代码,根据源代码了解,这里只是做总结
慢慢的将各种数据之间的关系都搞明白,全部都将其搞透彻.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/131522.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

HCIP第四天

HCIP实验配置一,实验要求二,172.16.0.0/16地址的划分三,搭建拓扑图四,配置IP地址和环回地址五,宣告并配置缺省路由下放,使用NAT技术六,R5中心站点配置隧道和静态IP七,R6分支站点的配…

canvas在小程序里写小游戏

最近接了个小需求需要写个小游戏,由简单的帧动画加上碰撞相关的处理,组成。具体页面信息如下图 具体的游戏步骤,是通过长按按钮蓄力,松开时卡通人物跳起,卡通人物跳起碰撞到上面的元宝等元素的得分,这里我们…

笔试题之编写SQL分析门店销售情况

销售员、客户、产品 文章目录前言一、SQL题目二、解答方法(一)建表插入测试数据(二)第一题解答(三)第二题解答(四)第三题解答总结前言 分享本人遇到的笔试真题与解法,并…

MATLAB算法实战应用案例精讲-【人工智能】语义分割(附实战应用案例及代码)

前言 语义分割是一种典型的计算机视觉问题,其涉及将一些原始数据(例如,平面图像)作为输入并将它们转换为具有突出显示的感兴趣区域的掩模。许多人使用术语全像素语义分割(full-pixel semantic segmentation),其中图像中的每个像素根据其所属的感兴趣对象被分配类别ID。…

[ XSS-labs通关宝典 ] xss-labs 通关宝典之 less1 - less5

🍬 博主介绍 👨‍🎓 博主介绍:大家好,我是 _PowerShell ,很高兴认识大家~ ✨主攻领域:【渗透领域】【数据通信】 【通讯安全】 【web安全】【面试分析】 🎉点赞➕评论➕收藏 养成习…

前端常见问题汇总(十)

一、HTTP1.0和HTTP2.0的区别 http1.0:每次请求都需要重新建立tcp连接,请求完后立即断开与服务器连接,这很大程度造成了性能上的缺陷,http1.0被抱怨最多的就是连接无法复用。 http1.1:引入了长连接(keep-al…

麒麟系统虚拟机安装教程

作者:朱金灿 来源:clever101的专栏 为什么大多数人学不会人工智能编程?>>> 1.首先得安装VM Ware软件。 2.打开VM Ware,点击“文件”->“新建虚拟机”。 3.进入新建虚拟机向导,点击下一步。如下图&…

API管理神器:Apifox

前言 代码未动,文档先行 其实大家都知道 API 文档先行的重要性,但是在实践过程中往往会遇到很多困难。 程序员最讨厌的两件事:1. 写文档,2. 别人不写文档。大多数开发人员不愿意写 API 文档的原因是写文档短期收益远低于付出的…

2023—静待“雨中的海棠”发芽

2023—静待“雨中的海棠”发芽认真负责、全身心的投入工作减少抱怨勤思考、多总结—>高效工作保持7*24小时在线全身心BKGWY坚持不懈多运动骑车车、练哑铃、慢跑多看书看自己喜欢的书环青海湖准备环青海湖的攻略身体上的准备内心信念的支撑最后就静待“雨中的海棠”发芽吧&am…

kali - 扫描

数据来源 Whatweb WhateWhatweb是一个基于Ruby语言的开源网站指纹识别软件,正如它的名字一样,,whate能够识别各种关于网站的详细信息,包括:CMS类型、博客平台、中间件、web框架模块、网站服务器、脚本类型、 Javascript库、lP、 …

Apollo 配置中心

Apollo 配置中心目录概述需求:设计思路实现思路分析1.Apollo 配置中心2.Client端配置中心3.爬虫调度器5.Server端配置中心参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardne…

(免费分享)基于jsp,ssm甜点网站

开发工具:eclipse,jdk1.8 数据库:mysql5.7,Tomcat8.0 package com.softeem.controller;import java.util.HashMap; import java.util.Map;import javax.annotation.Resource;import org.springframework.stereotype.Controller; …

labelImag安装及使用教程

在做目标检测任务时,需要进行标注,选择了LabelImg作为标注工具,下面是安装及使用过程。 我们使用Anconda的虚拟环境进行安装,激活环境后,执行: pip install labelimg -i https://pypi.tuna.tsinghua.edu.c…

WebSocket 协议详述( java在线聊天室_上篇)

文章目录1、 WebSocket 协议1.1、 何为WebSocket?1.2、 websocket 和 http(应用层的俩个协议)1.3、 websocket协议的具体过程1.4、websocket好处2、 WebSocket实现2.1、 客户端实现2.1.1、 websocket对象2.1.2、 websocket事件2.1.3、 websoc…

【linux】linux中vim/vi (linux基本开发工具)

本期主题:linux中vim/vi的使用和介绍。博客主页:小峰同学分享小编的在Linux中学习到的知识和遇到的问题小编的能力有限,出现错误希望大家不吝赐 目录 🍁vim键盘图 🍁vim基本概念 🍁vim的基本操作 &#x1…

Python使用库(二)

Python使用库(二) 第三方库 认识第三方库 第三方库就是别人已经实现好了的库, 我们可以拿过来直接使用. 虽然标准库已经很强大了, 但是终究是有限的. 而第三方库可以视为是集合了全世界 Python 程序猿的智慧, 可以说是几乎无穷无尽. 问题来了, 当我们…

Linux驱动入门-最简单字符设备驱动(基于pc ubuntu)

一.字符设备驱动概念 字符设备是 Linux 驱动中最基本的一类设备驱动,字符设备就是一个一个字节,按照字节流进行读写操作的设备,读写数据是分先后顺序的。比如我们最常见的点灯、按键、 IIC、 SPI,LCD 等等都是字符设备&#xff0…

公共管理老师赴英国G5名校-伦敦大学学院CSC公派访学

CSC青年骨干教师项目的实施院校一般都要求申请人提前上报邀请函等申请材料,以进行校内遴选。为提升竞争优势,A老师希望能获得英国名校的邀请函。最终我们为其申请到英国G5名校之一的伦敦大学学院,凭借该邀请函,A老师顺利通过了本校…

【2022年终总结】勇敢追梦,去和人生博弈

目录序言刚开始的1月松懈的2月忙碌的3月迷茫的4月开源项目的5月入职汇报的6月7月8月9月假期过后的10月至关重要的11月最后冲刺的12月2022年的总结2023年的目标往年回顾序言 在刚刚过完的平安夜和圣诞节之际,同时意味着2022年要画上一个句号。这一周算是比较煎熬的几…

高效的事件处理模式——Reactor、Proactor

IO模型 从理论上说,阻塞IO、IO复用和信号驱动IO都是同步IO模型。因为在这三种IO模型中,IO的读写操作,都是在IO事件发生之后,由应用程序来完成的。而POSIX规范所定义的异步IO模型则不同。对异步IO而言,用户可以直接对I…