c++调用tf.keras的模型

news2024/10/5 19:14:12

环境:

  • ubuntu 20.04

  • python 3.8

  • tensorflow-gpu 2.4.0

  • 显卡 nvidia rtx A6000 驱动 495.29.05

  • cuda 11.5

  • cudnn 8.3.0

  • tensorRT 8.4

1.将keras保存的h5模型转成darknet的weight,然后用opencv加载

  cv::dnn::Net net = cv::dnn::readNetFromDarknet("load_model.weight");

卷积网络和全连接可以加载成功,lstm网络不支持,opencv源码中没有lstm网络:

opencv源码

2.将keras模型转成onnx,两种方式,keras2onnx和tensorflow2onnx:

  1. keras-onnx
    keras-onnx已经停止维护了
    keras2onnx has been tested on Python 3.5 - 3.8, with tensorflow 1.x/2.0 - 2.2 (CI build). It does not support Python 2.x.
import keras2onnx
import onnx
import tensorflow as tf
from tensorflow.keras.models import load_model

if "__main__" == __name__:
    model = load_model('./model.h5')
    model.summary()
    onnx_model = keras2onnx.convert_keras(model, model.name)
    temp_model_file = 'model.onnx'
    onnx.save_model(onnx_model, temp_model_file)
  1. tensorflow-onnx
    在这里插入图片描述
  • 安装
pip install -U tf2onnx
import tensorflow as tf
from tensorflow.keras.models import load_model

if "__main__" == __name__:
    model = load_model('./model.h5')
    tf.saved_model.save(model, './model')


# python -m tf2onnx.convert --saved-model ./model/ --output model.onnx

先load预训练的h5模型,用tf.saved_model.save转化成pb格式,再执行

python -m tf2onnx.convert --saved-model ./model/ --output model.onnx

转成onnx格式
可以在https://netron.app/查看你的模型架构:
在这里插入图片描述

3.将onnx转成tensorRT的格式:

在这里插入图片描述

  • 在https://developer.nvidia.com/nvidia-tensorrt-8x-download下载并安装8.4版本:
sudo dpkg -i nv-tensorrt-repo-ubuntu2004-cuda11.6-trt8.4.0.6-ea-20220212_1-1_amd64.deb
sudo apt-key add /var/nv-tensorrt-repo-ubuntu2004-cuda11.6-trt8.4.0.6-ea-20220212/7fa2af80.pub
sudo apt-get update
sudo apt-get install tensorrt
  • 测试tensorrt是否安装成功
    准备MNIST的.PGM图像,将10张0~9的命名为0.PGM,1.PGM…的
    MNIST图像拷贝至data/mnist文件夹
sudo cp *.pgm /usr/src/tensorrt/data/mnist
cd /usr/src/tensorrt/samples/sampleMNIST
sudo make
cd ../../bin/
./sample_mnist

在这里插入图片描述

  • 查看tensorrt版本
dpkg -l | grep TensorRT
  • 安装tensorRT如果报错
    • libdvd-pkg: apt-get check failed, you may have broken packages. Aborting…
    • sudo dpkg-reconfigure libdvd-pkg
    • 参考https://blog.csdn.net/qq_22945165/article/details/87653208
将tensorRT加入环境变量

export PATH=$PATH:/usr/src/tensorrt/bin

trtexec --onnx=model.onnx --saveEngine=model.trt
  • 也可参考https://github.com/onnx/onnx-tensorrt这种方式
onnx2trt my_model.onnx -o my_engine.trt

4.模型加载

参考https://github.com/shouxieai/tensorRT_Pro/

