MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试

news2025/1/10 22:24:55

如果你是一个Mac用户和一个深度学习爱好者,你可能希望在某些时候Mac可以处理一些重型模型。苹果刚刚发布了MLX,一个在苹果芯片上高效运行机器学习模型的框架。

最近在PyTorch 1.12中引入MPS后端已经是一个大胆的步骤,但随着MLX的宣布,苹果还想在开源深度学习方面有更大的发展。

在本文中,我们将对这些新方法进行测试,在三种不同的Apple Silicon芯片和两个支持cuda的gpu上和传统CPU后端进行基准测试。

这里把基准测试集中在图卷积网络(GCN)模型上。这个模型主要由线性层组成,所以对于其他的模型也应该得到类似的结果。

创造环境

要为MLX构建环境,我们必须指定是使用i386还是arm架构。使用conda,可以使用:

 CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forge
 conda activate mlx

如果检查你的env是否实际使用了arm,下面命令的输出应该是arm,而不是i386(因为我们用的Apple Silicon):

 python -c "import platform; print(platform.processor())"

然后就是使用pip安装MLX:

 pip install mlx

GCN模型

GCN模型是图神经网络(GNN)的一种,它使用邻接矩阵(表示图结构)和节点特征。它通过收集邻近节点的信息来计算节点嵌入。每个节点获得其邻居特征的平均值。这种平均是通过将节点特征与标准化邻接矩阵相乘来完成的,并根据节点度进行调整。为了学习这个过程,特征首先通过线性层投射到嵌入空间中。

我们将使用MLX实现一个GCN层和一个GCN模型:

 import mlx.nn as nn
 
 class GCNLayer(nn.Module):
     def __init__(self, in_features, out_features, bias=True):
         super(GCNLayer, self).__init__()
         self.linear = nn.Linear(in_features, out_features, bias)
 
     def __call__(self, x, adj):
         x = self.linear(x)
         return adj @ x
 
 class GCN(nn.Module):
     def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
         super(GCN, self).__init__()
 
         layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
         self.gcn_layers = [
             GCNLayer(in_dim, out_dim, bias)
             for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
         ]
         self.dropout = nn.Dropout(p=dropout)
 
     def __call__(self, x, adj):
         for layer in self.gcn_layers[:-1]:
             x = nn.relu(layer(x, adj))
             x = self.dropout(x)
 
         x = self.gcn_layers[-1](x, adj)
         return x

可以看到,mlx的模型开发方式与tf2基本一样,都是调用

__call__

进行前向传播,其实torch也一样,只不过它自定义了一个forward函数。

下面就是训练

 gcn = GCN(
     x_dim=x.shape[-1],
     h_dim=args.hidden_dim,
     out_dim=args.nb_classes,
     nb_layers=args.nb_layers,
     dropout=args.dropout,
     bias=args.bias,
 )
 mx.eval(gcn.parameters())
 
 optimizer = optim.Adam(learning_rate=args.lr)
 loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)
 
 # Training loop
 for epoch in range(args.epochs):
 
     # Loss
     (loss, y_hat), grads = loss_and_grad_fn(
         gcn, x, adj, y, train_mask, args.weight_decay
     )
     optimizer.update(gcn, grads)
     mx.eval(gcn.parameters(), optimizer.state)
 
     # Validation
     val_loss = loss_fn(y_hat[val_mask], y[val_mask])
     val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,计算是惰性的,这意味着eval()通常用于在更新后实际计算新的模型参数。而另一个关键函数是nn.value_and_grad(),它生成一个计算参数损失的函数。第一个参数是保存当前参数的模型,第二个参数是用于前向传递和损失计算的可调用函数。它返回的函数接受与forward函数相同的参数(在本例中为forward_fn)。我们可以这样定义这个函数:

 def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
     y_hat = gcn(x, adj)
     loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
     return loss, y_hat

它仅仅包括计算前向传递和计算损失。Loss_fn()和eval_fn()定义如下:

 def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
     l = mx.mean(nn.losses.cross_entropy(y_hat, y))
 
     if weight_decay != 0.0:
         assert parameters != None, "Model parameters missing for L2 reg."
 
         l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
         return l + weight_decay * l2_reg
 
     return l
 
 def eval_fn(x, y):
     return mx.mean(mx.argmax(x, axis=1) == y)

损失函数是计算预测和标签之间的交叉熵,并包括L2正则化。由于L2正则化还不是内置特性,需要手动实现。

本文的完整代码:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些细节函数调用的差别,基本的训练流程与pytorch和tf都很类似,但是这里的一个很好的事情是消除了显式地将对象分配给特定设备的需要,就像我们在PyTorch中经常使用.cuda()和.to(device)那样。这是因为苹果硅芯片的统一内存架构,所有变量共存于同一空间,也就是说消除了CPU和GPU之间缓慢的数据传输,这样也可以保证不会再出现与设备不匹配相关的烦人的运行时错误。

基准测试

我们将使用MLX与MPS, CPU和GPU设备进行比较。我们的测试平台是一个2层GCN模型,应用于Cora数据集,其中包括2708个节点和5429条边。

对于MLX, MPS和CPU测试,我们对M1 Pro, M2 Ultra和M3 Max进行基准测试。在两款NVIDIA V100 PCIe和V100 NVLINK上进行测试

MPS:比M1 Pro的CPU快2倍以上,在其他两个芯片上,与CPU相比有30-50%的改进。

MLX:比M1 Pro上的MPS快2.34倍。与MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之间没有真正的改进。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max与MLX,这里的原因可能是因为我们的模型比较小,所以发挥不出V100和NVLINK的优势(NVLINK主要GPU之间的数据传输大的情况下会有提高)。这也说明了苹果的统一内存架构的确可以消除CPU和GPU之间缓慢的数据传输。

总结

与CPU和MPS相比,MLX可以说是非常大的金币,在小数据量的情况下它甚至接近特斯拉V100的性能。也就是说我们可以使用MLX跑一些不是那么大的模型,比如一些表格数据。

从上面的基准测试也可以看到,现在可以利用苹果芯片的全部力量在本地运行深度学习模型(我一直认为MPS还没发挥苹果的优势,这回MPS已经证明了这一点)。

MLX刚刚发布就已经取得了惊人的影响力,并展示了巨大的潜力。相信未来几年开源社区的进一步增强,可以期待在不久的将来更强大的苹果芯片,将MLX的性能提升到一个全新的水平。

另外也说明了MPS(虽然也发布不久)还是有巨大的发展空间的,毕竟切换框架是一件很麻烦的事情,如果MPS能达到MLX 80%或者90%的速度,我想不会有人去换框架的。

最后说到框架,现在已经有了Pytorch,TF,JAX,现在又多了一个MLX。各种设备、各种后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,现在又多了一个MPS。

https://avoid.overfit.cn/post/eb87d12f29eb4665adb43ad59fd3d64f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1325540.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Catboost算法助力乳腺癌预测:Shap值解析关键预测因素

一、引言 乳腺癌是一种常见的恶性肿瘤,对女性健康和生命造成严重威胁。乳腺癌的预测和治疗是当前研究的热点和难点。传统的预测方法主要基于临床病理学特征,但准确率有待提高。随着机器学习技术的发展,数据驱动的预测方法逐渐受到关注。Catbo…

ChatGPT如何计算token数?

GPT 不是适用于某一门语言的大型语言模型,它适用于几乎所有流行的自然语言。所以 GPT 的 token 需要 兼容 几乎人类的所有自然语言,那意味着 GPT 有一个非常全的 token 词汇表,它能表达出所有人类的自然语言。如何实现这个目的呢?…

EDI能够为企业间信息传输带来哪些帮助?

EDI全称电子数据交换,它的历史可以追溯到 1960 年代,从供应链和物流到医疗保健和金融都可以看到EDI的身影,没有任何行业限制。尽管EDI技术已经在国外得到广泛使用,国内自研EDI软件产品也已有十余年的历史,但国内目前的…

大数据----33.hbase中的shell文件操作

HBase的命令行工具,最简单的接口,适合HBase管理使用,可以使用shell命令来查询HBase中数据的详细情况。 注意:如果进入hbase后长时间不操作; 发生hbase自动关闭没有了进程; 原因是内存不够;可以关…

leetCode算法—11. 盛最多水的容器

11.给定一个长度为 n 的整数数组 height 。有 n 条垂线,第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 难度:中等 ** 找出其中的两条线,使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说明&#x…

Deap 框架细节介绍

创建一个 gp.PrimitiveSet 对象&#xff0c;对象名为 MAIN&#xff0c;自变量为 3 pset gp.PrimitiveSet("MAIN", 3) print(pset)<deap.gp.PrimitiveSet object at 0x000001FBE182AB20>gp.py&#xff08;均为产生函数集与节点集&#xff09; PrimitiveSet …

State of PostgreSQL 2023 报告解读

基于 PostgreSQL 内核的时序数据库厂商 Timescale 发布了一年一度的 State of Postgres 2023 报告。 Timescale 介绍 简单先介绍一下 Timescale 这家公司的历史。它最早是提供了一个 PG 的插件&#xff0c;引入了 Hypertable 这个概念&#xff0c;来高效地处理时序数据&…

分享一个项目——Sambert UI 声音克隆

文章目录 前言一、运行ipynb二、数据标注三、训练四、生成总结 前言 原教程视频 项目链接 运行一个ipynb&#xff0c;就可操作 总共四步 1&#xff09;运行ipynb 2&#xff09;数据标注 3&#xff09;训练 4&#xff09;生成 一、运行ipynb 等运行完毕后&#xff0c;获得该…

