编译 Keras 模型

news2025/1/12 20:38:17

本篇文章译自英文文档 Compile Keras Models

作者是 Yuwei Hu

更多 TVM 中文文档可访问 →TVM 中文站。

本文介绍如何用 Relay 部署 Keras 模型。

首先安装 Keras 和 TensorFlow,可通过 pip 快速安装:

pip install -U keras --user
pip install -U tensorflow --user

或参考官网:https://keras.io/#installation

import tvm
from tvm import te
import tvm.relay as relay
from tvm.contrib.download import download_testdata
import keras
import tensorflow as tf
import numpy as np

加载预训练的 Keras 模型


加载 Keras 提供的预训练 resnet-50 分类模型:

if tuple(keras.__version__.split(".")) < ("2", "4", "0"):
    weights_url = "".join(
        [
            "https://github.com/fchollet/deep-learning-models/releases/",
            "download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
        ]
    )
    weights_file = "resnet50_keras_old.h5"
else:
    weights_url = "".join(
        [
            " https://storage.googleapis.com/tensorflow/keras-applications/",
            "resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5",
        ]
    )
    weights_file = "resnet50_keras_new.h5"

weights_path = download_testdata(weights_url, weights_file, module="keras")
keras_resnet50 = tf.keras.applications.resnet50.ResNet50(
    include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
)
keras_resnet50.load_weights(weights_path)

加载测试图像

这里使用的还是先前猫咪的图像:

from PIL import Image
from matplotlib import pyplot as plt
from tensorflow.keras.applications.resnet50 import preprocess_input

img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
plt.imshow(img)
plt.show()
# 预处理输入
data = np.array(img)[np.newaxis, :].astype("float32")
data = preprocess_input(data).transpose([0, 3, 1, 2])
print("input_1", data.shape)

请添加图片描述

输出结果:

input_1 (1, 3, 224, 224)

使用 Relay 编译模型

将 Keras 模型(NHWC 布局)转换为 Relay 格式(NCHW 布局):

shape_dict = {"input_1": data.shape}
mod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)
# 编译模型
target = "cuda"
dev = tvm.cuda(0)

# TODO(mbs):opt_level=3 导致 nn.contrib_conv2d_winograd_weight_transform
# 很可能由于潜在的错误,最终出现在 cuda 上的内存验证失败的模块中。
# 注意:只能在 evaluate() 中传递 context,它不被 create_executor() 捕获。
with tvm.transform.PassContext(opt_level=0):
    model = relay.build_module.create_executor("graph", mod, dev, target, param).evaluate()

在 TVM 上执行

dtype = "float32"
tvm_out = model(tvm.nd.array(data.astype(dtype)))
top1_tvm = np.argmax(tvm_out.numpy()[0])

查找分类集名称

在 1000 个类的分类集中,查找分数最高的第一个:

synset_url = "".join(
    [
        "https://gist.githubusercontent.com/zhreshold/",
        "4d0b62f3d01426887599d4f7ede23ee5/raw/",
        "596b27d23537e5a1b5751d2b0481ef172f58b539/",
        "imagenet1000_clsid_to_human.txt",
    ]
)
synset_name = "imagenet1000_clsid_to_human.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synset = eval(f.read())
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, synset[top1_tvm]))
# 验证 Keras 输出的正确性
keras_out = keras_resnet50.predict(data.transpose([0, 2, 3, 1]))
top1_keras = np.argmax(keras_out)
print("Keras top-1 id: {}, class name: {}".format(top1_keras, synset[top1_keras]))

输出结果:

Relay top-1 id: 285, class name: Egyptian cat
Keras top-1 id: 285, class name: Egyptian cat

下载 Python 源代码:from_keras.py

下载 Jupyter Notebook:from_keras.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/636815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营第五十五天|392.判断子序列|115.不同的子序列

