使用Ray创建高效的深度学习数据管道

news2025/1/23 14:59:02

大家好,用于训练深度学习模型的GPU功能强大但价格昂贵。为了有效利用GPU,开发者需要一个高效的数据管道,以便在GPU准备好计算下一个训练步骤时尽快将数据传输到GPU,使用Ray可以大大提高数据管道的效率。

1.训练数据管道的结构

首先考虑下面的模型训练伪代码:

for step in range(num_steps):
  sample, target = next(dataset) # 步骤1
  train_step(sample, target) # 步骤2

在步骤1中,获取下一个小批量的样本和标签。在步骤2中,它们被传递给train_step函数,该函数会将它们复制到GPU上,执行前向传递和反向传递以计算损失和梯度,并更新优化器的权重。

当数据集太大无法放入内存时,步骤1将从磁盘或网络中获取下一个小批量数据。此外步骤1还涉及一定量的预处理,输入数据必须转换为数字张量或张量集合,然后再馈送给模型。在某些情况下,在将它们传递给模型之前,张量上还会应用其他转换,例如归一化、绕轴旋转等。

如果工作流程是严格按顺序执行的,即先执行步骤1,然后再执行步骤2,那么模型将始终需要等待下一批数据的输入、输出和预处理操作。GPU将无法得到有效利用,它将在加载下一个小批量数据时处于空闲状态。

为了解决这个问题,可以将数据管道视为生产者——消费者的问题。数据管道生成小批量数据并写入有界缓冲区。模型/GPU从缓冲区中消费小批量数据,执行前向/反向计算并更新模型权重。如果数据管道能够以模型/GPU消费的速度快速生成小批量数据,那么训练过程将会非常高效。

图片

2.Tensorflow tf.data API

Tensorflow tf.data API提供了一组丰富的功能,可用于高效创建数据管道,使用后台线程获取小批量数据,使模型无需等待。仅仅预先获取数据还不够,如果生成小批量数据的速度比GPU消费数据的速度慢,那么就需要使用并行化来加快数据的读取和转换。为此,Tensorflow提供了交错功能以利用多个线程并行读取数据,以及并行映射功能使用多个线程对小批量数据进行转换。

由于这些API基于多线程,因此可能会受到Python全局解释器锁(GIL)的限制。Python GIL限制了Python解释器一次只能运行单个线程的字节码。如果在管道中使用纯TensorFlow代码,通常不会受到这种限制,因为TensorFlow核心执行引擎在GIL的范围之外工作。但是,如果使用的第三方库没有发布GIL或者使用Python进行大量计算,那么依赖多线程来并行化管道就不可行。

3.使用多进程并行化数据管道

考虑以下生成器函数,该函数模拟加载和执行一些计算以生成小批量数据样本和标签。

def data_generator():
  for _ in range(10):
    # 模拟获取
    # 从磁盘/网络
    time.sleep(0.5)
    # 模拟计算
    for _ in range(10000):
      pass
    yield (
        np.random.random((4, 1000000, 3)).astype(np.float32), 
        np.random.random((4, 1)).astype(np.float32)
    )

接下来,在虚拟的训练管道中使用该生成器,并测量生成小批量数据所花费的平均时间。

generator_dataset = tf.data.Dataset.from_generator(
    data_generator,
    output_types=(tf.float64, tf.float64),
    output_shapes=((4, 1000000, 3), (4, 1))
).prefetch(tf.data.experimental.AUTOTUNE)

st = time.perf_counter()
times = []
for _ in generator_dataset:
    en = time.perf_counter()
    times.append(en - st)
    # 模拟训练步骤
    time.sleep(0.1)
    st = time.perf_counter()

print(np.mean(times))

据观察,平均耗时约为0.57秒(在配备Intel Core i7处理器的Mac笔记本电脑上测量)。如果这是一个真实的训练循环,GPU的利用率将相当低,它只需花费0.1秒进行计算,然后闲置0.57秒等待下一个批次数据。

为了加快数据加载速度,可以使用多进程生成器。

from multiprocessing import Queue, cpu_count, Process
def mp_data_generator():

    def producer(q):
        for _ in range(10):
            # 模拟获取
            # 从磁盘/网络
            time.sleep(0.5)
            # 模拟计算
            for _ in range(10000000):
                pass
            q.put((
                np.random.random((4, 1000000, 3)).astype(np.float32),
                np.random.random((4, 1)).astype(np.float32)
            ))
        q.put("DONE")

    queue = Queue(cpu_count()*2)

    num_parallel_processes = cpu_count()
    producers = []
    for _ in range(num_parallel_processes):
        p = Process(target=producer, args=(queue,))
        p.start()
        producers.append(p)
    done_counts = 0
    while done_counts < num_parallel_processes:
        msg = queue.get()
        if msg == "DONE":
            done_counts += 1
        else:
            yield msg
    queue.join()

