编译 MXNet 模型

news2024/11/29 7:57:25

本篇文章译自英文文档 Compile MXNet Models。

作者是 Joshua Z. Zhang,Kazutaka Morita。

更多 TVM 中文文档可访问 →TVM 中文站。

本文将介绍如何用 Relay 部署 MXNet 模型。

首先安装 mxnet 模块,可通过 pip 快速安装:

pip install mxnet --user

或参考官方安装指南:https://mxnet.apache.org/versions/master/install/index.html

# 一些标准的导包
import mxnet as mx
import tvm
import tvm.relay as relay
import numpy as np

从 Gluon Model Zoo 下载 Resnet18 模型

本节会下载预训练的 imagenet 模型,并对图像进行分类。

from tvm.contrib.download import download_testdata
from mxnet.gluon.model_zoo.vision import get_model
from PIL import Image
from matplotlib import pyplot as plt

block = get_model("resnet18_v1", pretrained=True)
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_name = "cat.png"
synset_url = "".join(
    [
        "https://gist.githubusercontent.com/zhreshold/",
        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
        "imagenet1000_clsid_to_human.txt",
    ]
)
synset_name = "imagenet1000_clsid_to_human.txt"
img_path = download_testdata(img_url, "cat.png", module="data")
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))
plt.imshow(image)
plt.show()

def transform_image(image):
    image = np.array(image) - np.array([123.0, 117.0, 104.0])
    image /= np.array([58.395, 57.12, 57.375])
    image = image.transpose((2, 0, 1))
    image = image[np.newaxis, :]
    return image

x = transform_image(image)
print("x", x.shape)

在这里插入图片描述
输出结果:

Downloading /workspace/.mxnet/models/resnet18_v1-a0666292.zip08d19deb-ddbf-4120-9643-fcfab19e7541 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/resnet18_v1-a0666292.zip...
x (1, 3, 224, 224)

编译计算图

只需几行代码,即可将 Gluon 模型移植到可移植计算图上。mxnet.gluon 支持 MXNet 静态图(符号)和 HybridBlock。

shape_dict = {"data": x.shape}
mod, params = relay.frontend.from_mxnet(block, shape_dict)
## 添加 softmax 算子来提高概率
func = mod["main"]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)

接下来编译计算图:

target = "cuda"
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(func, target, params=params)

输出结果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "

在 TVM 上执行可移植计算图

接下来用 TVM 重现相同的前向计算:

from tvm.contrib import graph_executor

dev = tvm.cuda(0)
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# 设置输入
m.set_input("data", tvm.nd.array(x.astype(dtype)))
# 执行
m.run()
# 得到输出
tvm_output = m.get_output(0)
top1 = np.argmax(tvm_output.numpy()[0])
print("TVM prediction top-1:", top1, synset[top1])

输出结果:

TVM prediction top-1: 282 tiger cat

使用带有预训练权重的 MXNet 符号

MXNet 常用 arg_params 和 aux_params 分别存储网络参数,下面将展示如何在现有 API 中使用这些权重:

def block2symbol(block):
    data = mx.sym.Variable("data")
    sym = block(data)
    args = {}
    auxs = {}
    for k, v in block.collect_params().items():
        args[k] = mx.nd.array(v.data().asnumpy())
    return sym, args, auxs

mx_sym, args, auxs = block2symbol(block)
# 通常将其保存/加载为检查点
mx.model.save_checkpoint("resnet18_v1", 0, mx_sym, args, auxs)
# 磁盘上有 "resnet18_v1-0000.params" 和 "resnet18_v1-symbol.json"

对于一般性 MXNet 模型:

mx_sym, args, auxs = mx.model.load_checkpoint("resnet18_v1", 0)
# 用相同的 API 来获取 Relay 计算图
mod, relay_params = relay.frontend.from_mxnet(mx_sym, shape_dict, arg_params=args, aux_params=auxs)
# 重复相同的步骤,用 TVM 运行这个模型

下载 Python 源代码:from_mxnet.py

下载 Jupyter Notebook:from_mxnet.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/528306.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4、picodet 小目标训练全流程

文章目录 1、数据准备1.1 VOC转COCO2、使用sahi切图2.1 切图分析及过程可视化2.2 使用完整的切图命令进行切图2.3 对各个数据集的状态进行查看2.4 过滤数据集中不合适的框 3、转换成VOC4、生成训练数据5、模型训练6、模型推理 使用picodet进行小目标检测。 本文以检测小目标乒乓…

索洛模型(二)

索洛模型(二) 文章目录 索洛模型(二)[toc]1 事实2 假设2.1 对生产函数的假设2.2对投入要素的假设 3 索洛模型的动态学3.1 k k k的动态学3.2 平衡增长路径 4 储蓄率变化的影响4.1 对产出的影响4.2 对消费的影响 索罗经济增长模型(Solow growth model)&am…

ClickHouse 安装部署

文章目录 ClickHouse 安装部署一、准备环节1、确认防火墙是在关闭状态2、CentOS 取消打开文件数限制3、安装依赖4、CentOS 取消 SELINUX 二、单机搭建三、启动server ClickHouse 安装部署 一、准备环节 1、确认防火墙是在关闭状态 输入命令: systemctl status fi…

Centos7.6系统里安装Superset,连接ClickHouse

​ 本文是在centos 7 虚拟机中安装Superset和clickhouse,首先要有 安装python3环境 Centos7.6默认有python2,要先安装python3,下边这个python3安装教程很详细。 参考连接:CentOS7下安装Python3,超详细完整教程_centos…