LeetCode392.判断子序列 动态规划五部曲&#xff1a; 1&#xff0c;确定dp数组&#xff08;dp table&#xff09;以及下标的含义&#xff1a;dp[i][j] 表示以下标i-1为结尾的字符串s&#xff0c;和以下标j-1为结尾的字符串t&#xff0c;相同子序列的长度为dp[i][j]。注意这里…

postman中级:导入文件数据,批量化参数

建议阅读对象&#xff1a;已掌握postman的基本使用&#xff08;参见&#xff1a;postman入门-主界面认识&#xff0c;模拟请求&#xff09; 本地安装的版本&#xff1a;Postman for Windows Version 10.14.9 1.创建csv文件 或 txt文件 文件数据格式&#xff1a; 第一行写下参…

python生成excel文件的三种方式

在我们做平常工作中都会遇到操作excel&#xff0c;那么今天写一篇&#xff0c;如何通过python操作excel。当然python操作excel的库有很多&#xff0c;比如pandas&#xff0c;xlwt/xlrd&#xff0c;openpyxl等&#xff0c;每个库都有不同的区别&#xff0c;具体的区别&#xff0…

day19--栈

用两个栈实现队列 栈&#xff1a;先进后出&#xff1b;队列&#xff1a;先进先出>因此两个栈即可模拟队列 class Solution { public:void push(int node) {//进队stack1.push(node);//进栈}int pop() {//出队int t;if(stack2.empty()){//栈2空while(!stack1.empty()){//栈1…

vue table页展示

<template><el-container><el-header><el-tabsv-model"groupId"tab-click"tabChange"class"w-full pt-11 ml-5"><el-tab-panelabel"登记进度"name"0"></el-tab-pane><el-tab-panela…

MarkDown使用教程

MarkDown使用教程 1.标题 #: 一级标题 ##: 二级标题 ###: 三级标题 一共分为六级 2.字体 斜体文本 斜体文本 粗体文本 粗体文本 粗斜体文本 粗斜体文本 3.列表 无序号的使用*、、- 作为列表的标记&#xff0c;这些标记后面添加一个空格 第一项第二项第三项 第一项第二项…

ArcGIS如何统计面内点的数量

本文来源&#xff1a;GIS科研实验室公众号 1 数据来介绍 本次教程使用的数据为&#xff1a;各小区的点坐标&#xff08;来源房天下&#xff0c;坐标为CGCS2000&#xff09;&#xff1b;基础教育设施、商业服务设施、金融保险设施、医疗卫生设施的POI坐标&#xff08;来源高德…

java:找不到符号 符号:变量:log get set

问题&#xff1a;java&#xff1a;找不到符号&#xff1a;变量&#xff1a;log get set解决方法&#xff1a;在idea中&#xff0c;点击file-Settings&#xff0c;打开配置页面&#xff0c;如图红框位置&#xff0c;输入&#xff1a; -Djps.track.ap.dependenciesfalse

pyecharts案例四——动态GDP柱状图绘制

思路 for循环每一年的数据&#xff0c;基于每一年的数据&#xff0c;创建每一年的Bar对象&#xff0c;并且将该对象添加到时间线timeline中&#xff0c;最后设置自动播放并绘图 实现代码 from pyecharts.charts import Bar, Timeline from pyecharts.options import * from …

(7)自动调优

文章目录 前言 1 在自动调优模式下飞行前的设置 2 如何调用自动调优 3 在位置保持下调用自动调优 4 如果自动调优失败 5 补充说明 6 常见的问题 7 Dataflash日志记录 8 地面控制站消息 前言 AutoTune 试图自动调优稳定P&#xff0c;速率P和 D&#xff0c;以及最大旋转…

交流(直流)电流采集方案

芯片原理图 注意途中的绿色部分&#xff0c;说明此芯片可以采集交流或者直流 内部霍尔工作原理图 通过曲线可以确定再0A时输出电压为2.5v 只有随着电流的变化是大于2.5v或者小于2.5v&#xff08;交流负方向&#xff0c;或者直流负方向&#xff09; 下面是一个插排的拆解视频截…

