再探pytorch的Dataset和DataLoader

news2025/1/15 17:09:27

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052

本文从分类检测分割三大任务的角度来剖析pytorch得datasetdataloader源码,可以让初学者深刻理解每个参数的由来和使用,并轻松自定义dataset。

思考:在探究Dataset和DataLoader之前,需要明白一个事情,就是当我们不管做是分类、检测还是分割任务时,我们的数据集一定由很多张图片组成的,形状大小各异;那麽我们在使用pytorch时,图片是怎么以一个batch的形式进行打包的呢,形状不同怎么处理,数据格式有什么要求,Dataset类中的初始化参数transform如何自定义,并是怎么调用的,在哪里调用?训练时如何取出DataLoader中的数据的?带着种种疑问,开始Dataset和DataLoader的探索之旅吧!

Dataset探究

Dataset的定义

  • 问题1:Dataset到底是什么,如何去调用?

  • 我们通常见到的Dataset的定义如下图,VOCSegmentation()在这里其实是一个

voc_train = VOCSegmentation('F:\pytorch_yolov3\yolo3-pytorch',year='2007',image_set='train',transform=transform,target_transform=target_transform)

自定义得Dataset类中,必有三个函数:__init__(),__getitem__(),__len__()函数,

  1. __init__(self,...):是几乎类中都会有的一个初始化函数,会在类进行实例化时初始化一些变量,供类中其他函数使用

  1. __getitem__(self, index):是python得magic method,一般如果想通过使用索引访问元素时,就可以在类中定义这个方法。

  1. __len__():返回数据集的长度

凡是在类中定义了这个__getitem__ 方法,那么它的实例对象(假定为p),可以像这样p[key] 取值,当实例对象做p[key] 运算时,会调用类中的方法__getitem__。

返回值为处理过后的img和seg

class VOCSegmentation(Dataset):
    def __init__('F:\pytorch_yolov3\yolo3-pytorch',year='2007',image_set='train',transform=transform,target_transform=target_transform):
        pass

    def __getitem__(self, index):
        pass

    def __len__(self):
        pass

分类任务的Dataset

  • 自定义分类Dataset

class MyDataset(Dataset): # 继承Dataset类
    def __init__(self, txt_path, transform=None, target_transform=None): # 定义txt_path参数
        fh = open(txt_path, 'r') # 读取txt文件
        imgs = []  # 定义imgs的列表
        for line in fh:
            line = line.rstrip() # 默认删除的是空白符('\n', '\r', '\t', ' ')
            words = line.split() # 默认以空格、换行(\n)、制表符(\t)进行分割,大多是"\"
            imgs.append((words[0], int(words[1]))) # 存放进imgs列表中

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index] # fn代表图片的路径,label代表标签
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1   参考:https://blog.csdn.net/icamera0/article/details/50843172

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)   # 返回图片的数量
  • 官方的MNIST类

def __init__(self, root, train=True, transform=None, download=False)

参数解析:

  • root:为数据集的路径,如果没找到本地路径,配合download使用,可以将download设置为True,会从网上为你下载数据

  • train:是一个boolean类型参数,用来定义数据使用的时train数据集还是test数据集,如果train为True,就会将train数据集路径加载进来,供后面读取数据使用。

  • download:boolean类型参数,如果为True,就会直接跳过本地数据集,去网上下载,下载后存放在当前工程目录下。

  • transform:可以是一个实例化对象,也可以使用torchvison.transforms.Compose来包裹一个含有多个实例化转换类对象的列表

注意:非常重要的是掌握transform参数需要的参数类型是什么,后面方便自己自定义transform,在自定义tranform前,先来看看pytorch官方自己定义的一些变换方法吧(变换都在torchvison.transforms.transforms.py中)。
  • transform解析

class CenterCrop(object):
    """Crops the given PIL Image at the center.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image.
        """
        return F.center_crop(img, self.size)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)

