pytorch学习(十二)c++调用minist训练的onnx模型

news2024/9/24 9:21:33

在实际使用过程中,使用python速度不够快,并且不太好嵌入到c++程序中,因此可以把pytorch训练的模型转成onnx模型,然后使用opencv进行调用。

所需要用到的库有:

opencv

1.完整的程序如下

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR
import os
import re
from PIL import Image



cur_pwd_path = os.getcwd()

def getBestModuleFilename(browser):
    file_name = browser             #"tf_logs/save_module"
    filenames = os.listdir(file_name)
    pattern = r"d+"
    result = []
    for i in range(len(filenames)):
        rst = int(filenames[i][10:-4])

        result.append(rst)
    val = max(result)
    index = result.index(val)
    file_best = filenames[index]
    print(file_best)
    return file_best

tensor = torch.randn(3,3)
bTensor = type(tensor) == torch.Tensor
print(bTensor)
print("tensor is on ", tensor.device)
#数据转到GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
if torch.cuda.is_available():
    tensor = tensor.to(device)
    print("tensor is on ",tensor.device)
#数据转到CPU
if tensor.device == 'cuda:0':
    tensor = tensor.to(torch.device("cpu"))
    print("tensor is on", tensor.device)
if tensor.device == "cpu":
    tensor = tensor.to(torch.device("cuda:0"))
    print("tensor is on", tensor.device)

trainning_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=True)
print(len(trainning_data))
test_data = datasets.MNIST(root="data",train=True,transform=ToTensor(),download=False)

train_loader = DataLoader(trainning_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64,shuffle=True)




print(len(train_loader)) #分成了多少个batch
print(len(trainning_data)) #总共多少个图像
# for x, y in train_loader:
#     print(x.shape)
#     print(y.shape)



class MinistNet(nn.Module):
    def __init__(self):
        super().__init__()
        # self.flat = nn.Flatten()
        self.conv1 = nn.Conv2d(1,1,3,1,1)
        self.hideLayer1 = nn.Linear(28*28,256)
        self.hideLayer2 = nn.Linear(256,10)
    def forward(self,x):
        x= self.conv1(x)
        x = x.view(-1,28*28)
        x = self.hideLayer1(x)
        x = torch.sigmoid(x)
        x = self.hideLayer2(x)
        # x = nn.Sigmoid(x)
        return x


model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\2.jpg"
img = Image.open(img_path)
test_model = MinistNet()
test_model1 = torch.load(model_path)
test_model.load_state_dict(test_model1["net"])

test_model.eval()
test_model.to("cuda")

transform =torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])

img = transform(img)
img = torch.unsqueeze(img, 0)
img = img.to("cuda")

result = test_model(img)
result = result.to("cpu")
val,index = torch.max(result,dim=1)
print(index)

model = MinistNet()
model = model.to(device)
cuda = next(model.parameters()).device
print(model)
criterion = nn.CrossEntropyLoss()
optimer = torch.optim.RMSprop(model.parameters(),lr= 0.001)

scheduler_1 = LambdaLR(optimer, lr_lambda=lambda epoch: 1/(epoch+1))



num_epoches =10
min_loss_val = 100000
Resume = False

