文章目录
- 一、前置知识
- 1.dataloader简要介绍
- 2.dataloader 官方文档(翻译后)
- 二、DataLoader的使用
一、前置知识
1.dataloader简要介绍
DataLoader 是 PyTorch 中用于加载数据的实用工具,它可以处理数据集的批量加载、数据集的随机打乱、多进程数据加载等功能。通过使用 DataLoader,可以更高效地将数据提供给模型进行训练或推理。
具体来说,DataLoader 提供了以下功能:
数据批量加载:DataLoader 可以将数据集划分为固定大小的批次,使得模型可以逐批次地处理数据。
数据集随机打乱:在训练模型时,通常会希望对数据集进行随机打乱,以避免模型学习到数据的顺序性特征。DataLoader 可以在每个周期(epoch)开始时对数据集进行随机打乱。
多进程数据加载:DataLoader 支持多进程数据加载,可以加快数据加载速度,尤其是当数据预处理耗时较长时。
自定义数据加载顺序:可以通过设置 sampler 或 batch_sampler 参数来自定义数据加载的顺序,比如指定按照某种策略抽取样本。
2.dataloader 官方文档(翻译后)
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)
DataLoader 类结合了数据集和采样器,并提供了对给定数据集的可迭代访问。
DataLoader 支持 map-style 和 iterable-style 数据集,可以进行单进程或多进程加载,自定义加载顺序,以及可选的自动分批(collation)和内存固定。
参数:
dataset(Dataset):要加载数据的数据集。
batch_size(int,可选):每个批次要加载的样本数(默认为 1)。
shuffle(bool,可选):设置为 True 时,每个周期都会对数据进行重新洗牌(默认为 False)。
sampler(Sampler 或 Iterable,可选):定义从数据集中抽取样本的策略。可以是任何实现了 len 方法的可迭代对象。如果指定了 sampler,则不能指定 shuffle。
batch_sampler(Sampler 或 Iterable,可选):类似于 sampler,但一次返回一批索引。与 batch_size、shuffle、sampler 和 drop_last 互斥。
num_workers(int,可选):用于数据加载的子进程数。0 表示数据将在主进程中加载(默认为 0)。
collate_fn(Callable,可选):合并样本列表以形成 Tensor 的小批量。当从 map-style 数据集进行批处理加载时使用。
pin_memory(bool,可选):如果为 True,则数据加载器将在返回数据之前将 Tensor 复制到设备/CUDA 固定内存中。
drop_last(bool,可选):设置为 True 时,如果数据集大小不能被批量大小整除,则丢弃最后一个不完整的批次。如果为 False 且数据集的大小不能被批次大小整除,则最后一个批次将较小(默认为 False)。
timeout(数值,可选):如果为正值,则为从工作进程收集批次的超时值。应始终为非负数(默认为 0)。
worker_init_fn(Callable,可选):如果不为 None,则会在每个工作进程上调用,输入为工作进程的 id(范围在 [0, num_workers - 1] 之间),在种子化之后数据加载之前使用(默认为 None)。
multiprocessing_context(str 或 multiprocessing.context.BaseContext,可选):如果为 None,则使用操作系统的默认多进程上下文(默认为 None)。
generator(torch.Generator,可选):如果不为 None,则 RandomSampler 将使用此 RNG 生成随机索引,多进程用于生成工作进程的基础种子(默认为 None)。
prefetch_factor(int,可选,仅限关键字参数):每个工作进程预先加载的批次数。2 表示所有工作进程总共会预先加载 2 * num_workers 个批次。
persistent_workers(bool,可选):如果为 True,则数据加载器在数据集被消耗一次后不会关闭工作进程。这允许保持工作进程的数据集实例处于活动状态(默认为 False)。
pin_memory_device(str,可选):如果 pin_memory 为 True,则用于内存固定的设备(默认为 “”)。
二、DataLoader的使用
代码如下:
import torchvision
# 准备测试的数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 从CIFAR10导入数据
test_data = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 定义数据加载方式
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
# print(img)
# print(target)
writer = SummaryWriter("data_loader")
# 两轮获取数据
for epoch in range(2):
step=0
print(epoch)
for data in test_loader:
imgs, target = data
# print(img.shape)
# print(target)
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step = step+1
writer.close()
若将shuffle设置为False,表示“不洗牌”,则两次结果一样:
若将shuffle设置为True,表示“洗牌”,则两次结果不一样: