网络上已经有公开的数据集,并且这些数据集被整合到了torchvision.datasets中,使用自带的函数可以直接下载。
1.数据集
具体有哪些数据可直接用torchvision.datasets加载呢?可以查看这个网址:
- datasets官网:Datasets — Torchvision 0.16 documentation (pytorch.org)
图像分类的有:
图像分割的有:
2.数据集的使用代码
接下来我们用代码来学习一下这些开源数据集的使用
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
trans = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.CIFAR10(root = "./data/CIFAR10",transform=trans, train= True,download = True)
test_dataset = torchvision.datasets.CIFAR10(root = "./data/CIFAR10", transform=trans,train= False,download = True)
img0,target = train_dataset[0]
print(img0.size)
print(target)
name = train_dataset.classes[target]
print(train_dataset.classes)
print(train_dataset.classes[target])
for i in range(10):
img, target = train_dataset[i]
writer.add_image("train",img,i)
img1, target1 = test_dataset[i]
writer.add_image("test",img1,i)
writer.close()
3.结果展示
print(img0.size)
结果为 torch.Size([3, 32, 32])
print(target)
结果为6
name = train_dataset.classes[target]
print(train_dataset.classes)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
print(train_dataset.classes[target])
frog
tensorboard中显示的为: