5-1 Dataset和DataLoader

news2025/1/4 15:05:56

Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。
Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。
DataLoader定义了按batch加载数据集的方法,它是一个实现了**iter**方法的可迭代对象,每次迭代输出一个batch的数据。
DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法(collate_fn),并且能够使用多进程读取数据。
在绝大部分情况下,用户只需实现Dataset的__len__方法和__getitem__方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。

一、深入理解Dataset和DataLoader的原理

1. 获取一个batch数据的步骤

让我们考虑一下从一个数据集中获取一个batch的数据需要哪些步骤。
(假定数据集的特征和标签分别表示为张量X和Y,数据集可以表示为(X,Y), 假定batch大小为m)
1,首先我们要确定数据集的长度n。
结果类似:n = 1000。
2,然后我们从0到n-1的范围中抽样出m个数(batch大小)。
假定m=4, 拿到的结果是一个列表,类似:indices = [1,4,8,9]
3,接着我们从数据集中去取这m个数对应下标的元素。
拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]
4,最后我们将结果整理成两个张量作为输出。
拿到的结果是两个张量,类似batch = (features,labels),
其中 features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])

2.Dataset和DataLoader的功能分工

上述第1个步骤确定数据集的长度是由 Dataset的__len__ 方法实现的。
第2个步骤从0到n-1的范围中抽样出m个数的方法是由 DataLoader 的 sampler 和 batch_sampler参数指定的。
sampler参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数
shuffle=True时采用随机抽样,shuffle=False时采用顺序抽样。

batch_sampler参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False时保留最后一个批次。
第3个步骤的核心逻辑根据下标取数据集中的元素 是由 Dataset的 getitem方法实现的。
第4个步骤的逻辑由DataLoader的参数
collate_fn
指定。一般情况下也无需用户设置。

import torch 
from torch.utils.data import TensorDataset,Dataset,DataLoader
from torch.utils.data import RandomSampler,BatchSampler 


ds = TensorDataset(torch.randn(1000,3),
                   torch.randint(low=0,high=2,size=(1000,)).float())
dl = DataLoader(ds,batch_size=4,drop_last = False)
features,labels = next(iter(dl))
print("features = ",features )
print("labels = ",labels )  

image.png

# step1: 确定数据集长度 (Dataset的 __len__ 方法实现)
ds = TensorDataset(torch.randn(1000,3),
                   torch.randint(low=0,high=2,size=(1000,)).float())
print("n = ", len(ds)) # len(ds)等价于 ds.__len__()

# step2: 确定抽样indices (DataLoader中的 Sampler和BatchSampler实现)
sampler = RandomSampler(data_source = ds)
batch_sampler = BatchSampler(sampler = sampler, 
                             batch_size = 4, drop_last = False)
for idxs in batch_sampler:
    indices = idxs
    break 
print("indices = ",indices)

# step3: 取出一批样本batch (Dataset的 __getitem__ 方法实现)
batch = [ds[i] for i in  indices]  #  ds[i] 等价于 ds.__getitem__(i)
print("batch = ", batch)

# step4: 整理成features和labels (DataLoader 的 collate_fn 方法实现)
def collate_fn(batch):
    features = torch.stack([sample[0] for sample in batch]) # torch.stack是一个torch库中的函数,用于沿着指定的维度对输入的张量序列进行堆叠(即堆叠张量)
    labels = torch.stack([sample[1] for sample in batch])
    return features,labels 

features,labels = collate_fn(batch)
print("features = ",features)
print("labels = ",labels)

image.png

3.Dataset和DataLoader的核心源码

import torch 
class Dataset(object):
    def __init__(self):
        pass
    
    def __len__(self):
        raise NotImplementedError
        
    def __getitem__(self,index):
        raise NotImplementedError
        

class DataLoader(object):
    def __init__(self,dataset, batch_size, collate_fn = None, shuffle = True, drop_last = False):
        self.dataset = dataset
        self.collate_fn = collate_fn
        self.sampler =torch.utils.data.RandomSampler if shuffle else \
           torch.utils.data.SequentialSampler
        self.batch_sampler = torch.utils.data.BatchSampler
        self.sample_iter = self.batch_sampler(
            self.sampler(self.dataset),
            batch_size = batch_size,drop_last = drop_last)
        self.collate_fn = collate_fn if collate_fn is not None else \
            torch.utils.data._utils.collate.default_collate
        
    def __next__(self):
        indices = next(iter(self.sample_iter))
        batch = self.collate_fn([self.dataset[i] for i in indices])
        return batch
    
    def __iter__(self):
        return self
    

