在上篇文章中,我们详细分享了 JAX 这一新兴的机器学习模型的发展和优势,本文我们将通过 Amazon SageMaker 示例展示如何部署并使用 JAX。
JAX 的工作机制
JAX 的完整工作机制可以用下面这幅图详细解释:
图片来源:“Intro to JAX” video on YouTube by Jake VanderPlas, Tech leader from JAX team
在图片左侧是开发者自己编写的 Python 代码,JAX 会追踪并变换成 JAX IR 的中间表示,并按照 Python 代码,通过 jax.jit 将其编译成 HLO (High Level Optimized) 代码,代表高级的优化代码,提供给 XLA 进行读取。XLA 在获取编译的 HLO 代码之后,会分配到对应的 CPU、GPU、TPU 或者 ASIC。
对于开发者来说,只需完成您的 Python 代码即可实现这一流程。开发者可以将 JAX 转换视为首先对 Python 函数进行跟踪专门化,然后将其转换为一个小而行为良好的中间形式,然后使用特定于转换的解释规则进行解释。
为什么 JAX 可以在如此小的软件包中提供如此强大的功能呢?
首先,它从熟悉且灵活的编程接口(使用 NumPy 的 Python)开始,并且使用实际的 Python 解释器来完成大部分繁重的工作;其次,它将计算的本质提炼成一个静态具有高阶功能的类型表达式语言,即 Jaxpr 语言。
JAX 应用场景
自 2019 年 JAX 出现之后,使用它的开发者逐年增多。在 2022 年更是达到了非常火热的状态,甚至有人认为它有可能会取代其他的机器学习框架。
支持 JAX 生态的应用场景包括:
深度学习 (Deep Learning):JAX 在深度学习场景下应用很广泛,很多团队基于 JAX 开发了更加高级的 API 支持不同的场景,方便开发者使用。
科学模拟 (Scientific Simulation):JAX 的出现不仅仅是针对于深度学习,其实也拥有很多其他的使命,如科学模拟。
机器人与控制系统 (Robotics and Control Systems)
概率编程 (Probabilistic Programming)
训练和部署深度学习模型
我们用下面这个具体例子展示使用 JAX 来和 Amazon SageMaker 训练和部署深度学习模型,会用到 Amazon SageMaker 的 BYOC 这种模式。
如上图所示,在这个 Amazon SageMaker 的示例中提供了 JAX 的代码示例:https://sagemaker-examples.readthedocs.io/en/latest/advanced_functionality/jax_bring_your_own/train_deploy_jax.html
在 Amazon SageMaker 上基于 JAX 的框架可使用自定义的容器来训练神经网络。
如图的 Amazon SageMaker Examples 提供的 JAX 示例中,我们使用自定义容器在 SageMaker 上 基于 JAX 框架或库训练神经网络。这在单个容器上是可能的,因为我们使用了 sagemaker-training-toolkit,它允许你在自己的自定义容器中使用脚本模式。自定义容器可以使用内置的 SageMaker 训练作业功能,如竞价训练和超参数调整。
训练模型后,您可以将经过训练的模型部署到托管端点。如前所述,SageMaker 具有推理容器,这些容器已针对亚马逊云科技的硬件和常用深度学习框架进行了优化。其中一项优化是针对 TensorFlow 框架的优化。由于 JAX 支持将模型导出为 TensorFlow SavedModel 格式,因此我们使用该功能来展示如何在优化的 SageMaker TensorFlow 推理端点上部署经过训练的模型。
整个训练和部署主要分为以下五个步骤:
创建 Docker 镜像并将其推送到 Amazon ECR。
使用 SageMaker 开发工具包传教自定义框架估算器,以便将模型输出归类为 TensorFlowModel。
代码仓库中有训练估算器的脚本。
使用 GPU 上的 SageMaker 训练作业来训练每个模型。
将模型部署到完全托管的终端节点。
下面我们来看看详细步骤:
创建 Docker 镜像并将其推送到 Amazon ECR。
*创建使用 JAX 训练模型容器的 Dockerfile
Docker 映像是在 NVIDIA 提供的支持 CUDA 的容器之上构建的。为了确保作为 JAX 中功能基础的 jaxlibpackage 支持 CUDA,请从 jax_releases 存储库中下载 jaxlib 软件包。
AX releases
https://storage.googleapis.com/jax-releases/jax_releases.html
这里需要注意的是:为了确保作为 JAX 中的功能基础的 JAX library package 能够支持 cuda,建议在去做这个创建自定义容器时,去看一下目前 JAX release 这个存储库中,它下载的这个 JAX library 包的版本号或者相关注意事项等等。
2、使用 SageMaker 开发工具包创建自定义框架估算器,以便将模型输出归类为 TensorFlowModel。
创建基本 SageMaker 框架估算器的子类,将估算器的模型类型指定为 TensorFlow 模型。为此,我们指定了一个自定义 create_model 方法,该方法使用现有的 TensorFlowModel 类来启动推理容器。
3、通过代码仓库训练估算器的脚本。
您可以通过传统的 SageMaker Python SDK 工作流通过模型执行训练、部署和运行推理。我们确保导入并初始化自定义框架估算器的代码片段中定义的 JaxEstimator,然后运行标准的 .fit () 和 .deploy () 调用。
对于 JAX ,可以调用 jax2tf 函数来执行相同的操作。代码在存储库中可用。设置正确的路径 /opt/ml/model/1 非常重要,这是 SageMaker wrapper(封装器) 假定模型已存储的地方。、
前面提到的 JAX 和 TF 的互操作性,目前 JAX 是通过 JAX to TF 这样的一个软件包,来为 JAX 和 TF 的互操作性提供支持,那 jax2tf.convert 是用于在 TensorFlow 的上下文中使用 JAX 函数,那 jax2tf.call_tf 是用于在 JAX 的上下文中使用的 TensorFlow 函数互操作来完成的。
4、使用 GPU 上的 SageMaker 训练作业来训练每个模型。
将模型部署到完全托管的终端节点。
vanilla_jax_predictor = vanilla_jax_estimator.deploy(
initial_instance_count=1, instance_type="ml.m4.xlarge"
)
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
def test_image(predictor, test_images, test_labels, image_number):
np_img
= np.expand_dims(np.expand_dims(test_images[image_number], axis=
-1
), axis=
0
)
result = predictor.predict(np_img)
pred_y = np.argmax(result["predictions"])
print("True Label:", test_labels[image_number])
print("Predicted Label:", pred_y)
plt.imshow(test_images[image_number]
*部署和准备输入的测试图像
*进行推理
有关在 Amazon SageMaker 上使用 JAX 训练和部署深度学习模型的详细过程和代码,请参考亚马逊云科技官方博客。
如图所示,上面的两张图是一个部署模型的例子,下面的图是进行推理的例子。由于我们的 Framework Estimator 知道模型将使用 TensorFlowModel 提供服务,因此部署这些端点只是对 estimator.deploy () 方法做调用即可。
参考资料
Training and Deploying ML Models using JAX on SageMaker
Train and deploy deep learning models using JAX with Amazon SageMaker
AX core from scratch
Building JAX from source
JAX 是一种越来越流行的库,它支持原生 Python 或 NumPy 函数的可组合函数转换,可用于高性能数值计算和机器学习研究。JAX 提供了编写 NumPy 程序的能力,这些程序可以使用 GPU/TPU 自动差分和加速,从而形成了更灵活的框架来支持现代深度学习架构。在这两篇文章中我们讨论了有关 JAX 的一些主题,希望对您用使用 JAX 这一框架进行深度学习研究有所帮助。
往期推荐
机器学习洞察 | JAX,机器学习领域的“新面孔”
机器学习洞察 | 降本增效,无服务器推理是怎么做到的?
机器学习洞察 | 分布式训练让机器学习更加快速准确