opencv-python加载pytorch训练好的onnx格式线性回归模型

news2025/1/11 7:24:16

    opencv是一个开源的图形库,有针对java,c++,python的库依赖,它本身对模型训练支持的不好,但是可以加载其他框架训练的模型来进行预测。

    这里举一个最简单的线性回归的例子,使用深度学习框架pytorch训练模型,最后保存模型为onnx格式。最后使用opencv-python库来进行加载预测。

    这里准备的线性回归模型数据如下:

x_datay_data
12
24
36
4

?

8?
10?
15?

    直观的看,这里其实就是一个 y = 2x的线性方程。但是机器学习里面,它是通过不断迭代的方式求得最终的系数,w,b的值。 

    本文使用过的python库版本:

  •     torch:1.13.0
  •     opencv-python: 4.5.5
  •     numpy 1.24.2
  •     python: 3.10

    show me the code:

import torch

x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])


class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)


model = LinearModel()
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.02)

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    if epoch % 100 == 0:
        print(epoch + 1, loss)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w = {}'.format(model.linear.weight.item()))
print('b = {}'.format(model.linear.bias.item()))

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred(4.0) = ', y_test.data)

x_test = torch.Tensor([[8.0]])
y_test = model(x_test)
print('y_pred(8.0) = ', y_test.data)

x_test = torch.Tensor([[10.0]])
y_test = model(x_test)
print('y_pred(10.0) = ', y_test.data)

x_test = torch.Tensor([[15.0]])
y_test = model(x_test)
print('y_pred(15.0) = ', y_test.data)

model.eval()
dummy_input = torch.randn(1, 1)
input_name = ["input"]
output_name = ["output"]
onnx_name = "test.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_name,
    verbose=True,
    input_names=input_name,
    output_names=output_name
)

    运行,打印信息如下:

 

    这里重点关注线性回归系数,w = 1.99,b = 1.31。这里的w其实很接近2,毕竟是计算机算出来的,这里没有明确表示使用梯度下降算法,但是实际上在训练那部分,确实使用的是梯度下降法。

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    if epoch % 100 == 0:
        print(epoch + 1, loss)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    最后,通过这个回归系数预测,符合预期。 

    算法最终会将模型保存在本地项目路径下的test.onnx文件中。

  //

    上面算是模型准备好了,下面我们通过opencv-python来加载这个模型,并预测:

from cv2 import dnn
import numpy as np

net = dnn.readNetFromONNX("test.onnx")
matblob = np.full((1, 1), 1024, dtype=np.int32)
net.setInput(matblob)
print('input = {}'.format(matblob))
output = net.forward()
print('output = {}'.format(output))

    这里,matblob其实就是一个Mat,但是在opencv-python里面,它可以通过numpy来创建,这里在网上的都是通过读取一个图片来生成matblob对象,我个人觉着我们这里很明确,就是需要指定一个数字1024,我们通过np.full() 来设置cols,rows,type,value就成功创建了这个输入Mat对象。

    运行这个代码,我们期望得到的是2048,实际结果如下所示:

    其实和2048很接近了,这是python代码的结果。如果使用opencv-c++来编码,代码基本类似。我没有去试c++,主要是这个模型就是通过python代码实现并生成的,而且opencv有python依赖库,所以就想着直接用python实现了,没必要再去研究opencv-c++的实现了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/473091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软考备战·希赛网每日一练】2023年4月28日

文章目录 一、今日成绩二、错题总结第一题第二题第三题 三、知识查缺 题目及解析来源:2023年04月28日软件设计师每日一练 一、今日成绩 二、错题总结 第一题 解析: 大体了解即可,题目要考察的核心意思:确定的有限自动机和不确定的…

js 操作数组内容

js 操作数组内容 数组添加元素(更改原数组) push和unshift会返回添加了新元素的数组长度 push从数组最后加入,unshift从数组最前面加入 const arr ["a", "b", "c"]; arr.push("d"); //返回4…

数据结构基础day9

题目&#xff1a;187. 重复的DNA序列 解法1&#xff1a;哈希表 class Solution { public:vector<string> findRepeatedDnaSequences(string s) {vector<string> ans;unordered_map<string, int> mp;int ns.size(), L10;for(int i0; i<n-L; i){ //从开头…

【fluent UDF】warning: unused variable警报:存在未使用的变量

一、问题背景 在编译UDF时&#xff0c;出现如下错误 curing_heat_v3.c: In function ‘iter_ending’: curing_heat_v3.c:105:14: warning: unused variable ‘volume_sum’ [-Wunused-variable] real volume_sum0.0; curing_heat_v3.c:104:14: warning: unused variable ‘…

【Python零基础学习入门篇②】——第二节:Python的常用语句

⬇️⬇️⬇️⬇️⬇️⬇️ ⭐⭐⭐Hello&#xff0c;大家好呀我是陈童学哦&#xff0c;一个普通大一在校生&#xff0c;请大家多多关照呀嘿嘿&#x1f601;&#x1f60a;&#x1f618; &#x1f31f;&#x1f31f;&#x1f31f;技术这条路固然很艰辛&#xff0c;但既已选择&…

网络编程之简单socket通信

一.什么是Socket? Socket&#xff0c;又叫套接字&#xff0c;是在应用层和传输层的一个抽象层。它把TCP/IP层复杂的操作抽象为几个简单的接口供应用层调用以实现进程在网络中通信。 socket分为流socket和数据报socket&#xff0c;分别基于tcp和udp实现。 SOCK_STREAM 有以下…

