机器学习洞察 | JAX,机器学习领域的“新面孔”

news2025/1/11 0:04:04

在之前的《机器学习洞察》系列文章中,我们分别针对于多模态机器学习和分布式训练、无服务器推理进行了解读,本文将为您重点介绍 JAX 的发展并剖析其演变和动机。下面,就让我们来认识一下 JAX 这一新崛起的深度学习框架——

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

 

开源机器学习框架的演进

从这张 GitHub Star 趋势图可以看到,自 2019 年 JAX 出现到如今保持着一个向上的抛物线走势。

在考察一个开源机器学习框架时,例如开发者熟知的 PyTorch, TensorFlow, MXNet 等,往往会从支持模型的广泛性、部署的成熟性、生态系统的丰富性来对它做一个评估:包括是否支持 Hugging Face 等主流模型,以及其框架相关研究论文的数量,还有它可提供复现代码的论文数量等等。

JAX 的源起

  • 为什么 Eager 模式是在 TensorFlow 1.4 版本之后引入的?

  • Eager 模式在 TensorFlow 2.0 之后变成了一个默认的执行模式,和原有的 Graph 模式的区别是什么?

回归并理清这些历史问题有助于开发者了解机器学习的演变逻辑,并了解 JAX 是如何吸取之前的教训,帮助开发者更方便地实践深度学习或机器学习应用。

Eager 模式 V.S. Graph 模式

在 TF 引进了 Eager 模式之后,它会采用更直观的界面,使用自然的 Python 代码和数据结构,而且享受更加便携的调试,在 Eager 模式中可以通过直接调用操作来检查和测试模型,而之前 Graph 这种模式有点类似于 C 和 C++,它的编程是写好程序之后要先进行编译才能运行。

Eager 模式有自然控制的流程,使用 Python 而不是图控制流,以及支持 GPU 和 TPU 的加速。做为开发者,我们希望可以客观地看待不同的框架,而不是比较他们的优劣。值得思考的一个问题是:通过了解 TF 的 Eager 模式对于 Graph 模式的改进,它的改进逻辑和思路在 JAX 中都有身影。

什么是 JAX

JAX 作为现在越来越流行的库,是一种类似于 NumPy(使用 Python 开源的数值计算扩展库)的轻量级用于阵列的计算。JAX 最开始的设计不仅仅是为了深度学习而设计的,深度学习只是它的一小部分,它提供了编写 NumPy 程序的能力,这些程序可以使用 GPU/TPU 自动拆分和加速。

JAX 用于基于阵列的计算时,开发者无需修改代码就可以在 CPU/GPU/ASIC 上同时运行,并支持原生 Python 和 NumPy 函数的四种可组合函数转换:

  • 自动微分 (Autodiff)

  • 即时编译 (JIT compilation)

  • 自动向量化 (Vectorization)

  • 代码并行化 (Parallelization)

JAX 初体验

我们可以通过下面这个简单的测试对比 JAX 和 NumPy 的计算性能。

输入一个 100 X 100 的二维数组 X,选取 ml.g4dn.12xlarge 计算实例通过 NumPy 和 JAX 分别对矩阵的前三次幂求和:

def fn(x):
  return x + x*x + x*x*x
x = np.random.randn(10000, 10000).astype(dtype='float32')
%timeit -n5 fn(x)

436 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 5 loops each)

我们发现此计算大约需要 436 毫秒。接下来,我们使用 JAX 实现以下计算:

jax_fn = jit(fn)
x = jnp.array(x)
%timeit jax_fn(x).block_until_ready()
3.67 ms ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX 仅在 3.67 毫秒内执行此计算,比 NumPy 快 118 倍以上。可见,JAX 有可能比 NumPy 快几个数量级(注意,JAX 使用 TPU 而 NumPy 正在使用 CPU)。

*以上为个人测试结果,非官方提供的数据,仅供研究参考

对比测试结果可得,NumPy 完成计算需要 436 毫秒,而 JAX 仅需要 3.67 毫秒,计算速度相差 100 多倍。这个测试也说明了为什么很多开发者对它的性能赞不绝口。

JAX 的动机剖析

我们希望通过回答这个问题来解读 JAX 的动机:

如何使用 Python 从头开始实现高性能和可扩展的深度神经网络?

在 NumPy 中创建深度学习系统

通常,Python 程序员会从 NumPy 之类的东西开始,因为它是一种熟悉的、基于数组的数据处理语言,在 Python 社区中已经使用了几十年。如果你想在 NumPy 中创建深度学习系统,你可以从预测方法开始。

