文章目录
- 三、Torchvision
- 1、Dataset
- 2、DataLoader
- 2.1 test_data
- 2.2 test_loader
- 2.3 drop_last
- 2.4 shuffle
三、Torchvision
PyTorch官网:https://pytorch.org
1、Dataset
数据集描述:https://www.cs.toronto.edu/~kriz/cifar.html
数据集使用说明:
CIFAR10数据集:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10
参数说明:
- root:数据集存放位置
- train:True(训练集)、False(测试集)
- transform:变化
- target_transform:target变化
- download:是否下载
基本使用:
import torchvision
train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../data", train=False, download=True)
print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F0220>, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F00D0>
3
cat
转为Tensor类型: 并使用TensorBoard显示
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=False, download=True)
writer = SummaryWriter("logs")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
2、DataLoader
介绍:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
参数说明:
- batch_size:每批要加载多少个样品(默认:1)
- shuffle:True(重新洗牌),(默认:False)
- num_workers:使用多少个子进程来加载数据,(默认:0 表示主进程)
- drop_last:是否舍去最后(除不尽的)
2.1 test_data
import torchvision
from torch.utils.data import DataLoader
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
# 测试集第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
torch.Size([3, 32, 32]) # 3通道 32 * 32
3
2.2 test_loader
import torchvision
from torch.utils.data import DataLoader
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 测试集第一张图片及target
# img, target = test_data[0]
# print(img.shape)
# print(target)
# test_loader
for data in test_loader:
imgs, targets = data
print(imgs.shape)
print(targets)
torch.Size([4, 3, 32, 32]) # 4张 3通道 32 * 32
tensor([1, 2, 0, 8]) # 4张图片的target糅合在一起
...
...
注意:
target[1, 2, 0, 8]
并不是按序采样,而是随机的!
2.3 drop_last
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# batch_size=64
writer = SummaryWriter("logs")
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("test_data", imgs, step)
step += 1
writer.close()
注意:最后一次采样只有16张图像,这是因为参数
drop_last=False
。当不满足每一次都取一定值的图片时,可以显示真实剩下的或者直接舍去(
drop_last=True
)。
当我们设置为drop_last=True
时,就会舍去最后一组采样:
2.4 shuffle
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
# shuffle=False
writer = SummaryWriter("logs")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1
writer.close()
注意:两者采样完全相同,如果想要 “洗牌”,应设置
shuffle=True
。