pytorch卷积神经网络CNN 手写数字识别 MNIST数据集

news2024/11/24 2:09:09

模型结构和训练代码来自这里 https://blog.csdn.net/weixin_41477928/article/details/123385000

俺又加了离线测试的代码:

  • 第一次运行此代码,需有网络,会下载开源数据集MNIST
  • 训练的过程中会把10个epoch的模型均保存到./models下,可能需要你创建好models文件夹。训练过程中的输出如下:
    [1,  300] loss:0.257
    [1,  600] loss:0.078
    [1,  900] loss:0.060
    Accuracy on test set:98 %
    ...
    [10,  300] loss:0.002
    [10,  600] loss:0.003
    [10,  900] loss:0.004
    Accuracy on test set:99 %
    
  • 如果想加载保存的模型文件,然后推理一个手写照片看预测结果,可将最下面main函数中的两个函数,注释第一个,使用第二个
    • 比如测试如下图片:
      在这里插入图片描述

    • 输出结果:

       The predicted digit is 5
      
import torch
from torchvision import transforms  # 是一个常用的图片变换类
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F

import cv2

# 如果有GPU那么就使用GPU跑代码,否则就使用cpu。cuda:0表示第1块显卡
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 将数据放在GPU上跑所需要的代码

# 定义数据批的大小,预处理
batch_size = 64
transform = transforms.Compose(
    [
        transforms.ToTensor(),  # 把数据转换成张量
        transforms.Normalize((0.1307,), (0.3081,))  # 0.1307是均值,0.3081是标准差
    ]
)

# 训练集、测试集 (首次运行会下载到root下)
train_dataset = datasets.MNIST(root='./data/',
                               train=True,
                               download=True,
                               transform=transform)
train_loader = DataLoader(train_dataset,
                          shuffle=True,
                          batch_size=batch_size)
test_dataset = datasets.MNIST(root='./data/',
                              train=False,
                              download=True,
                              transform=transform)
test_loader = DataLoader(test_dataset,
                         shuffle=True,
                         batch_size=batch_size)
 

# 定义一个神经网络
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 25, kernel_size=3),
            torch.nn.BatchNorm2d(25),
            torch.nn.ReLU(inplace=True)
        )
 
        self.layer2 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2, stride=2)
        )
 
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(25, 50, kernel_size=3),
            torch.nn.BatchNorm2d(50),
            torch.nn.ReLU(inplace=True)
        )
 
        self.layer4 = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2, stride=2)
        )
 
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(50 * 5 * 5, 1024),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(1024, 128),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(128, 10)
        )
 
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)  # 在进入全连接层之前需要把数据拉直Flatten
        x = self.fc(x)
        return x


# 实例化,得到神经网络的结构
model = MyNet()
model.to(device)  # 将数据放在GPU上跑所需要的代码

def train(epochs):
    criterion = torch.nn.CrossEntropyLoss()  # 使用交叉熵损失
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.5)  # momentum表示冲量,冲出局部最小

    running_loss = 0.0
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)  # 将数据放在GPU上跑所需要的代码
        optimizer.zero_grad()
        # 前向+反馈+更新
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if batch_idx % 300 == 299:  # 不让他每一次小的迭代就输出,而是300次小迭代再输出一次
            print('[%d,%5d] loss:%.3f' % (epochs + 1, batch_idx + 1, running_loss / 300))
            running_loss = 0.0

    torch.save(model.state_dict(), 'models/model_{}.pth'.format(epochs))
 
 
def test():
    correct = 0
    total = 0
    with torch.no_grad():  # 下面的代码就不会再计算梯度
        for data in test_loader:
            inputs, target = data
            inputs, target = inputs.to(device), target.to(device)  # 将数据放在GPU上跑所需要的代码
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, dim=1)  # _为每一行的最大值,predicted表示每一行最大值的下标
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print('Accuracy on test set:%d %%' % (100 * correct / total))
 