以CenterCrop为例,可以看出是类的实现。类中包含三个方法

  • __init__(self, size):用于初始化必要的变量

  • __call__(self, img):为python的magic方法,在类中实现这一方法可以使该类的实例(对象)像函数一样被调用。默认情况下该方法在类中是没有被实现的。使用callable()方法可以判断某对象是否可以被调用。

  • __repr__(self):支持打印功能,可以使用print(实例化对象),返回变换类名包括初始化的一些参数信息。

class People(object):
    def __init__(self,name):
        self.name=name

    def __call__(self):
        print("hello "+self.name)

a = People('无忌!')
a.__call__()       # 调用方法一
a()                # 调用方法二
#两种调用方法等价
#__call__()方法的作用其实是把一个类的实例化对象变成了可调用对象,
#也就是说把一个类的实例化对象变成了可调用对象,只要类里实现了__call__()方法就行。
#如当类里没有实现__call__()时,此时的对象p 只是个类的实例,不是一个可调用的对象,
#当调用它时会报错:‘Person’ object is not callable.
注意: __call__(self, img)方法中的参数是img,即图像,格式为PIL Image,所以一般是对图像做完各种图像变换之后再进行ToTensor()和Normalize()操作,顺序很重要,否则会报数据类型错误。__call__(self,img)的具体实现F.center_crop(img, self.size)在torchvision.transforms.functional.py中。

  • 问题:transforms中的compose为何可以以列表的形式接受各种变换

