onnx模型的保存与使用

news2024/9/20 22:34:00

1 onnx模型的保存

在网络训练结束之后,通常会将模型的权重参数保存到.pth或.pt文件中,如果部署环境中有pytorch,那么直接新建一个模型类对象,然后导入权重参数即可,但如果部署环境中只有OpenCV,没有pytorch,那么该如何部署呢?
答:先在训练环境中将.pth文件转成onnx文件,再将onnx文件部署到最终的环境中,导出onnx文件的命令为torch.onnx.export即可。

import torch
import torchvision as tv
from torch.utils.data import DataLoader


class CNN_Mnist(torch.nn.Module):
    def __init__(self):
        super(CNN_Mnist, self).__init__()
        self.cnn_layers = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1, stride=1),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, padding=1, stride=1),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.ReLU()
        )
        self.fc_layers = torch.nn.Sequential(
            torch.nn.Linear(7*7*32, 200),
            torch.nn.ReLU(),
            torch.nn.Linear(200, 100),
            torch.nn.ReLU(),
            torch.nn.Linear(100, 10),
            torch.nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        out = self.cnn_layers(x)
        out = out.view(-1, 7*7*32)
        out = self.fc_layers(out)
        return out


def train_and_test():
    model = CNN_Mnist().cuda()
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    for s in range(5):
        print("run in epoch : %d" % s)
        for i, (x_train, y_train) in enumerate(train_dl):
            x_train = x_train.cuda()
            y_train = y_train.cuda()
            y_pred = model.forward(x_train)
            train_loss = loss(y_pred, y_train)
            if (i + 1) % 100 == 0:
                print(i + 1, train_loss.item())
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

    torch.save(model.state_dict(), './cnn_mnist_model.pt')
    model.eval()

    total = 0
    correct_count = 0
    for test_images, test_labels in test_dl:
        pred_labels = model(test_images.cuda())
        predicted = torch.max(pred_labels, 1)[1]
        correct_count += (predicted == test_labels.cuda()).sum()
        total += len(test_labels)
    print("total acc : %.2f\n"%(correct_count / total))


if __name__ == '__main__':
    # 数据预处理方法
    transform = tv.transforms.Compose([tv.transforms.ToTensor(),
                                       tv.transforms.Normalize((0.5,), (0.5,)),
                                       ])
    # 数据集
    train_ts = tv.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_ts = tv.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # 数据集导入器
    train_dl = DataLoader(train_ts, batch_size=32, shuffle=True, drop_last=False)
    test_dl = DataLoader(test_ts, batch_size=64, shuffle=True, drop_last=False)
    
    # 训练与测试
    train_and_test()

    # 模型的导入
    model = CNN_Mnist()
    model.load_state_dict(torch.load('cnn_mnist_model.pt'))

    """下面演示将模型转化为onnx格式文件"""
    # 要把模型切换到评估状态,这样可以让某些层(如drop_out失效)
    model.eval()

    # 随机一个输入张量,这个张量的作用是告诉onnx框架,输入张量的shape是什么样的
    dummy_input = torch.randn(1, 1, 224, 224)

    # 导出onnx文件
    torch.onnx.export(model, (dummy_input), 'cnn_mnistorch.onnx', verbose=True) # onnx文件可以通过netron查看结构
    # verbose=True,则打印一些转换日志,并且onnx文件中会包含doc_string,即用于说明模型的文档字符串

这段程序执行完成后,就会在当前目录下得到一个名为“cnn_mnistorch.onnx”的文件

2 在OpenCV中调用onnx模型文件

OpenCV调用onnx文件也很方便,只需要cv.dnn.readNetFromONNX即可,剩下的部分就是调用OpenCV了

import cv2 as cv
import numpy as np


def mnist_onnx_demo():
    # 从onnx文件中读取网络
    mnist_net = cv.dnn.readNetFromONNX("cnn_mnist.onnx")

    # 读取图片并做相应的预处理
    image = cv.imread("test.png")
    gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
    cv.imshow("input", gray)
    blob = cv.dnn.blobFromImage(gray, 0.00392, (28, 28), (127.0)) / 0.5		# 这个API稍后会讲
    """上面对应图像在pytorch中的变换
    tv.transforms.Compose([tv.transforms.ToTensor(),
                            tv.transforms.Normalize((0.5,), (0.5,)),])"""
    print(blob.shape)

    # 前向传播,OpenCV的dnn模块,前向传播需要先设置网络的输入
    mnist_net.setInput(blob)
    result = mnist_net.forward()

    # 后处理
    pred_label = np.argmax(result, 1)
    print("predit label : %d" % pred_label)

    cv.waitKey(0)
    cv.destroyAllWindows()
    # 整个过程没有调用pytorch框架,也没使用模型对应的类,也就是说,onnx文件可以摆脱对框架的依赖


if __name__ == '__main__':
    mnist_onnx_demo()

在这里插入图片描述
输出:

(1, 1, 28, 28)
predit label : 3

可以看到,部署的过程中,完全拜托了对pytorch框架的依赖。

这里面比较重要的是cv2.dnn.blobFromImage这个函数,它的作用是将图像转化为网络模型的输入,API如下:

cv2.dnn.blobFromImage(image[, scalefactor[, size[, mean[, swapRB[, crop[, ddepth]]]]]])
作用:
对图像进行预处理,包括减均值,比例缩放,裁剪,交换通道等,返回一个4通道的blob(blob可以简单理解为一个N维的数组,用于神经网络的输入)

参数:
image:输入图像(13或者4通道)
可选参数
scalefactor:图像各通道数值的缩放比例
size:图像要转化成的空间尺寸,如size=(200,300)表示高h=300,宽w=200,相当于resize
mean:用于各通道减去的值,以降低光照的影响(e.g. image为BGR3通道的图像,mean=[104.0, 177.0, 123.0],表示b通道的值-104,g-177,r-123)
swapRB:交换RB通道,默认为False.(cv2.imread读取的是彩图是bgr通道)
crop:图像裁剪,默认为False.
	当值为True时,先按保持原来的高宽比缩放,直到其中一条边等于对应方向的长度,另一条边大于对应方向长度,然后从中心裁剪成size尺寸;
	如果值为False,则不管高宽比,直接缩放成指定尺寸(即size参数)。
	e.g.原图(300, 200),目标尺寸(400, 300)
	若crop=True(300, 200) --resize-->(450, 300)--crop-->(400, 300)
	若crop=False(300, 200) --resize-->(400, 300)
ddepth:输出的图像深度,可选CV_32F 或者 CV_8U.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/821817.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

融合大数据、物联网和人工智能的智慧校园云平台源码 智慧学校源码

电子班牌系统用以展示各个班级的考勤信息、授课信息、精品课程、德育宣传、班级荣誉、校园电视台、考场信息、校园通知、班级风采,是智慧校园和智慧教室的对外呈现窗口,也是学校校园文化宣传和各种信息展示的重要载体。将大数据、物联网和人工智能等新兴…

27岁到来之际,我在阿里实现了年薪30W+的小目标

毕业快 5 年了,每当和人聊起自己的职场飞升之路,都不由得感激当初果断逃离舒适圈的自己。出身一所非 211、985 院校,毕业后入职了一家小型互联网公司,当着普普通通的初级测试工程师,工作期间虽然也时常遇到挑战&#x…

性能优化点

Arts and Sciences - Computer Science | myUSF 索引3层(高度为3)一般对于数据库地址千万级别的表 大于2000万的数据进行分库分表存储 JVM整体结构及内存模型 JVM调优:主要为减少FULL GC的执行次数或者减少FULL GC执行时间 Spring Boot程序…

在线文档管理工具都有什么值得推荐的?

在线文档管理工具是现代企业和个人必备的工具之一,它们可以帮助用户方便地创建、编辑、共享和管理文档。 几个值得推荐的在线文档管理工具: Google 文档:Google 文档是一款免费的在线文档工具,它提供了和 Microsoft Word 类似的…

微信公众号开发学习

申请测试号 地址 通过F12抓取体验接口权限表的HTML 解析HTML 引入pom <dependency><groupId>org.projectlombok</groupId><artifactId>lombok</artifactId><optional>true</optional></dependency><dependency><…

物联网|可变参数的使用技巧|不一样的点灯实验|访问外设的寄存器|操作寄存器实现点灯|硬件编程的基本流程-学习笔记(11)

文章目录 可变参数的使用技巧第三阶段-初级实验Lesson5:不一样的点灯实验---学习I/O的输出 ☆点灯的电路图分析1 一起看看点灯的电路图Tip1:另一种点灯的电路Tip1:如何访问外设的寄存器2 STM32F407中操作GPIO的方法 通过直接操作寄存器实现点灯实验Tip1:硬件编程的基本流程 2代…

数据可视化(4)散点图及面积图

1.简单散点图 #散点图 #scatter(x,y) x数据&#xff0c;y数据 x[i for i in range(10)] y[random.randint(1,10) for i in range(10)] plt.scatter(x,y) plt.show()2.散点图分析 #分析广告支出与销售收入相关性 dfcarpd.read_excel(广告支出.xlsx) dfdatapd.read_excel(销售…

VS开发Qt程序,无法打印QDebug调试信息,VS进行Qt开发时Qt Designer无法使用“转到槽”选项

VS开发Qt程序&#xff0c;无法打印QDebug调试信息&#xff0c;VS进行Qt开发时Qt Designer无法使用“转到槽”选项 VS开发Qt程序&#xff0c;无法打印QDebug调试信息VS进行Qt开发时Qt Designer无法使用“转到槽”选项 VS开发Qt程序&#xff0c;无法打印QDebug调试信息 解决方案…

使用Idea提交项目到远程仓库

使用Idea提交项目到远程仓库 1.在Idea中打开本地要推送的项目2.创建远程仓库并提交 1.在Idea中打开本地要推送的项目 tips: 首先你得有git工具&#xff0c;没有的话可以参考下面的这篇文章 git与gitee结合使用&#xff0c;提交代码&#xff0c;文件到远程仓库 从导航栏中选择 V…

如何快速开拓海外华人市场?附解决方案!

开拓华人市场对于企业来说是非常必要的。华人市场庞大且潜力巨大&#xff0c;拥有巨额的消费能力。随着华人经济的不断增长&#xff0c;越来越多的企业开始意识到华人市场的重要性。 通过开拓华人市场&#xff0c;企业可以获得更多的销售机会&#xff0c;并且在竞争激烈的市场…

Go语言time库,时间和日期相关的操作方法

time库 用于处理时间、日期和时区的核心库。在实际开发中&#xff0c;常常需要与时间打交道&#xff0c;例如记录日志、处理时间差、计算时间间隔等等。因此&#xff0c;掌握time库的使用方法对于Go开发者来说非常重要。 在Go语言中&#xff0c;时间表示为time.Time类型&…

嵌入式开发的学习内容和技能包括:

. 熟悉C语言编程 掌握基础电子知识&#xff0c;如数字电路、模拟电路和单片机 .熟练掌握嵌入式操作系统的原理、内核架构和应用&#xff0c;如Linux、RTOS等 了解各种外设接口及其驱动程序开发&#xff0c;如SPI、I2C、USART等 熟悉常用的嵌入式开发工具和软件工程流程&#…

【ASPICE】:学习记录

学习记录 ASPICE中文资料什么是ASPICE过程参考模型 ASPICE全称“Automotive Software Process Improvement and Capability dEtermination”&#xff0c;即“汽车软件过程改进及能力评定”模型框架 ASPICE中文资料 主要资料来源 什么是ASPICE 过程参考模型

神经网络原理概述

文章目录 1.神经元和感知器1.1.什么是感知器1.2.什么是单层感知器1.3.多层感知机&#xff08;Multi-Layer Perceptron&#xff0c;MLP&#xff09; 2.激活函数2.1.单位阶跃函数2.2.sigmoid函数2.3.ReLU函数2.4.输出层激活函数 3.损失函数4.梯度下降和学习率5.过拟合和Dropout6.…

python学到什么程度算入门,python从入门到精通好吗

本篇文章给大家谈谈python学到什么程度算入门&#xff0c;以及python从入门到精通好吗&#xff0c;希望对各位有所帮助&#xff0c;不要忘了收藏本站喔。 学习 Python 之 进阶学习 一切皆对象 1. 变量和函数皆对象2. 模块和类皆对象3. 对象的基本操作 (1). 可以赋值给变量(2). …

JAVA- SQL注入案例(黑马程序员)和避免 超级详细

文章目录 sql注入准备1.创建应该新的数据库用于测试&#xff1b;2.修改配置3.启动jar包4.打开网页测试5.测试sql注入 sql注入避免1. java中的登录逻辑代码2.演示sql注入3.原因5.参数化查询-PreparedStatement SQL注入是什么&#xff1f; SQL 注入&#xff08;SQL Injection&…

【Python】Web学习笔记_flask(2)——getpost

flask提供的request请求对象可以实现获取url或表单中的字段值 GET请求 从URL中获取name、age两个参数 from flask import Flask,url_for,redirect,requestappFlask(__name__)app.route(/) def index():namerequest.args.get(name)agerequest.args.get(age)messagef姓名:{nam…

【LeetCode 75】第十七题(1493)删掉一个元素以后全为1的最长子数组

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码运行结果&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 给一个数组&#xff0c;求删除一个元素以后能得到的连续的最长的全是1的子数组。 我们可以先单独统计出连续为1的子数组分别长度…

命令模式-请求发送者与接收者解耦

去小餐馆吃饭的时候&#xff0c;顾客直接跟厨师说想要吃什么菜&#xff0c;然后厨师再开始炒菜。去大点的餐馆吃饭时&#xff0c;我们是跟服务员说想吃什么菜&#xff0c;然后服务员把这信息传到厨房&#xff0c;厨师根据这些订单信息炒菜。为什么大餐馆不省去这个步骤&#xf…

【JVM】(一)深入理解JVM运行时数据区

文章目录 一、JVM 运行流程二、虚拟机栈&#xff08;线程私有&#xff09;三、本地方法栈 &#xff08;线程私有&#xff09;四、方法区&#xff08;元数据区&#xff09;五、堆&#xff08;线程共享&#xff09;六、程序计数器&#xff08;线程私有&#xff09; 一、JVM 运行流…