# 方式1:训练、测试
def train_test():
    for epoch in range(10):
        train(epoch)
        test()

# 方式2:加载保存到本地的模型权重,然后推理得到预测结果
def load_model_test():
    model.load_state_dict(torch.load("models/model_9.pth"))
    model.eval()

    # 使用 OpenCV 处理本地手写数字图片
    img = cv2.imread('data/5-1.png', cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (28, 28))
    img = img / 255.0

    img = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        output = model(img)  # 推理并得到输出

        # 导出模型为onnx
        torch_out = torch.onnx.export(model, 
            img, 
            "./models/model_9.onnx",
            input_names=['i0'],
            export_params=True,
            opset_version=11,     # 转换为哪个版本的 onnx
            do_constant_folding=True,  # 是否执行常量折叠优化
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK  # 命名输入输出;支持超前op
        )

    pred = torch.argmax(output, dim=1)

    print(f'The predicted digit is {pred.item()}')

if __name__ == '__main__':
    train_test()       # 先训练,再测试,并保存训练好的模型

    # load_model_test()   # 加载保存后的模型权重,推理预测


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/606082.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年第三届陕西省大学生网络安全技能大赛--本科高校组 Reverse题解

文章目录 一. 我的upx -d怎么坏了1. 查看节区信息2. 动态调试脱壳3.输出迷宫图4.走迷宫 二. babypython1.字节码简单分析2. gpt分析3. 程序逻辑4.解题脚本 三. BadCoffee1. 相关文章2.解混淆3.解题脚本 四. Web&Assembly(暂时没复现出来,提供一些相关文章)总结 这次比赛做出…

冈萨雷斯DIP第5章知识点

