机器学习洞察 | 一文带你“讲透” JAX

news2025/1/10 21:40:31

上篇文章中,我们详细分享了 JAX 这一新兴的机器学习模型的发展和优势,本文我们将通过 Amazon SageMaker 示例展示如何部署并使用 JAX。

JAX 的工作机制


JAX 的完整工作机制可以用下面这幅图详细解释:

图片来源:“Intro to JAX” video on YouTube by Jake VanderPlas, Tech leader from JAX team

在图片左侧是开发者自己编写的 Python 代码,JAX 会追踪并变换成 JAX IR 的中间表示,并按照 Python 代码,通过 jax.jit 将其编译成 HLO (High Level Optimized) 代码,代表高级的优化代码,提供给 XLA 进行读取。XLA 在获取编译的 HLO 代码之后,会分配到对应的 CPU、GPU、TPU 或者 ASIC。

对于开发者来说,只需完成您的 Python 代码即可实现这一流程。开发者可以将 JAX 转换视为首先对 Python 函数进行跟踪专门化,然后将其转换为一个小而行为良好的中间形式,然后使用特定于转换的解释规则进行解释。

为什么 JAX 可以在如此小的软件包中提供如此强大的功能呢?

首先,它从熟悉且灵活的编程接口(使用 NumPy 的 Python)开始,并且使用实际的 Python 解释器来完成大部分繁重的工作;其次,它将计算的本质提炼成一个静态具有高阶功能的类型表达式语言,即 Jaxpr 语言。

JAX 应用场景


自 2019 年 JAX 出现之后,使用它的开发者逐年增多。在 2022 年更是达到了非常火热的状态,甚至有人认为它有可能会取代其他的机器学习框架。

支持 JAX 生态的应用场景包括:

  • 深度学习 (Deep Learning):JAX 在深度学习场景下应用很广泛,很多团队基于 JAX 开发了更加高级的 API 支持不同的场景,方便开发者使用。

  • 科学模拟 (Scientific Simulation):JAX 的出现不仅仅是针对于深度学习,其实也拥有很多其他的使命,如科学模拟。

  • 机器人与控制系统 (Robotics and Control Systems)

  • 概率编程 (Probabilistic Programming)

训练和部署深度学习模型


我们用下面这个具体例子展示使用 JAX 来和 Amazon SageMaker 训练和部署深度学习模型,会用到 Amazon SageMaker 的 BYOC 这种模式。

如上图所示,在这个 Amazon SageMaker 的示例中提供了 JAX 的代码示例:https://sagemaker-examples.readthedocs.io/en/latest/advanced_functionality/jax_bring_your_own/train_deploy_jax.html

在 Amazon SageMaker 上基于 JAX 的框架可使用自定义的容器来训练神经网络。

如图的 Amazon SageMaker Examples 提供的 JAX 示例中,我们使用自定义容器在 SageMaker 上 基于 JAX 框架或库训练神经网络。这在单个容器上是可能的,因为我们使用了 sagemaker-training-toolkit,它允许你在自己的自定义容器中使用脚本模式。自定义容器可以使用内置的 SageMaker 训练作业功能,如竞价训练和超参数调整。

训练模型后,您可以将经过训练的模型部署到托管端点。如前所述,SageMaker 具有推理容器,这些容器已针对亚马逊云科技的硬件和常用深度学习框架进行了优化。其中一项优化是针对 TensorFlow 框架的优化。由于 JAX 支持将模型导出为 TensorFlow SavedModel 格式,因此我们使用该功能来展示如何在优化的 SageMaker TensorFlow 推理端点上部署经过训练的模型。

整个训练和部署主要分为以下五个步骤:

  1. 创建 Docker 镜像并将其推送到 Amazon ECR。

  1. 使用 SageMaker 开发工具包传教自定义框架估算器,以便将模型输出归类为 TensorFlowModel。

  1. 代码仓库中有训练估算器的脚本。

  1. 使用 GPU 上的 SageMaker 训练作业来训练每个模型。

  1. 将模型部署到完全托管的终端节点。

下面我们来看看详细步骤:

  1. 创建 Docker 镜像并将其推送到 Amazon ECR。

*创建使用 JAX 训练模型容器的 Dockerfile

Docker 映像是在 NVIDIA 提供的支持 CUDA 的容器之上构建的。为了确保作为 JAX 中功能基础的 jaxlibpackage 支持 CUDA,请从 jax_releases 存储库中下载 jaxlib 软件包。

  • AX releases

https://storage.googleapis.com/jax-releases/jax_releases.html

这里需要注意的是:为了确保作为 JAX 中的功能基础的 JAX library package 能够支持 cuda,建议在去做这个创建自定义容器时,去看一下目前 JAX release 这个存储库中,它下载的这个 JAX library 包的版本号或者相关注意事项等等。

2、使用 SageMaker 开发工具包创建自定义框架估算器,以便将模型输出归类为 TensorFlowModel。

创建基本 SageMaker 框架估算器的子类,将估算器的模型类型指定为 TensorFlow 模型。为此,我们指定了一个自定义 create_model 方法,该方法使用现有的 TensorFlowModel 类来启动推理容器。