对源码进行测试:

class ToyDataset(Dataset):
    def __init__(self,X,Y):
        self.X = X
        self.Y = Y 
    def __len__(self):
        return len(self.X)
    def __getitem__(self,index):
        return self.X[index],self.Y[index]
    
X,Y = torch.randn(1000,3),torch.randint(low=0,high=2,size=(1000,)).float()
ds = ToyDataset(X,Y)

dl = DataLoader(ds,batch_size=4,drop_last = False)
features,labels = next(iter(dl))
print("features = ",features )
print("labels = ",labels )  

image.png

二、使用Dataset创建数据集

Dataset创建数据集常用的方法有:

  • 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
  • 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
  • 继承 torch.utils.data.Dataset 创建自定义数据集。

此外,还可以通过

  • torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
  • 调用Dataset的加法运算符(+)将多个数据集合并成一个数据集。

根据Tensor创建数据集

创建数据集:

# 根据Tensor创建数据集

from sklearn import datasets 
iris = datasets.load_iris()
ds_iris = TensorDataset(torch.tensor(iris.data),torch.tensor(iris.target))

# 分割成训练集和预测集
n_train = int(len(ds_iris)*0.8)
n_val = len(ds_iris) - n_train
ds_train,ds_val = random_split(ds_iris,[n_train,n_val])

print(type(ds_iris))
print(type(ds_train))

image.png
加载数据集:

# 使用DataLoader加载数据集
dl_train,dl_val = DataLoader(ds_train,batch_size = 8),DataLoader(ds_val,batch_size = 8)

for features,labels in dl_train:
    print(features,labels)
    break

image.png
演示加法运算符(+)的合并作用:

# 演示加法运算符(`+`)的合并作用

ds_data = ds_train + ds_val

print('len(ds_train) = ',len(ds_train))
print('len(ds_valid) = ',len(ds_val))
print('len(ds_train+ds_valid) = ',len(ds_data))

print(type(ds_data))

image.png

根据图片目录创建图片数据集

先定义图片增强操作:

# 定义图片增强操作

transform_train = transforms.Compose([
   transforms.RandomHorizontalFlip(), #随机水平翻转
   transforms.RandomVerticalFlip(), #随机垂直翻转
   transforms.RandomRotation(45),  #随机在45度角度内旋转
   transforms.ToTensor() #转换成张量
  ]
) 

transform_valid = transforms.Compose([
    transforms.ToTensor()
  ]
)

根据图片目录创建数据集:

# 根据图片目录创建数据集

def transform_label(x):
    return torch.tensor([x]).float()

ds_train = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/train/",
            transform = transform_train,target_transform= transform_label)
ds_val = datasets.ImageFolder("./eat_pytorch_datasets/cifar2/test/",
                              transform = transform_valid,
                              target_transform= transform_label)


print(ds_train.class_to_idx)

# 使用DataLoader加载数据集

dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)


for features,labels in dl_train:
    print(features.shape)
    print(labels.shape)
    break

image.png

创建自定义数据集

下面我们通过另外一种方式,即继承 torch.utils.data.Dataset 创建自定义数据集的方式来对 cifar2构建 数据管道。

from pathlib import Path 
from PIL import Image 

class Cifar2Dataset(Dataset): # 继承torch.utils.data.Dataset
    def __init__(self,imgs_dir, img_transform):
        self.files = list(Path(imgs_dir).rglob("*.jpg"))
        self.transform = img_transform
        
    def __len__(self,):
        return len(self.files)
    
    def __getitem__(self,i):
        file_i = str(self.files[i])
        img = Image.open(file_i)
        tensor = self.transform(img)
        label = torch.tensor([1.0]) if  "1_automobile" in file_i else torch.tensor([0.0])
        return tensor,label 
    
    
train_dir = "./eat_pytorch_datasets/cifar2/train/"
test_dir = "./eat_pytorch_datasets/cifar2/test/"

使用:

# 定义图片增强
transform_train = transforms.Compose([
   transforms.RandomHorizontalFlip(), #随机水平翻转
   transforms.RandomVerticalFlip(), #随机垂直翻转
   transforms.RandomRotation(45),  #随机在45度角度内旋转
   transforms.ToTensor() #转换成张量
  ]
) 

transform_val = transforms.Compose([
    transforms.ToTensor()
  ]
)
ds_train = Cifar2Dataset(train_dir,transform_train)
ds_val = Cifar2Dataset(test_dir,transform_val)


dl_train = DataLoader(ds_train,batch_size = 50,shuffle = True)
dl_val = DataLoader(ds_val,batch_size = 50,shuffle = True)


for features,labels in dl_train:
    print(features.shape)
    print(labels.shape)
    break

image.png

三、使用DataLoader加载数据集

DataLoader能够控制batch的大小batch中元素的采样方法(随机否),以及将batch结果整理成模型所需输入形式的方法(collate_fn),并且能够使用多进程读取数据
DataLoader的函数签名如下。

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None,
)

一般情况下,我们仅仅会配置 dataset, batch_size, shuffle, num_workers, pin_memory, drop_last这六个参数,
有时候对于一些复杂结构的数据集,还需要自定义collate_fn函数,其他参数一般使用默认值即可。
DataLoader除了可以加载我们前面讲的 torch.utils.data.Dataset 外,还能够加载另外一种数据集 torch.utils.data.IterableDataset。
和Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。 它更加复杂,一般较少使用。

  • dataset : 数据集
  • batch_size: 批次大小
  • shuffle: 是否乱序
  • sampler: 样本采样函数,一般无需设置
  • batch_sampler: 批次采样函数,一般无需设置
  • num_workers: 使用多进程读取数据,设置的进程数。
  • collate_fn: 整理一个批次数据的函数
  • pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
  • drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
  • timeout: 加载一个数据批次的最长等待时间,一般无需设置。
  • worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。
#构建输入数据管道
ds = TensorDataset(torch.arange(1,50))
dl = DataLoader(ds,
                batch_size = 10,
                shuffle= True,
                num_workers=2,
                drop_last = True)
#迭代数据
for batch, in dl:
    print(batch)

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

无涯教程-JavaScript - EVEN函数

描述 EVEN函数返回四舍五入到最接近的偶数整数的数字。您可以使用此功能来处理两个项目。 语法 EVEN (number)争论 Argument描述Required/OptionalNumberThe value to round.Required Notes 如果数字为非数字,则EVEN返回#VALUE!错误值。 不管数字的符号如何,当从零开始调…

VisualStudio配置驱动远程部署

目标机器开启ping命令 默认情况下,Windows出于安全考虑不允许外部主机对其进行Ping测试。 允许ICMP回显 设置如下: 打开win7防火墙设置界面 左边的菜单中选择 【高级设置】 在弹出的 【高级安全 Windows 防火墙】 界面,选择 【入站规则】 …

Java常见面试题(含答案,持续更新中~~)

目录 1、JVM、JRE和JDK的关系 2、什么是字节码?采用字节码的最大好处是什么 3、Java和C的区别与联系 4、Java和GO的区别与联系 5、 和 equals 的区别是什么? 6、Oracle JDK 和 OpenJDK 的对比 7、String 属于基础的数据类型吗? 8、fi…

VHOST-SCSI代码分析(3)数据流处理