static void test_load_onnx(){
    /** 加载编译好的引擎 **/
    auto infer = TRT::load_infer("./model.trt");
    if(infer == nullptr){
        INFOE("Engine is nullptr");
        return;
    }

    /** 设置输入的值 **/
    infer->input(0)->set_to(1.0f);

    /** 引擎进行推理 **/
    infer->forward();

    /** 取出引擎的输出并打印 **/
    auto out = infer->output(0);
    INFO("out.shape = %s", out->shape_string());
    for(int i = 0; i < out->channel(); ++i)
        INFO("%f", out->at<float>(0, i));
}
static void test_load_onnx(){
    /** 加载编译好的引擎 **/
    auto infer = TRT::load_infer("/home/user/my_python_test/net_float16/float32-model-2400/actor_model_32.trt");
    if(infer == nullptr){
        INFOE("Engine is nullptr");
        return;
    }
    infer->print();

    double timeAll = 0;
    int inferTimes = 1000;
    for(int i = 0; i < inferTimes; i++){        
        /** 设置输入的值 **/
        // infer->input(0)->set_to(1.0f);
        std::vector<float> input_ = {0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    3., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    4., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    5., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    6., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    7., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    8., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                                    9., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.};

        memcpy(infer->input(0)->cpu<float>(), input_.data(), sizeof(float) * input_.size());
        std::cout<<"############### input.data(): "<<input_.data()<<std::endl;
        std::cout<<"############### input.size(): "<<input_.size()<<std::endl;
        std::cout<<"############### infer->input(0): "<<infer->input(0)<<std::endl;
        std::cout<<"############### infer->input(0)->count(): "<<infer->input(0)->count()<<std::endl;
        std::cout<<"############### infer->input(0)->cpu<float>(0): "<<infer->input(0)->cpu<float>()<<std::endl;
        std::cout<<"############### infer->input(0)->at<float>(): "<<infer->input(0)->at<float>(0, 3)<<std::endl;
        
        clock_t start = clock();
        /** 引擎进行推理 **/
        infer->forward();
        clock_t end = clock();

        /** 取出引擎的输出并打印 **/
        auto out_0 = infer->output(0);
        auto out_1 = infer->output(1);
        auto out_2 = infer->output(2);
        auto out_3 = infer->output(3);

        INFO("out_0.shape = %s", out_0->shape_string());
        for(int i = 0; i < out_0->channel(); ++i)
            INFO("%f", out_0->at<float>(0, i));

        INFO("out_1.shape = %s", out_1->shape_string());
        for(int i = 0; i < out_1->channel(); ++i)
            INFO("%f", out_1->at<float>(0, i));

        INFO("out_2.shape = %s", out_2->shape_string());
        for(int i = 0; i < out_2->channel(); ++i)
            INFO("%f", out_2->at<float>(0, i));

        INFO("out_3.shape = %s", out_3->shape_string());
        for(int i = 0; i < out_3->channel(); ++i)
            INFO("%f", out_3->at<float>(0, i));
        
        double spend_time = (double)(end - start)/CLOCKS_PER_SEC*1000;
        std::cout << "############## Total inference time is " << spend_time << "ms" << std::endl;
        
        if(i>0)
            timeAll += spend_time;
    }
    std::cout << "############## average inference time is " << timeAll/inferTimes << "ms" << std::endl;


}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/23491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

链表中快慢指针的应用

目录 一、链表的中间结点 二、回文链表 三、链表中倒数第K个结点 四、删除链表的倒数第n个结点 一、链表的中间结点 给定一个头结点为 head 的非空单链表&#xff0c;返回链表的中间结点。 如果有两个中间结点&#xff0c;则返回第二个中间结点。 先设置两个low和fast都指…

【MySQL】测试题03

文章目录1、创建数据库2、使用数据库3、创建数据表【3.1】创建学生信息表Student【3.2】创建课程信息表Course【3.3】创建教师信息表Teacher【3.4】创建成绩信息表Score4、添加数据【4.1】向学生student表添加数据【4.2】向课程course表添加数据【4.3】向教师信息teacher表添加…

【动手学深度学习】softmax回归的从零开始实现(PyTorch版本)(含源代码)

目录&#xff1a;softmax回归的从零开始实现一、理论基础1.1 前言1.2 分类问题1.3 网络架构1.4 全连接层的参数开销1.5 softmax运算1.6 小批量样本的矢量化1.7 损失函数1.7.1 对数似然1.7.2 softmax及其导数1.7.3 交叉熵损失1.8 信息论基础1.8.1 熵1.8.2 信息量1.8.3 重新审视交…

19 02-检索满足客户端定义的状态掩码的DTC列表

诊断协议那些事儿 诊断协议那些事儿专栏系列文章&#xff0c;19服务作为UDS中子功能最多的服务&#xff0c;一共有28种子功能&#xff0c;本文将介绍常用的19 02服务&#xff1a;根据状态掩码读取DTC列表。 关联文章&#xff1a; 19服务List 19 01-通过状态掩码读取DTC数目 …

详细教程。2022年滁州市明光市、来安县等各地区高新技术企业申报

安徽省大力鼓励企业申报高新技术企业&#xff0c;于高企申报也有很多奖补。滁州市企业申报奖补政策发布&#xff0c;企业可以根据自身情况申请奖补&#xff0c;奖补金额为10万元至30万元不等&#xff0c;明光市&#xff0c;凤阳县等各地区奖补申请可以通过市级机关办理。 下面小…

跟艾文学编程《Python数据可视化》(01)基于Plotly的动态可视化绘图

作者&#xff1a;艾文&#xff0c;计算机硕士学位&#xff0c;企业内训讲师和金牌面试官&#xff0c;公司资深算法专家&#xff0c;现就职BAT一线大厂。邮箱&#xff1a;1121025745qq.com博客&#xff1a;https://wenjie.blog.csdn.net/内容&#xff1a;跟艾文学编程《Python数…

2022-11-21 mysql列存储引擎-架构实现缺陷梳理-P2

摘要: 收集现有代码的糟糕实现&#xff0c;前事不忘后事之师&#xff0c;把这些烂东西定死在耻辱柱上以免再次发生 糟糕的设计: 一. DGMaterializedIterator::GetNextPackrow 函数实现: int DimensionGroupMaterialized::DGMaterializedIterator::GetNextPackrow(int dim, int…

