Mindspore 初学教程 - 4. 数据集 Dataset

news2024/9/19 10:40:24

数据是深度学习的基础,MindSpore 提供基于 Pipeline 的 数据引擎,通过数据集 数据集(Dataset) 和 数据变换(Transforms) 实现高效的数据预处理。其中 Dataset 是 Pipeline 的起始,用于加载原始数据。mindspore.dataset 提供了内置的文本、图像、音频等数据集加载接口,并提供了自定义数据集加载接口。

一、数据集加载

这里使用 Mnist 数据集作为样例,使用 mindspore.dataset 进行加载的方法。mindspore.dataset 提供的接口 仅支持解压后的数据文件,因此我们使用 download 库下载数据集并解压。

def download_dataset(url, path="./"):
    """
    通过 download 下载数据集
    :param url: 下载链接
    :param path: 数据集保存地址
    :return:
    """
    try:
        if os.path.exists(path):
            print("{} 文件以存在".format(path))
        else:
            path = download(url, path, kind="zip", replace=True)
            print("下载完成:{}".format(path))

    except RuntimeWarning as e:
        print("数据集下载失败:{}".format(e))

下载完数据集后,可以通过 MnistDataset 加载数据集,其数据类型为 mindspore.dataset.engine.datasets_vision.MnistDataset

二、数据集迭代

数据集加载后,一般以迭代方式获取数据,然后送入神经网络中进行训练。我们可以用 create_tuple_iterator 或 create_dict_iterator 接口创建数据迭代器,迭代访问数据。访问的数据类型默认为 Tensor;若设置 output_numpy=True,访问的数据类型为 Numpy。这里可以定义一个可视化函数,迭代 9 张图片进行展示。

def show_visualize(dataset):
    # 创建一个画布
    figure = plt.figure(figsize=(4, 4))
    cols, rows = 3, 3

    plt.subplots_adjust(wspace=0.5, hspace=0.5)

    for idx, (image, label) in enumerate(dataset.create_tuple_iterator()):
        figure.add_subplot(rows, cols, idx + 1)
        plt.title(label)
        plt.axis("off")
        plt.imshow(image.asnumpy().squeeze(), cmap="gray")
        if idx == cols * rows - 1:
            break

    plt.show()

使用 Mnist 数据集作为示例,顺序展示 Mnist 数据集的 9 张图片。

 # 展示数据集
 url_mnist = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
 download_dataset(url_mnist, './mnist')
 train_dataset = MnistDataset('./mnist/MNIST_Data/train', shuffle=False)
 show_visualize(train_dataset)

请添加图片描述

三、数据集常用操作

Pipeline 的设计理念使得数据集的常用操作采用 dataset = dataset.operation() 的异步执行方式,执行操作返回新的Dataset,此时不执行具体操作,而是在 Pipeline 中加入
节点,最终进行迭代时,并行执行整个 Pipeline。下面分别介绍几种常见的数据集操作。

2.1 shuffle

数据集随机 shuffle 可以消除数据排列造成的分布不均问题。shuffle 操作就是打乱数据集中样例的顺序,起到解决数据列分布不均的问题,如下图所示:
在这里插入图片描述
mindspore.dataset 提供的数据集在加载时可配置 shuffle=True,或调用 shuffle 方法来打乱数据集中样例的顺序。

# 方法1:加载时配置 `shuffle=True`
dataset = MnistDataset(data_path, shuffle=False)

# 方法2:调用 `shuffle` 方法
dataset = dataset.shuffle(buffer_size=64)

Mnist 数据集作为示例,以分布均匀的方式展示 Mnist 数据集的 9 张图片。

def show_shuffle(data_path='./mnist/MNIST_Data/train'):
    dataset = MnistDataset(data_path, shuffle=False)
    dataset = dataset.shuffle(buffer_size=64)
    show_visualize(dataset)