苦学58天,最后就这结果......

背景 非计科大专一枚&#xff0c;当初学的机械自动化专业。大学完全可以说是玩过来的&#xff0c;临近毕业开始慌了&#xff0c;毕业后一直没能找到工作&#xff0c;在高中同学&#xff08;211 计科&#xff09;的引领下&#xff0c;入坑程序员&#xff0c;学的软件测试。 从…

Lombok简介

Lombok简介 1、lombok简介2、springboot整合lombok 1、lombok简介 Lombok是一个第三方的Java工具库&#xff0c;会自动插入编辑器和构建工具。Lombok提供了一组非常有用的注解&#xff0c;用来消除Java类中的大量样板代码&#xff0c;比如setter和getter方法、构造方法等。只需…

Vue(简单了解Cookie、生命周期)

一、了解Cookie 类似于对象响应携带数据 输入用户名密码跳转到指定页面 点击指定页面中其中一个按钮跳转到另一个指定页面&#xff08;再不需用输入用户名密码&#xff09; 例如现在很多浏览器实现七天免密登录 简单理解&#xff1a;就是在网站登录页面之后&#xff0c;服务…

新建Django项目

1. 创建项目 使用Django提供的命令&#xff0c;可以创建一个Django项目实例需要的配置项——包括数据库配置、Django配置和应用程序配置的集合。新建Django项目命令的语法格式如下&#xff1a; django-admin startproject 工程名称例如&#xff1a;想要在D:\的pythonProject目…

Mysql 存储过程 / 存储函数

目录 0 课程视频 1 基本语法 1.0 作用 ->在数据库中 封装sql语句 -> 复用 -> 减少网络交互 ->可接收参数返回数据 1.1 创建 1.2 调用 1.3 查看 1.4 删除 1.5 ; 封号结束符 改成 $$ 双刀符合结束语句 -> 因为打包封号结束有冲突 1.6 在cmd 中定义 存储过…

基于 SpringBoot+Vue+Java 的财务管理系统(附源码,教程)

文章目录 一 简介第二.主要技术第三、部分效果图第四章 系统设计4.1功能结构4.2 数据库设计4.2.1 数据库E/R图4.2.2 数据库表 第五章 系统功能实现5.1管理员功能模块 一 简介 财务管理系统的需求和管理上的不断提升&#xff0c;财务管理系统的潜力将无限扩大&#xff0c;财务管…

Postman预请求脚本、测试脚本(pre-request scripts、tests常用工作总结)

文章目录 Postman预请求脚本&#xff08;pre-request scripts工作常用总结&#xff09;Postman预请求脚本Postman测试脚本预请求脚本和测试脚本有什么区别常用工作总结登录接口返回的是Set-Cookie标头 Postman预请求脚本&#xff08;pre-request scripts工作常用总结&#xff0…

Spring Boot配置文件及日志信息

目录 前言&#xff1a; Spring Boot优点 配置文件 配置文件格式 读取配置文件 properties配置文件格式 properties优缺点分析 yml配置文件格式&#xff08;另一种标记语言&#xff09; yml优缺点分析 Spring Boot 不同平台配置文件规则 日志信息 日志的功能 Sprin…

Springboot +Flowable,设置任务处理人的四种方式(一)

一.简介 学习下UserTask 设置用户的三种方式&#xff0c;至于如何设置用户组&#xff0c;下篇文章再聊。 现在&#xff0c;假设我有如下一个简单的流程图&#xff1a; 那么该如何设置这个用户节点的处理人&#xff1f; 二.第一种&#xff1a;指定具体用户 第一种方式&…

电路中噪声来源

电路包括不同的部件和芯片&#xff0c;所有都有可能成为噪声的来源。例如&#xff0c;电阻会带来热噪声&#xff0c;这个噪声为宽频噪声&#xff0c;几乎涵盖所有频率范围&#xff1b;运算放大器其芯片内部会产生噪声&#xff1b;而 ADC产生的量化噪声相较于其他器件&#xff0…

【】右值引用完美转发

文章目录 右值引用和左值引用左值和右值概念左值引用 && 右值引用右值引用使用场景和意义左值引用的使用场景**左值引用的短板:**右值引用和移动语义STL容器增加的接口move函数右值引用的其他使用场景 完美转发万能引用完美转发保持值的属性完美转发的使用场景 右值引用…

【Linux】进程信号 --- 信号产生 信号递达和阻塞 信号捕捉

&#x1f34e;作者&#xff1a;阿润菜菜 &#x1f4d6;专栏&#xff1a;Linux系统编程 文章目录 一、预备知识二、信号产生1. 通过终端按键产生信号1.1 signal()1.2 core dump标志位、核心存储文件 2.通过系统调用向进程发送信号3.由软件条件产生信号3.1 alarm函数和SIGALRM信号…

vue3通过ref拿element弹框中的组件问题

写在<el-dialog>中的组件&#xff0c;在一开始&#xff0c;弹框没有弹出来的时候&#xff0c;<el-dialog>中的组件是没有被渲染出来的&#xff0c;因此在<el-dialog>中使用ref标记某个组件&#xff0c;在el-dialog没有被渲染出来之前去拿的话&#xff0c;是拿…

ml常见代码片段

常用ML代码片段 变换一列 new_df[brand] new_df[prod_name].apply(lambda x: x.split()[0])变换2列 new_df[chip_total_sales] new_df.apply(lambda x: x[total_sales] * x[is_chip], axis 1) # 重要的是axis1groupby 计数&#xff0c;求和&#xff0c;取第一个值&#x…