参考: https://blog.csdn.net/Chinesischguy/article/details/103198921
参考: https://zhuanlan.zhihu.com/p/76893455
参考:https://blog.csdn.net/lilai619/article/details/118784730
参考:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
本博客旨在介绍PyTorch深度学习框架中Dataset、Dataloader、Sampler、collate_fn组件之间相互关系,以及如何自定义各组件。这些组件是深度学习项目中不可或缺的组成部分,对于理解和使用PyTorch框架进行深度学习任务至关重要。
在PyTorch深度学习框架中,Dataset、Dataloader、Sampler和collate_fn是数据加载和处理过程中非常重要的组成部分。它们之间的调用关系如下:
-
Dataset:定义了数据集的接口,用于读取和处理数据。通常情况下,Dataset是从文件或数据库中读取数据的集合,它可以对数据进行预处理、增强等操作,并返回一个可迭代的对象,用于后续的数据加载过程。
-
Dataloader:实现了数据集的批量加载功能。Dataloader可以根据Dataset返回的可迭代对象,将数据分成多个batch,并按照指定的采样方式(如随机采样、分层采样等)进行采样。同时,Dataloader还可以自动调整batch size、设置数据加载器状态等。
-
Sampler:定义了数据集中每个batch所包含的数据的位置索引。通常情况下,Sampler是在数据加载之前设置的一个对象,它可以根据用户指定的要求(如按照类别、标签等)对数据集进行采样,并返回每个batch所包含的数据的位置索引。
-
collate_fn:用于将一个batch中的数据进行拼接和整理。通常情况下,collate_fn是在Dataloader创建时设置的一个函数,它可以根据Dataset返回的可迭代对象和Sampler返回的位置索引,将不同长度的输入数据转换为统一的形状,并返回一个新的tensor作为batch的数据。
综上所述,Dataset、Dataloader、Sampler和collate_fn之间是相互协作的,它们共同完成了数据加载和处理的过程。具体来说,Dataset提供了数据集的接口和一些基本的操作;Dataloader实现了数据的批量加载和一些高级的功能;Sampler根据用户指定的要求对数据集进行采样;collate_fn负责将不同长度的输入数据转换为统一的形状。本文将讨论这四个组件的使用方法,并提供一些自定义各组件的技术实践经验。我们将从以下几个方面来探讨:
1. Dataset的使用方法和自定义技巧;
2. Sampler的使用方法和自定义技巧;
3. collate_fn的使用方法和自定义技巧。
DataLoader, Sampler, Dataset三者的关系
1. Sampler提供indicies
2. Dataset根据indicies提供data,使用__getitem__方法
3. DataLoader将上面两个组合起来,提供最终的batch训练数据,其中collate_fn可以对batch中的数据做额外的处理
自定义Dataset
在PyTorch中,可以通过继承torch.utils.data.Dataset
类来自定义数据集(Dataset)类。自定义的数据集类可以包含自己的数据加载和预处理方法,以及一些额外的元数据。
import torch
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import torchvision
from torchvision.io import read_image
import random
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter
class MyDataset(Dataset):
"""
加载磁盘上的图像文件,并进行transform变换,返回变换后的图片和与之对应的标签编号
"""
def __init__(self, filenames, labels, transforms_pipeline=None):
super().__init__()
# 所有图像的路径列表
self.filenames = filenames
# 所有图片对应的label标签编号,从0开始
self.labels = labels
# 图像预处理
self.transforms_pipeline = transforms_pipeline
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filepath = self.filenames[idx]
img = read_image(filepath, mode=torchvision.io.ImageReadMode.RGB)
if self.transforms_pipeline:
img = self.transforms_pipeline(img)
return img, self.labels[idx]
以上代码自定义了一个Dataset类用于加载训练数据,训练数据中cat和dog目录下分别存储的是猫和狗的图片。
使用以下代码片段测试自定义的Dataset数据加载情况:
transforms_pipeline = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
]
)
# 图像存放位置,其中包含两个目录,cat和dog,cat下存放猫的图片,dog下存放狗的图片
data_path = "XXX"
image_folder = torchvision.datasets.ImageFolder(data_path)
# image_folder.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_folder.samples)
# image_folder.classes image_folder.samples中存放的类别索引编号相对应
classes = image_folder.classes
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_folder.samples:
# print(image_path, label)
filenames.append(image_path)
labels.append(label)
print(filenames, labels)
# 使用自定义Dataset类加载磁盘上的图上数据
my_dataset = MyDataset(filenames, labels, transforms_pipeline)
img, label = my_dataset[10]
print(img.shape, label)
自定义Sampler
在PyTorch中,可以通过继承torch.utils.data.Sampler
类来自定义采样器(Sampler)类。自定义的采样器类可以控制数据集中每个样本的采样方式,例如随机采样、分块采样等。
class MySampler(Sampler):
"""
自定义Sampler,在__iter__函数中定义indices的生成方式,也叫生成顺序
"""
def __init__(self, labels):
self.labels = np.array(labels)
self.image_ids = []
def __iter__(self):
"""
在每个batch中包含的每个类别的数量相等
:return:
"""
indices = []
counter = Counter(self.labels)
# 统计数据量最多的类别
most_common = counter.most_common(1)[0][1]
# 统计每张图片在filenames这个列表中对应的索引编号
for c in range(len(counter)):
indices.append(np.where(self.labels == c)[0].tolist())
# 所有类别通过复制的方式与最多的类别对齐
for indice in indices:
if len(indice) < most_common:
indice.extend(random.choices(indice, k=most_common - len(indice)))
random.shuffle(indice)
# 依次从所有类别中分别取一张图片组成batch
for ids in zip(*indices):
self.image_ids.extend(list(ids))
return iter(self.image_ids)
def __len__(self):
return len(self.image_ids)
以上自定义Sampler控制在返回训练样本编号的逻辑,使得每个batch中的各类别数据量相等,Sampler返回训练样本的编号,然后使用Dataset的__getitem__方法取出对应的样本。
使用以下代码片段测试自定义的Sampler的数据采样情况:
my_sampler = MySampler([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0])
sample_labels = []
for x in my_sampler:
print(x)
sample_labels.append(my_sampler.labels[x])
print(sample_labels)
print(len(my_sampler))
自定义collate_fn函数
在PyTorch中,自定义collate_fn
函数可以用于对数据集中的数据进行整合和处理。当使用自定义采样器(Sampler)加载数据时,collate_fn
函数会被自动调用来整合每个batch的数据。
def collate_fn(batch_data):
"""
对batch中的图像使用mixup,并返回mixup之后的结果
:param batch_data:
:return:
"""
def mixup_data(x, y, alpha=1.0, use_cuda=False):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
batch_img = []
batch_label = []
for img, label in batch_data:
batch_img.append(img)
batch_label.append(label)
batch_img = torch.stack(batch_img, dim=0)
batch_label = torch.tensor(batch_label)
# print(batch_img.shape, batch_label.shape)
batch_img, batch_label_a, batch_label_b, batch_lam = mixup_data(batch_img, batch_label)
return batch_img, batch_label_a, batch_label_b, batch_lam
在以上自定义collate_fn函数中,我们在每个batch批量样本之间使用mixup数据增强,并返回mixup之后的增强数据以及对应的标签和参数。
自定义Dataset、Sampler、collate_fn,以及使用Dataloader的完整代码
# coding:utf-8
import torch
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
import torchvision
from torchvision.io import read_image
import random
import numpy as np
from matplotlib import pyplot as plt
from collections import Counter
class MyDataset(Dataset):
"""
加载磁盘上的图像文件,并进行transform变换,返回变换后的图片和与之对应的标签编号
"""
def __init__(self, filenames, labels, transforms_pipeline=None):
super().__init__()
# 所有图像的路径列表
self.filenames = filenames
# 所有图片对应的label标签编号,从0开始
self.labels = labels
# 图像预处理
self.transforms_pipeline = transforms_pipeline
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filepath = self.filenames[idx]
img = read_image(filepath, mode=torchvision.io.ImageReadMode.RGB)
if self.transforms_pipeline:
img = self.transforms_pipeline(img)
return img, self.labels[idx]
def collate_fn(batch_data):
"""
对batch中的图像使用mixup,并返回mixup之后的结果
:param batch_data:
:return:
"""
def mixup_data(x, y, alpha=1.0, use_cuda=False):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
batch_img = []
batch_label = []
for img, label in batch_data:
batch_img.append(img)
batch_label.append(label)
batch_img = torch.stack(batch_img, dim=0)
batch_label = torch.tensor(batch_label)
# print(batch_img.shape, batch_label.shape)
batch_img, batch_label_a, batch_label_b, batch_lam = mixup_data(batch_img, batch_label)
return batch_img, batch_label_a, batch_label_b, batch_lam
class MySampler(Sampler):
"""
自定义Sampler,在__iter__函数中定义indices的生成方式,也叫生成顺序
"""
def __init__(self, labels):
self.labels = np.array(labels)
self.image_ids = []
def __iter__(self):
"""
在每个batch中包含的每个类别的数量相等
:return:
"""
indices = []
counter = Counter(self.labels)
# 统计数据量最多的类别
most_common = counter.most_common(1)[0][1]
# 统计每张图片在filenames这个列表中对应的索引编号
for c in range(len(counter)):
indices.append(np.where(self.labels == c)[0].tolist())
# 所有类别通过复制的方式与最多的类别对齐
for indice in indices:
if len(indice) < most_common:
indice.extend(random.choices(indice, k=most_common - len(indice)))
random.shuffle(indice)
# 依次从所有类别中分别取一张图片组成batch
for ids in zip(*indices):
self.image_ids.extend(list(ids))
return iter(self.image_ids)
def __len__(self):
return len(self.image_ids)
## 测试自定义Sampler
# my_sampler = MySampler([1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0])
# sample_labels = []
# for x in my_sampler:
# print(x)
# sample_labels.append(my_sampler.labels[x])
# print(sample_labels)
# print(len(my_sampler))
transforms_pipeline = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)),
]
)
# 图像存放位置,其中包含两个目录,cat和dog,cat下存放猫的图片,dog下存放狗的图片
data_path = r"C:\WorkDir\PythonWorkspace\MusicRecognition\mixup-cifar10-main\data\cat_and_dog"
image_folder = torchvision.datasets.ImageFolder(data_path)
# image_folder.samples 中存放的是图像数据的文件路径和类别索引编号(从0开始编号)
random.shuffle(image_folder.samples)
# image_folder.classes image_folder.samples中存放的类别索引编号相对应
classes = image_folder.classes
# 用于存放图像路径列表
filenames = []
# 用于存放图像对应的类别
labels = []
for image_path, label in image_folder.samples:
# print(image_path, label)
filenames.append(image_path)
labels.append(label)
print(filenames, labels)
# 使用自定义Dataset类加载磁盘上的图上数据
my_dataset = MyDataset(filenames, labels, transforms_pipeline)
# img, label = my_dataset[10]
# print(img.shape, label)
# 使用自定义collate_fn函数,在每个batch中进行mixup图片增强,并返回增强后的图片数据、标签、以及mixup系数
dataloader = DataLoader(
my_dataset,
batch_size=8, # batch_size要能整除类别数
shuffle=False, # 使用sampler时,shuffle参数要设置为False
sampler=MySampler(labels), # 自定义Sampler,返回的batch中每种类别的数量相等
batch_sampler=None,
collate_fn=collate_fn # 自定义collate_fn,其中执行mixup数据增强
)
for batch_img, batch_label_a, batch_label_b, batch_lam in dataloader:
print(batch_img.shape, batch_label_a.shape, batch_label_b.shape, batch_lam)
# batch中包含每个类别的数量相等,猫和狗都是4张
# {0: 4, 1: 4}
print(Counter(batch_label_a.detach().cpu().numpy().tolist()))
break
for idx, img in enumerate(batch_img):
plt.imshow(img.permute(1, 2, 0).int().clamp(min=0, max=255).detach().cpu().numpy())
plt.show()