【Linux系统】第一篇:基础指令篇

文章目录一、Linux中的文件二、Linux用户三、Linux基本指令ls指令pwd命令cd指令touch指令mkdir指令rmdir指令rm 指令man指令cp指令mv指令cat指令tac指令more指令less指令head指令tail指令管道重定向date指令cal指令find指令which指令alias指令whereis指令grep指令wc指令sort指令…

Node的web编程(二)

一、JSON数据 1、定义 JavaScript Object Notation&#xff0c;是一种轻量级的前后端数据交换的格式(数据格式)。 2、特点 &#xff08;1&#xff09;容易阅读和编写 &#xff08;2&#xff09;语言无关性 &#xff08;3&#xff09;便于编译、解析 3、语法要求 &#…

Mac m1配置flutter开发环境

Mac m1配置flutter开发环境 文章目录Mac m1配置flutter开发环境一、下载Android Studio二、下载flutter sdk三、新建flutter project四、使用在线环境进行Flutter开发Dart在线运行环境Flutter在线运行环境一、下载Android Studio 进入官网下载&#xff0c;选择苹果芯片版本。 …

【Spring(三)】熟练掌握Spring的使用

有关Spring的所有文章都收录于我的专栏&#xff1a;&#x1f449;Spring&#x1f448; 目录 一、前言 二、通过静态工厂获取对象 三、通过实例工厂获取对象 四、通过FactoryBean获取对象 五、Bean配置信息重用 六、Bean创建顺序 七、Bean对象的单例和多例 八、Bean的生命周期 九…

Weblogic SSRF 漏洞(CVE-2014-4210)分析

Weblogic SSRF 漏洞是一个比较经典的SSRF 漏洞案例&#xff0c;该漏洞存在于 http://127.0.0.1:7001/uddiexplorer/SearchPublicRegistries. jsp 页面中&#xff0c;如图 1-1 所示图 1-1 Weblogic SSRF 漏洞Weblogic SSRF 漏洞可以通过向服务端发送以下请求参数进行触发&#x…

ARFoundation系列讲解 - 70 HumanBodyTracking3D

---------------------------------------------- 视频教程来源于网络,侵权必删! --------------------------------------------- 一、简介 HumanBodyTracking3D(身体跟踪3D)案例,当设备检查到人体时,会返回检测到人体关节点的3D空间位置(需要在iOS 13或更高版本的A12…

瞪羚优化算法(Matlab代码实现)

&#x1f468;‍&#x1f393;个人主页&#xff1a;研学社的博客 &#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜…

Java集合类——ArrayList(扩容机制)

线性表 线性表是n个相同类型元素的有限序列&#xff0c;逻辑上连续物理上不一定是连续的&#xff0c;存储结构上分为顺序存储和链式存储&#xff0c;常见的线性表有&#xff1a;顺序表&#xff0c;链表&#xff0c;栈&#xff0c;队列…… ArrayList 数据结构 ArrayList&am…

赋值运算符重载,取地址及const取地址操作符重载

赋值运算符重载1.运算符重载2.赋值运算符重载3.取地址及const取地址操作符重载如果一个类中什么成员都没有&#xff0c;那么该类简称为空类。而空类中其实并不是真的什么都没有&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以下6个默认成员函数。构造函数&…

同花顺_代码解析_技术指标_V,W

本文通过对同花顺中现成代码进行解析&#xff0c;用以了解同花顺相关策略设计的思想 目录 V&R VMA VMACD VOSC VPT VR VRFS VRSI VSTD W&R WVAD V&R 波动区间 用来衡量该股的市场波动风险.即95%的概率波动区间. 行号 1 n -> 250 2 x -> 收…

【考研英语语法】状语从句精讲

一、状语从句概述 &#xff08;一&#xff09;状语从句的含义 状语从句&#xff0c;指的就是一个句子作状语&#xff0c;表达“描述性的信息”&#xff0c;补充说明另一个句子&#xff08;主句&#xff09;。描述性的信息有很多种&#xff0c;可以描述时间、地点、原因、结果…

Web大学生网页成品HTML+CSS音乐吧 7页

⛵ 源码获取 文末联系 ✈ Web前端开发技术 描述 网页设计题材&#xff0c;DIVCSS 布局制作,HTMLCSS网页设计期末课程大作业 | 音乐网页设计 | 仿网易云音乐 | 各大音乐官网网页 | 明星音乐演唱会主题 | 爵士乐音乐 | 民族音乐 | 等网站的设计与制作 | HTML期末大学生网页设计作…

Django开发笔记

Django开发笔记Django学习1. Django安装path()函数2. 创建项目2.1 终端命令创建2.2 pycharm创建项目3. App4. 创建页面4.1 再写一个页面4.2 模板---Templates4.3 静态文件4.3.1 创建static目录4.3.2 静态文件的引用5. 模板语法案例&#xff1a;伪联通新闻中心6. 请求和响应案例…