文章目录
- 1. 简介
- 2. 查看PyTorch自带的数据集(可视化)
- 3. 准备材料
- 3.1 图片数据
- 3.2 标签数据
- 4. 方法
1. 简介
尽管PyTorch提供了许多自带的数据集,如MNIST、CIFAR-10、ImageNet等,但它们对于没有经验的用户来说,理解数据加载器的工作原理以及如何正确地配置数据加载器可能会有一定难度。 用户需要了解所使用的数据集,包括数据集的内容、结构、标签等信息。对于一些复杂的数据集,用户可能需要理解数据集的结构和标签的含义。通过定义自己的数据集类,您可以更好地控制数据的加载和处理过程,提高代码的灵活性、可读性和可维护性,同时更好地满足模型训练的需求。
2. 查看PyTorch自带的数据集(可视化)
为了更好的定义自己的数据集,我们首先查看PyTorch自带的数据集的内容,代码如下
# 导入所需的库
import matplotlib.pyplot as plt # 导入Matplotlib库,用于可视化
import torch # 导入PyTorch库
from torchvision.datasets import MNIST # 从torchvision中导入MNIST数据集
from torchvision import transforms # 导入transforms模块,用于数据预处理
import numpy as np # 导入NumPy库
# 加载MNIST数据集
train_mnist_data = MNIST(root='./data', # 数据集存储路径
train=True, # 加载训练集
transform=transforms.Compose([transforms.Resize(size=(28, 28)), transforms.ToTensor()]), # 数据预处理操作
download=True) # 如果数据集不存在,则自动下载
# 设置要显示的样本数量
num_samples = 10
# 创建包含多个子图的大图窗口
fig, axes = plt.subplots(1, num_samples, figsize=(10, 6))
# 遍历选择要显示的样本
for i in range(num_samples):
# 从数据集中获取图像数据和标签
image, label = train_mnist_data[i]
# 在子图中显示图像
axes[i].imshow(image.squeeze().numpy(), cmap='gray') # 使用imshow函数显示图像,将张量转换为NumPy数组
axes[i].set_title(f"Label: {label}") # 设置子图标题,显示图像对应的标签
axes[i].axis('off') # 关闭坐标轴显示
# 将图像保存为PNG格式的图片文件,文件名以图像的标签命名
plt.imsave(f"./data/mnist_images/{label}.png", image.squeeze().numpy(), cmap='gray')
# 显示图形窗口
plt.show()
这里,我们使用MNIST
类加载MNIST数据集。在加载数据集时,通过transform
参数指定了数据预处理操作,包括将图像大小调整为28x28像素,并将图像转换为张量。train=True
表示加载训练集,download=True
表示如果数据集不存在则自动下载到指定的路径。
接下来,我们选择一些样本进行可视化。我们在一个子图中显示了10个样本,每个样本对应一个数字图像和其对应的标签。通过循环遍历这些样本,从数据集中获取图像数据和标签,并使用Matplotlib的imshow()
函数将图像显示在子图中。
同时,使用imsave()
函数将每个图像保存为PNG格式的图片文件,文件名以标签命名。最后,使用plt.show()
显示图形窗口,显示图像的同时也会将图像保存到指定的路径中。这段代码的执行结果是显示10张MNIST数据集中的数字图像,并将这些图像保存到指定路径下。保存的图片如下所示
通过上面程序可以看到,数据集主要是由图片数据和对应的标签构成,那么我们就可以用这两个主要构成成分来构建自己的数据集。
3. 准备材料
3.1 图片数据
这里我们就用刚才保存的十张图片,即
当然,你也可以准备其它的图片,并给图片分别命名为“0.png, 1.png, …”。
这里,十张图片的相对路径为
imgs_path = "./data/mnist_images"
注:你们要根据自己存储的路径来给定。
3.2 标签数据
创建一个txt文件,为每一幅图片指定标签数据,如下所示
这里,txt文件的相对路径为
labels_path = "labels.txt"
4. 方法
在PyTorch中,您可以通过创建一个自定义的数据集类来定义自己的数据集。这个自定义类需要继承自torch.utils.data.Dataset
类,并且实现两个主要的方法:__len__
和 __getitem__
。__len__
方法应该返回数据集的长度,而__getitem__
方法则根据给定的索引返回数据集中的样本。
下面我们展示如何创建一个自定义的数据集类:
import os # 导入os模块,用于操作文件路径
from PIL import Image # 导入PIL库中的Image模块,用于图像处理
import torch # 导入PyTorch库
from torch.utils.data import Dataset # 从torch.utils.data模块导入Dataset类,用于定义自定义数据集
from torchvision import transforms # 导入transforms模块,用于数据预处理
import numpy as np # 导入NumPy库,用于数值处理
import matplotlib.pyplot as plt # 导入Matplotlib库,用于可视化
class CustomDataset(Dataset):
def __init__(self, image_dir, label_file, transform=None):
super().__init__() # 调用父类的构造函数
self.image_dir = image_dir # 图像数据的路径
self.label_file = label_file # 标签文本的路径
self.transform = transform # 数据预处理操作
self.samples = self._load_samples() # 加载数据集样本信息
def _load_samples(self):
samples = [] # 存储样本信息的列表
with open(self.label_file, 'r') as f: # 打开标签文本文件
for line in f: # 逐行读取标签文本文件中的内容
image_name, label = line.strip().split(',') # 根据逗号分隔每行内容,获取图像文件名和标签
image_path = os.path.join(self.image_dir, image_name) # 拼接图像文件的完整路径
samples.append((image_path, int(label))) # 将图像路径和标签组成元组,加入样本列表
return samples # 返回样本列表
def __len__(self):
return len(self.samples) # 返回数据集样本的数量
def __getitem__(self, index):
image_path, label = self.samples[index] # 获取指定索引处的图像路径和标签
image = Image.open(image_path).convert('L') # 打开图像文件并将其转换为灰度图像
if self.transform: # 如果定义了数据预处理操作
image = self.transform(image) # 对图像进行预处理操作
return image, label # 返回预处理后的图像和标签
# 设置图片数据路径和标签文本路径
image_dir = './data/mnist_images' # 图像数据的路径
label_file = 'labels.txt' # 标签文本的路径
# 定义数据预处理操作,根据需要添加其他预处理操作
transform = transforms.Compose([
transforms.Resize((28, 28)), # 调整图像大小
transforms.ToTensor(), # 将图像转换为张量
])
# 创建自定义数据集实例
custom_dataset = CustomDataset(image_dir, label_file, transform=transform)
# 创建数据加载器
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=1, shuffle=False)
# 遍历数据加载器中的每个批次数据
for batch_images, batch_labels in data_loader:
# 使用squeeze()函数去除图像张量中的单维度,将图像数据转换为NumPy数组,并存储在变量image中
image = batch_images.squeeze().numpy()
# 使用imshow()函数显示图像,cmap='gray'指定使用灰度色彩映射
plt.imshow(image, cmap='gray')
# 设置图像标题,显示图像对应的标签,使用f-string格式化字符串,将batch_labels转换为Python标量并获取其值
plt.title(f"Label: {batch_labels.item()}")
# 关闭坐标轴显示,即不显示坐标轴
plt.axis('off')
# 显示图形窗口
plt.show()
这段代码实现了加载自定义数据集,并使用 PyTorch 的 DataLoader 将数据加载成批次,然后逐批次地展示图像。