Tensorflow预训练模型ckpt与pb两种文件类型的介绍

news2024/9/25 1:20:35

我们在  Tensorflow无人车使用移动端的SSD(单发多框检测)来识别物体及Graph的认识 熟悉了Graph计算图以及在 Tensorflow2.0中function(是1.0版本的Graph的推荐替代)的相关知识介绍 这个tf.function的用法,了解到控制流与计算图的各自作用,无论使用哪种方式,在深度学习中,最关键的都将使用到预训练模型。

跟其他框架需要加载预训练模型一样,这里同样需要导入权重参数文件,也就是前面使用SSD来识别对象的代码,如下:

MODEL_NAME = 'ssdlite_mobilenet_v2_coco_2018_05_09'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

可以看到这里加载了两个文件:一个是pb文件,另外一个是pbtxt标签文件,其中pb文件,我们看它所在目录文件名称ssdlite_mobilenet_v2_coco_2018_05_09,也可以知道属于移动端的轻量级SSD预训练模型,而且使用的是COCO2018的数据集。
接下来就是着重对预训练模型文件操作的介绍。

1、保存ckpt模型

先来看一个示例,训练之后我们将其模型进行保存,这样做的目的是,不需要每次都重新训练,可以直接加载参数来推理了,也可以很方便的移植给其他程序使用。

import tensorflow.compat.v1 as tf 
tf.disable_eager_execution()

v1 = tf.Variable(tf.constant([[11],[22]]),name='v1')
v2 = tf.Variable(tf.constant([[33],[44]]),name='v2')
result = v1 * v2

saver = tf.train.Saver()
with tf.Session() as sess:
    # 初始化所有变量
    #tf.global_variables_initializer().run()
    sess.run(tf.global_variables_initializer())
    print(sess.run(v1))
    print(sess.run(v2))
    print(sess.run(result))
    #这个扩展名".ckpt"可以忽略
    saver.save(sess,'model/model.ckpt')
'''
[[11]
 [22]]
[[33]
 [44]]
[[363]
 [968]]
'''

如下图:

这样就通过save方法保存了预训练模型,一个标准的.ckpt模型文件包含以下文件:

checkpoint:文本文件,所以可以直接打开,内容如下:

model_checkpoint_path: "model.ckpt"和all_model_checkpoint_paths: "model.ckpt"

model.ckpt.data-00000-of-00001:保存变量的值
model.ckpt.index:保存变量的名称,可以跟上面的一起看做是Key-Value的形式
model.ckpt.meta:保存计算图的结构

我们来查看下模型里面变量的情况:

read_ckpt =tf.train.NewCheckpointReader("model/model.ckpt")
print(read_ckpt.debug_string().decode("utf-8"))
print(read_ckpt.get_variable_to_dtype_map())
print(read_ckpt.get_variable_to_shape_map())
'''
v1 (DT_INT32) [2,1]
v2 (DT_INT32) [2,1]
{'v1': tf.int32, 'v2': tf.int32}
{'v1': [2, 1], 'v2': [2, 1]}
'''

使用的是tf.train.NewCheckpointReader来读取模型文件,显示了变量名,数据类型,形状。后面两种方法分别获取的是变量类型与形状的字典类型。

2、加载ckpt模型

 上面将模型保存好了之后,我们来加载模型测试下:

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

saver = tf.train.import_meta_graph('model/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess,'model/model.ckpt')
    out = tf.get_default_graph().get_tensor_by_name('mul:0')
    print(sess.run(out))
'''
INFO:tensorflow:Restoring parameters from model/model.ckpt
[[363]
 [968]]
'''

可以看到,通过tf.train.import_meta_graph首先加载计算图的结构。然后使用restore方法将模型文件恢复即可。
同样的也是需要临时禁用这个即时执行模式:tf.disable_eager_execution()
获取节点名称的方法,由于这里是做乘法运算,所以是mul,如果是加法就是add,当然如果是其他运算,想要知道名称也可以直接打印查看,比如在保存模型之前,可以先查看:

print(result)
Tensor("mul:0", shape=(2, 1), dtype=int32)

3、保存pb模型

上面的图片可以看到,保存的预训练模型有4个文件,各自分开的,显得比较的杂乱,我们可以将变量与权重值等,这些都一起写入到一个文件里面,就是前面文章中的SSD预训练模型使用到的pb文件。一起来看下,是如何保存为pb文件的

import tensorflow.compat.v1 as tf
from tensorflow.python.framework import graph_util
tf.disable_eager_execution()

v1 = tf.Variable(tf.constant([[1],[2]]),name='v1')
v2 = tf.Variable(tf.constant([[3],[4]]),name='v2')
result = v1 * v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    #print(graph_def)
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['mul'])
    with tf.gfile.GFile('model/newmodel.pb','wb') as f:
        f.write(output_graph_def.SerializeToString())

