MNIST手写字符分类-卷积

news2024/11/16 10:29:48

MNIST手写字符分类-卷积

文章目录

  • MNIST手写字符分类-卷积
    • 1 模型构造
    • 2 训练
    • 3 推理
    • 4 导出
    • 5 onnx测试
    • 6 opencv部署
    • 7 总结

  在上一篇中,我们介绍了如何在pytorch中使用线性层+ReLU非线性层堆叠的网络进行手写字符识别的网络构建、训练、模型保存、导出和推理测试。本篇文章中,我们将要使用卷积层进行网络构建,并完成后续的训练、保存、导出,并使用opencv在C++中推理我们的模型,将结果可视化。

1 模型构造

  在pytorch中,卷积层的使用比较方便,需要注意的是卷积层的输入通道数、输出通道数、卷积核的大小等参数。这里直接放出构建的网络结构:

import torch
from torch import nn
from torch.utils.data import DataLoader
class ZKNNNet_Conv(nn.Module):
    def __init__(self):
        super(ZKNNNet_Conv, self).__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(12*12*64, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        logits = self.conv_stack(x)
        return logits

在这里插入图片描述

从图中可以看出,该模型先堆叠了两个卷积层与ReLU单元,经过最大池化之后,展开并进行后续的全连接层训练。

2 训练

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
from ZKNNNet import ZKNNNet_Conv
import os
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = ZKNNNet_Conv()
if os.path.exists("./model/model_conv.pth"):
    model.load_state_dict(torch.load("./model/model_conv.pth"))
model = model.to(device)
print(model)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Loss function
loss_fn = nn.CrossEntropyLoss()

# Train
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

# Test
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

epochs = 200
maxAcc = 0
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    currentAcc = test(test_dataloader, model)
    if maxAcc < currentAcc:
        maxAcc = currentAcc
        torch.save(model.state_dict(), "./model/model_conv.pth")
print("Done!")

模型的训练代码与上一篇中的线性连接训练代码是一样的。
训练过程来看,使用卷积层,在相同数据集上训练,模型收敛速度比用线性层快很多。最终精度达到97.8%。

3 推理

模型训练完成之后,推理过程与上一篇一致,这里简单放一下推理代码。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets
from ZKNNNet import ZKNNNet_Conv

import matplotlib.pyplot as plt

# Get cpu or gpu device for inference.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device for inference".format(device))

# Load the trained model
model = ZKNNNet_Conv()
model.load_state_dict(torch.load("./model/model_conv.pth"))
model.to(device)
model.eval()

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=64)



# Perform inference
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Visualize the image and its predicted result
        for i in range(len(images)):
            image = images[i].cpu()
            label = labels[i].cpu()
            prediction = predicted[i].cpu()

            plt.imshow(image.squeeze(), cmap='gray')
            plt.title(f"Label: {label}, Predicted: {prediction}")
            plt.show()

    accuracy = 100 * correct / total
    print("Accuracy on test set: {:.2f}%".format(accuracy))

4 导出

模型导出方式与上一篇一致。

import torch
import torch.utils
import os
from ZKNNNet import ZKNNNet_3Layer,ZKNNNet_5Layer,ZKNNNet_Conv
model_conv = ZKNNNet_Conv()
if os.path.exists('./model/model_conv.pth'):
    model_conv.load_state_dict(torch.load('./model/model_conv.pth'))
model_conv = model_conv.to(device)
model_conv.eval()
torch.onnx.export(model_conv,torch.randn(1,1,28,28),'./model/model_conv.onnx',verbose=True)

5 onnx测试

import onnxruntime as rt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision import datasets

import matplotlib.pyplot as plt

from PIL import Image

sess = rt.InferenceSession("model/model_conv.onnx")
input_name = sess.get_inputs()[0].name
print(input_name)

image = Image.open('./data/test/2.png')
image_data = np.array(image)
image_data = image_data.astype(np.float32)/255.0
image_data = image_data[None,None,:,:]
print(image_data.shape)

outputs = sess.run(None,{input_name:image_data})
outputs = np.array(outputs).flatten()

prediction = np.argmax(outputs)
plt.imshow(image, cmap='gray')
plt.title(f"Predicted: {prediction}")
plt.show()

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

# Create data loader.
test_dataloader = DataLoader(test_data, batch_size=1)

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_dataloader:
        images = images.numpy()
        labels = labels.numpy()
        outputs = sess.run(None,{input_name:images})[0]
        outputs = np.array(outputs).flatten()
        prediction = np.argmax(outputs)

        # Visualize the image and its predicted result
        for i in range(len(images)):
            image = images[i]
            label = labels[i]

            plt.imshow(image.squeeze(), cmap='gray')
            plt.title(f"Label: {label}, Predicted: {prediction}")
            plt.show()

至此,模型已经成功的转换成onnx模型,可以用于后续各种部署环境的部署。

6 opencv部署

本例中,使用C++/opencv来尝试部署刚才训练的模型。输入为在之前的博文中提到的将MNIST测试集导出成png图片保存。

#include "opencv2/opencv.hpp"

