Jax(Random、Numpy)常用函数

news2025/1/22 20:55:36

目录

Jax

 vmap

Array

reshape

Random

PRNGKey

uniform

normal

split

 choice

Numpy

expand_dims

linspace

jax.numpy.linalg[pkg]

dot

matmul

arange

interp 

tile

reshape


Jax

jit

jax.jit(funin_shardings=UnspecifiedValueout_shardings=UnspecifiedValuestatic_argnums=Nonestatic_argnames=Nonedonate_argnums=Nonedonate_argnames=Nonekeep_unused=Falsedevice=Nonebackend=Noneinline=Falseabstracted_axes=None)[source]

注:jax.jit 是 JAX 中的一个装饰器,用于将 Python 函数编译为高效的机器代码,以提高运行速度。JIT(Just-In-Time)编译可以加速函数的执行,尤其是在循环或需要多次调用。

>>>jax.jit(lambda x,y : x + y)
<PjitFunction of <function <lambda> at 0x7ea7b402f130>>
>>>jax.jit(lambda x,y : x + y)(1,2) #process jitfunc -> lambda fun
Array(3, dtype=int32, weak_type=True)
>>>@jax.jit
   def fun(x,y):
        return x + y
>>>fun
<PjitFunction of <function fun at 0x7ea7b402f5b0>>
>>>fun(1,2)
Array(3, dtype=int32, weak_type=True)

 vmap

jax.vmap(funin_axes=0out_axes=0axis_name=Noneaxis_size=Nonespmd_axis_name=None)[source]

注:对函数进行向量化处理,通常用于批量处理数据,而不需要显式地编写循环,函数映射调用,区别于pmap,vmap单个设备(CPU或GPU)上处理批量数据,pmap在多个设备(GPU或TPU)上并行处理数据(分布式)

>>>f_xy = lambda x,y : x + y
>>>x = jax.numpy.array([[1, 2], 
                        [3, 4]])  # shape (2, 2)
>>>y = jax.numpy.array([[5, 6], 
                        [7, 8]])  # shape (2, 2)

# in this x and y array, axis 0 is row , axis 1 is col, ref shape index
# in x and y, axis -1 is shape[-1] , axis -2 is shape[-2]

