TensorFlow会被JAX代替吗,使用JAX训练第一个机器学习模型

news2025/1/12 1:04:01

上期文章我们分享了JAX的概念,Jax 是来自 Google 的一个相对较新的机器学习库。它更像是一个 autograd 库,可以区分每个本机 python 和 NumPy 代码。

“Python+NumPy 程序的可组合转换:微分、向量化、JIT 到 GPU/TPU 等等”。该库利用 grad 函数转换将函数转换为返回原始函数梯度的函数。Jax 还提供了一个函数转换 JIT,用于对现有函数进行即时编译,并分别提供了用于矢量化和并行化的 vmap 和 pmap

JAX 是Autograd和XLA的结合,JAX 本身不是一个深度学习的框架,他是一个高性能的数值计算库,更是结合了可组合的函数转换库,用于高性能机器学习研究。深度学习只是其中的一部分而已,但是你完全可以把自己的深度学习移植到JAX 上面。

自2018 年底谷歌的 JAX出现以来,它的受欢迎程度一直在稳步增长。DeepMind 202年宣布使用 JAX 来加速自己的相关研究,越来越多来自Google 大脑与其他项目也在使用 JAX。随着JAX越来越火,似乎 JAX 是下一个大型深度学习框架?虽然JAX并不是一个神经网络框架,但是随着JAX的发展,很多深度学习相关的研究也可以使用JAX来实现,本来tensorflow与pytorch 2个主流框架已经争的热火朝天,现在Google又加了一把火,让JAX进军深度学习。

上期文章我们也分享了JAX 与numpy 的速度对比,相比没有JAX加速的numpy,其速度远远落后于JAX,本期我们就使用JAX训练第一个机器学习模型。

使用JAX训练第一个机器学习模型

在使用JAX之前,我们需要安装JAX,好在JAX可以使用pip进行安装,但是JAX目前无法在Windows平台使用,小伙伴们可以使用Linux虚拟机进行体验。

pip install jaxpip install autogradpip install numpypip install jaxlib

首先我们需要安装上JAX等相关的第三方库,并import相关的第三方库。

import numpy as npimport jax.random as randomimport jaxfrom jax import numpy as jnpfrom jax import make_jaxprfrom jax import grad, jit, vmap, pmapimport matplotlib.pyplot as plt

然后我们建立一个y=ax+b的一个线性函数,其中参数a是直线的一个斜率,b是直线在Y轴方向的移动参数,并使用random随机函数生成一个随机的X数据,这样我们就得到了一个完成的y=ax+b线性函数,我们可以使用matplotlib来显示此函数的曲线。​​​​​​​

key = random.PRNGKey(56)x = random.normal(key, shape=(128, 1))a = 3.0b = 5.0ys = (a*xs) + b
plt.scatter(xs, ys)plt.xlabel("xs")plt.ylabel("ys")plt.title("Linear F(x)")plt.show()

运行以上代码后,我们就得到了一个y=ax+b的线性函数。

有了以上的线性函数,我们就搭建一个线性模型,使用机器学习的方式,来预测此条直线。​​​​​​​

def linear(theta, x):    weight, bias = theta    pred = x * weight + bias    return pred

然后我们再定义一个线性函数,此函数也是同样有2个参数,一个weight(权重),一个bias(偏差),训练的目的是找到一个合适的weight与bias参数,以便来预测上面的线性函数。当然,我们还需要建立一个loss函数,以便后期进行训练时,让loss逐渐减小。这里使用均方差作为损失函数来计算预测值与真实值的损失。

def p_loss(theta, x, y):    pred = linear(theta, x)    loss = jnp.mean((y - pred)**2)    return loss@jitdef update_step(theta, x, y, lr):    loss, gradient = jax.value_and_grad(p_loss)(theta, x, y)    updated_theta = theta - lr * gradient    return updated_theta, loss

然后使用jax.value_and_grad函数来更新loss,lr参数是神经网络的学习效率,这里我们可以随机一个比较小的值即可。有了以上的函数,我们就可以进行一个机器学习的模型训练了。

weight = 0.0bias = 0.0theta = jnp.array([weight, bias])epochs = 20000for item in range(epochs):    theta, loss_p = update_step(theta, xs, ys, 1e-4)    if item % 1000 == 0 and item != 0:        print(f"item {item} | loss {loss_p:.4f}")

我们初始化weight与bias参数,使用for循环来训练神经网络,使loss越来越来越小,这里我们每隔1000步来打印一下loss参数。​​​​​​​