测量等待下一个小批次数据所花费的时间,得到的平均时间为0.08秒,速度提高了近7倍,但理想情况下,希望这个时间接近0。

如果进行分析,可以发现相当多的时间都花在了准备数据的反序列化上。在多进程生成器中,生产者进程会返回大型NumPy数组,这些数组需要进行准备,然后在主进程中进行反序列化。

4.使用Ray并行化数据管道

Ray是一个用于在Python中运行分布式计算的框架,它带有一个共享内存对象存储区,可在不同进程间高效地传输对象。在不进行任何序列化和反序列化的情况下,对象存储区中的Numpy数组可在同一节点上的worker之间共享。Ray还可以轻松实现数据加载在多台机器上的扩展,并使用Apache Arrow高效地序列化和反序列化大型数组。

Ray带有一个实用函数from_iterators,可以创建并行迭代器,开发者可以用它包装data_generator生成器函数。

import ray
def ray_generator():
    num_parallel_processes = cpu_count()
    return ray.util.iter.from_iterators(
        [data_generator]*num_parallel_processes
    ).gather_async()

使用ray_generator,测量等待下一个小批量数据所花费的时间为0.02秒,比使用多进程处理的速度提高了4倍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1268378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

优化邮件群发效果的方法与策略

怎样优化邮件群发效果&#xff1f;这是许多企业在进行邮件营销时常常被问到的问题。邮件营销是一种高效且经济实惠的市场推广方式&#xff0c;但如何使邮件真正引起接收者的兴趣并产生预期的效果并不容易。好的营销效果可以带来高回报、高收益率&#xff0c;但是怎么提升群发效…

工会排队奖励模式:创新营销策略,实现共赢局面

在当今的商业环境中&#xff0c;创新营销策略的重要性日益凸显。工会排队奖励模式作为一种新型的营销策略&#xff0c;旨在通过结合线上和线下消费&#xff0c;激励消费者购买产品或服务&#xff0c;并获得返现奖励。这种模式通过将消费者的支出和商家的抽成资金纳入奖金池&…

CSS3样式详解之圆角、阴影及变形

目录 前言一、圆角样式&#xff08;border-radius&#xff09;二、元素阴影&#xff08;box-shadow&#xff09;三、过渡动画样式&#xff08;transition&#xff09;1. transition-property(用于设置属性名称)2. transition-duration&#xff08;设置时间&#xff09;3. trans…

7、信息收集(2)

文章目录 一、指纹识别1、Nmap工具2、Wafw00f工具 二、使用Maltego进行情报收集 一、指纹识别 1、Nmap工具 命令一&#xff1a;nmap -sS -sV <ip>&#xff0c;使用TCP SYN的方式&#xff0c;扫描目标主机上常规端口运行的服务版本。 -sS&#xff1a;指定使用TCP SYN的方…

注解(概念、分类、自定义注解)

注解基本概念 注解(元数据)为我们在代码中添加信息提供一种形式化的方法&#xff0c;我们可以在某个时刻非常方便的使用这些数据。将的通俗一点&#xff0c;就是为这个方法增加的说明或功能。 作用&#xff1a; 编写文档&#xff1a;通过代码里标识的注解生成文档【生成doc文…

常使用的定时任务

常使用的定时任务 一、 linux自带的定时任务 1、crontab 有这样一个需求&#xff1a;我们使用Java写一个工具jar包在系统空闲的时候去采集已经部署在Linux系统上的项目的一 些数据&#xff0c;可以使用 linux 系统的 crontab。 运行crontab -e&#xff0c;可以编辑定时器&…

git stash save untracked not staged

git stash save untracked not staged 如图 解决方案&#xff1a; git stash save "tag标记信息" --include-untracked或者&#xff1a; git stash save -u "tag标记信息" git stash clear清空本地暂存代码_zhangphil的博客-CSDN博客文章浏览阅读486次。…

合阔智云:实现API无代码开发,连接ERP系统和CRM系统提高运营效率

概述 合阔智云&#xff0c;一家成立于2011年的科技公司&#xff0c;核心业务是提供云原生和移动化设计的新一代全渠道“云端一体”履约中台和去中心化模式智能门店供应链业务中台。他们的系统可以无需API开发即可实现电商系统和客服系统的连接和集成&#xff0c;大大提高了企业…

通过python脚本上传本地/远程服务器文件到minio

