定义:
DataLoader类是一个用于从数据集(dataset)中加载数据,并以迭代器(iterator)的形式返回数据样本(data samples)的工具¹²。您给出的两个字典(dictionary)分别是训练集(train set)和测试集(test set)的数据加载参数,下面我会逐一解释它们的含义和默认值:
举例演示:
假设prefetch_factor=2,batch_size=4,persistent_workers=True,你的回答里面的每个worker每次获取多少数据呢?
假设您有一个名为train_dataset的数据集,它包含了100个数据样本,每个样本由一个特征向量和一个标签组成。您可以使用以下代码创建一个使用prefetch_factor=2,batch_size=4,persistent_workers=True的DataLoader对象:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=8, prefetch_factor=2, persistent_workers=True)
小提示:
pin_memory将数据样本保存在锁页内存,这里的内存指的是cpu的内存还是gpu的内存?
pin_memory参数是一个布尔值,表示是否将数据样本保存在锁页内存(pinned memory)中。这里的内存指的是CPU的内存,而不是GPU的内存12。
降低内存占用,推到的时候用到:
testloader_params = {
'batch_size': 1,
'shuffle': False,
'num_workers': 0,
'pin_memory': True,
# 'prefetch_factor': 4,
# 'persistent_workers': True,
'drop_last': False
}