item 1000 | loss 23.4526item 2000 | loss 15.4000item 3000 | loss 10.1152item 4000 | loss 6.6459item 5000 | loss 4.3678item 6000 | loss 2.8714item 7000 | loss 1.8883item 8000 | loss 1.2422item 9000 | loss 0.8174item 10000 | loss 0.5380item 11000 | loss 0.3543item 12000 | loss 0.2333item 13000 | loss 0.1538item 14000 | loss 0.1013item 15000 | loss 0.0668item 16000 | loss 0.0441item 17000 | loss 0.0291item 18000 | loss 0.0192item 19000 | loss 0.0127
从以上loss参数,我们可以看到,其模型的loss逐渐缩小,说明我们的设计的线性机器学习模型是有效的。我们也可以打印一下训练20000步后的模型输出函数。
plt.scatter(xs, ys, label="true")plt.scatter(xs, linear(theta, xs), label="pred")plt.legend()plt.show()

可以看到,其模型随着训练,其loss逐渐减小,当训练20000步后,其预测的y=ax+b函数与输入的初始函数值几乎重合,当然你也可以增加训练步骤,让loss再次缩小。

JAX虽然目前不被称之为一个神经网络的模型框架,但是随着pytorch,paddlepaddle以及mindSpore相关框架的加入,加剧神经网络框架之争,说不定Google会把JAX发展成下一代神经网络框架也不一定。

ChatGPT的大火,
带动了人工智能学习的热潮,
小编建立了一个AI学习圈,
分享相关人工智能技术,
大家一起学习。
https://wx2.expostar.cn/qz/pages/manor/index?id=1137&share_from_id=79482&sid=24
更多transformer模型
VIT模型
swin transformer模型
参考头条号:人工智能研究所 

daa5d4fdd65f42458ed5aa938c2f5dfb.gif 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/475409.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue 视频播放插件vue-video-player自定义样式

1、背景 项目中有涉及视频播放的需求,并且UI设计了样式,与原生的视频video组件有差异,所以使用了vue-video-player插件,并对vue-video-player进行样式改造,自定义播放暂停按钮、全屏按钮、时间进度条样式等 2、效果图…

10分钟叫你如何学会组织Prompt语言同AI沟通

提示词(Prompt)是与AI模型交流的语言,用以告诉AI模型想要生成的图像的特征。提示词的准确性、精准度直接决定了生成的图像是否符合我们的预期。 基础介绍 AIGC提示词通常由多个单词、词组或短句构成,以***,***分割组成&#xff…

如何更改Windows服务器时间

Windows操作系统自带时间同步功能,它会自动从互联网时间服务器获取时间,以保证系统时间的准确性。但是,有时候我们需要更改时间服务器,以获得更准确的时间同步。小编将为大家介绍如何更改Windows时间服务器,以及Window…

java基础知识——22.lambda表达式

这篇文章,我们来讲一下java的lambda表达式 目录 1.初识lambda表达式 2.lambda表达式介绍 2.1 函数式编程 2.2 lambda表达式的具体格式 2.3 Lambda表达式的好处 2.4 Lambda的省略写法 1.初识lambda表达式 首先,我们来看一下lambda表达式的应用 下…

运维——ssh无法登录云服务器

0x00 概述 一般来讲,无法登录ssh的原因挺多,如果无法登录云服务器,则除了要检查ssh端口是否放行,防火墙状态外,还需要检查云服务器web控制台入站规则是否开放了对应端口。如果你前面检查都是正常,那么还需…

实战打靶集锦-017-potato

提示:本文记录了博主的一次打靶过程 目录 1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查4.1 Apache探查4.2 ProFTPD探查4.2.1 strcmp()函数绕过4.2.2 查找apache日志文件4.2.3 查看/etc/passwd文件4.2.4 破译密码4.2.5 突破边界 5. 提权5.1 系统信息枚举5.2 定时任…

基于Yolov5的NEU-DET钢材表面缺陷检测,优化组合新颖程度较高:CVPR2023 DCNV3和InceptionNeXt,涨点明显

1.钢铁缺陷数据集介绍 NEU-DET钢材表面缺陷共有六大类,分别为:crazing,inclusion,patches,pitted_surface,rolled-in_scale,scratches 每个类别分布为: 训练结果如下: 2.基于yolov5s的训练 map值: 2.1 Inception-MetaNeXtStage 对应博客:https://cv2023.blog.csdn.n…

实验5 彩色图像处理与图像变换

文章目录 一、实验目的二、实验内容1. 彩色图像平滑。(课本P310 例6.12)2. 彩色边缘检测。(课本P318 例6.16)3. 一维小波变换。(课本P364 例7.20)4. 二维小波变换。(课本P369 例7.22)5. 小波包分解。(课本P376 例7.24) 一、实验目的 掌握RGB彩色模型和HSI彩色模型之间的转换方…