#include <iostream>
#include <filesystem>
#include <string>
#include <vector>

int main(int argc, char** argv)
{
    if (argc != 3)
    {
        std::cerr << "Usage: MNISTClassifier_onnx_opencv <onnx_model_path> <image_path>" << std::endl;
        return 1;
    }

    cv::dnn::Net net = cv::dnn::readNetFromONNX(argv[1]);
    if (net.empty())
    {
        std::cout << "Error: Failed to load ONNX file." << std::endl;
        return 1;
    }
    std::filesystem::path srcPath(argv[2]);

    for (auto& imgPath : std::filesystem::recursive_directory_iterator(srcPath))
    {
        if(!std::filesystem::is_regular_file(imgPath))
            continue;

        const cv::Mat image = cv::imread(imgPath.path().string(), cv::IMREAD_GRAYSCALE);
        if (image.empty())
        {
            std::cerr << "Error: Failed to read image file." << std::endl;
            continue;
        }

        const cv::Size size(28, 28);
        cv::Mat resized_image;
        cv::resize(image, resized_image, size);

        cv::Mat float_image;
        resized_image.convertTo(float_image, CV_32F, 1.0 / 255.0);

        cv::Mat input_blob = cv::dnn::blobFromImage(float_image);
        
        net.setInput(input_blob);
        cv::Mat output = net.forward();

        cv::Point classIdPoint;
        double confidence;
        cv::minMaxLoc(output.reshape(1, 1), nullptr, &confidence, nullptr, &classIdPoint);
        const int class_id = classIdPoint.x;

        std::cout << "Class ID: " << class_id << std::endl;
        std::cout << "Confidence: " << confidence << std::endl;
        cv::Mat bigImg;
        cv::resize(image,bigImg,cv::Size(128,128));
        auto parentPath = imgPath.path().parent_path();
        auto label = parentPath.filename().string()+std::string("<->")+std::to_string(class_id);
        cv::putText(bigImg, label, cv::Point(10, 20), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255), 1);
        cv::imshow("img",bigImg);
        cv::waitKey();
    }

    return 0;
}

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

7 总结

使用卷积神经网络进行MNIST手写字符识别,在模型结构无明显复杂的情况下,模型收敛速度较全连接层构建的网络收敛速度快。

按照相同的套路导出成onnx模型之后,直接通过opencv可以部署,简化深度学习算法部署的难度。

本部署方式需要依赖opencv dnn模块。试验中使用的是opencv4.8版本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1818338.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数字人的技术实现原理

数字人是一种利用计算机图形学、人工智能等技术创建的虚拟人物。数字人可以模拟真人进行各种动作和表情&#xff0c;并与用户进行交互。数字人的技术实现原理主要包括以下几个方面。北京木奇移动技术有限公司&#xff0c;专业的软件外包开发公司&#xff0c;欢迎交流合作。 1. …

如何将 API 管理从 Postman 转移到 Apifox

上一篇推文讲到用 Swagger 管理的 API 怎么迁移到 Apifox&#xff0c;有许多同学反馈说能不能介绍一下 Postman 的迁移以及迁移过程中需要注意的事项。那么今天&#xff0c;它来了&#xff01; 从 Postman 迁移到 Apifox 的方法有两种&#xff1a; 导出 Postman 集合 &#x…

诺派克ROPEX控制器维修RES-5008 RES-5006

德国希尔科诺派克ROPEX热封控制器维修型号包括&#xff1a;RES-401&#xff0c;RES-402&#xff0c;RES-403&#xff0c;RES-406&#xff0c;RES-407&#xff0c;RES-408&#xff0c;RES-409&#xff0c;RES-420&#xff0c;RES-440&#xff0c;RES-5008&#xff0c;RES-5006&a…

eNSP学习——PPP的认证

目录 主要命令 原理概述 实验目的 实验内容 实验拓扑 实验编址 实验步骤 1、基本配置 2、搭建OSPF网络 3、配置PPP的PAP认证 4、配置PPP的CHAP认证 主要命令 //设置本端的PPP协议对对端设备的认证方式为 PAP&#xff0c;认证采用的域名为huawei [R3]int s4/0/0 [R…

ToF原理记录

目录 1. ToF是什么&#xff1f;2. ToF深度测量原理2.1 脉冲调制法2.2 连续波调制法 1. ToF是什么&#xff1f; 飞行时间&#xff08;Time-of-Flight&#xff0c;ToF&#xff09;基本原理是通过连续发射光脉冲&#xff08;一般为不可见光&#xff09;到目标物体上&#xff0c;然…

BGP学习

BGP是一种矢量协议&#xff0c;使用TCP作为传输协议 ,目的端口号是179.是触发式更新&#xff0c;不是周期性更新 BGP的重点是策略路由的选路&#xff0c;能对路由进行路由汇总。运行BGP的路由器被称为BGP发言者&#xff0c;两个建立BGP会话的路由器互为对等体 IBGP和EBGP的区…

EasyExcel文件导出出现有文件但没有数据的问题