def train():
    global min_loss_val
    start_epoch = -1
    if Resume == False:
        start_epoch = 0
    else:
        #找到数字最大的pth文件


        path_checkpoint = r'tf_logs/'+"save_module"
        best_path_checkpoint = getBestModuleFilename(path_checkpoint)
        if(best_path_checkpoint == ""):
            return
        else:
            checkpointResume = torch.load(path_checkpoint)
            start_epoch = checkpointResume["epoch"]
            model.load_state_dict(checkpointResume["net"])
            optimer.load_state_dict(checkpointResume["optimizer"])
            scheduler_1.load_state_dict(checkpointResume["lr_schedule"])

    train_losses = []
    train_acces = []
    eval_losses = []
    eval_acces = []
    #训练
    model.train()
    tensorboard_ind =0;
    for epoch in range(num_epoches):
        batchsizeNum = 0
        train_loss = 0
        train_acc = 0
        train_correct = 0
        for x,y in train_loader:
            # print(epoch)
            # print(x.shape)
            # print(y.shape)
            x = x.to('cuda')
            y = y.to('cuda')
            bte = type(x)==torch.Tensor
            bte1 = type(y)==torch.Tensor
            A = x.device
            B = y.device
            pred_y = model(x)
            loss = criterion(pred_y,y)
            optimer.zero_grad()
            loss.backward()
            optimer.step()
            loss_val = loss.item()
            batchsizeNum = batchsizeNum +1
            train_acc += (pred_y.argmax(1) == y).type(torch.float).sum().item()
            train_loss += loss.item()
            tensorboard_ind += 1
        train_losses.append(train_loss / len(trainning_data))
        train_acces.append(train_acc / len(trainning_data))


        #测试
        test_loss_value = 0
        model.eval()
        with torch.no_grad():
            num_batch = len(test_data)
            numSize = len(test_data)
            test_loss, test_correct = 0,0
            for x,y in test_loader:
                x = x.to(device)
                y = y.to(device)
                pred_y = model(x)
                test_loss += criterion(pred_y, y).item()
                test_correct += (pred_y.argmax(1) == y).type(torch.float).sum().item()
            test_loss /= num_batch
            test_correct /= numSize
            eval_losses.append(test_loss)
            eval_acces.append(test_correct)
            test_loss_value = test_loss
            print("test result:",100 * test_correct,"%  avg loss:",test_loss)
        scheduler_1.step()
        #设置checkpoint
        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:
            min_loss_val = test_loss_value
            checkpoint = {"epoch": epoch,
                        "net": model.state_dict(),
                          "optimizer":optimer.state_dict(),
                          "lr_schedule":scheduler_1.state_dict()}

            if not os.path.isdir(r'tf_logs/' + "save_module"):
                os.makedirs("tf_logs/" + "save_module")
            PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))
            torch.save(checkpoint, PATH)



def test_singleFrame():
    model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
    img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\1.jpg"
    img =Image.open(img_path)
    test_model = MinistNet()
    test_model = torch.load(model_path)
    test_model.to("cuda")

    transform=ToTensor()
    img = transform(img)
    img.to("cuda")

    result = test_model(img)
    val, index = torch.max(result)
    print(index)


# Press the green button in the gutter to run the script.

if __name__ == '__main__':
        train()

        #保存onnx
        model.cpu()
        model.eval()
        x= torch.randn(1,1,28,28)
        torch.onnx.export(model,x,"model.onnx")

2.训练并保存模型

        if epoch > int(num_epoches/3) and test_loss_value < min_loss_val:
            min_loss_val = test_loss_value
            checkpoint = {"epoch": epoch,
                        "net": model.state_dict(),
                          "optimizer":optimer.state_dict(),
                          "lr_schedule":scheduler_1.state_dict()}

            if not os.path.isdir(r'tf_logs/' + "save_module"):
                os.makedirs("tf_logs/" + "save_module")
            PATH = r'tf_logs/'+"save_module" + "/ckpt_best_%s.pth"%(str(epoch+1))
            torch.save(checkpoint, PATH)

3.加载并测试模型

model_path = "E:\\TOOLE\\slam_evo\\pythonProject\\tf_logs\\save_module\\ckpt_best_10.pth"
img_path = "E:\\TOOLE\\slam_evo\\pythonProject\\2.jpg"
img = Image.open(img_path)
test_model = MinistNet()
test_model1 = torch.load(model_path)
test_model.load_state_dict(test_model1["net"])

test_model.eval()
test_model.to("cuda")

transform =torchvision.transforms.Compose([
torchvision.transforms.Grayscale(),
torchvision.transforms.ToTensor()
])

img = transform(img)
img = torch.unsqueeze(img, 0)
img = img.to("cuda")

result = test_model(img)
result = result.to("cpu")
val,index = torch.max(result,dim=1)
print(index)

结果如下:

按照第0,1来数数,tensor([1])刚好就是2.

4.保存onnx模型

if __name__ == '__main__':
        train()

        #保存onnx
        model.cpu()
        model.eval()
        x= torch.randn(1,1,28,28)
        torch.onnx.export(model,x,"model.onnx")

5.使用C++加opencv实现minist手写数字的识别

// test_onnm.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。
//

#include<ostream>
#include<opencv2/opencv.hpp>
#include<opencv2/dnn.hpp>

#include <iostream>

using namespace std;
using namespace cv;
using namespace dnn;

int main()
{
    std::cout << "Hello World!\n";

    //cv::dnn::Net net = cv::dnn::readTensorFromONNX();
    cv::dnn::Net net = cv::dnn::readNetFromONNX("E:\\TOOLE\\slam_evo\\pythonProject\\model.onnx");
    if (net.empty())
    {
        std::cout << "加载onnx模型失败" << std::endl;
        return -1;
    }

    net.setPreferableBackend(DNN_BACKEND_OPENCV);
    net.setPreferableTarget(DNN_TARGET_CPU);

    cv::Mat img = cv::imread("E:\\TOOLE\\slam_evo\\pythonProject\\1.jpg",cv::IMREAD_GRAYSCALE);

    if(img.cols != 28 || img.rows != 28)
    {
        return -1;
    }

    cv::Mat blob;
    float scaleFactor = 1 / 255.0;
    blobFromImage(img, blob, scaleFactor, Size(), Scalar(), true, false, CV_32F);

    net.setInput(blob);
    cv::Mat predict = net.forward();
    for (int i = 0; i < predict.total(); i++)
    {
        std::cout << predict.at<float>(i) << "  ";
    }
    std::cout << std::endl;


    double minVal, maxVal;
    Point minLoc, maxLoc;

    // 查找最大值和最小值及其位置
    minMaxLoc(predict, &minVal, &maxVal, &minLoc, &maxLoc);
    cout << maxVal << "    " << maxLoc.x<<"   "<< maxLoc.y << "\n";
    return 0;

}

结果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1938044.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

06. 截断文本 选择任何链接 :root 和 html 有什么区别

截断文本 对超过一行的文本进行截断,在末尾添加省略号(…)。 使用 overflow: hidden 防止文本超出其尺寸。使用 white-space: nowrap 防止文本超过一行高度。使用 text-overflow: ellipsis 使得如果文本超出其尺寸,将以省略号结尾。为元素指定固定的 width,以确定何时显示省略…

韩顺平0基础学Java——第35天

p689-714 格式化语句 gpt说的&#xff1a; System.out.println 方法不支持像 printf 一样的格式化字符串。要使用格式化字符串&#xff0c;你可以使用 System.out.printf 方法或将格式化后的字符串传递给 System.out.println。下面是两种修正的方法&#xff1a; ### 方法一…

科研绘图系列:R语言circos图(circos plot)

介绍 Circos图是一种数据可视化工具,它以圆形布局展示数据,通常用于显示数据之间的关系和模式。这种图表特别适合于展示分层数据或网络关系。Circos图的一些关键特点包括: 圆形布局:数据被组织在一个或多个同心圆中,每个圆可以代表不同的数据维度或层次。扇区:每个圆被划…

昇思25天学习打卡营第25天|MindNLP ChatGLM-6B StreamChat

配置环节 %%capture captured_output !pip uninstall mindspore -y !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore2.2.14 !pip install mindnlp !pip install mdtex2html配置国内镜像 !export HF_ENDPOINThttps://hf-mirror.com下载与加载模型 from m…

【safari】react在safari浏览器中,遇到异步时间差的问题,导致状态没有及时更新到state,引起传参错误。如何解决

在safari浏览器中&#xff0c;可能会遇到异步时间差的问题&#xff0c;导致状态没有及时更新到state&#xff0c;引起传参错误。 PS&#xff1a;由于useState是一个普通的函数&#xff0c; 定义为() > void;因此此处不能用await/async替代setTimeout&#xff0c;只能用在返…

Vue3 composition api计算属性活学活用(作业题1 - 计算扁平化树树节点的索引)

本示例节选自vue3最新开源组件实战教程大纲&#xff08;持续更新中&#xff09;的tree组件开发部分。在学习了tree组件实现折叠与展开功能&#xff08;方式2 - visible计算属性&#xff09;后&#xff0c;给读者朋友留的一道编程作业题。 作业要求 合理的设计和实现树节点的计…

【C#】计算两条直线的交点坐标

问题描述 计算两条直线的交点坐标&#xff0c;可以理解为给定坐标P1、P2、P3、P4&#xff0c;形成两条线&#xff0c;返回这两条直线的交点坐标&#xff1f; 注意区分&#xff1a;这两条线是否垂直、是否平行。 代码实现 斜率解释 斜率是数学中的一个概念&#xff0c;特别是…

HTML开发笔记:1.环境、标签和属性、CSS语法

一、环境与新建 在VSCODE里&#xff0c;加载插件&#xff1a;“open in browser” 然后新建一个文件夹&#xff0c;再在VSCODE中打开该文件夹&#xff0c;在右上角图标新建文档&#xff0c;一定要是加.html&#xff0c;不要忘了文件后缀 复制任意一个代码比如&#xff1a; <…

reserve和resize