VHOST SCSI数据流如下所示: IO下发过程 虚拟机中应用态程序下发IO,依次经过VFS/文件系统层,BLOCK层,SCSI层,经VIRTIO SCSI驱动virtscsi_commit_rqs访问寄存器通知HOST内核中VHOST设备(VHOST KICK过程&#…

tolua源码分析(十一)代码生成

tolua源码分析(十一)代码生成 上一节我们分析了tolua中struct数据在lua和C#之间传递的过程,这一节我们来看一下tolua自动生成各种辅助代码的流程。 生成所有代码的入口位于ToLuaMenu.cs的GenLuaAll: [MenuItem("Lua/Genera…

达梦数据库-DW-国产化--九五小庞

武汉达梦数据库股份有限公司成立于2000年,是国内领先的数据库产品开发服务商,国内数据库基础软件产业发展的关键推动者。公司为客户提供各类数据库软件及集群软件、云计算与大数据等一系列数据库产品及相关技术服务,致力于成为国际顶尖的全栈…

读取yaml文件的值

记录一下,读取yaml文件中属性的值,这里用Kubernetes的deployment.yaml文件来举例。 读取yaml文件中的image的值 yaml文件 apiVersion: apps/v1 # 1.9.0 之前的版本使用 apps/v1beta2,可通过命令 kubectl api-versions 查看 kind: Deploy…

获取中文词组的汉语拼音首字母拼接

我们需要一个快捷批量处理&#xff1a;中文词组获取其汉语拼音首字母并拼接起来。 比如&#xff1a; 输出功率3&#xff1a;SCGL3 一鸣惊人&#xff1a;YMJR 我们可以采用字符字典法&#xff0c;穷举出所有的汉字【暂只考虑简体中文】 Dictionary<char,string> dict…

【数据分享】2006-2021年我国省份级别的市容环境卫生相关指标(20多项指标)

《中国城市建设统计年鉴》中细致地统计了我国城市市政公用设施建设与发展情况&#xff0c;在之前的文章中&#xff0c;我们分享过基于2006-2021年《中国城市建设统计年鉴》整理的2006—2021年我国省份级别的市政设施水平相关指标、2006-2021年我国省份级别的各类建设用地面积数…

【C++】STL—— unordered_map的介绍和使用、 unordered_map的构造函数和迭代器、 unordered_map的增删查改函数

文章目录 1. unordered_map的介绍2. unordered_map的使用2.1unordered_map的构造函数2.2unordered_map的迭代器2.3unordered_map的容量和访问函数2.4unordered_map的增删查改函数 1. unordered_map的介绍 unordered_map的介绍 &#xff08;1&#xff09;unordered_map是存储&l…

ElasticSearch(ES)简答了解

ES简介 Elasticsearch&#xff08;通常简称为ES&#xff09;是一个开源的分布式搜索和分析引擎&#xff0c;旨在处理各种类型的数据&#xff0c;包括结构化、半结构化和非结构化数据。它最初是为全文搜索而设计的&#xff0c;但随着时间的推移&#xff0c;它已经演变成一个功能…

web系统安全设计原则

一、前言 近日&#xff0c;针对西工大网络被攻击&#xff0c;国家计算机病毒应急处理中心和360公司对一款名为“二次约会”的间谍软件进行了技术分析。分析报告显示&#xff0c;该软件是美国国家安全局&#xff08;NSA&#xff09;开发的网络间谍武器。当下&#xff0c;我们发现…

【骑行之旅】昆明草海湿地公园和海晏村的美丽邂逅

这是一个九月的星期六&#xff0c;在昆明的大观公园门口&#xff0c;我们集合了一群热爱骑行的骑友。今天&#xff0c;阳光明媚&#xff0c;天空湛蓝&#xff0c;一切都充满了活力。我们的旅程从这里开始&#xff0c;一路向西&#xff0c;向着下一站&#xff0c;美丽的草海湿地…

虚拟机用户切换及设置root权限的密码

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

驱动开发,IO多路复用(select,poll,epoll三种实现方式的比较)

1.IO多路复用介绍 在使用单进程或单线程情况下&#xff0c;同时处理多个输入输出请求&#xff0c;需要用到IO多路复用&#xff1b;IO多路复用有select/poll/epoll三种实现方式&#xff1b;由于不需要创建新的进程和线程&#xff0c;减少了系统资源的开销&#xff0c;减少了上下…

从0到1搭建Halo博客系统教程

前期准备 云服务器&#xff0c;域名&#xff0c;命令工具&#xff08;这里使用是Mobaxterm&#xff09; 安装环境 宝塔面板 yum install -y wget && wget -O install.sh https://download.bt.cn/install/install_6.0.sh && sh install.sh ed8484bec在命令工…

【计算机视觉】Image Generation Models算法介绍合集

文章目录 一、Diffusion二、Guided Language to Image Diffusion for Generation and Editing&#xff08;GLIDE&#xff09;三、classifier-guidance四、Blended Diffusion五、DALLE 2六、AltDiffusion七、Group Decreasing Network八、Make-A-Scene九、Iterative Inpainting十…

c++ - 抽象类 和 使用多态当中一些注意事项

抽象类 纯虚函数 在虚函数的后面写上 0 &#xff0c;则这个函数为纯虚函数。 class A { public:virtual void func() 0; }; 纯虚函数不需要写函数的定义&#xff0c;他有类似声明一样的结构。 抽象类概念 我们把具有纯虚函数的类&#xff0c;叫做抽象类。 所谓抽象就是&a…

docker gitlab+jenkins搭建

一&#xff1a;gitlab搭建: 1&#xff1a;docker部署 2&#xff1a;修改root密码 3&#xff1a;创建普通账户 4&#xff1a;设置sshken 二&#xff1a;jenkins搭建 配置脚本 bash -x /var/jenkins_home/shell/game01.sh

图解数据结构

&#x1f31e;欢迎来到数据结构的世界 &#x1f308;博客主页&#xff1a;卿云阁 &#x1f48c;欢迎关注&#x1f389;点赞&#x1f44d;收藏⭐️留言&#x1f4dd; &#x1f31f;本文由卿云阁原创&#xff01; &#x1f4c6;首发时间&#xff1a;&#x1f339;2023年9月17日&…