代码解析
class Compose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.

    Example:
        >>> transforms.Compose([
        >>>     transforms.CenterCrop(10),
        >>>     transforms.ToTensor(),
        >>> ])
    """

    def __init__(self, transforms):
        #先批量初始化变换方法
        self.transforms = transforms

    def __call__(self, img):
        #__call__方法是python的魔法方法,会
        #遍历list,得到第一个transforms.CenterCrop(10),每一个变换都是一个类,CenterCrop(10)
        #对中心裁剪类进行了实例化,真正调用__call__()方法是在,__getitem__(self, index)方法中
        #self.transform(img),返回变换后的img(循环顺序调用,像是一个管道,img从一边输入,从另一边
        #另一边输出,输出的图像已经变成了我们想要的模样)
        for t in self.transforms:
            img = t(img)
        return img

DataLoader

DataLoader的定义

问题:DataLoader到底是什么?
DataLoader也是以 类的方式实现的, DataLoader是一个可迭代对象,其类中实现了__iter__方法,因此在实例化DataLoader后,可以通过for循环的方式来迭代获取打包后的batch数据,在使用for循环遍历DataLoader时,会调用DataLoader里的__iter__方法,返回一个_SingleProcessDataLoaderIter(self)迭代器,后面每次迭代时会调用__next__方法获得下一批数据.

def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):
参数解析:
dataset:自定义的数据集
batch_size:每次喂入网络几张图
shuffle:数据是否乱序
sampler:一个采样器
batch_sampler:以sampler为参数,在遍历dataloader时,会对采样器进行迭代产生一个batch的index
num_workers:使用几个进程来加载数据,默认是0,由主进程来加载数据
collate_fn:主要用来打包通过一个batch的index取到的数据,将img,target分开打包成tensor

collate_fn的参数可以自定义实现,因为pytorch官方提供的collate_fn函数只是将数据简单stack在一起,如果需要实现一些其他功能,可以自行定义,然后作为参数传给collate_fn.

DataLoader案例

Dataset和DataLoader的关系

参考文献:

https://blog.51cto.com/u_15274944/2921682

https://zhuanlan.zhihu.com/p/35698470

https://zhuanlan.zhihu.com/p/27661382

https://zhuanlan.zhihu.com/p/359998425

https://blog.csdn.net/qq_28057379/article/details/115427052

https://zhuanlan.zhihu.com/p/30385675

https://cloud.tencent.com/developer/article/1728103

https://ai.deepshare.net/p/t_pc/course_pc_detail/video/v_5e1694ec1470e_SJ7JwMgA?product_id=p_5d552aea89dd9_CWCmRNUu&type=5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/419436.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL LIMIT

SQL LIMIT SQL LIMIT子句简介 要检索查询返回的行的一部分,请使用LIMIT和OFFSET子句。 以下说明了这些子句的语法: SELECT column_list FROMtable1 ORDER BY column_list LIMIT row_count OFFSET offset;在这个语法中, row_count确定将返…

Html5版贪吃蛇游戏制作(经典玩法)

回味经典小游戏,用Html5做了个贪吃蛇的小游戏,完成了核心经典玩法的功能。 游戏可以通过电脑的键盘“方向键”控制,也可以点击屏幕中的按钮进行控制。(支持移动端哈) 点击这里试玩 蛇的移动是在18 x 18的格子中进行移…

sqoop数据导入

创建数据库 mysql全表数据导入hdfs mysql查询数据导入hdfs mysql指定列导入hdfs 使用查询条件关键字将mysql数据导入hdfs mysql数据导入hive 创建数据库 hive中创建user表 create table users( id bigint, name string ) row format delimited fields terminated by &…

数据结构 - 归并排序 | C

思路分析 什么是归并&#xff1f; 示例&#xff1a;&#xff08;归并后的结果copy到原数组&#xff09; 逻辑&#xff1a; if (a[begin1] < a[begin2]) {tmp[i] a[begin1];} else {tmp[i] a[begin2];} 归并排序 分解到“有序”再归并 递归 int middle (left righ…

哈希——unordered系列关联式容器

目录 unordered系列关联式容器 概念 unordered_map 无序去重 operator[] unordered_set 无序去重 OJ练习题 重复n次的元素 两个数组的交集 两个数的交集二 底层结构 概念 哈希冲突 闭散列 结点的定义 扩容 字符串取模 插入 查找 删除 闭散列完整代码 开…

安卓远程控制软件哪个好用

如果您曾希望将个人电脑放在口袋里&#xff0c;那么您可能只需要安卓远程访问软件。 没有远程访问应用程序&#xff1a;使用和控制计算机的唯一方法是坐在计算机前并手动输入命令。 使用远程访问应用程序&#xff1a;您可以在世界任何地方通过 Internet 连接从您的安卓平板电…

【30天python从零到一】---第七天:列表和元组

&#x1f34e; 博客主页&#xff1a;&#x1f319;披星戴月的贾维斯 &#x1f34e; 欢迎关注&#xff1a;&#x1f44d;点赞&#x1f343;收藏&#x1f525;留言 &#x1f347;系列专栏&#xff1a;&#x1f319; Python专栏 &#x1f319;请不要相信胜利就像山坡上的蒲公英一样…

计算机组成原理---第五章中央处理器

&#xff08;一&#xff09;CPU 的功能和组成 CPU 的功能 Ⅰ 概述&#xff1a;当程序指令装入内存储器后&#xff0c;CPU 用来自动完成取指令和执行指令的任务。 Ⅱ CPU 的功能&#xff1a;①指令控制 ②操作控制 ③时间控制 ④数据加工 2.CPU 的基本组成 CPU 的基本部分为运…

【论文阅读】[JBHI] VLTENet、[ISBI]

[JBHI] VLTENet 论文连接&#xff1a;VLTENet: A Deep-Learning-Based Vertebra Localization and Tilt Estimation Network for Automatic Cobb Angle Estimation | IEEE Journals & Magazine | IEEE Xplore Published in: IEEE Journal of Biomedical and Health Infor…

9.1 相关分析

学习目标&#xff1a; 如果我要学习相关分析&#xff0c;我可能会按照以下步骤进行&#xff1a; 确定学习相关分析的目的和应用场景&#xff0c;例如研究两个变量之间的相关性、了解变量之间的关系、预测未来趋势等。学习相关分析的基本概念和原理&#xff0c;包括相关系数、…

VS——Visual Studio 2022 社区版——快捷键

VS——Visual Studio 2022 社区版——快捷键官网简介PDF完整PDF编辑编辑&#xff1a;常用快捷方式菜单栏 会显示 快捷键功能搜索大纲 折叠 展开Ctrl M M 切换官网 https://learn.microsoft.com/zh-cn/visualstudio/ide/default-keyboard-shortcuts-in-visual-studio?viewvs-2…

数据结构 — 【排序算法】

目录 1.排序的概念及其运用 1.1排序的概念 1.2排序运用 1.3 常见的排序算法 2.常见排序算法的实现 2.1 插入排序 直接插入排序 希尔排序 2.2 选择排序 直接选择排序 堆排序 2.3 交换排序 冒泡排序 快速排序 2.4 归并排序 2.5 非比较排序 计数排序 基数排序 3.排序算法…

【Unity入门】12.MonoBehaviour事件函数

【Unity入门】MonoBehaviour事件函数 大家好&#xff0c;我是Lampard~~ 欢迎来到Unity入门系列博客&#xff0c;所学知识来自B站阿发老师~感谢 &#xff08;一&#xff09;常用的事件函数 &#xff08;1&#xff09;start和update方法 之前我们写的脚本&#xff0c;会默认帮助…

4.3 分部积分法

学习目标&#xff1a; 学习分部积分法&#xff0c;我可能会按照以下步骤进行&#xff1a; 理解分部积分法的基本思想。分部积分法是一种通过对积分式中的不同部分进行乘积分解&#xff0c;然后对乘积中的某一项进行积分&#xff0c;对另一项进行微分&#xff0c;从而将原积分式…

NumPy 秘籍中文第二版:五、音频和图像处理

原文&#xff1a;NumPy Cookbook - Second Edition 协议&#xff1a;CC BY-NC-SA 4.0 译者&#xff1a;飞龙 在本章中&#xff0c;我们将介绍 NumPy 和 SciPy 的基本图像和音频&#xff08;WAV 文件&#xff09;处理。 在以下秘籍中&#xff0c;我们将使用 NumPy 对声音和图像进…

叮咚买菜基于 Apache Doris 统一 OLAP 引擎的应用实践

导读&#xff1a; 随着叮咚买菜业务的发展&#xff0c;不同的业务场景对数据分析提出了不同的需求&#xff0c;他们希望引入一款实时 OLAP 数据库&#xff0c;构建一个灵活的多维实时查询和分析的平台&#xff0c;统一数据的接入和查询方案&#xff0c;解决各业务线对数据高效实…

一键构建分布式云原生平台

目录专栏导读一、分布式云原生平台1、应用无所不能2、运行无处不在3、服务千行白业二、分布式云原生平台关键要素1、统一应用管理2、统一流量自治3、统一数据管理4、统一运维三、多云多集群已经广泛应用四、分布式云的优势&#xff1a;1、避免厂商锁定2、满足合规化要求3、增强…

收藏!7个国内「小众」的程序员社区

技术社区是大量开发者的集聚地&#xff0c;在技术社区可以了解到行业的最新进展&#xff0c;学习最前沿的技术&#xff0c;认识有相同爱好的朋友&#xff0c;在一起学习和交流。 国内知名的技术社区有CSDN、博客园、开源中国、51CTO&#xff0c;还有近两年火热的掘金&#xff…

基于决策树及集成算法的回归与分类案例

基于决策树及集成算法的回归与分类案例 描述 本任务基于决策树及集成算法分别实现鲍鱼年龄预测案例和肿瘤分类案例。鲍鱼年龄预测案例是建立一个回归模型&#xff0c;根据鲍鱼的特征数据&#xff08;长度、直径、高度、总重量、剥壳重量、内脏重量、壳重&#xff09;等预测其…

Python:超级大全网上面试题搜集整理(四)

转载参考&#xff1a; python 面试题(高级)_python高级面试题_梦幻python的博客-CSDN博客 cpython pypy_介绍Cython&#xff0c;Pypy Cpython Numba各有什么缺点【面试题详解】_函明的博客-CSDN博客 Cython、PyPy专题开篇 - 知乎 Python抽象类和接口类_python 接口类_代码输…