# 展示 shuffle
show_shuffle(data_path='./mnist/MNIST_Data/train')
show_shuffle(data_path='./mnist/MNIST_Data/train')

2.2 map

map 操作是数据预处理的关键操作,可以针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。mindspore.dataset.engine.datasets_vision.MnistDataset 支持的不同变换类型详见 数据变换 Transforms。以 Mnist 数据集作为示例,对数据集中的图片数据做缩放处理,将图像统一除以255,数据类型由 uint8 转为了 float32

def show_map(data_path='./mnist/MNIST_Data/train'):
    dataset = MnistDataset(data_path, shuffle=False)
    image, label = next(dataset.create_tuple_iterator())
    print("数据的列名")
    print(dataset.create_dict_iterator().get_col_names())

    print("数据类型调整前:")
    print(image.shape, image.dtype)

    print("数据类型调整后:")
    dataset = dataset.map(vision.Rescale(1.0 / 255.0, 0), input_columns='image')
    image, label = next(dataset.create_tuple_iterator())
    print(image.shape, image.dtype)

在这里插入图片描述
对比 map 前后的数据,可以看到数据类型变化。这里需要格外说明的是 MindSpore 对数据的处理可以分成三类分别是图片(vision)、文本(text)、音频(audio),这里我们处理的是图片数据,因此调用了相关的 Version 方法。

2.3 batch

将数据集打包为固定大小的 batch 是在有限硬件资源下使用梯度下降进行模型优化的折中方法,可以保证梯度下降的随机性和优化计算量。
op-batch

一般我们会设置一个固定的 batch size,将连续的数据分为若干批(batch)。以 Mnist 数据集作为示例,分别展示 batch 设置为 32 和 128 时,每次迭代获取的样例的维度。

def show_batch(data_path='./mnist/MNIST_Data/train'):
    dataset = MnistDataset(data_path, shuffle=False)
    dataset_32 = dataset.batch(batch_size=32)
    image, label = next(dataset_32.create_tuple_iterator())
    print("batch 为 32 时,每次迭代获取的样例:")
    print(image.shape, image.dtype)

    dataset = MnistDataset(data_path, shuffle=False)
    dataset_128 = dataset.batch(batch_size=128)
    image, label = next(dataset_128.create_tuple_iterator())
    print("batch 为 128 时,每次迭代获取的样例:")
    print(image.shape, image.dtype)

在这里插入图片描述

四、自定义数据集

mindspore.dataset 模块提供了一些常用的公开数据集和标准格式数据集的加载 API。对于 MindSpore 来说,暂不支持直接加载的数据集,可以构造自定义数据加载类或自定义数据集生成函数的方式来生成数据集,然后通过 GeneratorDataset 接口实现自定义方式的数据集加载。GeneratorDataset 支持通过可随机访问数据集对象、可迭代数据集对象和生成器(generator)构造自定义数据集,下面分别对其进行介绍。

4.1 可随机访问数据集

可随机访问数据集是实现了 __getitem____len__ 方法的数据集,表示可以通过索引(键)直接访问对应位置的数据样本。例如,当使用 dataset[idx] 访问这样的数据集时,可以读取 dataset 内容中第idx 个样本或标签。

class RandomAccessDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))

    def __getitem__(self, index):
        return self._data[index], self._label[index]

    def __len__(self):
        return len(self._data)


def show_dataset():
    loader = RandomAccessDataset()

    dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

    for data in dataset:
        print(data)

在这里插入图片描述

4.2 可迭代数据集

可迭代的数据集是实现了 __iter____next__ 方法的数据集,表示可以通过迭代的方式逐步获取数据样本。这种类型的数据集特别适用于随机访问成本太高或者不可行的情况。例如,当使用iter(dataset) 的形式访问数据集时,可以读取从数据库、远程服务器返回的数据流。下面构造一个简单迭代器,并将其加载至GeneratorDataset