C语言指针的使用

文章目录 前言一、指针基本概念介绍二、指针的大小三、使用指针访问变量和变量地址四、使用指针遍历数组总结 前言 一、指针基本概念介绍 在 C 语言中,指针是一种用于存储内存地址的数据类型。指针可以存储任何数据类型的内存地址,包括基本数据类型、数…

C语言之单链表的实现以及链表的介绍

一、为什么会存在链表 因为我们常用的顺序表会存在以下的一些问题: 1. 中间/头部的插入删除,时间复杂度为O(N) 2. 增容需要申请新空间,拷贝数据,释放旧空间。会有不小的消耗。 3. 增容一般是呈2倍的增长,势必会有一定…

算法的特性和空间复杂度---数据结构

目录 前言: 1.算法 1.1算法的特性 1.2设计算法 2.空间复杂度 3.学习复杂度的意义 ❤博主CSDN:啊苏要学习 ▶专栏分类:数据结构◀ 学习数据结构是一件有趣的事情,希望读者能在我的博文切实感受到数据之间存在的关系&#xff…

【3dmax】常用的快捷键总结以及如何修改快捷键

💗 未来的游戏开发程序媛,现在的努力学习菜鸡 💦本专栏是我关于建模的笔记 🈶本篇是3dmax常用的快捷键总结以及如何修改快捷键 3dmax常用的快捷键总结以及如何修改快捷键 3dmax常用快捷键如何添加或修改快捷键 3dmax常用快捷键 视…

go pprof性能调优工具

go pprof 一、性能调优原则二、pprof1、pprof 功能简介2、pprof 排查实战前置工作a、CPUb、Heapc、goroutined、mutexe、block 3、pprof 的采样过程和原理a、cpub、heapc、goroutine && threadCreated、block && mutex 三、调优流程1、业务优化a、流程 2、基础…

2023.4.17-4.23 AI行业周刊(第146期):创业要趁早

最近有很多外部拓展培训的需求,联盟的共学课程培训,公司视觉软件的培训,行业课程的培训,每一项培训听起来简单,但是其实都需要大量的时间精力。 前两年也准备过一份《30天入门人工智能》的视频课程,总共31…

Ansible自动化部署工具|各个模块的使用

Ansible自动化部署工具|各个模块的使用 一、自动化运维工具—Ansible二、安装Ansible查询webserver组中主机的日期 三 Ansible常用模块(1) ansible命令行模块(2) command模块(3) shell模块(4) cron模块(5) user模块(6) grup模块(7) copy模块(8) file模块(9) ping模块(10) servi…

内网穿透NPS和宝塔Nginx配合使用,开启SSL访问本地局域网网络

并非为了教学,仅供自己记录,方便下次用。所以内容不会刻意花时间写的很细节详细。 1. 服务器NPS配置 NPS install安装后,配置文件会在其他位置,通过是 /etc/nps/nps.conf目录。 找到进行修改,主要修改的是http_proxy_p…

【flask】三种路由和各自的比较配置文件所有的字母必须大写if __name__的作用核心对象循环引用的几种解决方式--难Flask的经典错误

三种路由 方法1:装饰器 python C#, java 都可以用这种方式 from flask import Flask app Flask(__name__)app.route(/hello) def hello():return Hello world!app.run(debugTrue)方法2: 注册路由 php python from flask import Flask app Flask(__name__)//app…

【以太坊 Solidity】管理员读写权限/访问控制/角色控制

摘要 在 Solidity 语言的多继承中,若多个合约共同继承一个父合约,则这多个合约 共享 父合约中的变量和函数。 1.测试的智能合约 合约继承路线如下: #mermaid-svg-DtimeTjOch5CJh50 {font-family:"trebuchet ms",verdana,arial,s…

应用,auto,内联函数

6.引用&#xff1a; //指针 int main() {int a 0;int& b a;int& c b;int& d c;cout << &a << endl;cout << &b << endl;cout << &c << endl;cout << &d << endl;b;d;cout << a <<…

WEB攻防通用漏洞跨域CORS资源JSONP回调域名接管劫持

目录 一、同源策略&#xff08;SOC&#xff09; 二、跨域资源&#xff08;COSP&#xff09; 三、回调跨域&#xff08;JSOP&#xff09; 四、CORS资源跨域-敏感页面原码获取 五、JSONP 回调跨域-某牙个人信息泄露 六、子域名劫持接管 一、同源策略&#xff08;SOC&#x…