OpenCV读取tensorflow神经网络模型:SavedModel格式转为frozen graph的方法

news2024/11/15 7:05:50

  本文介绍基于Pythontensorflow库,将tensorflowkeras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++ 等其他语言中将其打开的方法。

  如果我们需要训练使用一个神经网络模型,一般情况下都是首先借助Python语言中完善的神经网络模型API对其加以训练,训练完毕后在C++Java等语言环境下高效、快速地使用它。最近,就需要在C++ 中打开、使用几个前期已经在Pythontensorflow库中训练好的神经网络模型。但是,由于训练模型时使用的是2.X版本的tensorflow库(且用的是keras的框架),所以训练模型后保存的是SavedModel格式的神经网络模型文件——就是包含3.pb格式文件,以及assetsvariables2个文件夹那种形式的模型;如下图所示。

  而在C++ 中读取神经网络模型,首先是可以借助tensorflow库的C++ API来实现,但是这种方法非常复杂——完整的TensorFlow C++ API部署起来非常困难——需要系统盘至少40 G50 G的剩余空间、动辄0.5 h1 h的编译时长,经常需要花费一周的时间才可以配置成功;所以如果仅仅是需要在C++ 中读取已经训练好的神经网络模型的话,没必要花费这么大功夫去配置TensorFlow C++ API。而同时,基于OpenCV库,我们则可以在简单、快速地配置完其环境后,就基于1个函数对训练好的tensorflow库神经网络模型加以读取、使用。这里如果大家需要配置C++ 环境的OpenCV库,可以参考文章C++计算机视觉库OpenCV在Visual Studio 2022的配置方法(https://blog.csdn.net/zhebushibiaoshifu/article/details/128260507)。

  但是,还有一个问题——OpenCV库自身目前仅支持读取tensorflowfrozen graph格式的神经网络模型,不支持读取SavedModel格式的模型。因此,如果希望基于OpenCV库读取tensorflowSavedModel格式的模型,就需要首先将其转换为frozen graph格式;那么,本文就介绍一下这个操作的具体方法,并给出2种实现这一转换功能的Python代码。

  首先,本文神经网络模型格式转换的代码是基于Python环境中tensorflow库实现的,因此需要配置好这一个库(大家都已经需要转换神经网络模型的格式了,那Python环境中tensorflow库肯定早已经配置好了);如果没有配置,可以参考文章Anaconda配置Python新版本tensorflow库(CPU、GPU通用)的方法(https://blog.csdn.net/zhebushibiaoshifu/article/details/129285815)。

  第1种代码如下。

# -*- coding: utf-8 -*-
"""
Created on Sat Mar  9 14:31:18 2024

@author: fkxxgis
"""

import tensorflow as tf
from tensorflow.keras import models
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

# model_save_model = tf.saved_model.load("F:/Data_Reflectance_Rec/model/model_blue/")
model_save_model = models.load_model("F:/Data_Reflectance_Rec/model/model_blue/")

signatures = model_save_model.signatures["serving_default"]
graph = tf.function(lambda x: model_save_model(x))
graph = graph.get_concrete_function(tf.TensorSpec(signatures.inputs[0].shape.as_list(), signatures.inputs[0].dtype.name))
frozen_variable = convert_variables_to_constants_v2(graph)
frozen_variable.graph.as_graph_def();

tf.io.write_graph(graph_or_graph_def = frozen_variable.graph, 
                  logdir = "F:/Data_Reflectance_Rec/model/model_blue_new", 
                  name = "frozen_graph.pb", 
                  as_text = False)
# tf.io.write_graph(graph_or_graph_def = frozen_variable.graph, 
#                   logdir = "F:/Data_Reflectance_Rec/model/model_blue_new", 
#                   name = "frozen_graph.pbtxt", 
#                   as_text = True)

  其中,我们首先需要导入对应的Python模块和convert_variables_to_constants_v2()函数。

  随后,加载我们待转换的、SavedModel格式的tensorflow神经网络模型。这里需要注意,我写了2句不同的代码来加载初始的模型——其中,如果用第1句代码加载模型,倒也可以不报错地运行完成上述代码,但是等到用C++ 环境的OpenCV库读取这个转换后的模型时,会出现Microsoft C++ 异常: cv::Exception字样的报错,如下图所示;而如果用第2句代码加载模型,就没有问题。之所以会这样,应该是因为我当初训练这个神经网络模型时,用的是tensorflowkeras模块的Model,所以导致加载模型时,就不能用传统的加载SavedModel格式模型的方法了(可能是这样)。

  接下来,我们从初始模型中获取其签名tensorflow库中的签名(Signature),是用于定义模型输入、输出的一种机制——其定义了模型接受的输入参数和返回的输出结果的名称、数据类型和形状等信息;这个默认签名为serving_default,我们这里获取这个默认的签名即可。

  接下来,这个graph = tf.function(lambda x: model_save_model(x))表示将模型封装在tensorflow的图函数中;随后,get_concrete_function()获取具体函数并指定输入张量的形状和数据类型。说实话,这里的2行代码我也搞不太清楚具体详细含义是什么——但大体上,这些内容应该是tensorflow1.X版本中的一些操作与名词(因为frozen graph格式的模型本来就是tensorflow1.X版本中用的,而SavedModel格式则是2.X版本中常用的)。

  再次,通过convert_variables_to_constants_v2()函数,将图中的变量转换为常量,并基于as_graph_def()定义1个冻结图。

  最后,就可以通过tf.io.write_graph()函数,将冻结图写入指定的目录中,输出文件名为frozen_graph.pbas_text = False表示以二进制格式保存这个模型(如果不加这个参数,就相当于成了.pbtxt文件了,导致后续用C++环境的OpenCV库还是读取不了这个模型)。代码末尾,还有一段注释的部分——如果取消注释,将以文本格式保存冻结图,也就是.pbtxt文件。因为我们只要.pb文件就够了,所以就不需要这段代码了。

  执行上述代码,在结果文件夹中,我们将看到1.pb格式的神经网络模型结果文件,如下图所示。

  接下来,在C++Python等语言的OpenCV库中,我们都可以基于cv::dnn::readNetFromTensorflow()这个函数,来读取我们的神经网络模型了。

  除此之外,再给出另一个版本的转换代码;这个代码其实和前述代码的含义差不多,如果前述代码不能执行,大家可以再尝试尝试下面这个。

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

loaded = tf.saved_model.load('F:/Data_Reflectance_Rec/model/model_nir/')
infer = loaded.signatures['serving_default']

f = tf.function(infer).get_concrete_function(tf.TensorSpec(infer.inputs[0].shape.as_list(), dtype=tf.float32))
f2 = convert_variables_to_constants_v2(f)
graph_def = f2.graph.as_graph_def()

with tf.io.gfile.GFile('frozen_graph.pb', 'wb') as f:
    f.write(graph_def.SerializeToString())

  至此,大功告成。

欢迎关注:疯狂学习GIS

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1516141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何保证缓存与数据库的双写一致性?(史上最全)

目录 一、面试官心理分析 1. Cache Aside Pattern 2. 最初级的缓存不一致问题及解决方案 3. 比较复杂的数据不一致问题分析 一、面试官心理分析 你只要用缓存,就可能会涉及到缓存与数据库双存储双写,你只要是双写,就一定会有数据一致性的问…

django-comment-migrate 模型注释的使用

django-comment-migrate 的使用 django-comment-migrate 是一个 Django 应用,用于将模型注释自动迁移到数据库表注释中。它可以帮助您保持数据库表注释与模型定义的一致性,并提高代码的可读性。 安装 要使用 django-comment-migrate,您需要…

使用大型语言模型进行实体提取

原文地址:Using A Large Language Model For Entity Extraction LLM 能否比传统 NLP 方法更好地提取实体? 2022 年 7 月 12 日 Large Language Models for Generative Information Extraction: A Survey 实体简介 使用Co:here大型语言模型。 实体可以被视…

Qt学习--自定义命名空间

假设我们要创建一个命名空间来包含与圆形相关的功能。我们可以命名这个命名空间为 Cir : 在这个头文件中,我们定义了一个名为 Cir 的命名空间,其中包含了计算圆的面积和周长的函数,以及 圆周率常量 PI 。 使用命名空间 在…

软考73-上午题-【面向对象技术2-UML】-UML中的图4

一、构件图(组件图) 1-1、构件图的定义 展现了,一组构件之间的组织和依赖。 构件图专注于系统的静态实现图。 构件图与类图相关,通常把构件映射为一个、多个类、接口、协作。 【回顾】: 类图展示了一组对象、接口、…

学生时期学习资源同步-1 第一学期结业考试题4

原创作者:田超凡(程序员田宝宝) 版权所有,引用请注明原作者,严禁复制转载

【node版本问题】运行项目报错 PostCSS received undefined instead of CSS string

最近该项目没有做任何修改,今天运行突然跑不起来报错了 PostCSS received undefined instead of CSS string 【原因】突然想起来期间有换过 node 版本为 16.17.1 【解决】将 node 版本换回之前的 14.18.0 就可以了

【Java - 框架 - Mybatis】(02) SpringBoot整合Mybatis操作Mysql - 快速上手

“SpringBoot"整合"Mybatis"操作"Mysql” - 快速上手; 环境 Java版本"1.8.0_202";Spring Boot版本"2.5.9";Windows 11 专业版_22621.2428;IntelliJ IDEA 2021.1.3(Ultimate Edition)&a…

VC++ BitBlt函数学习

1 BitBlt BitBlt函数执行与像素矩形相对应的颜色数据的位块传输,从指定的源设备上下文传输到目标设备上下文。 把位块从一个DC传到另一个DC; VC单文档工程,写3句代码如下; void CDeskdcView::OnDraw(CDC* pDC) {CDeskdcDoc* pDoc = GetDocument();ASSERT_VALID(pDoc);//…

mac输入su命令报错如何重置密码

diannao1xiejiandeMacBook-Air ~ % su Password: su: Sorry输入 sudo passwd 命令重置密码即可。

Seata 2.x 系列【10】回滚日志表 undo_log

有道无术,术尚可求,有术无道,止于术。 本系列Seata 版本 2.0.0 本系列Spring Boot 版本 3.2.0 本系列Spring Cloud 版本 2023.0.0 源码地址:https://gitee.com/pearl-organization/study-seata-demo 文章目录 1. 概述2. 表语句…

[抽象]工厂模式([Abstract] Factory)——创建型模式

[抽象]工厂模式——创建型模式 什么是抽象工厂? 抽象工厂模式是一种创建型设计模式,让你能够保证在客户端程序中创建一系列有依赖的对象组时,无需关心这些对象的类型。 具体来说: 对象的创建与使用分离: 抽象工厂模…

手把手带你实现大模型检索增强生成RAG(一)——数据清洗准备

首先,需要整理一大堆可以用来检索的文本数据,这些数据可以是网页、论文、报告、电影脚本、电视剧脚本等等。这些数据可以是原始的文本数据,也可以是经过清洗、处理过的文本数据。 作为IT打工仔,我从二道贩子处购入一本软考秘籍。…

【Numpy】基础学习:一文了解np.expand_dims的作用、用法

【Numpy】基础学习:一文了解np.expand_dims的作用、用法 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyTorch零基础入门教程👈 希望…

电脑闹钟软件哪个好用,电脑闹钟软件推荐助你高效工作

在这个快节奏的社会中,时间对于每个人来说都是宝贵的。如何有效利用时间,提高工作效率成为了现代人所面临的重要问题之一。而电脑闹钟软件作为一款实用的工作助手,可以提醒我们按时完成任务,规划好时间,使我们更加高效…

STC89C52单片机 启动!!!(一)

跑马灯实现 直接上代码 #include<regx52.h> sbit D1P2^0; sbit D2P2^1; sbit D3P2^2; sbit D4P2^3; sbit D5P2^4; sbit D6P2^5; sbit D7P2^6; sbit D8P2^7; void delay(int num){while(num--){} } void led_running(){//从第1盏灯到第8盏灯依次点亮D10;delay(40000);D2…

解决方案:淘宝NPM镜像证书到期导致的安装Node失败

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

宏集案例 | 风电滑动轴承齿轮箱内多点温度采集与处理

前言 风力发电机组中的滑动轴承齿轮箱作为关键的传动装置&#xff0c;承担着将风能转化为电能的重要角色。齿轮箱内多点温度的实时监测可以有效地预防设备故障和性能下降。实时监测齿轮箱内多点温度可以有效地预防设备故障和性能下降。 为了确保风力发电机组的安全稳定运行&a…

HarmonyOS NEXT应用开发之深色模式适配

介绍 本示例介绍在开发应用以适应深色模式时&#xff0c;对于深色和浅色模式的适配方案&#xff0c;采取了多种策略如下&#xff1a; 固定属性适配&#xff1a;对于部分组件的颜色属性&#xff0c;如背景色或字体颜色&#xff0c;若保持不变&#xff0c;可直接设定固定色值或…

数组名结合指针的面试题的讲解

笔试题 第一题&#xff1a; 已知条件&#xff1a; 已知p为结构体指针变量&#xff0c;值为0x100000&#xff0c;并且结构体的大小为20字节&#xff0c;并且打印格式均为%p&#xff0c;%p不会在乎正负数&#xff0c;它会以补码的形式直接打印&#xff0c;0x1为16进制的1。 第一问…