1、准备数据集(测试集)
import torchvision
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
注意数据集中的图片是PIL的格式,需要格式转换。
2、使用DataLoader
from torch.utils.data import DataLoader
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
3、查看数据集中图片的尺寸及target
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
结果如下:
4、DataLoader的返回
其做了一个打包处理。
测试如下:
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
结果如下:
4代表4张图片(batch_size的大小)
5、drop_last的作用
如果为True,则若有剩余且数量小于batch_size,直接丢弃;
如果为False,则保留。
6、实际中需要结合epoch来使用
代码如下:
import torchvision
# 准备的测试集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter('dataloader')
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images('Epoch:{}'.format(epoch), imgs, step)
step = step + 1
writer.close()
结果如下: