Pytorch的torch.utils.data中Dataset以及DataLoader等详解

news2025/1/6 17:14:28

在我们进行深度学习的过程中,不免要用到数据集,那么数据集是如何加载到我们的模型中进行训练的呢?以往我们大多数初学者肯定都是拿网上的代码直接用,但是它底层的原理到底是什么还是不太清楚。所以今天就从内置的Dataset函数和自定义的Dataset函数做一个详细的解析。

文章目录

  • 前言
  • 1、自定义Dataset类
  • 2、torchvision.datasets
  • 3、DataLoader
  • 4、torchvision.transforms

前言

torch.utils.dataPyTorch提供的一个模块,用于处理和加载数据。该模块提供了一系列工具类和函数,用于创建、操作和批量加载数据集。

下面是 torch.utils.data 模块中一些常用的类和函数:

  • Dataset: 定义了抽象的数据集类,用户可以通过继承该类来构建自己的数据集。Dataset 类提供了两个必须实现的方法:__getitem__ 用于访问单个样本,__len__ 用于返回数据集的大小。
  • TensorDataset: 继承自 Dataset 类,用于将张量数据打包成数据集。它接受多个张量作为输入,并按照第一个输入张量的大小来确定数据集的大小。
  • DataLoader: 数据加载器类,用于批量加载数据集。它接受一个数据集对象作为输入,并提供多种数据加载和预处理的功能,如设置批量大小、多线程数据加载和数据打乱等。
  • Subset: 数据集的子集类,用于从数据集中选择指定的样本。
  • random_split: 将一个数据集随机划分为多个子集,可以指定划分的比例或指定每个子集的大小。
  • ConcatDataset: 将多个数据集连接在一起形成一个更大的数据集。
  • get_worker_info: 获取当前数据加载器所在的进程信息。

除了上述的类和函数之外,torch.utils.data 还提供了一些常用的数据预处理的工具,如随机裁剪、随机旋转、标准化等。

通过 torch.utils.data 模块提供的类和函数,可以方便地加载、处理和批量加载数据,为模型训练和验证提供了便利。但是,我们最常用的两个类还是DatasetDataLoader类。

1、自定义Dataset类

torch.utils.data.Dataset是 PyTorch 中用于表示数据集的抽象类,用于定义数据集的访问方式和样本数量。

Dataset 类是一个基类,我们可以通过继承该类并实现下面两个方法来创建自定义的数据集类:

getitem(self, index): 根据给定的索引 index,返回对应的样本数据。索引可以是一个整数,表示按顺序获取样本,也可以是其他方式,如通过文件名获取样本等。
len(self): 返回数据集中样本的数量。

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)
# 3

上面的代码样例主要实现的是一个自定义Dataset数据集类的方法,这一般都是在我们需要训练自己的数据时候需要定义的。但是一般我们作为深度学习初学者来讲,使用的都是MNIST、CIFAR-10等内置数据集,这时候就不需要再自己定义Dataset类了。至于为什么,我们下面进行详解。

2、torchvision.datasets

如果要使用PyTorch中的内置数据集,通常是通过torchvision.datasets模块来实现。torchvision.datasets模块提供了许多常用的计算机视觉数据集,如MNIST、CIFAR10、ImageNet等。

下面是使用内置数据集的示例代码:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

在上述代码中,我们实现的便是一个内置MNIST(手写数字)数据集的加载和使用。可以看到,我们在这里面并未用到上面所提到的torch.utils.data.Dataset类,这是为什么呢?

这是因为在 torchvision.datasets 模块中,内置的数据集类已经实现了torch.utils.data.Dataset 接口,并直接返回一个可用的数据集对象。因此,在使用内置数据集时,我们可以直接实例化内置数据集类,而不需要显式地继承 torch.utils.data.Dataset 类。

内置数据集类(如 torchvision.datasets.MNIST)的实现已经包含了对 __getitem____len__ 方法的定义,这使得我们可以直接从内置数据集对象中获取样本和确定数据集的大小。这样,我们在使用内置数据集时可以直接将内置数据集对象传递给 torch.utils.data.DataLoader 进行数据加载和批量处理。

在内置数据集的背后,它们仍然是基于 torch.utils.data.Dataset 类进行实现,只是为了方便使用和提供更多功能,PyTorch 将这些常用数据集封装成了内置的数据集类。

为此,我专门到pytorch官网去查看了该内置数据集的加载代码,如下图所示:
在这里插入图片描述
可以看出,确实以及内置了Dataset数据集类。

3、DataLoader

torch.utils.data.DataLoader 是 PyTorch 中用于批量加载数据的工具类。它接受一个数据集对象(如 torch.utils.data.Dataset 的子类)并提供多种功能,如数据加载、批量处理、数据打乱等。

