数据是深度学习的基础,MindSpore 提供基于 Pipeline 的 数据引擎,通过数据集 数据集(Dataset) 和 数据变换(Transforms) 实现高效的数据预处理。其中 Dataset 是 Pipeline 的起始,用于加载原始数据。mindspore.dataset
提供了内置的文本、图像、音频等数据集加载接口,并提供了自定义数据集加载接口。
一、数据集加载
这里使用 Mnist 数据集作为样例,使用 mindspore.dataset
进行加载的方法。mindspore.dataset
提供的接口 仅支持解压后的数据文件,因此我们使用 download
库下载数据集并解压。
def download_dataset(url, path="./"):
"""
通过 download 下载数据集
:param url: 下载链接
:param path: 数据集保存地址
:return:
"""
try:
if os.path.exists(path):
print("{} 文件以存在".format(path))
else:
path = download(url, path, kind="zip", replace=True)
print("下载完成:{}".format(path))
except RuntimeWarning as e:
print("数据集下载失败:{}".format(e))
下载完数据集后,可以通过 MnistDataset
加载数据集,其数据类型为 mindspore.dataset.engine.datasets_vision.MnistDataset。
二、数据集迭代
数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。我们可以用 create_tuple_iterator 或 create_dict_iterator 接口创建数据迭代器,迭代访问数据。访问的数据类型默认为 Tensor
;若设置 output_numpy=True
,访问的数据类型为 Numpy
。这里可以定义一个可视化函数,迭代 9 张图片进行展示。
def show_visualize(dataset):
# 创建一个画布
figure = plt.figure(figsize=(4, 4))
cols, rows = 3, 3
plt.subplots_adjust(wspace=0.5, hspace=0.5)
for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
figure.add_subplot(rows, cols, idx + 1)
plt.title(label)
plt.axis("off")
plt.imshow(image.asnumpy().squeeze(), cmap="gray")
if idx == cols * rows - 1:
break
plt.show()
使用 Mnist 数据集作为示例,顺序展示 Mnist 数据集的 9 张图片。
# 展示数据集
url_mnist = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download_dataset(url_mnist, './mnist')
train_dataset = MnistDataset('./mnist/MNIST_Data/train', shuffle=False)
show_visualize(train_dataset)
三、数据集常用操作
Pipeline 的设计理念使得数据集的常用操作采用 dataset = dataset.operation()
的异步执行方式,执行操作返回新的Dataset,此时不执行具体操作,而是在 Pipeline 中加入
节点,最终进行迭代时,并行执行整个 Pipeline。下面分别介绍几种常见的数据集操作。
2.1 shuffle
数据集随机 shuffle
可以消除数据排列造成的分布不均问题。shuffle
操作就是打乱数据集中样例的顺序,起到解决数据列分布不均的问题,如下图所示:
mindspore.dataset
提供的数据集在加载时可配置 shuffle=True
,或调用 shuffle
方法来打乱数据集中样例的顺序。
# 方法1:加载时配置 `shuffle=True`
dataset = MnistDataset(data_path, shuffle=False)
# 方法2:调用 `shuffle` 方法
dataset = dataset.shuffle(buffer_size=64)
以 Mnist 数据集作为示例,以分布均匀的方式展示 Mnist 数据集的 9 张图片。
def show_shuffle(data_path='./mnist/MNIST_Data/train'):
dataset = MnistDataset(data_path, shuffle=False)
dataset = dataset.shuffle(buffer_size=64)
show_visualize(dataset)
# 展示 shuffle
show_shuffle(data_path='./mnist/MNIST_Data/train')
show_shuffle(data_path='./mnist/MNIST_Data/train')
2.2 map
map
操作是数据预处理的关键操作,可以针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。mindspore.dataset.engine.datasets_vision.MnistDataset 支持的不同变换类型详见 数据变换 Transforms。以 Mnist 数据集作为示例,对数据集中的图片数据做缩放处理,将图像统一除以255,数据类型由 uint8 转为了 float32。
def show_map(data_path='./mnist/MNIST_Data/train'):
dataset = MnistDataset(data_path, shuffle=False)
image, label = next(dataset.create_tuple_iterator())
print("数据的列名")
print(dataset.create_dict_iterator().get_col_names())
print("数据类型调整前:")
print(image.shape, image.dtype)
print("数据类型调整后:")
dataset = dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
image, label = next(dataset.create_tuple_iterator())
print(image.shape, image.dtype)
对比 map 前后的数据,可以看到数据类型变化。这里需要格外说明的是 MindSpore 对数据的处理可以分成三类分别是图片(vision)、文本(text)、音频(audio),这里我们处理的是图片数据,因此调用了相关的 Version 方法。
2.3 batch
将数据集打包为固定大小的 batch
是在有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量。
一般我们会设置一个固定的 batch size,将连续的数据分为若干批(batch)。以 Mnist 数据集作为示例,分别展示 batch 设置为 32 和 128 时,每次迭代获取的样例的维度。
def show_batch(data_path='./mnist/MNIST_Data/train'):
dataset = MnistDataset(data_path, shuffle=False)
dataset_32 = dataset.batch(batch_size=32)
image, label = next(dataset_32.create_tuple_iterator())
print("batch 为 32 时,每次迭代获取的样例:")
print(image.shape, image.dtype)
dataset = MnistDataset(data_path, shuffle=False)
dataset_128 = dataset.batch(batch_size=128)
image, label = next(dataset_128.create_tuple_iterator())
print("batch 为 128 时,每次迭代获取的样例:")
print(image.shape, image.dtype)
四、自定义数据集
mindspore.dataset
模块提供了一些常用的公开数据集和标准格式数据集的加载 API。对于 MindSpore 来说,暂不支持直接加载的数据集,可以构造自定义数据加载类或自定义数据集生成函数的方式来生成数据集,然后通过 GeneratorDataset
接口实现自定义方式的数据集加载。GeneratorDataset
支持通过可随机访问数据集对象、可迭代数据集对象和生成器(generator)构造自定义数据集,下面分别对其进行介绍。
4.1 可随机访问数据集
可随机访问数据集是实现了 __getitem__
和 __len__
方法的数据集,表示可以通过索引(键)直接访问对应位置的数据样本。例如,当使用 dataset[idx]
访问这样的数据集时,可以读取 dataset 内容中第idx 个样本或标签。
class RandomAccessDataset:
def __init__(self):
self._data = np.ones((5, 2))
self._label = np.zeros((5, 1))
def __getitem__(self, index):
return self._data[index], self._label[index]
def __len__(self):
return len(self._data)
def show_dataset():
loader = RandomAccessDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])
for data in dataset:
print(data)
4.2 可迭代数据集
可迭代的数据集是实现了 __iter__
和 __next__
方法的数据集,表示可以通过迭代的方式逐步获取数据样本。这种类型的数据集特别适用于随机访问成本太高或者不可行的情况。例如,当使用iter(dataset)
的形式访问数据集时,可以读取从数据库、远程服务器返回的数据流。下面构造一个简单迭代器,并将其加载至GeneratorDataset
。
class IterableDataset:
def __init__(self):
self._data = np.ones((5, 2))
self._label = np.zeros((5, 1))
self._index = len(self._label) + 1
def __next__(self):
if next(self.index):
print(self.index)
return next(self.data), next(self.label)
def __iter__(self):
self.index = iter(self.breaker, 3)
self.data = iter(self._data)
self.label = iter(self._label)
return self
def breaker(self):
self._index -= 1
return self._index
def show_iter_dataset():
loader = IterableDataset()
dataset = GeneratorDataset(source=loader, column_names=["data", "label"])
for data in dataset:
print(data)