2023年的深度学习入门指南(16) - JAX和TPU加速

news2025/1/23 21:10:28

2023年的深度学习入门指南(16) - JAX和TPU加速

上一节我们介绍了ChatGPT的核心算法之一的人类指示的强化学习的原理。我知道大家都没看懂,因为需要的知识储备有点多。不过没关系,大模型也不是一天能够训练出来的,也不可能一天就对齐。我们有充足的时间把基础知识先打好。

上一节的强化学习部分没有展开讲的原因是担心大家对于数学知识都快忘光了,而且数学课上学的东西也没有学习编程。这一节我们来引入两个基础工具,一个可以说是各个Python深度学习框架必然绕不过去的NumPy库,另一个是Google开发的可以认为是GPU和TPU版的NumPy库JAX。

学习这两个框架的目的还是补数学课,尤其是数学编程,这次也是TPU首次登场我们的教程部分。当然,也是可以用GPU的。

矩阵

NumPy最为核心的功能就是多维矩阵的支持。

我们可以通过pip install numpy的方式来安装NumPy,然后在Python中通过import numpy as np的方式来引入NumPy库。
但是,NumPy不能支持GPU和TPU加速,对于我们将来要处理的计算来说不太实用,所以我们这里引入JAX库。

JAX的安装文档请见JAX官方文档

之前我们多次使用CUDA来进行GPU加速了,这里我们不妨来看看TPU的加速效果。
TPU只有Google一家有,我们只能买到TPU的云服务,不过,我们可以使用Google Colab来使用TPU。

在Colab上,已经安装好了JAX和TPU的运行时。我们运行下面的代码即可激活TPU:

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

我们来看看有多少个TPU设备可以使用:

print(jax.device_count())
print(jax.local_device_count())
print(jax.devices())

输出结果如下:

8
8
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

说明我们有8个TPU设备可以用。

下面我们就用jax.numpy来代替numpy来使用。

NumPy最重要的功能就是多维矩阵的支持。我们可以通过np.array来创建一个多维矩阵。

我们先从一维的向量开始:

import jax.numpy as jnp
a1 = jnp.array([1,2,3])
print(a1)

然后我们可以使用二维数组来创建一个矩阵:

a2 = jnp.array([[1,2],[0,4]])
print(a2)

矩阵可以统一赋初值。zeros函数可以创建一个全0的矩阵,ones函数可以创建一个全1的矩阵,full函数可以创建一个全是某个值的矩阵。

比如给10行10列的矩阵全赋0值,我们可以这样写:

a3 = jnp.zeros((10,10))
print(a3)

全1的矩阵:

a4 = jnp.ones((10,10))

全赋100的:

a5 = jnp.full((10,10),100)

另外,我们还可以通过linspace函数生成一个序列。linpsace函数的第一个参数是序列的起始值,第二个参数是序列的结束值,第三个参数是序列的长度。比如我们可以生成一个从1到100的序列,长度为100:

a7 = jnp.linspace(1,100,100) # 从1到100,生成100个数
a7.reshape(10,10)
print(a7)

最后,JAX给矩阵生成随机值的方式跟NumPy并不一样,并没有jnp.random这样的包。我们可以使用jax.random来生成随机值。JAX的随机数生成函数都需要一个显式的随机状态作为第一个参数,这个状态由两个无符号32位整数组成,称为一个key。用一个key不会修改它,所以重复使用同一个key会得到相同的结果。如果需要新的随机数,可以使用jax.random.split()来生成新的子key。

from jax import random
key = random.PRNGKey(0) # a random key
key, subkey = random.split(key) # split a key into two subkeys
a8 = random.uniform(subkey,shape=(10,10)) # a random number using subkey
print(a8)

范数

范数(Norm)是一个数学概念,用于测量向量空间中向量的“大小”。范数需要满足以下性质:

  • 非负性:所有向量的范数都大于或等于零,只有零向量的范数为零。
  • 齐次性:对任意实数λ和任意向量v,有||λv|| = |λ| ||v||。
  • 三角不等式:对任意向量u和v,有||u + v|| ≤ ||u|| + ||v||。
    在实际应用中,范数通常用于衡量向量或矩阵的大小,比如在机器学习中,范数常用于正则化项的计算。

