TensorRT推理手写数字分类(一)

news2024/9/21 23:35:17

系列文章目录

(一)使用pytorch搭建模型并训练


文章目录

  • 系列文章目录
  • 前言
  • 一、网络搭建
    • 1.LeNet网络结构
    • 2.pytorch代码
  • 二、网络训练
    • 1.pytorch代码
    • 2.结果展示
  • 三、保存和加载模型
    • 1.保存整个网络
    • 2.保存网络中的参数
  • 总结


前言

  为了学习一下使用TensorRT进行推理的全过程,便想着写一个TensorRT推理手写数字分类的小例程。这个例程包括使用pytorch进行LeNet网络的搭建、训练、保存pytorch格式的模型(pth)、将模型(pth)转为onnx通用格式、使用tensorRT解析onnx模型进行推理等。
  本节介绍使用pytorch进行手写数字分类网络的搭建,并进行训练。


一、网络搭建

1.LeNet网络结构

网络结构图如下所示:
在这里插入图片描述

结构说明:输入是单通道的12828的灰度图像,经过卷积、池化、卷积、池化后shape变为5044(50为通道数)。将其展平后维度为1*800,然后连接一个维度为500的线性层C5,C5层的输出经过ReLU函数激活后再连接一个维度为10的线性层C6,C6层的输出就为网络的输出。
一般来说,我们要求的是输入图片属于某一类的概率,所有我们要将C6的输出通过softmax函数进行转换。

2.pytorch代码

新建model.py文件,包含以下代码:

# 搭建网络模型
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchinfo import summary  #用来打印网络层的信息
# from torchkeras import summary  module 'torch.backends' has no attribute 'mps'

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, kernel_size=5)
        self.conv2 = nn.Conv2d(20, 50, kernel_size=5)
        self.fc1 = nn.Linear(800, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), kernel_size=2, stride=2)
        x = F.max_pool2d(self.conv2(x), kernel_size=2, stride=2)
        x = x.view(-1, 800)  # 将其展平
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  # 要使用NLLLoss()损失函数,所以输出要先经过log_softmax

if __name__ == "__main__":
    net = Net()
    summary(net, (1,1,28,28))

二、网络训练

1.pytorch代码

新建train.py,包含以下代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

from model import Net
import numpy as np
import os
import torch.utils.data
from random import randint
class MnistModel(object):
    def __init__(self):
        self.batch_size = 64  # 训练batch_size
        self.test_batch_size = 100  # 测试batch_size 
        self.learning_rate = 0.0025  #学习率
        self.sgd_momentum = 0.9
        self.log_interval = 100

        # 构造数据
        self.train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "./tmp/mnist/data",
                train=True,
                download=True,
                transform=transforms.Compose(  # 预处理:对训练数据只进行标准化
                                            [transforms.ToTensor(),
                                            transforms.Normalize((0.1307),(0.3081,))])
                ),
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=4,
                timeout=600,
        )
        self.test_loader = torch.utils.data.DataLoader(
                datasets.MNIST(
                    "./tmp/mnist/data",
                    train=False,
                    transform=transforms.Compose(
                                            [transforms.ToTensor(),
                                            transforms.Normalize((0.1307),(0.3081,))])
                ),
                batch_size = self.test_batch_size,
                shuffle = True,
                num_workers=4,
                timeout=600,
        )
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.network = Net()
        self.network.to(self.device)  

    def learn(self, num_epochs=2):  # 训练两个epoch
        # 
        # Train the network for a single epoch
        def train(epoch):
            self.network.train()
            optimizer = optim.SGD(self.network.parameters(), lr=self.learning_rate, momentum=self.sgd_momentum)  # 使用SGD优化器
            for batch, (data, target) in enumerate(self.train_loader):
                data, target = Variable(data.to(self.device)), Variable(target.to(self.device))
                optimizer.zero_grad()
                output = self.network(data)
                loss = F.nll_loss(output, target).to(self.device)
                loss.backward()
                optimizer.step()
                if batch % self.log_interval == 0: #每100个batch打印一次信息
                    print(
                        "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                            epoch,
                            batch * len(data),
                            len(self.train_loader.dataset),
                            100.0 * batch / len(self.train_loader),
                            loss.data.item(),
                        )
                    )

        # Test the network
        def test(epoch):
            self.network.eval()
            test_loss = 0
            correct = 0
            for data, target in self.test_loader:
                with torch.no_grad():
                    data, target = Variable(data.to(self.device)), Variable(target.to(self.device))
                output = self.network(data)
                test_loss += F.nll_loss(output, target).data.item()
                pred = output.data.max(1)[1]  # 输出最大值的索引为预测的类别
                correct += pred.eq(target.data).cpu().sum()
            test_loss /= len(self.test_loader)v # 测试集每一个batch的平均损失
            print(
                "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
                    test_loss, correct, len(self.test_loader.dataset), 100.0 * correct / len(self.test_loader.dataset)
                )
            )

        for e in range(num_epochs):
            train(e + 1)
            test(e + 1)
train_model = MnistModel()
train_model.learn()

总的来说,训练代码中没有太值得让人注意的地方。如果非要说有,那我觉得以下三点可能是要注意的地方:

  1. 对输入的预处理,转为Tensor,然后作了标准化(均值为0,标准差为1),除此之外再也没有做其他的操作。
  2. 损失函数这里,我们决定使用交叉熵损失函数。因为我们在定义网络时,网络最后一层的输出经过了log_softmax,所以这里使用了nn.NLLLoss()损失函数即可。如果你网络最后一层的输出没有经过log_softmax,那么你可以使用nn.CrossEntropyLoss(),因为nn.NLLLoss()+log_softmax=nn.CrossEntropyLoss()。在代码中,我们使用的是F.nll_loss()函数,其实与nn.NLLLoss()没有区别(nn.NLLLoss()类其实也是调用F.nll_loss()函数)。
  3. 这里选择只训练两个epoch,是因为我在训练的时候,两个epoch后网络在验证集上就有比较好的效果,网络训练打印的信息在结果展示中贴出。

2.结果展示

在这里插入图片描述
可以看到,两个epoch后,模型的准确率为99%,所以我选择停止训练,然后保存模型。


三、保存和加载模型

在pytorch中保存模型有两种形式,一种是保存整个网络,一种是只保存网络中的参数。

1.保存整个网络

保存整个网络的方法如下:

# 保存整个网络
torch.save(net, path)
# 加载网络
model = torch.load(path)

2.保存网络中的参数

只保存网络中的参数的方法如下:

# 保存
torch.save(net.state_dict(), path)
# 加载
model = model.load_state_dict(torch.load(path))

在这个demo中,我们只需要在train.py后加上

torch.save(net.state_dict(), './model.pth')

就可以保存模型为model.pth文件。

总结

本节我们进行了模型的搭建、训练以及保存模型。下一节我们将介绍如何将我们保存的pth文件转为onnx通用格式,同时对我们转成的onnx文件进行检查和验证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/69544.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

硬核!Github星标79.4K的阿里强推Java面试参考指南到底有多强?

谈到Java面试,相信大家第一时间脑子里想到的词肯定是金三银四,金九银十。好像大家的潜意识里做Java开发的都得在这个时候才能出去面试,跳槽成功率才高!但LZ不这么认为,LZ觉得我们做技术的一生中会遇到很多大大小小的面…

Java并发编程—如何写好代码?链式调用该怎么玩?

目录 一、案例说明 二、原生方式代码流程 三、链式调用代码流程 四、链式调用手搓的方式: 五、总结: 在上一篇博客https://blog.csdn.net/qq_52545155/article/details/128212148?spm1001.2014.3001.5501,博主在写商城统计商品价格的时…

mybatis中其他数据源也使用XML进行操作(SqlSessionFactory.openSession(Connection connection)方法)

文章目录1. 前言2. 先说结论3. 例子1. 准备数据2. 思考过程3. 结论1. 前言 当前在使用springbootmybatis的时候,通常会先在配置文件中配置好数据源,并在Mapper.xml文件编写好相关SQL,使用mybatis进行对数据库进行所谓的crud操作。 有时候会出…

nginx代理https妈妈级手册

目录 背景说明 相关地址 https证书生成 nginx安装及配置 结果展示​编辑 背景说明 为了保证传输加密、访问安全,我们采用nginx服务器将http服务代理为https。所需材料:openssl(用来生成证书)、http服务、nginx自身。 相关地址…

C/C++第三方库zeromq、log4cpp交叉编译、本地安装ubuntu180.04

一、zeromq的编译安装 1)ubuntu下命令 apt-get install libzmq3-dev不推荐这种方式,因为很可能安装的版本并不是最新的; 2)自己编译安装(推荐) 地址:https://github.com/zeromq/libzmq/relea…

设计模式--装饰者模式

文章目录前言一、未使用设计模式二、装饰者模式1.定义2.角色三、应用场景四、优缺点优缺前言 晓子(咖啡店员),来一杯美式,加点威士忌和砂糖。 抱歉啊,猫。收银系统还没有你说的组合,要不换一个&#x1f60…

React 的调度系统 Scheduler

大家好,我是前端西瓜哥。今天来学习 React 的调度系统 Scheduler。 React 版本为 18.2.0 React 使用了全新的 Fiber 架构,将原本需要一次性递归找出所有的改变,并一次性更新真实 DOM 的流程,改成通过时间分片,先分成一…

nnUnet测试

https://github.com/MIC-DKFZ/nnUNet nnUnet要在Windows上跑起来有点麻烦,主要是项目路径的问题,我目前测试了2分类遥感数据(其实只要是二分类都行,无所谓什么数据),我这里说难是因为我没有安装&#xff0…