这里我们可以看到,跟保存ckpt文件还是有区别,首先获取计算图,然后使用graph_util.convert_variables_to_constants将变量转换成常量(这样就将网络架构和权重都保存在一个文件里面了),最后通过SerializeToString()将其转为字节流写入到文件。

4、加载pb模型

上面保存了pb模型文件之后,我们加载模型来测试下:

with tf.Session() as sess:
    with tf.io.gfile.GFile('model/newmodel.pb','rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    result = tf.import_graph_def(graph_def,return_elements=['mul:0'])
    print(sess.run(result))
'''
[array([[3],
       [8]])]
'''

同样的是使用tf.io.gfile.GFile方法,这里指定参数rb来读取这个模型文件,然后读取并解析成计算图,最后使用tf.import_graph_def将其导入到默认图即可,这里的Session里面参数如果没有指定计算图,就是默认图。
另外不管是保存还是读取,都是在tf.Session这个会话里面进行,需要计算节点的值,就通过sess.run()函数即可。

5、tf.placeholder占位符

我们在前面的文章打印了计算图,可以看到里面的op确实很多,而且每次的minibatch都将是一个op,这也是一种资源开销,所以我们使用占位符tf.placeholder来处理重复的操作,比如说,每次的minibatch传入到x = tf.placeholder(tf.float32,[None,32])上,在下一次传进来的x将直接替换掉上一次的x,不会重新产生新的op,这样就节省了开销。

import tensorflow.compat.v1 as tf
import numpy as np
tf.disable_eager_execution()

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
output = tf.multiply(a, b)
 
with tf.Session() as sess:
    print(sess.run(output, feed_dict = {a:[12.], b: [3.5]}))#[42.]

 使用占位符,然后使用feed_dict喂入数据即可。

6、错误处理

RuntimeError: When eager execution is enabled, `var_list` must specify a list or dict of variables to save

运行时错误:当启用即时执行时,' var_list '必须指定要保存的变量列表或字典

我们可以通过禁用即时执行来解决,这个即时执行的意思就是Tensorflow会立即执行每个操作,而不是先建立计算图,这个eager execution在Tensorflow2.0版本开始是默认开启的,我这里为了便于介绍相关知识点,将使用1.0,并临时禁用,这样就会在运行时建立计算图。tf.disable_eager_execution()

RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
运行时错误:会话图为空。在调用run()之前向图中添加操作。

这里也跟即时执行有关,同样的我们临时禁用:tf.disable_eager_execution()

另外 from tensorflow.python.framework.graph_util_impl 这个在2.0版本之后都将移除,被淘汰,上述代码也是为了做演示而兼容1.0版本做的测试。

RuntimeError: tf.placeholder() is not compatible with eager execution.
运行时错误:占位符与即时执行不兼容

同样的临时禁用即时执行:tf.disable_eager_execution()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/806184.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

向量vector与erase()

运行代码: //向量vector与erase() #include"std_lib_facilities.h" //声明Item类 struct Item {string name;int iid;double value;Item():name(" "),iid(0),value(0.0){}Item(string ss,int ii,double vv):name(ss),iid(ii),value(vv){}frien…

数据库原理1——《小猫猫大课堂》数据库原理篇

宝子,你不点个赞吗?不评个论吗?不收个藏吗? 最后的最后,关注我,关注我,关注我,你会看到更多有趣的博客哦!!! 喵喵喵,你对我真的很重要…

分布式锁中的王者方案 - Redission

文章目录 5.1 分布式锁-redission功能介绍5.2 分布式锁-Redission快速入门5.3 分布式锁-redission可重入锁原理5.4 分布式锁-redission锁重试和WatchDog机制5.5 分布式锁-redission锁的MutiLock原理 5.1 分布式锁-redission功能介绍 基于setnx实现的分布式锁存在下面的问题&am…

【如何训练一个中英翻译模型】LSTM机器翻译模型部署之onnx(python)(四)

系列文章 【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码(一) 【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存(二) 【如何训练一个中英翻译模型】LSTM机器翻译模型部署(三) 【如何…

极简每周计划应用程序WeekToDo

什么是 WeekToDo ? WeekToDo 是一款免费的极简每周计划应用程序,专注于隐私。使用待办事项列表和日历安排您的任务和项目。适用于 Windows、Mac、Linux 或在线。 WeekToDo 是一个免费且开源的极简每周计划程序。借助 WeekToDo,您可以以简单直观的方式定…

Matplotlib_绘制柱状图

绘制柱状图 🧩bar方法 bar()是Matplotlib.pyplot库中用于绘制条形图(bar chart)的函数。条形图是一种常见的数据可视化图表,用于显示不同类别之间的比较。 函数签名: matplotlib.pyplot.bar(x, height, width0.8, …

ICMP协议(网际报文控制协议)详解

ICMP协议(网际报文控制协议)详解 ICMP协议的功能ICMP的报文格式常见的ICMP报文差错报文目的站不可达数据报超时 查询报文回送请求或回答 ICMP协议是一个网络层协议。 一个新搭建好的网络,往往需要先进行一个简单的测试,来验证网络…

JDBC Some Templates

JDBCTemplate 是Spring对JDBC的封装&#xff0c;使用JDBCTemplate方便实现对数据的操作。 <!-- orm:Object relationship mapping m对象 关系 映射-->引入依赖 <!-- 基于Maven依赖的传递性&#xff0c;导入spring-content依赖即可导入当前所需的所有…

Spring项目启动报错无法访问org.springframework.boot.SpringApplication:6

当springBoot项目启动后报错如下 解决办法如下&#xff1a;将jdk版本调为11,springBoot版本降低为2.7.12。然后clean&#xff0c;再package重新打包。最后重新启动项目

存储论——经济订货批量的R实现

存储论又称库存理论&#xff0c;是运筹学中发展较早的分支。早在 1915 年&#xff0c;哈李斯&#xff08;F.Harris&#xff09;针对银行货币的储备问题进行了详细的研究&#xff0c;建立了一个确定性的存储费用模型&#xff0c;并求得了最佳批量公式。1934 年威尔逊&#xff08…

第五章 HL7 架构和可用工具 - 创建新的自定义架构

文章目录 第五章 HL7 架构和可用工具 - 创建新的自定义架构创建新的自定义架构定义新段 第五章 HL7 架构和可用工具 - 创建新的自定义架构 创建新的自定义架构 要从管理门户启动自定义架构编辑器&#xff0c;请从主页选择互操作性 > 互操作 > HL7 v2.x >HL7 v2.x 架…

单机和集群以及分布式的浅析

假设一个大系统分为A、B、C、D、E五个模块&#xff0c;也可以认为是五个基本的服务&#xff0c;该系统靠这五个模块协同工作&#xff0c;共同为用户提供服务。 单机 单机&#xff1a;显然&#xff0c;单机表名该系统完完全全的部署在该台机器上&#xff0c;拥有完整的服务&am…

算法38:反转链表【O(n)方案】

一、需求 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1] 示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1] 示例3&#xff…