class IterableDataset:
    def __init__(self):
        self._data = np.ones((5, 2))
        self._label = np.zeros((5, 1))
        self._index = len(self._label) + 1

    def __next__(self):
        if next(self.index):
            print(self.index)
            return next(self.data), next(self.label)

    def __iter__(self):
        self.index = iter(self.breaker, 3)
        self.data = iter(self._data)
        self.label = iter(self._label)
        return self

    def breaker(self):
        self._index -= 1
        return self._index



def show_iter_dataset():
    loader = IterableDataset()

    dataset = GeneratorDataset(source=loader, column_names=["data", "label"])

    for data in dataset:
        print(data)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2108900.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

# centos7 安装 mysql

centos7安装mysql 1、添加 mysql 官方 yum 存储库 wget https://dev.mysql.com/get/mysql80-community-release-el7-3.noarch.rpmrpm -ivh mysql80-community-release-el7-3.noarch.rpm2、使用Yum安装MySQL服务器: sudo yum install mysql-server3、启动MySQL服务…

Redis集群技术2——redis基础

Redis安装 Redis 的安装相对简单,无论是 Windows、Linux 还是 macOS 系统,都有相应的安装方法。以下是针对不同操作系统的 Redis 安装简述。 1. Linux 系统安装 Redis 在 Linux 系统中安装 Redis 通常有多种方式,这里以 Ubuntu 和 CentOS 为…

配置阿里云千问大模型--环境变量dashscope

1 开通百炼 首先要进入到阿里云平台,然后进入百炼平台。 2 获取API-KEY 进入之后再右上角可以查看到自己的API-KEY,这个东西就是需要配置在环境变量里的。 点击查看就可以获取 3 配置DASHSCOPE环境变量 如果使用dashscope来进行千问大模型的API对…

速度滞后补偿控制

这里介绍的速度滞后补偿控制和我们前面介绍的前馈控制有所区别,前馈控制的前提是能够获取位置参考指令的速度或加速度信号。在无法获取位置参考指令的上述性息的前提下,我们可以采用速度滞后补偿控制提高机电伺服控制系统动态跟踪精度。前馈控制的一些基…

2024社区版IDEA springboot日志输出颜色

IDEA版本:IntelliJ IDEA 2024.1.4 (Community Edition) 1、纯白色终端 2、彩色终端 3、配置过程 1、打开配置 2、选择启动类 3、点击修改选项,勾选虚拟机选项 4、在虚拟机选项框输入以下代码 -Dspring.output.ansi.enabledALWAYS5、应用确定&#xff0…

NLP从零开始------18.文本中阶处理之序列到序列模型(3)

4.3 其他解码问题和解码技巧 贪心解码和束解码只是最基础的解码方法,其解码结果会出现许多问题。这里主要介绍3种常见问题,并简单介绍解决方案。 4.3.1 重复性问题 有时我们会发现序列到序列模型不断重复的输出同一个词。一个解决方案是解码时在所预测的…

GateWay三大案例组件

一、局部过滤器接口耗时(LogTime) 命名规则:以GatewayFilterFactory结尾编写接口耗时过滤器 Slf4j Component public class LogTimeGatewayFilterFactory extends AbstractNameValueGatewayFilterFactory {private static long timeSpan 0…

ruoyi-vue-pro快速修改的包名和选配功能板块

使用KIT进行构建 KIT是一个专门构建框架的网站,ruoyi-vue-pro也发布至KIT了,所以我们可以通过KIT快速的选配功能和修改报名等操作。 构建地址:http://www.goldpankit.com/space/service/install?space%E8%8A%8B%E9%81%93%E6%BA%90%E7%A0%8…

AI建模——AI生成3D内容算法产品介绍与模型免费下载

说明: 记录AI文生3D模型、图生3D模型的相关产品;记录其性能、功能、收费与免费方法 0.AI建模产品 Rodin MeshAnything Meshy 生成效果比较: Rodin效果最好、Meshy其次 1.Rodin 官网:gHyperHuman 支持:文生模型、…

TextIn ParseX:助力开发者解析版面元素信息

TextIn ParseX通用文档解析是一款大模型友好的解析工具,支持将pdf文档、jpg、img图像等文件快速转换为markdown格式,支持各类表格、公式解析,帮助大语言模型的数据清洗和文档问答任务。 产品特点 支持多种扫描内容:能良好处理各类…

[数据集][目标检测]西红柿缺陷检测数据集VOC+YOLO格式17318张3类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):17318 标注数量(xml文件个数):17318 标注数量(txt文件个数):17318 标…

