RealBasicVSR模型转成ONNX以及用c++推理

news2024/11/27 12:54:33

文章目录

    • 安装RealBasicVSR的环境
      • 1. 新建一个conda环境
      • 2. 安装pytorch(官网上选择合适的版本)版本太低会有问题
      • 3. 安装 mim 和 mmcv-full
      • 4. 安装 mmedit
    • 下载RealBasicVSR源码
    • 下载模型文件
    • 写一个模型转换的脚步
    • 测试生成的模型

安装RealBasicVSR的环境

1. 新建一个conda环境

conda create -n RealBasicVSR_to_ONNX  python=3.8 -y
conda activate RealBasicVSR_to_ONNX

2. 安装pytorch(官网上选择合适的版本)版本太低会有问题

conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge

3. 安装 mim 和 mmcv-full

pip install openmim
mim install mmcv-full

4. 安装 mmedit

pip install mmedit

下载RealBasicVSR源码

git clone https://github.com/ckkelvinchan/RealBasicVSR.git

下载模型文件

模型文件下载 (Dropbox / Google Drive / OneDrive) ,随便选一个渠道下载就行

cd RealBasicVSR
#然后新建文件夹model
将模型文件放在model文件夹下

写一个模型转换的脚步

import cv2
import mmcv
import numpy as np
import torch
from mmcv.runner import load_checkpoint
from mmedit.core import tensor2img

from realbasicvsr.models.builder import build_model

def init_model(config, checkpoint=None):
    if isinstance(config, str):
        config = mmcv.Config.fromfile(config)
    elif not isinstance(config, mmcv.Config):
        raise TypeError('config must be a filename or Config object, '
                        f'but got {type(config)}')
    config.model.pretrained = None
    config.test_cfg.metrics = None
    model = build_model(config.model, test_cfg=config.test_cfg)
    if checkpoint is not None:
        checkpoint = load_checkpoint(model, checkpoint)

    model.cfg = config  # save the config in the model for convenience
    model.eval()
    return model


def main():
    model = init_model("./configs/realbasicvsr_x4.py","./model/RealBasicVSR_x4.pth")
    src = cv2.imread("./data/img/test1.png")
    src = torch.from_numpy(src / 255.).permute(2, 0, 1).float()
    src = src.unsqueeze(0)
    input_arg = torch.stack([src], dim=1)

    torch.onnx.export(model,
        input_arg,
        'realbasicvsr.onnx',
        training= True,
        input_names= ['input'],
        output_names=['output'],
        opset_version=11,
        dynamic_axes={'input' : {0 : 'batch_size', 3 : 'w', 4 : 'h'}, 'output' : {0 : 'batch_size', 3 : 'dstw', 4 : 'dsth'}})


if __name__ == '__main__':
    main()

这里报错:

ValueError: SRGAN model does not support `forward_train` function.

直接将这个test_mode默认值改为Ture,让程序能走下去就行了。
在这里插入图片描述

测试生成的模型

这里已经得到了 realbasicvsr.onnx 模型文件了.

import onnxruntime as ort
import numpy as np
import onnx
import cv2


def main():
    onnx_model = onnx.load_model("./realbasicvsr.onnx")
    onnxstrongmodel = onnx_model.SerializeToString()
    sess = ort.InferenceSession(onnxstrongmodel)
    
    providers = ['CPUExecutionProvider']
    options = [{}]
    is_cuda_available = ort.get_device() == 'GPU'
    if is_cuda_available:
        providers.insert(0, 'CUDAExecutionProvider')
        options.insert(0, {'device_id': 0})
    sess.set_providers(providers, options)

    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[1].name
    print(sess.get_inputs()[0])
    print(sess.get_outputs()[0])
    print(sess.get_outputs()[0].shape)
    print(sess.get_inputs()[0].shape)

    img = cv2.imread("./data/img/test1.png")
    img = np.expand_dims((img/255.0).astype(np.float32).transpose(2,0,1), axis=0)
    imgs = np.array([img])
    print(imgs.shape)
    print(imgs)
    output = sess.run([output_name], {input_name : imgs})

    print(output)

    print(output[0].shape)
    output = np.clip(output, 0, 1)

    res = output[0][0][0].transpose(1, 2, 0)
    cv2.imwrite("./testout.png", (res * 255).astype(np.uint8))

if __name__ == '__main__':
    main()

至此模型转换部分就成功完成了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/383054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第一次做性能测试,有亿点点紧张

确认需求 确定性能需求和性能测试方案、需要确定性能测试范围(覆盖哪些场景)、性能测试策略、并发用户数和加压方式。 时间、人员、任务的分配安排,一般在总体测试计划中会预留性能测试的时间。性能测试方案是开展性能测试前的核心内容&…

时序预测 | MATLAB实现Rmsprop算法优化LSTM长短期记忆神经网络时间序列多步预测(滚动预测未来,多指标,含验证Loss曲线)

时序预测 | MATLAB实现Rmsprop算法优化LSTM长短期记忆神经网络时间序列多步预测(滚动预测未来,多指标,含训练和验证Loss曲线) 目录 时序预测 | MATLAB实现Rmsprop算法优化LSTM长短期记忆神经网络时间序列多步预测(滚动预测未来,多指标,含训练和验证Loss曲线)效果一览基本描…

Flume使用入门

目录 一. Flume简单介绍 1. Agent 2. Source 3. Sink 4. Channel 5. Event 二. 环境安装 1. 创建日志目录 2. 修改日志配置文件 3.修改运行堆内存 4. 确定日志打印的位置 5. 修改flume使用内存 内存调大 三. 校验flume 1. 安装netcat工具和net-tools工具 2. 判…

MySQL服务端与客户端之间的连接过程

服务器程序和客户端程序本质上是两个进程,所以连接过程的本质就是两个进程直接的通信,MySQL支持下面三种方式来进行通信。一:TCP/IP数据库服务器进程和客户端进程可能运行在不同的主机中,需要通过网络来进行通信二:命名…

初识scrapy

认识scrapyscrapy是一个为了爬取网站数据,提取结构性数据而编写的应用框架,我们只需实现少量的代码,就能实现数据的快速抓取scrapy使用了Twisted异步网络架构,可以加快下载速度 pip install twisted安装:pip install s…

进程控制(详解)

进程控制上篇文章介绍了进程的相关概念,形如进程的内核数据结构task_struct 、进程是如何被操作系统管理的、进程的查看、进程标识符、进程状态、进程优先级、已经环境变量和进程地址空间等知识点; 本篇文章接着上篇文章继续对进程的控制进行展开&#x…

Spark 内存运用

RDD Cache 当同一个 RDD 被引用多次时,就可以考虑进行 Cache,从而提升作业的执行效率 // 用 cache 对 wordCounts 加缓存 wordCounts.cache // cache 后要用 action 才能触发 RDD 内存物化 wordCounts.count// 自定义 Cache 的存储介质、存储形式、副本…

【OJ比赛日历】快周末了,不来一场比赛吗? #03.04-03.10 #12场

CompHub 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…)比赛。本账号同时会推送最新的比赛消息,欢迎关注!更多比赛信息见 CompHub主页 或 点击文末阅读原文以下信息仅供参考,以比赛官网为准目录2023-03-04&…

从100%进口到自主可控,从600块降到10块,中科院攻克重要芯片

前言 2月28日,“20多位中科院专家把芯片价格打到10块”冲上微博热搜,据河南省官媒大象新闻报道,热搜中提到的中科院专家所在企业为全球最大的PLC分路器芯片制造商仕佳光子,坐落于河南鹤壁。 为实现芯片技术自主可控自立自强&#…

根据栅格数据的范围和像元大小生成等比例的矢量数据

为啥会有这么一个需求呢,后面要是继续写的话会详细说,首先是有一个栅格数据,比如这样的:我的目标是这样的(矢量),就是这样,上面的栅格数据像大小是2*2的,直接上代码&…

【JAVA程序设计】【C00109】基于SSM(非maven)的员工工资管理系统

基于SSM(非maven)的员工工资管理系统项目简介项目获取开发环境项目技术运行截图项目简介 基于ssm框架非maven开发的企业工资管理系统共分为二个角色:系统管理员、员工 管理员角色包含以下功能: 系统后台登陆、管理员管理、员工信…

Linux基础命令-nice调整进程的优先级

文章目录 Nice 命令介绍 语法格式 常用参数 参考实例 1 调整bash的优先级为-10 2 调整脚本的优先级为6 3 调整指令的优先级 4 默认使用nice命令调整优先级 命令总结 Nice 命令介绍 nice命令的主要功能是用于调整进程的优先级,合理分配系统资源。Linux系…

torchaudio的I/O函数

info、load、save1.1 infotorchaudio.info(filepath: str, ...)Fetch meta data of an audio file. Refer to torchaudio.backend for the detail.返回音频文件的meta信息这里的meta元信息包括采样率、帧数、通道数、量化位数、音频格式info torchaudio.info(rE:\adins\data\2…

Java使用DFA算法实现敏感词过滤

1 前言敏感词过滤就是你在项目中输入某些字(比如输入xxoo相关的文字时)时要能检测出来,很多项目中都会有一个敏感词管理模块,在敏感词管理模块中你可以加入敏感词,然后根据加入的敏感词去过滤输入内容中的敏感词并进行…

测试软件5

一 css基础 css定义:可以设置网页中的样式,外观,美化 css中文名字:级联样式表,层叠样式表,样式表 二 css基础语法 1.style标签写在title标签后面 2.选择器{属性名1:属性值1;属性名…

ChatGPT最牛应用,让它帮你更新网站新闻吧!

谁能想到,ChatGPT火了!既能对话入流,又能写诗歌论文、出面试题、编代码,甚至还通过了谷歌面试拿到L3工程师offer,放在一年之前,没人相信这是当前AI能够达到的水平。ChatGPT自面世以来,凭借其极为…

【数据结构初阶】手把手带你实现栈

前言 在进入数据结构初阶的学习之后,我们学习了顺序表和链表,当然栈也是一种特殊的数据结构,他的特点是后进先出。 栈的概念及结构 栈(stack)又名堆栈,它是一种运算受限的线性表。限定仅在表尾进行插入和删…

iptables的介绍

iptables简介 1、 什么是iptables? iptables是linux防火墙工作在用户空间的管理工具,是 netfilter/iptables IP信息包过滤系统的一部分,用来设置、维护和检查Linux内核的IP数据包过滤规则 2、 iptables特点 iptables是基于内核的防火墙&…

【pytorch 入门系列】02 手把手多分类从0到1

温故而知新,通过手把手写一个多分类任务来复习之前所学过的知识。 前置知识 factorize的妙用:把文本数据枚举化 labels, uniques pd.factorize([b, b, a, c, b]) labels,uniques(array([0, 0, 1, 2, 0]), array([‘b’, ‘a’, ‘c’], dtypeobject))…

【C++】-- 特殊类设计

对于类的思维境界提升,没有太大的实际意义,但是锻炼思想。 目录 单例模式 饿汉模式 懒汉模式 #:请设计一个类,不能被拷贝。 拷贝只会发生在两个场景中:拷贝构造函数赋值运算符重载因此想要让一个类禁止拷贝&#xf…