前言 将文件上传到MinIO对象存储后&#xff0c;MinIO会将文件存储为对象(.meta文件)&#xff0c;并为每个对象生成相应的元数据。元数据是描述对象的属性和信息的数据。 通常&#xff0c;元数据包括对象的名称、大小、创建日期等。 在MinIO中&#xff0c;对象的元数据存储在独立…

振南技术干货集:各大平台串口调试软件大赏(1)

注解目录 &#xff08;串口的重要性不言而喻。为什么很多平台把串口称为 tty&#xff0c;比如 Linux、MacOS 等等&#xff0c;振南告诉你。&#xff09; 1、各平台上的串口调试软件 1.1Windows 1.1.1 STCISP &#xff08;感谢 STC 姚老板设计出 STCISP 这个软件。&#xf…

FPGA程序执行相关知识点

1.目前&#xff0c;大多数FPGA芯片是基于 SRAM 的结构的&#xff0c; 而 SRAM 单元中的数据掉电就会丢失&#xff0c;因此系统上电后&#xff0c;必须要由配置电路将正确的配置数据加载到 SRAM 中&#xff0c;此后 FPGA 才能够正常的运行。 常见的配置芯片有EPCS 芯片 &#x…

基于傅里叶变换的运动模糊图像恢复算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1、傅里叶变换与图像恢复 4.2、基于傅里叶变换的运动模糊图像恢复算法原理 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 matlab2022a 3.部分核心程序 %获取角度 img…

[密码学]DES

先声明两个基本概念 代换&#xff08;substitution&#xff09;,用别的元素代替当前元素。des的s-box遵循这一设计。 abc-->def 置换&#xff08;permutation&#xff09;&#xff0c;只改变元素的排列顺序。des的p-box遵循这一设计。 abc-->bac DES最核心的算法就是…

从零搭建AlibabaCloud微服务项目

1&#xff0c;创建maven项目工程如下 equipment-admin 后台equipment-applet 前台或小程序端或app、h5equipment-common 公共模块equipment-gateway 网关equipment-mapper mapper层操作数据库equipment-model 实体类对应数据库表 2&#xff0c;在父pom文件引入依赖 <proper…

找不到 sun.misc.BASE64Decoder ,sun.misc.BASE64Encoder 类

找不到 sun.misc.BASE64Decoder &#xff0c;sun.misc.BASE64Encoder 类 1. 现象 idea 引用报错 找不到对应的包 import sun.misc.BASE64Decoder; import sun.misc.BASE64Encoder;2. 原因 因为sun.misc.BASE64Decoder和sun.misc.BASE64Encoder是Java的内部API&#xff0c;通…

二叉树OJ题之二

今天我们一起来看一道判断一棵树是否为对称二叉树的题&#xff0c;力扣101题&#xff0c; https://leetcode.cn/problems/symmetric-tree/ 我们首先先来分析这道题&#xff0c;要判断这道题是否对称&#xff0c;我们首先需要判断的是这颗树根节点的左右子树是否对称&#xff0…

qt-C++笔记之主线程中使用异步逻辑来处理ROS事件循环和Qt事件循环解决相互阻塞的问题

qt-C笔记之主线程中使用异步逻辑来处理ROS事件循环和异步循环解决相互阻塞的问题 code review! 文章目录 qt-C笔记之主线程中使用异步逻辑来处理ROS事件循环和异步循环解决相互阻塞的问题1.Qt的app.exec()详解2.ros::spin()详解3.ros::AsyncSpinner详解4.主线程中结合使用的示…

图面试专题

一、概念 和二叉树的区别&#xff1a;图可能有环 常见概念 顶点&#xff08;Vertex&#xff09;&#xff1a; 图中的节点或点。边&#xff08;Edge&#xff09;&#xff1a; 顶点之间的连接线&#xff0c;描述节点之间的关系。有向图&#xff08;Directed Graph&#xff09;&…

05、基于梯度下降的协同过滤算法

05、基于梯度下降的协同过滤算法理论与实践Python 开始学习机器学习啦&#xff0c;已经把吴恩达的课全部刷完了&#xff0c;现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣&#xff0c;作为入门的素材非常合适。 协同过滤算法是一种常用的推荐算法&#xff0c;基于…

vue3+ts 实现时间间隔选择器

需求背景解决效果视频效果balancedTimeElement.vue 需求背景 实现一个分片的时间间隔选择器&#xff0c;需要把显示时间段显示成图表&#xff0c;涉及一下集中数据转换 [“02:30-05:30”,“07:30-10:30”,“14:30-17:30”]‘[(2,5),(7,10),(14,17)]’[4, 5, 6, 7, 8, 9, 10, …