这里可以用一个详细的例子说明问题,从 NumPy 上的深度学习的场景说起:

上述代码展示了订阅一个前馈的神经网络,它执行了一系列的点积和激活函数,然后将输入转化为某种可以学习的输出。一旦定义了这样的一个模型,接下来需要做就是要定义损失函数,这个函数将为你提供正在尝试优化的那些指标,来适应最佳的机器学习模型。例如以上代码的损失函数是以均方误差损失函数 MSE 为例。

现在我们来分析下:在深度学习场景使用 NumPy 还缺少什么?

硬件加速 (GPU/TPU)

自动微分 (autodiff) 快速优化

添加编译 (Compilation) 融合操作

向量化操作批处理 (batching)

大型数据集并行化 (Parallelization)

1)硬件加速 (GPU/TPU):首先深度学习需要大量的计算,我们想在加速的硬件上运行它。所以我们想在 GPU 和 TPU/ASIC 上运行这个模型,这对于经典的 NumPy 来说有点困难;

2)自动微分 (autodiff) 快速优化:接下来我们想要做自动微分,这样就可以有效地拟合这个损失函数,而不必自己来实现数值微分;

3)然后我们需要添加编译 (Compilation):这样你就可以将这些操作融合在一起,使它们更加高效;

4)向量化操作批处理 (Batching):另外,当我们编写了某些函数后,可能希望将其应用于多个数据片段,而不再需要重写预测和损失函数来处理这些批量数据;

5)大型数据集并行化 (Parallelization):最后,如果我们正在处理大型数据集,会希望能够支持跨多个 cores 或多台 machines 做并行化操作。

JAX 的动机剖析:XLA 和自动定位

JAX 非常重要的一个动机就是 XLA 和自动定位。让我们来看看 JAX 可以做些什么,来填补前面分析的在深度学习场景使用 NumPy 还缺少的功能。

首先,用 jax.numpy 替换 numpy 导入模块。在许多情况下,jax.numpy 与经典的 NumPy 具有相同的 API,但 jax.numpy 可以完成前面分析时发现 NumPy 缺少,但是在深度学习场景却非常需要的的东西。

JAX 可以通过 XLA 后端,来自动定位 CPU、GPU 和 TPU 或者 ASIC,以便快速计算模型和算法。

JAX 动机剖析:Autograd

第二个重要动机是 Autograd。开发者可以通过下面的代码调用 Autograd 版本:

通过 from jax import grad 模块,使用 Autograd 的更新版本,JAX 可以自动微分原生 Python 和 NumPy 函数。它可以处理 Python 功能的大子集,包括循环、Ifs、递归等,甚至可以接受导数的导数。

JAX 提供了一组可组合的变换,其中之一是 grad 变换

例子中,像 mse_loss 这样的损失函数,通过 grad (mse_loss) 将其转换为计算梯度的 Python 函数。

 

Autograd 的主要预期应用是基于梯度的优化。

有关更多信息,请查看 JAX 教程和示例:Https://github.com/hips/autograd

JAX 动机剖析:vmap

在使用梯度函数时,开发者希望将其应用于多个数据片段,而在 JAX 中,你不再需要重写预测和损失函数来处理这些批量数据。

如图中代码最后一行 (perexample_grads …) 所诠释的那样,如果你通过 vmap transform 传递它,这会自动向量化这个代码,这样就可以在多个批次中使用相同的代码。

JAX 动机剖析:jit

JAX 还有一个重要的组合函数——jit,开发者可以使用 jit transform 实现即时编译。

jit 结合后台可以使用 XLA 后端编译器将操作融合在一起,来自动定位 CPU、GPU 和 TPU 或者 ASIC,加速计算模型和算法。

JAX 动机剖析:pmap

最后,如果想并行化你的代码,有一个和 vmap 非常相似得转换叫 pmap。

通过代码运行 pmap,开发者能够本地定位系统中的多个内核或你有权访问的 GPU、TPU 或 ASIC 集群。

这最终成为一个非常强大的系统,可以在没有太多额外代码的情况下构建我们用类似于 NumPy 的熟悉 API,做深度学习的快速计算等工作负载。

JAX 的关键设计思想

通过上述对比可以看到, JAX 不仅为开发者提供了和 NumPy 相似的 API,上述的五大函数转换组合也让 JAX 可以在不需要额外代码的情况下,帮助开发者构建深度学习应用进行快速计算。

这里的关键思想是:

1)首先,在 JAX 中,Python 代码被追溯到中间表示,JAX 知道如何转换这个中间表示。

2)在下篇文章中我们也将详细分析 JAX 的工作机制:同样的中间表示,通过允许 XLA 进行特定领域 (CPU/GPU 等) 的编译,如何来瞄准不同的后端;

3)另外,JAX 还有基于 NumPy 和 SciPy 的面向用户的 API,如果开发者一直使用 Python 的技术栈,应该会对 JAX 感觉相当熟悉;

4)最后,JAX 提供了功能强大的变换:grad, git, vmap, pmap 等,来支持深度学习等计算,因此 JAX 可以做到之前 NumPy 代码无法做到的事情。

通过前面的介绍,我们可以看到,开发者熟悉的 API 和语法以及四种强大的转换组合让开发者更加喜欢 JAX,并让深度学习场景或者科学计算变得非常简便。

欢迎回顾关于机器学习的往期文章,以及更多面向开发者的技术分享。请持续关注 Build On Cloud 微信公众号!

往期推荐

  • 机器学习洞察 | 挖掘多模态数据机器学习的价值
  • 机器学习洞察 | 分布式训练让机器学习更加快速准确
  • 机器学习洞察 | 降本增效,无服务器推理是怎么做到的?

 

 

作者黄浩文

亚马逊云科技资深开发者布道师,专注于 AI/ML、Data Science 等。拥有 20 多年电信、移动互联网以及云计算等行业架构设计、技术及创业管理等丰富经验,曾就职于 Microsoft、Sun Microsystems、中国电信等企业,专注为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

 文章来源:https://dev.amazoncloud.cn/column/article/63e33239e5e05b6ff897ca0e?sc_medium=regulartraffic&sc_campaign=crossplatform&sc_channel=CSDN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/721134.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

threejs课程笔记-20 向量点乘、叉乘

向量点乘dot 点乘是向量的一种运算规则,点乘也有其它称呼,比如点积、数量积、标量积。 threejs三维向量Vector3封装了一个点乘相关的方法.dot(),本节课主要目的就是让大家能够灵活应用点乘方法.dot() 已知向量a和向量b 已知两个向量a和b&…

设计模式3:单例模式:静态内部类模式是怎么保证单例且线程安全的?