使用vercel免费搭建vue项目

之前是通过Github作为服务器来发布静态网站,今天有人告诉我,这里有一个叫vercel的商家可以直接白嫖,来试试给他上一课。 1 注册账号 进入官网vercel.com进行注册,并且绑定自己的 Github 2 项目代码 若是自己的项目就不用管; 不是…

夏令营教育小程序开发功能和优势有哪些?

随着人们生活水平的提高,对于孩子的教育问题也是越来越重视,无论是教育方式还是教育内容上都追求新颖、多样化。在暑假期间,很多家长也希望孩子能够在这个长假期之间参加一些活动,培养孩子兴趣的同时也丰富假期内容,让…

【云原生进阶之PaaS中间件】第一章Redis-2.1架构综述

1 Redis组件模型 Redis 组件的系统架构如图所示,主要包括事件处理、数据存储及管理、用于系统扩展的主从复制/集群管理,以及为插件化功能扩展的 Module System 模块。 Redis的客户端与服务端的交互过程如下所示: 1.1 事件处理机制 Redis 中的…

21天学会C++:Day3----缺省参数

CSDN的uu们,大家好。这里是C入门的第三讲。 座右铭:前路坎坷,披荆斩棘,扶摇直上。 博客主页: 姬如祎 收录专栏:C专题 目录 1. 知识引入 2. 缺省参数知识点 2.1 全缺省 2.2 半缺省 2.3 函数定义给缺…

MySQL 数据库 高可用 MAH

概述 什么是 MHA MHA(Master High Availability)是一套优秀的MySQL高可用环境下故障切换和主从复制的软件。 MHA 的出现就是解决MySQL 单点的问题。 MySQL故障切换过程中,MHA能做到0-30秒内自动完成故障切换操作。 MHA能在故障切换的过程中最…

数据结构-堆和堆排序-TopK问题

内容总览 1.堆的定义2.堆的实现接口(大堆)2.1 堆结构体定义2.2 堆的初始化与销毁2.3 堆的向上调整算法和插入2.4 堆的向下调整算法和删除堆顶元素2.5 堆的其他接口(调整堆递归版本) 3.建堆效率问题分析3.1 向上建堆3.2 向下建堆 4…

Java中的TCP (Android通用)

TCP服务端,创建了一个线程的接口 public class TCPServer implements Runnable {private static final String TAG "TCPServer";private String chaSet "UTF-8";private int port;private boolean isListen true;public TCPServer(int port)…

TypeScript 学习笔记 (学习中)

学习视频1:coderwhy 学习视频2:尚硅谷 文章目录 TypeScript 学习笔记概述TypeScript 开发环境搭建 类型注解类型推断 数据类型JS的7个原始类型Array数组object、Object 和 {}1.可选属性 ? 2.type 类型别名 和 接口interface函数TS类型: any类型 | unkno…

分享Python采集66个css3代码,总有一款适合您

分享Python采集66个css3代码,总有一款适合您 Python采集的66个css3代码下链接: 百度网盘 请输入提取码 提取码:mads css3svg炫酷水滴Loading特效 css剪裁GIF背景图片动画特效 纯CSS制作辛普森一家卡通人物动画特效 CSS3图片遮罩层变形…

1688商品详情数据采集技术,支持整站数据高并发采集

一、如何通过手动方式查看1688商品详情页面的数据 1.1688商品详情 API 接口(item_get - 获得1688商品详情接口),1688API 接口代码对接可以获取到宝贝 ID,宝贝标题,价格,优惠价,掌柜名称&a…

ArcSWAT报错:-2147217385;创建栅格数据集失败

文章目录 1 报错内容2 报错分析3 解决方案3.1 数据集路径错误3.2 数据格式不受支持3.3 文件访问权限问题 1 报错内容 此报错通常发生在建立了一个SWAT数据库后,执行Watershed Delineator中的Automatic Watershed Delineation操作中,在选择了DEM数据后弹出…

亚马逊云科技Amazon Compute Optimizer基础设施

亚马逊云科技Amazon Compute Optimizer如今推出了一项新功能,可以利用多个CPU架构(包括基于x86的实例和基于Amazon Graviton的实例)更轻松地优化EC2实例。Compute Optimizer是一项可选服务基础设施,可为工作负载推荐最佳Amazon资源…

Kali-linux使用OpenVAS

OpenVAS(开放式漏洞评估系统)是一个客户端/服务器架构,它常用来评估目标主机上的漏洞。OpenVAS是Nessus项目的一个分支,它提供的产品是完全地免费。OpenVAS默认安装在标准的Kali Linux上,本节将介绍配置及启动OpenVAS。…

Flink基础介绍-3 Time与Window

Flink基础介绍-3 Time与Window 三、流处理中的Time与Window3.1 Time3.2 window3.3 Window API3.4 Watermark 三、流处理中的Time与Window 3.1 Time Event Time:是事件创建的时间。它通常由事件中的时间戳描述,例如采集的日志数据中,每一条日…

SpringSecurity原理和实际应用

前提知识 认证:系统提供的用于识别用户身份的功能,通常提供用户名和密码进行登录其实就是在进行认证,认证的目的是让系统知道你是谁。 授权:用户认证成功后,需要为用户授权,其实就是指定当前用户可以操作哪…

Spring Resource接口 学习

Resource 接口是 Spring 资源访问策略的抽象,它本身并不提供任何资源访问实现,具体的资源访问由该接口的实现类完成——每个实现类代表一种资源访问策略。Resource一般包括这些实现类:UrlResource、ClassPathResource、FileSystemResource、S…