void test_vector4() {vector<int> v1;//cout << v1.max_size() << endl;//v1.reserve(10);v1.resize(10);for (size_t i 0; i < 10; i){v1[i] i;}for (auto e : v1){cout << e << " ";}cout << endl;} 在上面这段代码中对…

数学建模--国赛备赛---TOPSIS算法

目录 1.准备部分 1.1提交材料 1.2MD5码相关要求 2.TOPSIS算法 2.1算法概述 2.2基本概念 2.3算法核心思想 2.4拓展思考 3.适用赛题 3.1适用赛题说明 3.2适用赛题举例 4.赛题分析 4.1指标的分类 4.2数据预处理 4.2.1区间型属性的变换 4.2.2向量规范化 4.3数据加…

vue 侧边锚点外圆角

环境&#xff1a;uniapp、vue3、unocss、vant4 效果&#xff1a; 代码 主要是&#xff1a;pointTop 、pointCentent 、pointBottom&#xff0c;这三个样式 html <div v-show"!showPoint" class"fixedLeftDiv"><div><div class"pointT…

RPG素材Unity7月20闪促限时4折游戏开发资产兽人角色模型动画休闲放置模板物理交互流体水下焦散VR界面UI2D模板场景20240720

今天这个是RPG素材比较多&#xff0c;还有一些休闲放置模板、FPS场景素材、角色模型、动画、特效。 详细内容展示&#xff1a;www.bilibili.com/video/BV1Tx4y1s7vm 闪促限时4折&#xff1a;https://prf.hn/l/0eEOG1P 半价促销&#xff1a;https://prf.hn/l/RlDmDeQ 7月闪促…

java开发报错合集

mapstruct 1. 报错信息&#xff1a; mapstruct 错误 java.lang.NoSuchMethodError: Ljava/lang/Double 错误 解决方案&#xff1a; mapstruct 错误 java.lang.NoSuchMethodError: Ljava/lang/Double 错误_mapstruct nosuchmethoderror-CSDN博客 2. 报错信息&#xff1a; maps…

数据结构——线性表(单链表)

一、链式存储结构定义 线性表的链式存储结构定义是指使用指针将线性表中的元素按照其逻辑次序依次存储在存储空间中&#xff0c;通过指针来表示数据元素之间的逻辑关系。具体来说&#xff0c;链式存储结构由数据域和指针域组成&#xff0c;数据域存储数据元素的数值&#xff0…

手机接Usb hub再连接电脑下D+D-波形

&#x1f3c6;本文收录于《CSDN问答解答》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&…

UE4-光照渲染、自动曝光、雾

目录 一.光源种类 二.灯光的移动性 三.自动曝光 四.指数级高度雾 五.实现光束 一.光源种类 1.定向光源 用来模拟现实中的太阳光。 2.点光源 比如现实中的灯泡 3.聚光源 4.矩形光源 是这几个光源中性能开销最大的&#xff0c;一般不用到游戏场景中&#xff0c;因为游…

win安装mysql

解压到目录没如果多个mysql创建不同的名字 创建data和my.ini my.ini内容 [client] default-character-setutf8mb4[mysqld] #设置3306端口 port 3306 # 设置mysql的安装目录 basedirF:\mysql-5.7.31 # 设置mysql的数据存放目录 datadirF:\mysql-5.7.31\data # 允许最大连接数 …

ComfyUI面部修复FaceDetailer使用指南

原文&#xff1a;ComfyUI面部修复完全指南 (chinaz.com) 让我们开始使用ComfyUI中的人脸详细修复器吧。人脸详细修复器节点乍一看可能很复杂&#xff0c;但不要担心&#xff0c;我们会一点一点地分解它。通过理解每个输入、输出和参数&#xff0c;你很快就能像专业人士一样使用…

处理AI模型中的“Type Mismatch”报错:数据类型转换技巧

处理AI模型中的“Type Mismatch”报错&#xff1a;数据类型转换技巧 &#x1f504; 处理AI模型中的“Type Mismatch”报错&#xff1a;数据类型转换技巧 &#x1f504;摘要引言正文内容1. 错误解析&#xff1a;什么是“Type Mismatch”&#xff1f;2. 数据类型转换技巧2.1 检查…

大数据环境下的房地产数据分析与预测研究的设计与实现

1绪论 1.1研究背景及意义 随着经济的快速发展和城市化进程的推进&#xff0c;房地产市场成为了国民经济的重要组成部分。在中国&#xff0c;房地产行业对经济增长、就业创造和资本投资起到了重要的支撑作用。作为中国西南地区的重要城市&#xff0c;昆明的房地产市场也备受关…