3、通过代码仓库训练估算器的脚本。

您可以通过传统的 SageMaker Python SDK 工作流通过模型执行训练、部署和运行推理。我们确保导入并初始化自定义框架估算器的代码片段中定义的 JaxEstimator,然后运行标准的 .fit () 和 .deploy () 调用。

对于 JAX ,可以调用 jax2tf 函数来执行相同的操作。代码在存储库中可用。设置正确的路径 /opt/ml/model/1 非常重要,这是 SageMaker wrapper(封装器) 假定模型已存储的地方。、

前面提到的 JAX 和 TF 的互操作性,目前 JAX 是通过 JAX to TF 这样的一个软件包,来为 JAX 和 TF 的互操作性提供支持,那 jax2tf.convert 是用于在 TensorFlow 的上下文中使用 JAX 函数,那 jax2tf.call_tf 是用于在 JAX 的上下文中使用的 TensorFlow 函数互操作来完成的。

4、使用 GPU 上的 SageMaker 训练作业来训练每个模型。

  1. 将模型部署到完全托管的终端节点。

vanilla_jax_predictor = vanilla_jax_estimator.deploy(
    initial_instance_count=1, instance_type="ml.m4.xlarge"
)
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def test_image(predictor, test_images, test_labels, image_number):
    np_img 