上篇文章:设计模式3:单例模式:静态内部类单例模式简单测试了静态内部类单例模式,确实只生成了一个实例。我们继续深入理解。 静态变量什么时候被初始化? public class Manager {private static class ManagerHolder …

探索 Jetson Nano 为 myCobot 280 机械臂提供的强大功能

探索 Jetson Nano 为 myCobot 280 提供的强大功能,机器人技术的一个有前途的组合 介绍 近年来,科学技术的发展给我们的生活带来了许多新的产品和服务,包括机器人在各个领域的集成。机器人已经成为我们生活中必不可少的一部分,从…

C语言求鸡兔同笼问题案例讲解

前言: 作者本人在今年4月份参加了一个C语言考试,编程大题里有一道鸡兔同笼问题;本来以为简简单单,几分钟搞定,拿个满分;结果翻车了,因为我在考场的时候想着,母鸡到底有几只脚呢&…

FlinkCDC第二部分-搭建Flink单机服务,ctrl就完事~

Flink版本:1.16 环境:Linux CentOS 7.0、jdk1.8 基础文件:flink-1.16.2-bin-scala_2.12.tgz、flink-connector-jdbc-3.0.0-1.16.jar、flink-sql-connector-mysql-cdc-2.3.0.jar 1. 在目录/home/flink下解压flink-1.16.2-bin-scala_2.12.tg…

基于 R 对卫星图像进行无监督 kMeans 分类

一、前言 本文将向您展示如何使用 R 对卫星图像执行非常基本的 kMeans 无监督分类。我们将在 Sentinel-2 图像的一小部分上执行此操作。 Sentinel-2 是由欧洲航天局发射的一颗卫星,其数据可在此处免费访问。 我要使用的图像显示了 Neusiedl 湖的北部(奥地…

系统移植 根文件系统的移植 7.5

根文件系统的移植 根文件系统:根目录下的所有文件和工具的集合 根文件系统是内核启动后挂载的第一个文件系统系统引导程序会在根文件系统挂载后从中把一些基本的初始化脚本和服务等加载到内存中去运行文件系统层次结构标准 文件具体的属性只能在内核中看到&#xf…

django-vue-admin curd_demo 快速crud教程

django-vue-admin curd_demo 快速crud教程 快速CRUD开发教程:https://bbs.django-vue-admin.com/article/9.html 如何在 env.py 文件配置Mysql数据库:https://bbs.django-vue-admin.com/question/4.html 导入导出配置教程:https://bbs.djang…

Linux MySQL三种安装方式

MySQL 三种常用安装方式: 离线安装: 在mysql官网进行下载,步骤如下: 然后找到这个: 因为我这里使用的OS为CentOS7,大家可以按自己的系统进行选择。 最后通过XFTP/SCRTXF将文件传到虚拟机上。 然后将…

剑指 Offer 15. 二进制中1的个数 / LeetCode 191. 位1的个数(位运算)

题目: 链接:剑指 Offer 15. 二进制中1的个数;LeetCode 191. 位1的个数 难度:简单 编写一个函数,输入是一个无符号整数(以二进制串的形式),返回其二进制表达式中数字位数为 ‘1’ 的…

MYSQL04高级_逻辑架构剖析、查询缓存、解析器、优化器、执行器、存储引擎

文章目录 ①. 逻辑架构剖析②. 服务层 - 查询缓存③. 服务层 - 解析器④. 服务层 - 优化器⑤. 服务层 - 执行器 ①. 逻辑架构剖析 ①. 服务器处理客户端请求 ②. 连接层 系统(客户端)访问MySQL服务器前,做的第一件事就是建立TCP连接经过三次握手建立连接成功后,MySQL服务器对…

安装centos报错usb2-port3: Cannot enable. Maybe the USB cable is bad?的垃圾解决办法

使用U盘安装系统,用Rufus烧录,建议使用DVD版本,MINIMAL没有图形界面,同时安装的时候也要选安装GNOME联想P330在开机Lenovo出现时狂按F12,选USB UEFI partition1进入,差不多这个界面,还有一些BIO…

rt-thread-------内存管理(内存堆)

系列文章目录 rt-thread 之 fal移植 rt-thread 之 生成工程模板 STM32------串口理论篇 rt-thread------串口V1版本(一)配置 rt-thread------串口V1版本(二)发送篇 rt-thread------串口V1版本(三)接收篇 r…

系统移植 uboot移植 7.3

给fs4412板子配置uboot uboot激活流程 (arch/arm/cpu/armv7/start.S) reset 1.设置CPU模式为SVC//在这里加点灯的代码。看程序的代码有没有执行// ldr r0,0x11000C40 ldr r1,[r0] bic r1,r1,#0xf0000000 orr r1,r1,#0x10000000 str r1,[r0]ldr r0,0x11…

打包时未添加livepusher模块

我们的项目采用的是混入开发,html5, 使用到了安卓离线打包,其中使用到了livepusher模块,本来没什么难事的,很简单的一个问题,但是中文的官方文档却介绍错了包名,一直在郁闷为啥不行,痛苦啊。本来…

WiFi cfg80211的kernel架构(基于Linux 3.08)

目录 1.框架 2.主要流程 2.1.malloc & init(softmac) 2.1.3 内存分配 2.2. 结构体关系 2.3.初始化顺序 2.4.beacon frame 2.4.1.接收流程 2.4.2.beacon响应流程 2.5.scan 2.6.auth and associate 2.7. rx/tx data 2.7.1.rx 2.7.2.xmit 2.8.csa 2.9.missi…

MATLAB基础篇(下)

本文为MATLAB基础篇(上)的后续。 二、 MATLAB基本语法 7、基本绘图方法 Ⅰ、 MATLAB绘图的一般步骤 对数轴进行采样对采样点计算相应的函数值, 得到平面(或空间)上的点的数据运用绘图命令将数据进行图形化显示 x-1:0.01:1; %对数轴进行采样ysin(1./x);…

Linux之基础git命令的使用

Linux之基础git命令的使用 提交第一步提交第二步提交第三步查看历史提交记录查看是否需要提交过滤提交时的文件 git命令的初始使用 在使用之前,我们先确定我们的xshell是否安装的git,需要输入命令 git --version 如果没有显示版本号,则需要进…

【操作系统】c语言--使用信号量解决生产者和消费者问题

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

Vue--》Vue3打造可扩展的项目管理系统后台的完整指南(九)

今天开始使用 vue3 ts 搭建一个项目管理的后台&#xff0c;因为文章会将项目的每一个地方代码的书写都会讲解到&#xff0c;所以本项目会分成好几篇文章进行讲解&#xff0c;我会在最后一篇文章中会将项目代码开源到我的GithHub上&#xff0c;大家可以自行去进行下载运行&…