常见的范数有:

  • L0范数:向量中非零元素的个数。
  • L1范数:向量中各个元素绝对值之和,也被称为曼哈顿距离。
  • L2范数:向量中各个元素的平方和然后开方,也被称为欧几里得距离。
  • 无穷范数:向量中各个元素绝对值的最大值。
    需要注意的是,L0范数并不是严格意义上的范数,因为它违反了齐次性。但是在机器学习中,L0范数常用于衡量向量中非零元素的个数,因此也被称为“伪范数”。

我们先从计算一个一维向量的L1范数开始,不要L1范数这个名字给吓到,其实就是绝对值之和:

norm10_1 = jnp.linalg.norm(a10,ord=1)
print(norm10_1)

结果不出所料就是6.

下面我们再看L2范数,也就是欧几里得距离,也就是平方和开方:

a10 = jnp.array([1, 2, 3])
norm10 = jnp.linalg.norm(a10)
print(norm10)

根据L2范数的定义,我们可以手动计算一下:norm10 = jnp.sort(1 + 22 + 33) = 3.7416573.

我们可以看到,上面的norm10的值跟我们手动计算的是一样的。

下面我们来计算无穷范数,其实就是最大值:

norm10_inf = jnp.linalg.norm(a10, ord = jnp.inf)
print(norm10_inf)

结果为3.

我们来算一个大点的巩固一下:

a10 = jnp.linspace(1,100,100) # 从1到100,生成100个数
n10 = jnp.linalg.norm(a10,ord=2)
print(n10)

这个结果为581.67865.

逆矩阵

对角线是1,其它全是0的方阵,我们称为单位矩阵。在NumPy和JAX中,我们用eye函数来生成单位矩阵。

既然是方阵,就不用跟行和列两个值了,只需要一个值就可以了,这个值就是矩阵的行数和列数。用这一个值赋给eye函数的第一个参数,就可以生成一个单位矩阵。

下面我们来复习一下矩阵乘法是如何计算的。

对于矩阵A的每一行,我们需要与矩阵B的每一列相乘。这里的“相乘”意味着取A的一行和B的一列,然后将它们的对应元素相乘,然后将这些乘积相加。这个和就是结果矩阵中相应位置的元素。

举个例子,假设我们有两个2x2的矩阵A和B:

A = 1 2     B = 4 5
    3 4         6 7

我们可以这样计算矩阵A和矩阵B的乘积:

(1*4 + 2*6) (1*5 + 2*7)     16 19
(3*4 + 4*6) (3*5 + 4*7) =  34 43

我们用JAX来计算一下:

ma1 = jnp.array([[1,2],[3,4]])
ma2 = jnp.array([[4,5],[6,7]])
ma3 = jnp.dot(ma1,ma2)
print(ma1)
print(ma2)
print(ma3)

输出为:

[[1 2]
 [3 4]]
[[4 5]
 [6 7]]
[[16 19]
 [36 43]]

如果A*B=I,I为单位矩阵,那么我们称B是A的逆矩阵。

我们可以用inv函数来计算矩阵的逆矩阵。

ma1 = jnp.array([[1,2],[3,4]])
inv1 = jnp.linalg.inv(ma1)
print(inv1)

输出的结果为:

[[-2.0000002   1.0000001 ]
 [ 1.5000001  -0.50000006]]

导数与梯度

导数是一个函数在某点处的变化率,用于描述函数在该点处的变化率。导数可以表示函数在该点处的斜率,即函数在该点处的陡峭程度。

梯度(gradient)是一个向量,表示函数在该点处的方向导数沿着该方向取得最大值。梯度可以表示函数在该点处的变化最快和变化率最大的方向。在单变量的实值函数中,梯度可以简单理解为导数。

JAX作为支持深度学习的框架,对于梯度的支持是被优先考虑的。我们可以使用jax.grad函数来计算梯度。针对一个一元函数,梯度就是导数。我们可以用下面的代码来计算sin函数在x=1.0处的梯度:

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sin(x)

# 计算 f 在 x=1.0 处的梯度
grad_f = jax.grad(f)
print(grad_f(1.0))

我们如果每次都沿着梯度方向前进,那么我们就可以找到函数的极值。这种采用梯度方向前进的方法,就是梯度下降法。梯度下降法是一种常用的优化算法,它的核心思想是:如果一个函数在某点的梯度值为正,那么函数在该点沿着梯度方向下降的速度最快;如果一个函数在某点的梯度值为负,那么函数在该点沿着梯度方向上升的速度最快。因此,我们可以通过不断地沿着梯度方向前进,来找到函数的极值。

