实战:基于卷积的MNIST手写体分类

news2025/1/19 7:18:10

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。

1.  数据的准备

在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地标注出数据的通道,代码如下:

import numpy as np

import einops.layers.torch as elt

#载入数据

x_train = np.load("../dataset/mnist/x_train.npy")

y_train_label = np.load("../dataset/mnist/y_train_label.npy")

x_train = np.expand_dims(x_train,axis=1)   #在指定维度上进行扩充

print(x_train.shape)

这里是对数据的修正,np.expand_dims的作用是在指定维度上进行扩充,这里在第二维(也就是PyTorch的通道维度)进行扩充,结果如下:

(60000, 1, 28, 28)

2.  模型的设计

下面使用PyTorch 2.0框架对模型进行设计,在本例中将使用卷积层对数据进行处理,完整的模型如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
class MnistNetword(nn.Module):
    def __init__(self):
        super(MnistNetword, self).__init__()
        #前置的特征提取模块
        self.convs_stack = nn.Sequential(
            nn.Conv2d(1,12,kernel_size=7),  	#第一个卷积层
            nn.ReLU(),
            nn.Conv2d(12,24,kernel_size=5), 	#第二个卷积层
            nn.ReLU(),
            nn.Conv2d(24,6,kernel_size=3)  	#第三个卷积层
        )
        #最终分类器层
        self.logits_layer = nn.Linear(in_features=1536,out_features=10)
    def forward(self,inputs):
        image = inputs
        x = self.convs_stack(image)        
        #elt.Rearrange的作用是对输入数据的维度进行调整,读者可以使用torch.nn.Flatten函数完成此工作
        x = elt.Rearrange("b c h w -> b (c h w)")(x)
        logits = self.logits_layer(x)
        return logits
model = MnistNetword()
torch.save(model,"model.pth")

这里首先设定了3个卷积层作为前置的特征提取层,最后一个全连接层作为分类器层,需要注意的是,对于分类器的全连接层,输入维度需要手动计算,当然读者可以一步一步尝试打印特征提取层的结果,依次将结果作为下一层的输入维度。最后对模型进行保存。

3.  基于卷积的MNIST分类模型

