一、Dataloader是啥
前面我在写PyTorch的第一篇文章里讲过Dataset是啥,Dataset就是将数据集分类,并且分析出这些数据集它的位置哪、大小多少、这个数据集一共有多少数据......等等信息
那么把Dataset比作一副扑克牌,那么如果你就让这副牌放在桌子那不去取牌,那你怎么打牌?Dataloader就是做【取牌】这个操作,就是去【读取数据】
二、使用DataLoader
首先先看一下官方文档对于DataLoader是怎么使用的:torch.utils.data — PyTorch 2.4 documentation
其中框住的解释的是常用的参数变量的作用解释
用一些例子结合tensorboard,直观地生动地解释一下
【batch_size参数】:一次读取几个数据
;
【drop_last参数】:最后一次读取,剩余数据不足【batch_size】时,要不要舍去
;
【shuffle参数】:当多轮读取的时候,图片顺序是否一样,False是顺序一样
代码编写:导包(torchvision为了dataset,DataLoader则来自torch.utils.data)
然后先用dataset把数据集获取到,这里我用的是下载好的pytorch内置数据集CIFAR10,你们也可以用自定义数据集,注意语法区别就行
然后用DataLoader,设置好参数配置
import torchvision
from torch.utils.data import DataLoader
# 用dataset获取pytorch的内置数据集(我已经下载好,而且选用测试数据集)
test_dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=torchvision.transforms.ToTensor())
# 然后用DataLoader读取,并设置好参数(上面例子里没讲到的参数,你就当默认这么写就好了,我也不知道)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
三、结合transforms、tensorboard
语法都是之前学过的,直接创建SummaryWriter( )对象,指定图像文件生成在哪个文件夹;
然后遍历整个DataLoader返回的数据,返回的是一个列表;
每次循环,提取出每个元素里的【img】跟【target】,【img】就是tensorboard的【.add_images()】所需要的图像,另外step跟着遍历递增就行
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 用dataset获取pytorch的内置数据集(我已经下载好,而且选用测试数据集)
test_dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, transform=torchvision.transforms.ToTensor())
# 然后用DataLoader读取,并设置好参数(上面例子里没讲到的参数,你就当默认这么写就好了,我也不知道)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
step = 0
write = SummaryWriter("DataLoader_logs")
for item in test_loader:
img, target = item
# print(img.shape)
# print(target)
# 利用tensorboard生成图像
# 一定一定要注意!!是.add_images不是.add_image!不能漏了s
write.add_images("dataloader", img, step)
step += 1
write.close()