【SQL】MVCC 多版本并发控制

MVCC多版本并发控制快照读与当前读隔离级别隐藏字段,undo log 版本链隐藏字段trx_id版本链read view举例说明read committed(读已提交)隔离级别下repeatable read(可重复读)隔离级别下innodb如何解决幻读总结并发问题的…

LaTex使用技巧9:argmin / argmax下标写法

记录两种写法 1.arg⁡max⁡θ\mathop{\arg\max}\limits_{\theta}θargmax​的写法 写法1: $\mathop{\arg\max}\limits_{\theta}$ 写法2: $\sideset{}{}{\arg\max}_{\theta}^{} $ 2.arg⁡min⁡θ\mathop{\arg\min}\limits_{\theta}θargmin​的写法 写法…

STL常用生成算法和集合算法(20221207)

STL的常用算法 概述&#xff1a; 算法主要是由头文件<algorithm> <functional> <numeric> 组成。 <algorithm>是所有STL头文件中最大的一个&#xff0c;涉及比较、交换、查找、遍历等等&#xff1b; <functional>定义了一些模板类&#xff0…

做一个公司网站大概要多少钱?

做一个公司网站大概要多少钱&#xff0c;很多公司在做网站之前可能已经简单了解过费用&#xff0c;但是费用差距都会比较大&#xff0c;为什么的呢&#xff0c;因为一般都是受到制作方式因素的影响。下面给大家说说不同的方式做一个公司网站大概要多少钱。 一、自己/团队做公司…

SQLyog —— 图形化工具使用

SQLyog下载链接&#xff1a; 点击跳转 在这一篇内容MySQL数据库 —— 常用语句当中讲到关于MySQL数据库命令的基本使用&#xff0c;这一篇是关于SQLyog数据库图形化工具的内容&#xff0c;先进行安装演示后在通过SQLyog进行操作数据库&#xff1a; SQLyog 安装 下载完成之后双击…

pageoffice在线打开word文件加盖电子印章

一、加盖印章的 js 方法 js方法 二、常见使用场景 1、常规盖章。弹出用户名、密码输入框&#xff0c;选择对应印章。 点击盖章按钮弹出用户名密码登录框&#xff0c;登录以后显示选择电子印章。 document.getElementById("PageOfficeCtrl1").ZoomSeal.AddSeal(…

Python模块pathlib操作文件和目录操作总结

前言 目前大家常用的对于文件和操作的操作使用 os.path 较多&#xff0c;比如 获取当前路径os.getcwd()&#xff0c;判断文件路径是否存在os.path.exists(folder) 等等。 在Python3.4开始&#xff0c;官方提供了 pathlib 面向对象的文件系统路径&#xff0c;核心的点在于 面向…

chatGPT代码写的有点好啊,程序员要失业了?

AI神器ChatGPT 火了。 能直接生成代码、会自动修复bug、在线问诊、模仿莎士比亚风格写作……各种话题都能hold住&#xff0c;它就是OpenAI刚刚推出的——ChatGPT。 有脑洞大开的网友甚至用它来设计游戏&#xff1a;先用ChatGPT生成游戏设定&#xff0c;再用Midjourney出图&…

element-plus elplus el-tree三种图标自定义 并且点击图标展开收起 点击文字获取数据

前言 公司需求,需要实现如下样式的树形列表 (基于vue3 element-plus) 当节点展开时,显示展开的文件夹图标,当节点收起时显示收起的文件夹,最后一级显示文件样式 废话没有了, 代码如下 <!-- 树形列表组件 --> <template><div class"tree-input" v-i…

Vue学习:回顾Object.defineProperty(给对象添加或者定义属性的)

<script>//定义对象let person{name:李四,sex:"男"}Object.defineProperty(person,age,{value:18});//参数:添加属性的对象 添加的属性名 配置项console.log(person)</script> 颜色不同&#xff1a;说明了age不可以枚举age属性不参与遍历 Object.keys(…

电脑屏幕录制怎么弄?电脑上怎么录制屏幕, 3个实用方法

对于日常办公的小伙伴来说&#xff0c;电脑、键盘、鼠标等办公设备都是不可分割的。事实上&#xff0c;不仅仅是在日常办公&#xff0c;在很多业余的活动中&#xff0c;也会使用到电脑设备。在使用电脑的时候&#xff0c;会经常有需要录制电脑屏幕的情况&#xff0c;比如记录会…

阿里云Linux热扩容云盘(growpart和resize2fs工具)

阿里云linux机器系统盘空间不够进行扩容 一、扩容物理盘 阿里云控制台在线扩容完成 二、安装growpart工具和resize2fs工具 [rootA ~]# yum install cloud-utils-growpart [rootA ~]# yum install xfsprogs 三、检查扩容磁盘属性 1、检查云盘大小 /dev/vda1显示容量为20G(在线…