监听镜像版本变化触发 GitOps工作流

文章目录 前言工作流总览安装和配置 ArgoCD Image Updater创建 Image Pull Secret&#xff08;可选&#xff09;创建 Helm Chart 仓库创建 ArgoCD Application删除旧应用&#xff08;可选&#xff09;配置仓库访问权限创建 ArgoCD 应用 体验 GitOps 工作流总结 前言 在【GitOps…

FastDeploy的方式在OK3588上部署yolov7-- C++

FastDeploy介绍 ⚡️FastDeploy是一款全场景、易用灵活、极致高效的AI推理部署工具&#xff0c; 支持云边端部署。提供超过 &#x1f525;160 Text&#xff0c;Vision&#xff0c; Speech和跨模态模型&#x1f4e6;开箱即用的部署体验&#xff0c;并实现&#x1f51a;端到端的…

附录1-将uni-app运行到微信开发者工具

目录 1 在manifest.json写入AppID 2 配置微信开发者工具的安装路径 3 微信开发者工具的安全设置 4 运行 5 修改一些配置项 1 在manifest.json写入AppID 2 配置微信开发者工具的安装路径 如果你忘了安装在哪里了&#xff0c;可以右键快捷方式看一下属性 在运行设置…

邻接矩阵与邻接表

文章目录 0 前面几种数据结构的回顾1 图1.1 图的定义1.2 常见术语1.3 图的抽象数据类型定义1.4 表示一个图1.4.1 邻接矩阵表示法1.4.2 邻接表 1.5 图的构建1.5.1 邻接矩阵法1.5.2 邻接表法 0 前面几种数据结构的回顾 1 图 1.1 图的定义 图&#xff1a; G (V,E) // Graph (V…

Moke 一百万条 Mysql 的数据

文章目录 前言创建数据库创建表结构生成数据 前言 想研究一下&#xff0c;数据量大的情况下&#xff0c;如何优化前端分页&#xff0c;所以需要 Moke 一些数据 创建数据库 在 Mysql的基础上&#xff0c;可以写个语句执行 CREATE DATABASE test_oneMillion; USE test_oneMi…

Jmeter —— 录制脚本

1. 第一步&#xff1a;添加http代理服务器&#xff0c;在测试计划--》添加--》非测试元件--》http代理服务器 2. 第二步&#xff1a;添加线程组&#xff08;这个线程组是用来放录制的脚本&#xff0c;不添加也可以&#xff0c;就直接放在代理服务器下&#xff09; 测试计划--》…

【Linux】sed修改文件指定内容

sed修改文件指定内容&#xff1a; 参考&#xff1a;(5条消息) Linux系列讲解 —— 【cat echo sed】操作读写文件内容_shell命令修改文件内容_星际工程师的博客-CSDN博客