mt5和mt4交易软件有什么区别?

MetaTrader 4&#xff08;MT4&#xff09;和MetaTrader 5&#xff08;MT5&#xff09;是两种广泛使用的外汇和金融市场交易平台&#xff0c;由MetaQuotes公司开发。尽管它们都是外汇交易的常见选择&#xff0c;但在功能和特性上存在一些区别。以下是MT4和MT5之间的主要区别&…

2023 英特尔On技术创新大会直播 |我感受到的“芯”魅力

文章目录 每日一句正能量前言AI时代&#xff0c;云与PC结合为用户带来更好体验全新处理器&#xff0c;首次引入针对人工智能加速的NPU大模型时代&#xff0c;软硬结合带来更好训练成果后记 每日一句正能量 成长是一条必走的路路上我们伤痛在所难免。 前言 在2023年的英特尔On技…

Go语言HTTP编程入门指南

如果你是一名开发者&#xff0c;那么你一定听说过Go语言。Go&#xff0c;也被称为Golang&#xff0c;是由Google开发的一种静态类型、编译型语言。它的设计理念是“简单、快速、高效”&#xff0c;这使得Go语言在许多方面都表现出色&#xff0c;尤其是在网络编程和并发编程方面…

fastGitHub工具推荐(如果打不开github或者使用很慢可以使用该工具)

目录 一&#xff0c;针对问题二&#xff0c;下载1&#xff0c;github里面下载FastGitHub2&#xff0c;博客上传了下载资源 三&#xff0c;安装使用点击执行文件即可 一&#xff0c;针对问题 当使用github很慢&#xff0c;或者根本打不开的时候&#xff0c;就可以使用该工具 …

HDFS NFS Gateway(环境配置,超级详细!!)

HDFS NFS Gateway简介: ​ HDFS NFS Gateway是Hadoop Distributed File System&#xff08;HDFS&#xff09;中的一个组件&#xff0c;它允许客户端通过NFS&#xff08;Network File System&#xff0c;网络文件系统&#xff09;与HDFS进行交互。具体来说&#xff0c;HDFS NFS…

搭建esp32-idf开发环境并烧入第一个程序

ESP32下载idf并烧入第一个程序 一.官网下载idf安装包二.安装idf三 .测试安装是否成功3.1进入idf控制台3.2 查看安装版本3.3 编译工程 四.下载程序4.1查看所在端口4.2下载程序4.3 监听串口 一.官网下载idf安装包 点击下载 如图&#xff1a; 我们选择离线下载&#xff0c;注意…

行业前景咋样?大厂找我用C++抓取化工产品数据并分析

最近又来活了&#xff0c;天天忙到半夜&#xff0c;但是想想收益还是再坚持坚持。是这么一回事&#xff0c;兄弟所在的化工公司最近想看看某些行业数据&#xff0c;看看市面的同类型产品销量收益等情况是否满足预期效果&#xff0c;也就找到我让我给用爬虫写一个采集并分析的报…

如何实现设备联网控制?

在工业自动化领域&#xff0c;设备联网控制已经成为一种趋势。通过设备联网&#xff0c;可以实现设备的远程监控和管理&#xff0c;提高设备的可用性和效率。本文将介绍如何实现设备联网控制。 设备如何联网&#xff1f; 使用网关联网&#xff1a; HiWoo Box是一款功能强大的…

Sectigo的ov多域名ssl证书

OV多域名SSL证书和EV多域名SSL证书都只支持企事业单位申请&#xff0c;但是EV多域名SSL证书审核比较严格&#xff0c;价格也比较高&#xff0c;OV多域名SSL证书能加密网站传输数据&#xff0c;也能对服务器身份进行认证。对于大多数事业单位&#xff0c;OV多域名SSL证书就能满足…

外贸业务员该如何写好一份有质感的年终总结?内附外贸大神例文

庄子云&#xff1a;人生天地之间&#xff0c;若白驹之过隙&#xff0c;忽然而已... 2023年注定是不平凡的一年&#xff0c;临近年终&#xff0c;你可能听到最多的关键词就是外贸有点难做。不管是因为什么&#xff0c;客观来说2023年的外贸之路确实不太平坦&#xff0c;最近胡塞…

车辆违规开启远光灯检测系统:融合YOLO-MS改进YOLOv8

1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 研究背景与意义 随着社会的不断发展和交通工具的普及&#xff0c;车辆违规行为成为了一个严重的问题。其中&#xff0c;车辆违规开启远光灯是一种常见的违规行为&#xff0c;给其…

快速能访问服务器的文件

1、背景 访问ubuntu上的文件 2、方法 python3 -m http.server 8081 --directory /home/ NAS 共享访问协议 — NFS、SMB、FTP、WebDAV 各有何优势&#xff1f;http://1 Ubuntu 搭建文件服务器&#xff08;Nginx&#xff09;