那么,梯度下降法有什么作用呢?我们可以用梯度下降法来求解函数的最小值。我们可以用下面的代码来求解函数 f ( x ) = x 2 f(x)=x^2 f(x)=x2的最小值:

import jax
import jax.numpy as jnp

def f(x):
    return x ** 2

grad_f = jax.grad(f)

x = 2.0  # 初始点
learning_rate = 0.1  # 学习率
num_steps = 100  # 迭代步数

for i in range(num_steps):
    grad = grad_f(x)  # 计算梯度
    x = x - learning_rate * grad  # 按负梯度方向更新 x

print(x)  # 打印最终的 x 值,应接近 0(函数的最小值)

我这一次运行的结果是4.0740736e-10. 也就是说,我们用梯度下降法求解函数 f ( x ) = x 2 f(x)=x^2 f(x)=x2的最小值,最终得到的x值接近于0,也就是函数的最小值。

其中,学习率(或称为步长)是一个正数,用于控制每一步更新的幅度。学习率需要仔细选择,过大可能导致算法不收敛,过小可能导致收敛速度过慢。

概率

唤醒完线性代数和高等数学的一些记忆之后,最后我们来回顾一下概率论。

我们还是从扔硬币说起。我们知道,假设一枚硬币是均匀的,那么只要扔的次数足够多,正面朝上的次数就会接近于总次数的一半。

这种只有两种可能结果的随机试验,我们给它起个高大上的名字叫做伯努利试验(Bernoulli trial)。

下面我们就用JAX的伯努利分布来模拟一下扔硬币的过程。

import jax
import time
from jax import random

# 生成一个形状为 (10, ) 的随机矩阵,元素取值为 0 或 1,概率为 0.5
key = random.PRNGKey(int(time.time()))
rand_matrix = random.bernoulli(key, p=0.5, shape=(10, ))
print(rand_matrix)
mean_x = jnp.mean(rand_matrix)
print(mean_x)

mean函数用来求平均值,也叫做数学期望。

打印的结果可能是0.5,也可能是0.3,0.8等等。这是因为我们只扔了10次硬币,扔的次数太少了,所以正面朝上的次数不一定接近于总次数的一半。

这是其中一次0.6的结果:

[ True  True  True  True False False  True False False  True]
0.6

多跑几次,出现0.1,0.9都不稀奇:

[False False False False False False False False False  True]
0.1

当我们把shape改成100,1000,10000等更大的数之后,这个结果就离0.5越来越近了。

下面再复习下表示偏差的两个值:

  • 方差(Variance):方差是度量数据点偏离平均值的程度的一种方式。换句话说,它描述了数据点与平均值之间的平均距离的平方。
  • 标准差(Standard Deviation):标准差是方差的平方根。因为方差是在平均偏差的基础上平方得到的,所以它的量纲(单位)与原数据不同。为了解决这个问题,我们引入了标准差的概念。标准差与原数据有相同的量纲,更便于解释。

这两个统计量都反映了数据分布的离散程度。方差和标准差越大,数据点就越分散;反之,方差和标准差越小,数据点就越集中。

我们可以用JAX的var函数来计算方差,用std函数来计算标准差。

import jax
import time
from jax import random

# 生成一个形状为 (1000, ) 的随机矩阵,元素取值为 0 或 1,概率为 0.5
key = random.PRNGKey(int(time.time()))
rand_matrix = random.bernoulli(key, p=0.5, shape=(1000, ))
#print(rand_matrix)
mean_x = jnp.mean(rand_matrix)
print(mean_x)
var_x = jnp.var(rand_matrix)
print(var_x)
std_x = jnp.std(rand_matrix)
print(std_x)

最后我们来复习一下之前讲到的信息量。我们来思考一个问题,如何能让伯努利分布的平均信息量最大?

我们先构造两个特殊情况,比如如果p=0,那么我们就永远不会得到正面朝上的结果,这个时候我们就知道了结果,信息量为0。如果p=1,那么我们就永远不会得到反面朝上的结果,这个时候我们也知道了结果,信息量也为0。

如果p=0.01,能给我们带来的平均信息量仍然不大,因为基本上我们可以盲猜结果是背面朝上的,偶然出现的正面朝上的结果,虽然带来了较大的单次信息量,但是出现的概率太低了,所以平均信息量仍然不大。

而如果p=0.5,我们就完全猜不到结果是正面朝上还是背面朝上,这个时候我们得到的平均信息量最大。

当然,这只是定性的分析,我们还需要给出一个定量的公式:

H ( X ) = − ∑ x ∈ X p ( x ) log ⁡ 2 p ( x ) H(X) = - \sum_{x \in X} p(x) \log_2 p(x) H(X)=xXp(x)log2p(x)

import jax.numpy as jnp

# 计算离散型随机变量 X 的平均信息量
def avg_information(p):
    p = jnp.maximum(p, 1e-10)
    return jnp.negative(jnp.sum(jnp.multiply(p, jnp.log2(p))))

# 计算随机变量 X 取值为 0 和 1 的概率分别为 0.3 和 0.7 时的平均信息量
p = jnp.array([0.3, 0.7])
avg_info = avg_information(p)
print(avg_info)

我们试几次计算可以得到,当p为0.3时,平均信息量是0.8812325;当p为0.01时,平均信息量为0.08079329;当p为0.5时,平均信息量为1.0,达到最大。

如果嫌使用Python函数的计算慢,我们可以调用JAX的jit函数来加速。我们只需要在函数定义的前面加上@jit即可。

import jax.numpy as jnp
from jax import jit

# 计算离散型随机变量 X 的平均信息量
@jit
def avg_information(p):
    p = jnp.maximum(p, 1e-10)
    return jnp.negative(jnp.sum(jnp.multiply(p, jnp.log2(p))))

# 计算随机变量 X 取值为 0 和 1 的概率分别为 0.3 和 0.7 时的平均信息量
p = jnp.array([0.01, 0.99])
avg_info = avg_information(p)
print(avg_info)

小结

上面我们选取了一些线性代数,高等数学和概率论的知识点,来唤醒大家的记忆。同时,我们也介绍了它们在JAX上的实现和加速。
虽然我们的例子都不起眼,但它们是确确实实在TPU上跑起来的。

大模型虽然提供了很强的能力,但是我们仍然要花充足的时间在基本功上。硬件和框架都在日新月部分,但是数学基础知识的进化是非常缓慢的,投入产出比很高。有了扎实的基本功之后,框架和新硬件就是边用边学就可以了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/624589.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

葡萄酒质量预测

本文中所有代码及数据均存放于:https://github.com/MADMAX110/WineQualityPrediction 本文根据酸度、残糖和酒精浓度等特征训练和调整一个随机的葡萄酒质量森林模型。 一、设置环境,确认你的电脑安装了以下环境 Python 3NumPyPandasScikit-Learn (a.k.a…

Ubuntu 18.04 交叉编译Opencv-4.6.0

环境 操作系统:Ubuntu 18.04 OpenCv版本:4.6.0 交叉工具链:arm-linux-gnueabihf-gcc-5.3.1 下载OpenCV源代码 这里推荐大家到网上找OpenCV的Linux版本安装包(.tar.gz结尾),不要github上clone&#xff08…

leetcode688. 骑士在棋盘上的概率(java)

骑士在棋盘上的概率 leetcode688. 骑士在棋盘上的概率题目描述 解题思路代码演示动态规划专题 leetcode688. 骑士在棋盘上的概率 来源:力扣(LeetCode) 链接:https://leetcode.cn/problems/knight-probability-in-chessboard 题目描…

【源码篇】基于ssm+vue+微信小程序的医疗科普小程序

系统介绍 这是一个ssmvue微信小程序的医疗科普小程序,分为pc端和微信小程序端 pc端包括:管理员角色和学生角色。 管理员拥有:学生管理、科普知识管理、论坛管理、收藏管理、试卷管理、留言板管理、试题管理、系统管理、考试管理 学生端拥…

AI实战营第二期 第五节 《目标检测与MMDetection》——笔记6

文章目录 摘要主要特性 常用概念框、边界框交并比 (loU)感受野有效感受野置信度 目标检测的基本思路难点滑框在特征图进行密集计算边界框回归基于锚框VS无锚框NMS(非极大值抑制)使周密集预测模型进行推理步骤如何训练密集预测模型的训练匹配的基本思路密…

C++ 教程(01)

C 教程 C 是一种高级语言,它是由 Bjarne Stroustrup 于 1979 年在贝尔实验室开始设计开发的。C 进一步扩充和完善了 C 语言,是一种面向对象的程序设计语言。C 可运行于多种平台上,如 Windows、MAC 操作系统以及 UNIX 的各种版本。 本教程通过…

节省90%编译时间,这是字节跳动开源的基于Rust的前端构建工具

Rspack 是一个基于 Rust 的高性能构建引擎,它可以与 Webpack 生态系统交互,并提供更好的构建性能。 在处理具有复杂构建配置的巨石应用时,Rspack 可以提供 5~10 倍的编译性能提升。 字节跳动将 Rspack 开源后,它在 GitHub 上已有 …

Bert+FGSM/PGD实现中文文本分类(Loss=0.5L1+0.5L2)

任务目标:在使用FGSM/PGD来训练Bert模型进行文本分类,其实现原理可以简单概括为以下几个步骤: 对原始文本每个词转换为对应的嵌入向量。将每个嵌入向量与一个小的扰动向量相加,从而生成对抗样本。这个扰动向量的大小可以通过一个超…

2023年牛客网最新版大厂Java八股文面试题总结(覆盖所有面试题考点)

程序员真的是需要将终生学习贯彻到底的职业,一旦停止学习,离被淘汰,也就不远了。 金九银十跳槽季,这是一个千年不变的话题,每到这个时候,很多人都会临阵磨枪,相信不快也光。于是,大…

多业务线下,IT企业如何应对市场经济下行危机?

多业务线下,IT企业如何应对市场经济下行危机? 市场经济下行就像是一辆行驶的车子遇到了坡道,速度开始变慢甚至停下来。在这个情况下,经济的增长变得较为缓慢,消费减少,投资减少,也对企业会带来…

运筹说 第25期 | 对偶理论经典例题讲解

对偶理论是研究线性规划中原始问题与对偶问题之间关系的理论,主要研究经济学中的相互确定关系,涉及到经济学的诸多方面。产出与成本的对偶、效用与支出的对偶,是经济学中典型的对偶关系。 对偶理论中最有力的武器是影子价格,影子…

【MySQL】主从复制部署

文章目录 概述SQL数据库的三大范式 主从复制技术产生原因主从形式原理图主节点 binary log dump 线程从节点I/O线程作用从节点SQL线程作用 复制过程复制模式异步模式(mysql async-mode)半同步模式(mysql semi-sync)全同步模式 复制机制binlog记录模式GTI…

android frida检测绕过

Frida检测是一种常见的安卓逆向技术,常用于防止应用程序被反向工程。如果您遇到了Frida检测,您可以尝试以下方法来绕过它: 使用Magisk Hide模块:Magisk是一个强大的安卓root工具,它附带了一个Magisk Hide模块&#xff…

二阳大规模来袭,热图地图分析新冠疫情期间的高发地点,掌握防控重点!

一、概述 最近,新冠疫情似乎又要“卷土重来”... 身边逐渐有人传来“二阳”或者“三羊”的消息,网上相关的讨论和报道也变得越来越多。 据「钟南山院士」在大湾区科学论坛上的发言,预测模型seirs显示,第二波新冠疫情已于4月中旬开…

当数据汇聚成海,Excel 表成为我们的航海图,如何在茫茫数据中找到目标?——Excel 表中某个范围内的单元格遍历思路

本篇博客会讲解力扣“2194. Excel 表中某个范围内的单元格”的解题思路,这是题目链接。 先来审题: 以下是输出示例和提示: 这道题的解题思路是:模拟,遍历每一列,某一列遍历完后遍历下一列。 下面我们需…

爆肝整理,性能测试-全链路压测与普通压测区别总结,进阶高级测试...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 抛出一个问题&…

Shell脚本:expect脚本免交互

Shell脚本:expect脚本免交互 expect脚本免交互 一、免交互基本概述:1.交互与免交互的区别:2.格式:3.通过read实现免交互:4.通过cat实现查看和重定向:5.变量替换: 二、expect安装:1.…

Docker Registry部署

之前执行 docker pull的命令都是从 docker hub上拉取的,是docker 公共仓库,如果在公司中使用docker,我们不可能把自己的镜像上传到公共仓库,这个时候就需要一个自己的仓库(私有仓库),在局域网之…

usb 驱动

usb 驱动 usb 的基本概念 这个忽略, 基本上usb 是啥都知道 usb 的拓扑结构 usb 是一种主从结构的系统 usb主机由usb主控之器(Host Controller)和根集线器(Root Hub) 构成 usb 主控制器: 主要负责数据处理(就是我…

微信怎么批量自动添加好友?

如何批量加客户资源到微信,怎么加微信好友,这个基本上熟悉的人都会知道。 实际上,你知道所有添加微信好友的方式吗?或者说,你知道如何批量加客户微信吗? 比如说在一定时间内,把你所有的客户资…