>>>jax.vmap(f_xy,in_axes=(0,0))(x,y)      # default out_axes = 0,row ouput
# x row + y row , need x row dim equal y row dim
Array([[ 6,  8],
       [10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,0),out_axes=1)(x,y) #show output by col
Array([[ 6,  8],
       [10, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1))(x,y) 
# x row + y col , need x row's dim equal y col's dim
Array([[ 6,  9],
       [ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(0,1),out_axes=1)(x,y) #show output by col 
Array([[ 6,  9],
       [ 9, 12]], dtype=int32)
>>>jax.vmap(f_xy,in_axes=(None,0))(x,y) #no vector x by row or col, x is block
# x block + y row vector, x shape (2,2) , y shape(2,2), need x row equal y row
# return shape(y_dim_2,x_dim_1,x_dim2)
Array([[[ 6,  8],
        [ 8, 10]],
       [[ 8, 10],
        [10, 12]]], dtype=int32)

ref:Learning about JAX :axes in vmap()

Array

reshape

abstract Array.reshape(*argsorder='C')[source]

注:Array对象的实例方法,引用jax.numpy.reshape函数

Random

PRNGKey

jax.random.PRNGKey(seed*impl=None)[source]#

注:创建一个 PRNG key,作为生成随机数的种子Seed

eg:       

>>>jax.random.PRNGKey(0)
Array([0, 0], dtype=uint32)

uniform

jax.random.uniform(keyshape=()dtype=<class 'float'>minval=0.0maxval=1.0)[source]

注:在给定的形状(shape)和数据类型(dtype)下,从 [minval, maxval) 区间内采样均匀分布的随机值

>>>k = jax.random.PRNGKey(0)
>>>jax.random.uniform(k,shape=(1,))
Array([0.41845703], dtype=float32)

normal

normal(keyshape=()dtype=<class 'float'>)[source]

注:在给定的形状shape和浮点数据类型dtype下,采样标准正态分布的随机值

>>>k = jax.random.PRNGKey(0)
>>>jax.random.normal(k,shape=(1,))
Array([-0.20584226], dtype=float32)

split

jax.random.split(keynum=2)[source]

注:用于生成伪随机数生成器(PRNG)状态的函数。它允许你从一个现有的 PRNG 状态中生成多个新的状态,从而实现随机数的可重复性和并行性。 

>>>k = jax.random.PRNGKey(1)
>>>k1,k2 = jax.random.split(k)
>>>k1
Array([2441914641, 1384938218], dtype=uint32)
>>>k2
Array([3819641963, 2025898573], dtype=uint32)

 choice

jax.random.choice(keyashape=()replace=Truep=Noneaxis=0)[source]

注:从给定数组a中按shape生成随机样本,区别于numpy.random.choice函数。default choice one elem。

>>>k = jax.random.PRNGKey(0)
>>>a = jax.numpy.array([1,2,3,4,5,6,7,8,9,0])
>>>jax.random.choice(k,a,(10,)) # random no seq
Array([9, 6, 8, 7, 8, 4, 1, 2, 3, 3], dtype=int32)
>>>jax.random.choice(k,a,(2,5))
Array([[9, 6, 8, 7, 8],
       [4, 1, 2, 3, 3]], dtype=int32)

Numpy

expand_dims

expand_dims(aaxis)[source]

注:为数组a的维度axis增加1维度

>>>arr = jax.numpy.array([1,2,3])
>>>arr.shape
(3,)
>>>jax.numpy.expand_dims(arr,axis=0)
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=0).shape
(1, 3)
>>>jax.numpy.expand_dims(arr,axis=1)
Array([[1],
       [2],
       [3]], dtype=int32)
>>>jax.numpy.expand_dims(arr,axis=1).shape
(3, 1)

linspace

linspace(start: ArrayLikestop: ArrayLikenum: int = 50endpoint: bool = Trueretstep: Literal[False] = Falsedtype: DTypeLike | None = Noneaxis: int = 0*device: xc.Device | Sharding | None = None) → Array[source]

注:在给定区间[start,stop]内返回均匀间隔的数字

>>>jax.numpy.linspace(0,1,5)
Array([0.  , 0.25, 0.5 , 0.75, 1.  ], dtype=float32)

jax.numpy.linalg[pkg]

jax.numpy.linalg 是 JAX 库中用于线性代数操作的模块,对应numpy.linalg库实现

        jax.numpy.linalg.cholesky(a*upper=False)[source]

        注:计算一个正定矩阵A的 Cholesky 分解,得到满足A=L@L.T等式的下三角或上三角矩阵L,@为Python1.5定义的矩阵乘运算(jax.numpy.matmul),L.T为L转置矩阵L^{T}

>>> d = jax.numpy.array([[2. , 1.],
                         [1. , 2.]])
>>>jax.numpy.linalg.cholesky(d)
Array([[1.4142135 , 0.        ],
       [0.70710677, 1.2247449 ]], dtype=float32)

>>>L = jax.numpy.linalg.cholesky(d)
>>>L@L.T
Array([[1.9999999 , 0.99999994],
       [0.99999994, 2.        ]], dtype=float32)

dot

dot(ab*precision=Nonepreferred_element_type=None)[source]

注:用于计算两个数组的点积(dot product),对于一维数组,它计算的是向量的内积;对于二维数组(矩阵),它计算的是矩阵乘积;对于更高维度的数组,它执行的是逐元素的点积,并在最后一个轴上进行求和

  • 对于一维数组(向量)numpy.dot(a, b) 计算的是向量 a 和 b 的点积,结果是一个标量。
  • 对于二维数组(矩阵)numpy.dot(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相等。结果是一个新的矩阵。
  • 对于更高维度的数组numpy.dot() 可以进行更复杂的广播和求和运算,但通常用于计算张量积(tensor product)的某个维度上的和。
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),2)
Array([2, 4, 6], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2,3],
                                  [4,5,6]]),
                  jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.dot(jax.numpy.array([[1,2],
                                  [4,5]]),
                 jax.numpy.array([[1,2],
                                  [4,5]]))
Array([[ 9, 12],
       [24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.dot(a,b).shape
(1, 3, 1, 4) #matmul ret (1,3,4)

matmul

matmul(ab*precision=Nonepreferred_element_type=None)[source]#

注:于执行矩阵乘法,也称为 @ 运算符(在 Python 3.5+ 中引入),对于一维数组(向量),它计算的是内积(与 dot 相同);对于二维数组(矩阵),它计算的是矩阵乘积(与 dot 相同);对于更高维度的数组,它执行的是逐元素的矩阵乘法,并保留其他轴

  • 对于一维数组(向量)numpy.matmul(a, b) 通常不被定义为向量之间的运算,除非 a 是一个二维数组(表示多个向量)的单个行或列,并且 b 的形状与之兼容。
  • 对于二维数组(矩阵)numpy.matmul(A, B) 计算的是矩阵 A 和 B 的乘积,其中 A 的列数必须与 B 的行数相等。这与 numpy.dot() 对于二维数组的行为相同。
  • 对于更高维度的数组numpy.matmul() 遵循爱因斯坦求和约定(Einstein summation convention)的特定规则,允许在不同维度的数组之间执行矩阵乘法。这包括批处理矩阵乘法,其中每个批次独立地进行乘法运算。
>>>jax.numpy.matmul(jax.numpy.array([1,2,3]),jax.numpy.array([1,2,3]))
Array(14, dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2,3],
                                     [4,5,6]]),
                     jax.numpy.array([1,2,3]))
Array([14, 32], dtype=int32)
>>>jax.numpy.matmul(jax.numpy.array([[1,2],
                                     [4,5]]),
                    jax.numpy.array([[1,2],
                                     [4,5]]))
Array([[ 9, 12],
       [24, 33]], dtype=int32)
>>>a = jax.numpy.zeros((1,3,2))
>>>b = jax.numpy.zeros((1,2,4))
>>>jax.numpy.matmul(a,b).shape
(1, 3, 4) #dot ret (1,3,1,4)

arange

jax.numpy.arange(startstop=Nonestep=Nonedtype=None*device=None)[source]

注:default step 为1,在区间[start,stop)生成步长为1的数组,类似range函数

>>>jax.numpy.arange(0,10,1)
Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

interp 

interp(xxpfpleft=Noneright=Noneperiod=None)[source]

注:在xp点列表中线性插值x,线性插值满足y=y_{i}+\frac{y_{i+1}-y_{i}}{x_{i+1}-x_{i}}(x-x_{i}),x\epsilon [x_{i},x_{i+1}),xi和xi+1表示xp数组相邻两点,插值x位于两点区间之间,xp点对于y值为fp,线性插值为保持符合fp = fun(xp)两点区间斜率的增量

>>>xp = jax.numpy.arange(0,10,1)
>>>fp = jax.numpy.array(range(0,10,1)) * 2
>>>x = jax.numpy.array([1,2,3])
>>>jax.numpy.interp(x,xp,fp)
Array([2., 4., 6.], dtype=float32)

tile

jax.numpy.tile(Areps)[source]

注:将A数组按reps重复化生成新Array

a = jax.numpy.array([1,2,3])
>>>jax.numpy.tile(a,2)
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(2,))
Array([1, 2, 3, 1, 2, 3], dtype=int32)
>>>jax.numpy.tile(a,(1,1))
Array([[1, 2, 3]], dtype=int32)
>>>jax.numpy.tile(a,(2,1)) # repeat axis 0 (row) by 2, repeat axis 1 (col) by 1
Array([[1, 2, 3],
       [1, 2, 3]], dtype=int32)

reshape

jax.numpy.reshape(ashape=Noneorder='C'*newshape=Deprecatedcopy=None)[source]

注:从定义Array a的shape形状为shape元组(),支持-1,推断dim数值

>>>a = jax.numpy.array([[1, 2, 3],
                        [4, 5, 6]])
>>>jax.numpy.reshape(a,6) # equal reshape(a,(6,))
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,-1) # equal reshape(a,6)  -1 is inferred to be 3
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>>jax.numpy.reshape(a,(-1,2)) # equal reshape(a,(3,2)) , -1 is inferred to be 3
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)
>>>jax.numpy.reshape(a,(1,-1)) # not (n,) inferred to 2 d
Array([[1, 2, 3, 4, 5, 6]], dtype=int32)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2193034.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

国际象棋和大模型的内部世界 (2)

国际象棋和大模型的内部世界 &#xff08;2&#xff09; 最近一直在做大模型的一些实践和应用工作。最近看了一些agent的一些在大模型上的探索&#xff0c;包括基于大模型驱动的类似MUD类的游戏。 最近2篇论文都是基于国际象棋的&#xff0c;作者的思路基本上差不多&#xff0c…

SQL第12课——联结表

三点&#xff1a;什么是联结&#xff1f;为什么使用联结&#xff1f;如何编写使用联结的select语句 12.1 联结 SQL最强大的功能之一就是能在数据查询的执行中联结&#xff08;join)表。联结是利用SQL的select能执行的最重要的操作。 在使用联结前&#xff0c;需要了解关系表…

Vue组件库Element-ui

Vue组件库Element-ui Element是一套为开发者、设计师和产品经理准备的基于Vue2.0的桌面端组件库。Element - 网站快速成型工具 安装element-ui npm install element-ui # element-ui版本&#xff08;可以指定版本号引入ElementUI组件库&#xff0c;在main.js中添加内容得到&…

【动态规划-最长公共子序列(LCS)】力扣1035. 不相交的线

在两条独立的水平线上按给定的顺序写下 nums1 和 nums2 中的整数。 现在&#xff0c;可以绘制一些连接两个数字 nums1[i] 和 nums2[j] 的直线&#xff0c;这些直线需要同时满足&#xff1a; nums1[i] nums2[j] 且绘制的直线不与任何其他连线&#xff08;非水平线&#xff09…

graphql--快速了解graphql特点

graphql--快速了解graphql特点 1.它的作用2.demo示例2.1依赖引入2.2定义schema2.3定义GrapQL端点2.4运行测试2.5一些坑 今天浏览博客时看到graphQL,之前在招聘网站上第一次接触,以为是图数据查询语言, 简单了解后,发现对graphQL的介绍主要是用作API的查询语言,不仅限于图数据查…

dbeaver的使用

新增mysql连接 新增clickhouse 连接 新建编辑器 执行 结果&#xff0c;想看某条结果明细&#xff0c;选中某行安tab键 设置快捷键 窗口-》首选项-》用户界面-》键

ReentrantLock 实现原理

文章目录 ReentrantLock 基本使用可重入锁等待可中断设置超时时间公平锁条件变量 ReentrantLock 原理加锁流程解锁流程可重入锁原理可打断原理公平锁原理条件变量原理 ReentrantLock 基本使用 在Java中&#xff0c;synchronized 和 ReentrantLock 都是用于确保线程同步的锁&am…

JUPITER Benchmark Suite:是一套全面的23个基准测试程序,目的支持JUPITER——欧洲首台E级超级计算机的采购

2024-08-30&#xff0c;由于利希超级计算中心 创建JUPITER Benchmark Suite&#xff0c;这是一个全面的 23 个基准测试程序集合&#xff0c;经过精心记录和设计&#xff0c;目的支持购买欧洲第一台百万兆次级超级计算机 JUPITER。 一、研究背景&#xff1a; 随着E级超级计算机…

AI大模型有哪些,收藏起来防踩坑

大模型是指具有数千万甚至数亿参数的深度学习模型&#xff0c;通常由深度神经网络构建而成&#xff0c;拥有数十亿甚至数千亿个参数。大模型的设计目的是为了提高模型的表达能力和预测性能&#xff0c;能够处理更加复杂的任务和数据。以下是对大模型的详细数据与分析&#xff1…

在网页中渲染LaTex公式

概述 MathJax可以实现网页浏览器中的LaTex公式渲染。 引入 可以使用特定的模板形式引入和配置&#xff0c;具体可参考&#xff1a;配置mathjax — MathJax 3.2 文档 (osgeo.cn)。其中代码可以以CDN形式引入&#xff1a;mathjax (v3.2.2) -BootCDN。 <script> MathJax …

【C++驾轻就熟】vector深入了解及模拟实现

​ 目录 ​编辑​ 一、vector介绍 二、标准库中的vector 2.1vector常见的构造函数 2.1.1无参构造函数 2.1.2 有参构造函数&#xff08;构造并初始化n个val&#xff09; 2.1.3有参构造函数&#xff08;使用迭代器进行初始化构造&#xff09; 2.2 vector iterator 的使…

集全CNS!西北农林发表建校以来第一篇Nature

9月25日&#xff0c;国际学术期刊《自然》&#xff08;Nature&#xff09;在线发表了西北农林科技大学青年科学家岳超研究员领衔的团队题为《极端森林大火放大火后地表升温》的研究成果。该研究首次从林火规模这一独特视角&#xff0c;揭示了极端大火对生态系统破坏性、林火碳排…

受电端取电快充协议芯片的工作原理

随着电池技术的不断进步&#xff0c;快充技术应运而生&#xff0c;它以惊人的速度解决了“电量焦虑”成为手机技术发展的重要程碑。 快充技术&#xff0c;通过提高充电功率&#xff0c;大幅度缩短手机等设备充电时间的技术。相对于传统的慢充方式&#xff0c;快充技术能够在短…

ASP.NET MVC 下拉框的传值-foreach循环

数据表&#xff1a; -- 创建包裹分类表 CREATE TABLE PackageCategories (CategoryID INT PRIMARY KEY IDENTITY(1,1), -- 分类ID&#xff1a;整数类型&#xff0c;主键&#xff0c;自增&#xff0c;包裹分类的唯一标识CategoryName NVARCHAR(255) NOT NULL -- 包裹分类名称&a…

从零开始讲PCIe(11)——数据链路层介绍

一、概述 数据链路层这一层的逻辑是用来负责链路管理的&#xff0c;它主要表现为 3 个功能TLP 错误纠正、流量控制以及一些链路电源管理。它是通过如图 2-24 所示的DLLP&#xff08;Data Link Layer Packet&#xff09;来完成这些功能的。 二、DLLPs 数据链路层包&#xff08;D…

基于Springboot+Vue的在线问诊系统的设计与实现(含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统中…

PDFToMarkdown

pdf转markdown 安装Tesseract-OCR项目拉取pytorch安装开始转换转换单个文件转换多个文件总结github开源PDF转markdown git clone https://github.com/VikParuchuri/marker.git 注意该项目有些包的语法需要python3.10,所以需要安装python3.10. 导入pycharm,下面选择取消 安…

Git分支-团队协作以及GitHub操作

Git分支操作 在版本控制过程中&#xff0c;同时推进多个任务> 程序员开发与开发主线并行&#xff0c;互不影响 分支底层也是指针的引用 hot-fix:相当于若在进行分支合并后程序出现了bug和卡顿等现象&#xff0c;通过热补丁来进行程序的更新&#xff0c;确保程序正常运行 常…

【Conda】Conda命令详解:高效更新与环境管理指南

目录 1. Conda 更新命令1.1 更新 Conda 核心1.2 更新所有包 2. 严格频道优先级3. 强制安装特定版本4. 创建与管理环境4.1 创建新环境4.2 激活和停用环境4.3 导出和导入环境4.4 删除环境 5. 清理缓存总结 Conda 是一个强大的包管理和环境管理工具&#xff0c;广泛应用于数据科学…

Linux中环境变量

基本概念 环境变量Environmental variables一般是指在操作系统中用来指定操作系统运行环境一些参数。 我们在编写C、C代码时候&#xff0c;在链接的时候从来不知道我们所链接的动态、静态库在哪里。但是还是照样可以链接成功。生成可执行程序。原因就是相关环境变量帮助编译器…