以下是 torch.utils.data.DataLoader 的常用参数和功能:

  • dataset: 数据集对象,可以是 torch.utils.data.Dataset 的子类对象。
  • batch_size: 每个批次的样本数量,默认为 1。
  • shuffle: 是否对数据进行打乱,默认为 False。在每个 epoch 时会重新打乱数据。
  • num_workers: 使用多少个子进程加载数据,默认为 0,表示在主进程中加载数据。其实在Windows系统里面都设置为0,但是在Linux中可以设置成大于0的数。
  • collate_fn: 在返回批次数据之前,对每个样本进行处理的函数。如果为 None,默认使用 torch.utils.data._utils.collate.default_collate 函数进行处理。
  • drop_last: 是否丢弃最后一个样本数量不足一个批次的数据,默认为 False
  • pin_memory: 是否将加载的数据存放在 CUDA 对应的固定内存中,默认为 False
  • prefetch_factor: 预取因子,用于预取数据到设备,默认为 2。
  • persistent_workers: 如果为 True,则在每个 epoch 中使用持久的子进程进行数据加载,默认为 False

示例代码如下:

import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...

4、torchvision.transforms

torchvision.transforms模块是PyTorch中用于图像数据预处理的功能模块。它提供了一系列的转换函数,用于在加载、训练或推断图像数据时进行各种常见的数据变换和增强操作。下面是一些常用的转换函数的详细解释:

  1. Resize:调整图像大小

    • Resize(size):将图像调整为给定的尺寸。可以接受一个整数作为较短边的大小,也可以接受一个元组或列表作为图像的目标大小。
  2. ToTensor:将图像转换为张量

    • ToTensor():将图像转换为张量,像素值范围从0-255映射到0-1。适用于将图像数据传递给深度学习模型。
  3. Normalize:标准化图像数据

    • Normalize(mean, std):对图像数据进行标准化处理。传入的mean和std是用于像素值归一化的均值和标准差。需要注意的是,mean和std需要与之前使用的数据集相对应。
  4. RandomHorizontalFlip:随机水平翻转图像

    • RandomHorizontalFlip(p=0.5):以给定的概率对图像进行随机水平翻转。概率p控制翻转的概率,默认为0.5。
  5. RandomCrop:随机裁剪图像

    • RandomCrop(size, padding=None):随机裁剪图像为给定的尺寸。可以提供一个元组或整数作为目标尺寸,并可选地提供填充值。
  6. ColorJitter:颜色调整

    • ColorJitter(brightness=0, contrast=0, saturation=0, hue=0):随机调整图像的亮度、对比度、饱和度和色调。可以通过设置不同的参数来调整图像的样貌。

在使用的时候,我们常常通过transforms.Compose来对这些数据处理操作进行一个组合,使用的时候,直接调用该组合即可。

示例代码如下:

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/902295.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Go】Go 文本匹配 - 正则表达式基础与编程中的应用 (8000+字)

正则表达式(Regular Expression, 缩写常用regex, regexp表示)是计算机科学中的一个概念,很多高级语言都支持正则表达式。 目录 何为正则表达式 语法规则 普通字符 字符转义 限定符 定位符 分组构造 模式匹配 regexp包 MatchString…

websocker无法注入依赖

在公司中准备用websocker统计在线人数,在WebSocketServer使用StringRedisTemplate保存数据到redis中去,但是在保存的时候显示 StringRedisTemplate变量为null 详细问题 2023-08-20 10:37:14.109 ERROR 28240 --- [nio-7125-exec-1] o.a.t.websocket.po…

【AI】文心一言的使用

一、获得内测资格: 1、点击网页链接申请:https://yiyan.baidu.com/ 2、点击加入体验,等待通过 二、获得AI伙伴内测名额 1、收到短信通知,点击链接 网页Link:https://chat.baidu.com/page/launch.html?fa&sourc…

LeetCode_Java_2236. 判断根结点是否等于子结点之和

2236. 判断根结点是否等于子结点之和 给你一个 二叉树 的根结点 root,该二叉树由恰好 3 个结点组成:根结点、左子结点和右子结点。 如果根结点值等于两个子结点值之和,返回 true ,否则返回 false 。 示例1 输入:roo…

【CMake保姆级教程】CMake的使用

文章目录 前言CMake的使用注释注释行注释块 CMake操作共处一室VIP 包房 前言 在上节课我们已经讲了CMake的安装和简单使用,本节课我们来讲解CMake的命令和他的含义 CMake的使用 CMake支持大写、小写、混合大小写的命令。如果在编写CMakeLists.txt文件时使用的工具…

vue3中的router和useRouter的区别