= np.expand_dims(np.expand_dims(test_images[image_number], axis=
-1
), axis=
0
)

    result = predictor.predict(np_img)
    pred_y = np.argmax(result["predictions"])

    print("True Label:", test_labels[image_number])
    print("Predicted Label:", pred_y)
    plt.imshow(test_images[image_number]

*部署和准备输入的测试图像

*进行推理

有关在 Amazon SageMaker 上使用 JAX 训练和部署深度学习模型的详细过程和代码,请参考亚马逊云科技官方博客

如图所示,上面的两张图是一个部署模型的例子,下面的图是进行推理的例子。由于我们的 Framework Estimator 知道模型将使用 TensorFlowModel 提供服务,因此部署这些端点只是对 estimator.deploy () 方法做调用即可。

参考资料

  • Training and Deploying ML Models using JAX on SageMaker

  • Train and deploy deep learning models using JAX with Amazon SageMaker

  • AX core from scratch

  • Building JAX from source

JAX 是一种越来越流行的库,它支持原生 Python 或 NumPy 函数的可组合函数转换,可用于高性能数值计算和机器学习研究。JAX 提供了编写 NumPy 程序的能力,这些程序可以使用 GPU/TPU 自动差分和加速,从而形成了更灵活的框架来支持现代深度学习架构。在这两篇文章中我们讨论了有关 JAX 的一些主题,希望对您用使用 JAX 这一框架进行深度学习研究有所帮助。

往期推荐


  • 机器学习洞察 | JAX,机器学习领域的“新面孔”

  • 机器学习洞察 | 降本增效,无服务器推理是怎么做到的?

  • 机器学习洞察 | 分布式训练让机器学习更加快速准确

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/352060.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python 的selenium自动操控浏览器教程(2)

人生苦短,我用py 文章目录人生苦短,我用py关于部分网页无法找到元素的问题1方案1方案2关于部分网页无法找到元素的问题2解决方案被网站检查出来我们使用了selenium了怎么办?如何实现前进后退当使用py删除文件时报禁止访问怎么办怎么使用py实现…

【服务器数据恢复】存储之间迁移数据时数据损坏的数据恢复案例

服务器数据恢复环境&故障: 一台某品牌的存储设备,Windows操作系统。由于业务需求,需要把这台存储设备中的数据迁移到另外一台存储设备中,在迁移数据过程中突然无法读取数据,管理界面报错。管理员查看服务器内的数据…

【Github的注册】

目录 一、打开官网 二、注册 1、点击右上角的 sign up 2、依次输入邮箱、密码、用户名 3、验证真人,create account,验证码 4、点击“个人“,“学生”,进入另一个页面后滚动鼠标直接点击continue,进入另一个页面后…

Yakit Web Fuzzer 终极能力强化:热加载 Fuzz

Background 在 HTB:BountyHunter 中,我们发现 Web Fuzzer 在使用中可以 “更强”,我们需要编写 Yak 脚本的事情,如果可以经过某些 Web Fuzzer 的优化,可以达到同样的效果。 在一个标签中,我们实现{{base6…

智慧工地火焰烟火识别检测 yolo

智慧工地火焰烟火识别检测算法通过yolo网络模型深度学习技术,智慧工地火焰烟火识别检测算法对现场浓烟和烟火情况,立即抓拍告警并进行存档。YOLO 的核心思想就是把目标检测转变成一个回归问题,利用整张图作为网络的输入,仅仅经过一…

图解LeetCode——2335. 装满杯子需要的最短总时长

一、题目 现有一台饮水机,可以制备冷水、温水和热水。每秒钟,可以装满 2 杯 不同 类型的水或者 1 杯任意类型的水。 给你一个下标从 0 开始、长度为 3 的整数数组 amount ,其中 amount[0]、amount[1] 和 amount[2] 分别表示需要装满冷水、温…

【THREE.JS】网页中的炫酷3D

web3d一、前言粒子特效二维漫画可视化后期处理二、项目使用流程2.1 项目结构2.2 基本使用2.3 项目模板2.4 技术栈三、基础动画3.1 THREE.Clock3.2 GASP四、照相机8.1 正交相机8.2 透视相机4.3 相机控制器五、画布和全屏六、几何体七、Debug UI八、纹理贴图8.1 mipmapping8.2 放…

关于IcmpSendEcho2的使用和回调问题

由于我的需求是短时间内ping多台机子,所以需要异步执行,微软提供的例子是同步方式的,根据微软官方提供的icmpSendEcho2 函数的信息 ,我需要定义一个空的宏PIO_APC_ROUTINE_DEFINED ,定义完之后,编译又出现…

Java基础:回调函数

因为在看Android代码的时候发现了许多关于回调函数的知识, 所以去了解了一下. 对于我来说不太好懂, 因为我觉得看的那些博文的讲法对我来说很绕, 所以我在理解了之后想写一篇关于回调函数的博文来给和我一样理解能力稍差的人一点帮助. 回调函数的作用其实就是将需要这个功能的调…

【JavaGuide面试总结】Redis篇·中

【JavaGuide面试总结】Redis篇中1.Redis 单线程模型了解吗?2.Redis6.0 之后为何引入了多线程?3.Redis 是如何判断数据是否过期的呢?4.过期的数据的删除策略了解么?5.Redis 内存淘汰机制了解么?6.什么是 RDB 持久化&…

【Python+Appium】自动化测试框架

目录:导读 appium简介 设计思路 测试框架设计 测试框架目录结构 测试框架思维导图 测试结果展示 appium简介 Appium 是一个开源的、跨平台的测试框架,可以用来测试 Native App、混合应用、移动 Web 应用(H5 应用)等&#xf…

Spring之依赖注入源码解析

Spring之依赖注入源码解析 依赖注入原理流程图: https://www.processon.com/view/link/5f899fa5f346fb06e1d8f570 Spring 中有几种依赖注入的方式? 首先分为两种: 1、手动注入 2、自动注入 1、手动注入 在 XML 中定义 Bean 时&#xff0c…

Gartner 再度预测2023低代码趋势,真的会赚钱吗?

2023年,从业者对低代码的发展充满了想象,人们认为,未来低代码的商业价值不可估量。 此话并非空穴来风。据Gartner的最新报告显示,到2023年,超过70%的企业将采用低代码作为他们发展战略的关键目标之一;到202…

训练自己的中文word2vec(词向量)--skip-gram方法

训练自己的中文word2vec(词向量)–skip-gram方法 什么是词向量 ​ 将单词映射/嵌入(Embedding)到一个新的空间,形成词向量,以此来表示词的语义信息,在这个新的空间中,语义相同的单…

双塔多目标MVKE

MVKE:Mixture of Virtual-Kernel Experts for Multi-Objective User Profile Modeling MVKE论文中是给用户打tag标记,构建用户画像。使用的也是经典的双塔模型,另外在双塔的基础上面叠加了ctr和cvr的多个目标。但是论文最大的创新点是在用户…

基于龙芯 CPU 的气井控制器的软件设计(三)

4.1 系统软件的总体设计 基于龙芯 CPU 的气井控制器的设计需要开发测试硬件模块的测试软件,主要对 RTC 模块、存储器模块、4G 通信、以太网通信、UART 串口以及 AI 模块进行了驱动程序和 应用程序设计。将各个模块设计为不同的任务,龙芯 RTU 软件设计流程…

Redis 监听过期的key(KeyExpirationEventMessageListener)

目录一、简介二、maven依赖三、编码实现3.1、application.properties3.2、Redis配置类3.3、监听器3.4、服务类3.5、工具类四、测试4.1、测试类4.2、单实例4.3、多实例结语一、简介 本文今天主要是讲Redis中对过期key的监听,可能很多小伙伴不会,或者使用会…

day15_常用类

今日内容 上课同步视频:CuteN饕餮的个人空间_哔哩哔哩_bilibili 同步笔记沐沐霸的博客_CSDN博客-Java2301 零、 复习昨日 一、作业 二、代码块[了解] 三、API 四、Object 五、包装类 六、数学和随机 零、 复习昨日 抽象接口修饰符abstractinterface是不是类类接口属性正常属性没…

Leetcode(每日一题)——1139. 最大的以 1 为边界的正方形

摘要 1139. 最大的以 1 为边界的正方形 一、以1为边界的最大正方形 1.1 动态规划 第530题需要正方形所有网格中的数字都是1,只要搞懂动态规划的原理,代码就非常简洁。而这题只要正方形4条边的网格都是1即可,中间是什么数字不用管。 这题…

Hive的安装与配置

一、配置Hadoop环境先看看伪分布式下的集群环境有没有错误的情况:输入命令:start-all.sh jps查看伪分布式的所有进程是否完善二、解压并配置HiveHive压缩包→ https://pan.baidu.com/s/1eOF_ICZV8rV-CEh3nX-7Xw 提取码: m31e 复制这段内容后打开百度网盘…