图像增强:主要是一种 主观处理,而图像复原很大程度上是一种 客观处理。 5.1 图像退化/复原处理的一个模型 如图5.1 本章把图像退化建模为一个算子 H \mathcal{H} H 该算子 与一个加性噪声项 η ( x , y ) η(x,y) η(x,y) 共同对输入图像 f ( x , y…

Rust每日一练(Leetday0013) 解数独、外观数列、组合总和

目录 37. 解数独 Sudoku Solver 🌟🌟🌟 38. 外观数列 Count and Say 🌟🌟 39. 组合总和 Combination Sum 🌟🌟 🌟 每日一练刷题专栏 🌟 Rust每日一练 专栏 Gola…

常微分方程(ODE)求解方法总结

常微分(ODE)方程求解方法总结 1 常微分方程(ODE)介绍1.1 微分方程介绍和分类1.2 常微分方程得计算方法1.3 线性微分方程求解的推导过程 2 一阶常微分方程(ODE)求解方法2.1 欧拉方法2.1.1 欧拉方法的改进思路…

逻辑推理——弟弟的编程课

前言 这篇文章不写代码,不科普知识。而是推理! 这是我弟弟编程课上的一个同学;不是我的同学,我都成年了,这还是个小毛孩! 这是他们学的: 乍一看这没任何问题,还有人会说&#xff…

谷歌地图模型自动下载

本工具是收费软件,学生党勿扰,闹眼子党勿扰 本工具收费1000元 视频教程 1 概述 记得去年写过一篇关于谷歌地图模型提取的博客,得到了广泛好评。有很多同学提出,能不能自动下载谷歌地图模型,由于提出此需求的人太多了…

【起点到终点 走哪条路径使得(路径长度排序从大到小后) 第k+1条边最小】通信线路

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…

单片机GD32F303RCT6 (Macos环境)开发 (三十四)—— 数字加速度计 (ADXL345)

数字加速度计 (ADXL345)- 计算xyz轴重力值 1、i2c总线读取 1、接线 上一节的软件模式i2c的方式,选择PB10(SCL) PB11(SDA)。 GD32 ADXL345PB10 --------------- SCLPB11 --------------- SDA3.3 --…

Eclipse 教程 完结

Eclipse 快捷键 关于快捷键 Eclipse 的很多操作都提供了快捷键功能,我们可以通过键盘就能很好的控制 Eclipse 各个功能: 使用快捷键关联菜单或菜单项使用快捷键关联对话窗口或视图或编辑器使用快捷键关联工具条上的功能按钮 Eclipse 快捷键列表可通过…

《crossfire》游戏分析

文章目录 一、 穿越火线简介和定位二、 游戏发行三、 用户基础四、 游戏玩法枪王排位团队竞技爆破模式歼灭模式突围模式幽灵模式生化模式个人竞技挑战模式跳跳乐地图工坊 五、 游戏竞技公平性cf竞技公平性 六、CF火热到现在的原因分析1.时代、空间背景2.用户基础3.丰富的游戏模…

【iOS】—— nil、Nil、NULL和NSNull学习

nil、Nil、NULL和NSNull 文章目录 nil、Nil、NULL和NSNullnilNSNullNilNULL总结: 我们先来看看这几个苹果官方文档的解释: nil:Defines the id of a null instance.(定义空实例的id)Nil:Defines the id of…

给编程初学者的一封信

提醒:以下内容仅做参考,具体请自行设计。 随着信息技术的快速发展,编程已经成为一个越来越重要的技能。那么,我们该如何入门编程呢?欢迎大家积极讨论 一、自学编程需要注意什么? 要有足够的时间、精力等…

大数据治理入门系列:数据目录

在元数据管理一文中,我们曾将数据比喻为一本本的书,将书的作者、出版时间等信息比喻为元数据。试想一下,假如你是一名新任的图书管理员,如何快速掌握图书馆的馆藏情况呢?假如你是一名读者,如何快速找到你需…

Redis GEO功能详细介绍与实战

一、概述 Redis的Geo功能主要用于存储地理位置信息,并对其进行操作。该功能在Redis 3.2版本新增。Redis Geo操作方法包括: geoadd:添加地理位置的坐标;geopos:获取地理位置的坐标;geodist:计算…

第五届湖北省大学生程序设计竞赛(HBCPC 2023)vp赛后补题

Problem - B - Codeforces 思路: 数位dp,如果我们暴力的计算的状态的话,显然就是记录每个数字出现几次。但是显然这样难以发挥数位dp的记忆化功效,因为只有出现次数相同,你是什么数字,实际是无所谓的。所…

I2C学习笔记——I2C协议学习

1、I2C简介:一种简单、双线双向的同步串行总线,利用串行时钟线(SCL)和串行数据线(SDA)在连接总线的两个器件之间进行信息传递; 数据传输是通过对SCL和SDA线高低电平时序的控制,来产生I2C总线协议所需要的信号。在总线空闲状态时&a…

【Linux C】基于树莓派/香橙派的蓝牙服务端——支持多蓝牙设备接入

一、需求 在树莓派/香橙派上利用开发板自带的蓝牙作为一个蓝牙服务端(主机),允许外来设备(从机)通过蓝牙接入进行通信,通信格式为透传方式;采用的编程语言为Linux C 二、环境准备 bluez安装 …

三波混频下的相位失配原理

原理推导 在四波混频情况下,实现零相位失配是一件很困难的事情。因为在四波混频中,相位调制和增益都依赖于相同的参数,即克尔非线性 γ \gamma γ。这个问题可以用嵌入在传输线上的辅助共振元件的复杂色散工程来部分解决。 但是在三波混频中…

ceph集群监控

文章目录 Ceph Dashboard启用dashboard插件dashboard启用ssl Promethues监控ceph启用prometheus模块配置prometheus采集数据grafana数据展示 Ceph Dashboard ceph-dashboard官方介绍:https://docs.ceph.com/en/latest/mgr/dashboard/ Ceph Dashboard是一个内置的c…

数据库系统概论---选择题刷题实训

(一)选择题 1.下列选项中,不属于关系模型三要素的是( C ) A.数据结构 B.数据操纵 C.数据安全 D.数据完整性规则 2.保证数据库…