结论:从vue-router的官方文档中发现,useRouter需要在setup中使用,而router可以在任何组件中使用。二者是等价的,适用场景不同。 问题:使用useRouter创建的实例,未能成功调用方法push()。会报错:…

gdb 常用命令

gdb 常用命令 文章目录 gdb 常用命令gdb 调试一般步骤常用命令infostep、next、continue、finish、untilexaminebreakinfo 、enable、disable和delete命令 backtrace和framebacktraceframe listprintwhatis和ptypethreadnext、stepreturn、finishuntiljumpdisassembleset args …

攻防世界-can_has_stdio?

原题 解题思路 这使用的是brainfuck语言,语言介绍如下:Brainfuck详解。 使用网站解码即可:CTF在线工具。

《HeadFirst设计模式(第二版)》第十一章代码——代理模式

代码文件目录: RMI: MyRemote package Chapter11_ProxyPattern.RMI;import java.rmi.Remote; import java.rmi.RemoteException;public interface MyRemote extends Remote {public String sayHello() throws RemoteException; }MyRemoteClient packa…

前端Vue自定义等分底部菜单导航按钮 自适应文字宽度 可更改组件位置

随着技术的发展,开发的复杂度也越来越高,传统开发方式将一个系统做成了整块应用,经常出现的情况就是一个小小的改动或者一个小功能的增加可能会引起整体逻辑的修改,造成牵一发而动全身。 通过组件化开发,可以有效实现…

Windows 平台下微软开发的神器 PowerToys 使用笔记

文章目录 Part.I IntroductionPart.II 安装Part.III 常用操作Chap.I 快捷键Chap.II 分屏示例 Reference Part.I Introduction PowerToys 是一款来自微软的系统增强工具,就像是一个神奇的 Win10 外挂,整套软件由若干子组件构成,包括&#xff…

从零开始创建 Spring Cloud 分布式项目,不会你打我

目录 一、Spring Cloud 和 分布式 二、创建新项目 三、导入 Spring Cloud 依赖 四、配置 Spring Cloud 一、Spring Cloud 和 分布式 Spring Cloud是一个基于Spring框架的开源微服务框架,它提供了一系列工具和组件,用于帮助开发人员构建分布式系统中…

EndNote极简入门【如何使用】

在开始菜单栏打开: (点左下角忽略即可) 第一次打开呢就如下图所示: 空白的,只有一个灰色的界面: 新建一个自己库: 然后就会弹出东西啦: 导出endnote学术期刊的引文:&…

人工智能时代未来程序员必备的三大利器:异,理,说

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

Python学习:迭代器与生成器的深入解析

函数在Python中扮演着重要角色,不仅可以封装代码逻辑,还能通过迭代器和生成器这两种强大的技术,实现更高效的数据处理和遍历。本篇博客将深入探讨Python函数的迭代器和生成器,结合实际案例为你揭示它们的神奇,以及如何…

《论文阅读18》 SSD: Single Shot MultiBox Detector

一、论文 研究领域: 2D目标检测论文:SSD: Single Shot MultiBox Detector ECCV 2016 论文链接论文github 二、论文概要 SSD网络是作者Wei Liu在ECCV 2016上发表的论文。对于输入尺寸300x300的网络 使用Nvidia Titan X在VOC 2007测试集上达到74.3%mA…

无涯教程-PHP - 条件判断

if... elseif ... else和switch语句用于根据不同条件进行判断。 您可以在代码中使用条件语句来做出决定, PHP支持以下三个决策语句- if ... else语句 - 如果要在条件为真时执行,而在条件不为真时执行另一个代码,请使用此语句 els…

nodejs+vue古诗词在线测试管理系统

一开始,本文就对系统内谈到的基本知识,从整体上进行了描述,并在此基础上进行了系统分析。为了能够使本系统较好、较为完善的被设计实现出来,就必须先进行分析调查。基于之前相关的基础,在功能上,对新系统进…

免费开源使用的几款红黑网络流量工具,自动化的多功能网络侦查工具、超级关键词URL采集工具、Burpsuite被动扫描流量转发插件

免费开源使用的几款红黑网络流量工具,自动化的多功能网络侦查工具、超级关键词URL采集工具、Burpsuite被动扫描流量转发插件。 #################### 免责声明:工具本身并无好坏,希望大家以遵守《网络安全法》相关法律为前提来使用该工具&am…

C++ string 的用法

目录 string类string类接口函数及基本用法构造函数,析构函数及赋值重载函数元素访问相关函数operator[]atback和front 迭代器iterator容量操作size()和length()capacity()max_sizeclearemptyreserveresizeshrink_to_fit string类对象修改操作operatorpush_backappen…