一开始由于JDK版本过高&#xff0c;我用的17&#xff0c;一直excel没有数据&#xff0c;表头也没有&#xff0c;后来摸索了好久&#xff0c;找了资料也没有&#xff0c;后来改了代码后报了一个错误&#xff08;com.alibaba.excel.exception.ExcelGenerateException: java.lang.…

Redis高性能原理:Redis为什么这么快?

目录 前言&#xff1a; 一、Redis知识系统观 二、Redis为什么这么快&#xff1f; 三、Redis 唯快不破的原理总结 四、Redis6.x的多线程 前言&#xff1a; Redis 为了高性能&#xff0c;从各方各面都进行了优化。学习一门技术&#xff0c;通常只接触了零散的技术点&#xff…

pioneer电源维修PM33213BP-10P-1-6PH-H

开关电源出现不启振的时候&#xff0c;我们通常需要查看开关频率是否正确、保护电路是否断路、电压反馈电路、电流反馈电路又没问题&#xff0c;开关管是否击穿等。 电源维修实践中&#xff0c;有许多开关电源采用UC38系列8脚PWM组件&#xff0c;大多数电源不能工作都是因为电…

搜索二叉树的概念及实现

搜索二叉树的概念 搜索二叉树规则&#xff08;左小右大&#xff09;&#xff1a; 非空左子树的键值小于其根节点的键值非空右子树的键值大于其根节点的键值左右子树均为搜索二叉树 如图&#xff1a; 在搜索时&#xff0c;若大于根&#xff0c;则去右子树寻找&#xff1b;若小…

基于单片机的多功能智能小车设计

第一章 绪论 1.1 课题背景和意义 随着计算机、微电子、信息技术的快速发展,智能化技术的发展速度越来越快,智能化与人们生活的联系也越来越紧密,智能化是未来社会发展的必然趋势。智能小车实际上就是一个可以自由移动的智能机器人,比较适合在人们无法工作的地方工作,也可…

基于WPF技术的换热站智能监控系统06--实现左侧故障统计

1、区域划分 2、ui实现 这里使用的是livechart的柱状图呈现的 3、运行效果 走过路过不要错过&#xff0c;点赞关注收藏又圈粉&#xff0c;共同致富&#xff0c;为财务自由作出贡献

【git使用二】gitee远程仓库创建与本地git命令用法

目录 gitee介绍 管理者注册gitee账号 管理者在gitee网站上创建远程仓库 每个开发者安装git与基本配置 1.git的下载和安装 2.配置SSH公钥 3.开发者信息配置 git命令用法 gitee介绍 Gitee&#xff08;又称码云&#xff09;是一个基于Git的代码托管服务&#xff0c;由开源…

python数据分析-量化分析

一、研究背景 随着经济的发展和金融市场的不断完善&#xff0c;股票投资成为了人们重要的投资方式之一。汽车行业作为国民经济的重要支柱产业&#xff0c;其上市公司的股票表现备受关注。Fama-French 三因子模型是一种广泛应用于股票市场的资产定价模型&#xff0c;它考虑了市场…

vue3中用setup写的数据,不能动态渲染(非响应式)解决办法

相比于2.0&#xff0c;vue3.0在新增了一个setup函数&#xff0c;我们在setup中可以写数据也可以写方法&#xff0c;就像我们以前最开始学习js一样&#xff0c;在js文件中写代码。 For instance <template><div class"person"><h2>姓名&#xff1…

中电联系列四:rocket手把手教你理解中电联协议!

分享《慧哥的充电桩开源SAAS系统&#xff0c;支持汽车充电桩、二轮自行车充电桩。》 电动汽车充换电服务信息交换 第4部分&#xff1a;数据传输与安全 Interactive of charging and battery swap service information for electric vehicle Part 4:Data transmission and secu…

平安科技智能运维案例

平安科技智能运维案例 在信息技术迅速发展的背景下&#xff0c;平安科技面临着运维规模庞大、内容复杂和交付要求高等挑战。通过探索智能运维&#xff0c;平安科技建立了集中配置管理、完善的运营管理体系和全生命周期运维平台&#xff0c;实施了全链路监控&#xff0c;显著提…

基于块生成最大剩余空间的三维装箱算法

问题简介 三维装箱问题&#xff08;3D Bin Packing Problem&#xff0c;3D BPP&#xff09;是一类组合优化问题。它涉及到将一定数量的三维物品放入一个或多个三维容器&#xff08;称为“箱子”&#xff09;中&#xff0c;同时遵循一定的约束&#xff0c;通常目标是最大化空间…

2024/06/13--代码随想录算法2/17| 62.不同路径、63. 不同路径 II、343. 整数拆分 (可跳过)、96.不同的二叉搜索树 (可跳过)

62.不同路径 力扣链接 动态规划5步曲 确定dp数组&#xff08;dp table&#xff09;以及下标的含义&#xff1a; dp[i][j] &#xff1a;表示从&#xff08;0 &#xff0c;0&#xff09;出发&#xff0c;到(i, j) 有dp[i][j]条不同的路径。确定递推公式&#xff0c;dp[i][j] d…