下面进入本章的最后示例部分,也就是MNIST手写体的分类。完整的训练代码如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
#载入数据
x_train = np.load("../dataset/mnist/x_train.npy")
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x_train = np.expand_dims(x_train,axis=1)
print(x_train.shape)
class MnistNetword(nn.Module):
    def __init__(self):
        super(MnistNetword, self).__init__()
        self.convs_stack = nn.Sequential(
            nn.Conv2d(1,12,kernel_size=7),
            nn.ReLU(),
            nn.Conv2d(12,24,kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(24,6,kernel_size=3)
        )
        self.logits_layer = nn.Linear(in_features=1536,out_features=10)
    def forward(self,inputs):
        image = inputs
        x = self.convs_stack(image)
        x = elt.Rearrange("b c h w -> b (c h w)")(x)
        logits = self.logits_layer(x)
        return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
#注意记得将model发送到GPU计算
model = MnistNetword().to(device)
model = torch.compile(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
batch_size = 128
for epoch in range(42):
    train_num = len(x_train)//128
    train_loss = 0.
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size
        x_batch = torch.tensor(x_train[start:end]).to(device)
        y_batch = torch.tensor(y_train_label[start:end]).to(device)
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()  # 记录每个批次的损失值
    # 计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_size
    print("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

在这里,我们使用了本章新定义的卷积神经网络模块作为局部特征抽取,而对于其他的损失函数以及优化函数,只使用了与前期一样的模式进行模型训练。最终结果如下所示,请读者自行验证。

(60000, 1, 28, 28)
epoch: 0 train_loss: 2.3 accuracy: 0.11
epoch: 1 train_loss: 2.3 accuracy: 0.13
epoch: 2 train_loss: 2.3 accuracy: 0.2
epoch: 3 train_loss: 2.3 accuracy: 0.18
…
epoch: 58 train_loss: 0.5 accuracy: 0.98
epoch: 59 train_loss: 0.49 accuracy: 0.98
epoch: 60 train_loss: 0.49 accuracy: 0.98
epoch: 61 train_loss: 0.48 accuracy: 0.98
epoch: 62 train_loss: 0.48 accuracy: 0.98

Process finished with exit code 0

本文节选自《PyTorch 2.0深度学习从零开始学》,本书实战案例丰富,可带领读者快速掌握深度学习算法及其常见案例。

   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/951561.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

对比Flink、Storm、Spark Streaming 的反压机制

分析&回答 Flink 反压机制 Flink 如何处理反压? Storm 反压机制 Storm反压机制 Storm 在每一个 Bolt 都会有一个监测反压的线程(Backpressure Thread),这个线程一但检测到 Bolt 里的接收队列(recv queue)出现了…

C++笔记之临时变量与临时对象与匿名对象

C笔记之临时变量与临时对象与匿名对象 code review! 文章目录 C笔记之临时变量与临时对象与匿名对象1.C中的临时变量指的是什么?2.C中的临时对象指的是什么?3.C中临时对象的作用是什么?什么时候要用到临时对象?4.给我列举具体的例子说明临…

idm下载视频

idm下载视频 安装后 地址为: 链接:下载地址 提取码:fgzv 安装后 设置浏览器插件 完成 参考文章

自动驾驶和辅助驾驶系统的概念性架构(二)

摘要: 本篇为第二部分主要介绍底层计算单元、示例工作负载 前言 本文档参考自动驾驶计算联盟(Autonomous Vehicle Computing Consortium)关于自动驾驶和辅助驾驶计算系统的概念系统架构。该架构旨在与SAE L1-L5级别的自动驾驶保持一致。本文主要介绍包括功能模块图…

自从学了C++之后,小雅兰就有对象了!!!(类与对象)(下)——“C++”

各位CSDN的uu们好呀,好久没有更新啦,今天继续类和对象的内容,下面,让我们进入西嘎嘎类和对象的世界吧!!! 1. 再谈构造函数 2. Static成员 3. 友元 4. 内部类 5.匿名对象 6.拷贝对象时的一些…

设计模式的使用——模板方法模式+动态代理模式

一、需求介绍 现有自己写的的一套审批流程逻辑,由于代码重构,需要把以前的很多业务加上审批的功能,再执行完审批与原有业务之后,生成一个任务,然后再统一处理一个任务(本来是通过数据库作业去处理的&#x…

pandas数据分析之数据绘图

一图胜千言,将信息可视化(绘图)是数据分析中最重要的工作之一。它除了让人们对数据更加直观以外,还可以帮助我们找出异常值、必要的数据转换、得出有关模型的想法等等。pandas 在数据分析、数据可视化方面有着较为广泛的应用。本文…

C# 试图加载格式不正确的程序。 (异常来自 HRESULT:0x8007000B)

C# 在调用Cdll时,可能会出现 :试图加载格式不正确的程序。 (异常来自 HRESULT:0x8007000B)这个错误。 一般情况下是C#目标平台跟Cdll不兼容,64位跟32位兼容性问题, a.客户端调用Cdll报的错则, 1)允许的话把C#客户端…

SQL-basics

SQL 一些常用的查询语句用法 SQL 中的聚合函数 SQL 中的子查询 SQL 使用实例 SELECT F_NAME , L_NAME FROM EMPLOYEES WHERE ADDRESS LIKE ‘%Elgin,IL%’; SELECT F_NAME , L_NAME FROM EMPLOYEES WHERE B_DATE LIKE ‘197%’; SELECT * FROM EMPLOYEES WHERE (SALARY BET…

总结974

今日共计学习12h,日计划完成90%.今晚又把总结时间占用了,明天预留0.5h进行月总结吧,重新制定学习时间表,之前的已经用不了。 跟一个学府的老师聊了聊天,感觉聊完之后,本以为会心情舒畅,没想到反…

ELK安装、部署、调试(五)filebeat的安装与配置

1.介绍 logstash 也可以收集日志,但是数据量大时太消耗系统新能。而filebeat是轻量级的,占用系统资源极少。 Filebeat 由两个主要组件组成:harvester 和 prospector。 采集器 harvester 的主要职责是读取单个文件的内容。读取每个文件&…

同学,您有一张校招绿通卡请查收!

“金三银四”过去,马上“金九银十”了,有实习还没着落的,有在实习但留下成谜的,也有不顾其他只忙秋招的;还有依旧撒网投简历的;以及23 届还在找工作的。 小伙伴们,今年真的好难。 &#xff08…

SQLI-labs-第三关

目录 知识点:单括号)字符get注入 1、判断注入点: 2、判断当前表的字段数 3、判断回显位置 4、爆库名 5、爆表名 6、爆字段名,以users表为例 7、爆值 知识点:单括号)字符get注入 思路: 1、判断注入点&#xff1…

jmeter+nmon+crontab简单的执行接口定时压测

一、概述 临时接到任务要对系统的接口进行压测,上面的要求就是:压测,并发2000 在不熟悉系统的情况下,按目前的需求,需要做的步骤: 需要有接口脚本需要能监控系统性能需要能定时执行脚本 二、观察 >针…

springsecurity+oauth 分布式认证授权笔记总结12

一 springsecurity实现权限认证的笔记 1.1 springsecurity的作用 springsecurity两大核心功能是认证和授权,通过usernamepasswordAuthenticationFilter进行认证;通过filtersecurityintercepter进行授权。springsecurity其实多个filter过滤链进行过滤。…

STM32+RTThread配置以太网无法ping通,无法获取动态ip的问题

记录一个非常蠢的问题,今天在移植rtthread的以太网驱动的时候出现无法获取动态ip的问题,问题如下: 设置为动态ip时不管是连接路由器还是电脑主机都无法ping通,也无法获取dns地址。 设置为静态ip时无法ping通主机。 使用wireshark…

Xilinx-7系列之可配置逻辑块CLB

目录 一、概览 二、CLB结构 三、Slice内部结构 3.1 SliceM结构 3.2 SliceL结构 3.3 查找表LUT 3.4 多路复用器 3.5 存储单元 3.6 进位逻辑 四、应用 4.1 分布式RAM 4.2 ROM(只读存储器) 4.3 Shift Registers( 移位寄存器) 4.4 存储资源容量…

【【萌新的STM32学习23----数据通信的基本类型】】

萌新的STM32学习23----数据通信的基本类型 数据通信的基本概念 数据通信方式可以分为串行通信,并行通信 串行通信: 数据逐位按顺序依次传输 并行: 数据各位通过多条线同时传输 串行通信: 传输效率低,抗干扰能力强&am…

Milk-V Duo开发板实战——基于MobileNetV2的的图像分类

Milk-V Duo开发板实战——基于MobileNetV2的的图像分类 本教程介绍使用TPU-MLIR工具链对MobileNet-Caffe模型进行转换,生成MLIR以及MLIR量化成INT8模型,并在Milk-V Duo开发板上进行部署测试,完成图像分类任务,涉及以下步骤&#…

ELK安装、部署、调试(三)zookeeper安装,配置

1.准备 java安装,系统自带即可 2.下载zookeeper zookeeper.apache.org上可以下载 tar -zxvf apache-zookeeper-3.7.1-bin.tar.gz -C /usr/local mv apache-zookeeper-3.7.1-bin zookeeper 3.配置zookeeper mv zoo_sample.cfg zoo.cfg /usr/local/zookeeper/con…