PyTorch入门教学——torchvision中数据集的使用
1、torchvision.datasets
datasets是torchvision工具集中的一个工具。 可以理解为调用官方数据集的一种方式,其中有很多开源的数据集,可供我们学习使用。 datasets官网:Datasets — Torchvision 0.16 documentation (pytorch.org)
2、使用
这里以使用CIFAR10中的数据为例。 其中有这个数据集的使用方法和具体介绍。 参数:(每个数据集的参数大致相同)
root:数据集下载后存放的目录。 train:如果为True,则从训练集创建数据集,否则从测试集创建。 transform:接收PIL图像的转换方式,并返回转换后的版本。 download:如果为True,则从互联网下载数据集,然后将其放在设置的目录中。如果数据集已下载,则不会再次下载。 代码演示——查看数据集中图片的信息
import torchvision
train_set = torchvision.datasets.CIFAR10(root="./Dataset/CIFAR10", train=True, download=True) # root:数据集要存放在什么位置
test_set = torchvision.datasets.CIFAR10(root="./Dataset/CIFAR10", train=False, download=True)
print(test_set[0]) # 第一张图片的信息,包含格式和标签
print(test_set.classes) # 数据集中所包含的图片类别
img, target = test_set[0]
print(img)
print(target) # 标签
print(test_set.classes[target]) # 第一张图片的标签为猫
img.show() # 显示图片
代码演示——将数据集中的前10张图片在tensorboard中展示出来。
import torchvision
from torch.utils.tensorboard import SummaryWriter
test_set = torchvision.datasets.CIFAR10(
root="./Dataset/CIFAR10",
transform=torchvision.transforms.ToTensor(), # 将图片转换为totensor数据类型
train=False,
download=True)
writer = SummaryWriter('logs') # writer把summary内容写在哪个目录下
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set', img, i)
writer.close()
运行程序后,打开终端,输入下列命令打开tensorboard。 tensorboard --logdir=logs --port=6007
(该数据集的图片像素为32*32,所以比较模糊)
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1140862.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!