ORA-24067: exceeded maximum number of subscribers for queue ADMIN.SMS_MT_QUEUE

临时处理办法: delete from aq$_ss_MT_tab_D; delete from aq$_ss_MT_tab_g; delete from aq$_ss_MT_tab_h; delete from aq$_ss_MT_tab_i; delete from aq$_ss_MT_tab_p; delete from aq$_ss_MT_tab_s; delete from aq$_ss_MT_tab_t; commit; 根本处理办法&#x…

IIS 反向代理模块: URL Rewrite 和 Application Request Routing (ARR)

需要设置iis反向代理的场景其实挺多的。例如websocket、Server Sent Events(SSE) 都需要反向代理。 对于需要临时放公网访问的应用,直接运行127.0.0.1的开发环境,然后通过反向代理访问127.0.0.1就可以了,省去麻烦的iis设置。 IIS 实现反向代…

学习记录:js算法(二十五):合并两个有序链表

文章目录 合并两个有序链表我的思路网上思路 总结 合并两个有序链表 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 图一 示例 1:(如图一) 输入:l1 [1,2,4], l2 [1,3,4] …

解决职业摔跤手分类问题的算法与实现

解决职业摔跤手分类问题的算法与实现 引言问题定义算法设计二分图判定算法步骤伪代码C语言实现引言 在职业摔跤界,摔跤手通常被分为“娃娃脸”(“好人”)型和“高跟鞋”(“坏人”)型。在任意一对摔跤手之间,都有可能存在竞争关系。本文的目标是设计一个算法,用于判断是…

优化采样参数提升大语言模型响应质量:深入分析温度、top_p、top_k和min_p的随机解码策略

当向大语言模型(LLM)提出查询时,模型会为其词汇表中的每个可能标记输出概率值。从这个概率分布中采样一个标记后,我们可以将该标记附加到输入提示中,使LLM能够继续输出下一个标记的概率。这个采样过程可以通过诸如 temperature和 top_p等参…

openSUSE变更默认编译器

Debian很稳定,但是必须要添加unstable源才能安装一些需要更新的软件,比如说稳定版的firefox是ESR版的,必须要从unstable源才能安装新版。但是unstable源是把所有的软件包都放在里面,操作过程中一旦不小心把核心组件更新到unstable…

使用 RabbitMQ 和 Go 构建异步订单处理系统

使用 RabbitMQ 和 Go 构建异步订单处理系统 我们可以通过构建一个订单处理系统来演示如何使用消息队列(MQ)实现异步任务处理。这个项目将使用 RabbitMQ 作为消息队列,并使用 Go 语言来实现。以下是项目的详细教程和相关环境配置。 项目描述…

uniapp+vue3实现双通道透明MP4播放支持小程序和h5

双通道透明MP4视频播放的截图 以下是合成后结果,二个合并在一起进行播放 下载资源,打开运行直接使用看到效果 https://download.csdn.net/download/qq_40039641/89715780

[iBOT] Image BERT Pre-Training with Online Tokenizer

1、目的 探索visual tokenizer编码下的MIM(Masked Image Modeling) 2、方法 iBOT(image BERT pre-training with Online Tokenizer) 1)knowledge distillation(KD) distill knowledge from the…