9. ThreadLocal

9.1 ThreadLocal简介 9.1.1 面试题 ● ThreadLocal中ThreadLocalMap的数据结构和关系 ● ThreadLocal的key是弱引用&#xff0c;这是为什么&#xff1f; ● ThreadLocal内存泄漏问题你知道吗&#xff1f; ● ThreadLocal中最后为什么要加remove方法&#xff1f; 9.1.2 是什么&a…

KaiwuDB 受邀亮相山东省数字化转型论坛

4月21日&#xff0c;第十五届信博会暨中国&#xff08;济南&#xff09;数字经济高端峰会成功举办。KaiwuDB 受邀出席峰会重要论坛—山东省数字化转型论坛&#xff0c;并发表《工业物联网时代&#xff0c;数据库赋能企业数字化转型落地实践》主题演讲&#xff0c;与来自国内的 …

Java动态代理:优化静态代理模式的灵活解决方案

文章目录 代理模式定义具体实现分析优缺点 优化使用动态代理解决优化相关知识动态代理种类场景应用 代理模式 定义 代理模式&#xff0c;为其他对象提供一种代理以控制对这个对象的访问 具体实现 代理模式的具体实现描述可以分为以下几个步骤&#xff1a; 创建抽象对象接…

什么是Vue的JSX语法?如何使用JSX语法?

什么是Vue的JSX语法&#xff1f;如何使用JSX语法&#xff1f; 在Vue中&#xff0c;我们通常使用模板语法来编写组件的模板。但是&#xff0c;有些开发者更喜欢使用类似于React的JSX语法来编写组件。Vue也支持使用JSX语法来编写组件&#xff0c;本文将介绍什么是Vue的JSX语法以…

企业级信息系统开发讲课笔记4.7 Spring Boot整合JPA

文章目录 零、学习目标一、Spring Data JPA概述1、Spring Data JPA简介2、Spring Data JPA基本使用3、使用Spring Data JPA进行数据操作的多种实现方式4、自定义Repository接口中的Transactional注解5、变更操作&#xff0c;要配合使用Query与Modify注解 二、Spring Boot整合JP…

热门图表软件推荐,哪款更功能更强大?

在如今的数据化时代&#xff0c;各种企业都需要有一套高效的报表制作工具。而图表是报表中最常用、也是最重要的一部分&#xff0c;因此选择一款优秀的图表软件显得尤为重要。本文将为大家介绍5款热门图表软件&#xff0c;并突出介绍VeryReport图表软件的优势。 1. VeryReport…

NetApp 全闪存 ASA 系统可为您的任务关键型企业级应用程序、数据库和 VMware 基础架构提供简单专用的块存储

NetApp ASA&#xff1a;全闪存 SAN 阵列 在性能和效率之间进行艰难抉择的时代已经过去。NetApp ASA 系统提供简单专用的块存储&#xff0c;具有卓越的性能、高可用性和领先的效率 — 无需权衡取舍。 为什么选择适用于 SAN 的 NetApp ASA 系统&#xff1f; 简单的 SAN 存储&…

DCL单例及synchrosized问题

疑问待解&#xff1a; 1 synchronized代码块执行完后&#xff0c;在没有return INSTANCE之前&#xff0c;其他线程是否可见这个对象&#xff08;因为synchronized出块后会把工作内存写到主存&#xff09;&#xff1f; 如果可见&#xff0c;那么return的作用是不是可有可无&…

object类型(equals、hashCode、getClass、getName)

equals方法的改写 Override//重写equals方法&#xff0c;重写方法后对比的属性值&#xff08;没有重写前对比的是属性值&#xff09;public boolean equals(Object obj) {Students s (Students) obj;return this.name.equals(s.name) && this.age s.age;}public clas…