简单的CNN实现——MNIST手写数字识别

news2025/1/12 2:45:13

0.概述

此文章不涉及复杂的理论知识,仅仅只是利用PyTorch组建一个简单的CNN去实现MNIST的手写数字识别,用好的效果去激发学习CNN的好奇心,并且以后以此为基础,去进行一些改造。(前提是把基础代码看明白)
本文CNN网络结构:
在这里插入图片描述
以下为最基本的代码(不需要GPU):

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Super parameter
batch_size = 64
lr = 0.01
momentum = 0.5
epoch = 10
# Prepare dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Design model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.MaxPool2d(2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(10, 20, kernel_size=5),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Linear(320, 10)
        )
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # flatten (batch, 20,4,4) ==> (batch,320)
        x = self.fc(x)
        return x
model = Net()
# Construct loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# Train and Test
def train():
    for (images, target) in train_loader:
        outputs = model(images)
        loss = criterion(outputs, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, target) in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print('[%d / %d]: %.1f %% ' % (i + 1, epoch, 100 * correct / total))
# Start train and Test
print('Accuracy on test set:')
for i in range(epoch):
    train()
    test()

输出结果:

Accuracy on test set:
[1 / 10]: 96.7 % 
[2 / 10]: 97.7 % 
[3 / 10]: 98.1 % 
[4 / 10]: 98.4 % 
[5 / 10]: 98.2 % 
[6 / 10]: 98.8 % 
[7 / 10]: 98.6 % 
[8 / 10]: 98.7 % 
[9 / 10]: 98.7 % 
[10 / 10]: 98.9 % 

1.MNIST数据集介绍

1.数据量

MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的单通道灰度图像,每张图像包含一个手写数字。

2.标注类别

共10个类别,每个类别代表0~9之间的一个数字,每张图像只有一个类别。

3.可视化

from matplotlib import pyplot as plt
from torchvision import datasets, transforms
# Prepare dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# View picture
fig = plt.figure()
for i in range(12):
    plt.subplot(3, 4, i + 1)
    plt.tight_layout()
    plt.imshow(train_dataset.data[i], cmap='gray', interpolation='none')
    plt.title("Label: {}".format(train_dataset.targets[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

在这里插入图片描述

4.张量化

二进制压缩文件–>train_dataset->train_loader

# Prepare dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform,download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
1.train_dataset中的数据组织
from torchvision import datasets, transforms
# Prepare dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# Explore train_dataset
x = train_dataset
print(type(x))  # <class 'torchvision.datasets.mnist.MNIST'>
print(len(x))  # 60000
print(x)

print(type(x[0]))  # <class 'tuple'>
print(x[0])
print(len(x[0]))  # 2

print(type(x[0][0]))  # <class 'torch.Tensor'>
print(type(x[0][1]))  # <class 'int'>

print(x[0][0].shape)  # torch.Size([1, 28, 28])  图片
print(x[0][1])  # 5  类别标签

结论:train_dataset是一个含有60000个数据点的Dataset类,每个数据点(如x[0])是一个长度为2的元组,索引0表示图片张量,索引1表示图片的类别标签(0~9)

2.train_loader中的数据组织

同理可得结论: train_loader是一个生成器,我们设置了batch_size是4,所以dataloader会把60000个样本,4个样本一组,按照组的顺序一组一组传给我们,总共938组,每组4张图片和对应标签。每一组的类型是长度为2的list列表,索引0表示一个412828的张量,即把4个图片张量拼在一起,索引1表示一个41的张量,即把4个标签拼在一起。

2.模型设计

图片张量维度的两个变化点:
1.通道数C:卷积层会改变它,1->10->20
2.尺寸W*H:卷积层会小幅改变它,池化层会大幅改变它,28->24->12->8->4

构造模型的两个关注点:
1.卷积层关注前后的通道数变化
2.全连接层关注连接前一张图片的全通道像素数320和连接后的分类标签数10
在这里插入图片描述
在连接到全连接层之前,将一张图片的所有通道全部展开和连接构成一个一维数组,即2044展开为320个元素组成的数组,经过全连接层将其按权重加和为10个类别标签。

按照模型图代码设计如下:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.MaxPool2d(2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(10, 20, kernel_size=5),
            nn.MaxPool2d(2),
        )
        self.fc = nn.Sequential(
            nn.Linear(320, 10)
        )
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # flatten (batch, 20,4,4) ==> (batch,320)
        x = self.fc(x)
        return x
model = Net()

1.torch.nn.Module

Module类是所有神经网络模块的基类,Module可以以树形结构包含其他的Module。Module类中包含网络各层的定义及forward方法,下面介绍我们如何定义自已的网络:

  1. 需要继承nn.Module类,并实现forward方法;
  2. 一般把网络中具有可学习参数的层放在构造函数__init__()中;
  3. 不具有可学习参数的层(如ReLU)可在forward中使用nn.functional来代替;
  4. 只要在nn.Module的子类中定义了forward函数,利用Autograd自动实现反向求导。

2.super(Net, self).init()

子类Net类继承父类nn.Module,super(Net, self).init()就是对继承自父类nn.Module的属性进行初始化。并且是用nn.Module的初始化方法来初始化继承的属性。也就是:用父类的方法初始化子类的属性。
为什么要用父类的方法去初始化属性呢?原因很简单:因为父类的方法已经写好了,我们只需要调用就可以了。不需要自己写一堆代码去初始化各种权重和参数和处理一堆forward和backward的逻辑。
python中__init()的作用:在python中创建类后,通常会创建一个 init ()方法,这个方法会在创建类的实例的时候自动执行。

3.torch.nn.Sequential()

torch.nn.Sequential 类是 torch.nn 中的一种序列容器,最主要的是,参数会按照我们定义好的序列自动传递下去。
不使用Sequential :
在这里插入图片描述
使用Sequential :
在这里插入图片描述
输出model结果如下:

Net(
  (conv1): Sequential(
    (0): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=320, out_features=10, bias=True)
  )
)

4.torch.nn.Conv2d(1, 10, kernel_size=5)

函数原型:

torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)

参数说明:
在这里插入图片描述

5.torch.nn.Linear(320, 10)

函数原型:torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
在这里插入图片描述

6.x = x.view(x.size(0), -1)

作用是将前面多维度的tensor展平成一维。一般出现在model类的forward函数中,具体位置一般都是在调用分类器之前。分类器是一个简单的nn.Linear()结构,输入输出都是维度为1的值。
x.size()为(batch_size,channels,H,W),则x.size(0)=batch_size。
view()函数的功能和reshape类似,用来转换size大小。x = x.view(batchsize, -1)中batchsize指转换后有几行,而-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。

3.训练与测试

1.训练

def train():
    for (images, target) in train_loader:
        outputs = model(images)
        loss = criterion(outputs, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

1.获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数。
2.optimizer.zero_grad() 清空过往梯度。
3.loss.backward() 反向传播,计算当前梯度。
4.optimizer.step() 根据梯度更新网络参数。

2.测试

def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for (images,target) in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    print('[%d / %d]: %.1f %% ' % (i+ 1, epoch, 100 * correct / total))

如何理解_, predicted = torch.max(outputs.data, dim=1)
torch.max()这个函数返回的是两个值,第一个值是具体的value(我们用下划线_表示),第二个值是value所在的index(也就是predicted)。
在图像分类任务中,值所对应的index就对应着相应的类别class,当我们只关心网络预测的类别是什么,而不关心该类别的预测概率是多少时,就选择使用下划线_。
dim=1表示输出所在行的最大值,若改写成dim=0则输出所在列的最大值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/61451.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java计算机毕业设计ssm社团管理系统0gl2e(附源码、数据库)

java计算机毕业设计ssm社团管理系统0gl2e&#xff08;附源码、数据库&#xff09; 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。…

Android Room的使用详解

Android Room的使用详解 一&#xff1a;Room的基本介绍 Room 是 Android 架构组件的一部分&#xff0c;Room 持久性库在 SQLite上提供了一个抽象层&#xff0c;以便在充分利用 SQLite 的强大功能的同时&#xff0c;能够流畅地访问数据库。具体来说&#xff0c;Room 具有以下优…

【Linux Kernel 6.1 代码剖析】- 进程管理概论

目录 进程与线程的概念&#xff08;内核线程和用户线程&#xff09; 进程的3种基本状态 引入挂起后的7种基本状态 Linux 内核6.1 - 进程的8种详细状态 进程控制块 PCB SMP 架构 进程与线程的概念&#xff08;内核线程和用户线程&#xff09; 进程是正在运行的程序实体&a…

基于java+ssm+vue+mysql的旅游管理系统

项目介绍 随着现在网络的快速发展&#xff0c;网上管理系统也逐渐快速发展起来&#xff0c;网上管理模式很快融入到了许多企业的之中&#xff0c;随之就产生了“旅游信息管理系统”&#xff0c;这样就让旅游信息管理系统更加方便简单。 对于本旅游信息管理系统的设计来说&…

QDir(目录)

QDir 类提供对目录结构及其内容的访问&#xff0c;QDir 用于操作路径名、访问有关路径和文件的信息以及操作底层文件系统&#xff0c;它也可以用来访问Qt的资源系统。 Qt使用“/”作为通用目录分隔符&#xff0c;就像URL中的“/”用作路径分隔符一样。如果您始终使用“/”作为…

2022年大一学生实训作业【基于HTML+CSS制作中华传统文化传统美德网站 (6页面)】

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

【YOLOv7/YOLOv5系列算法改进NO.47】改进激活函数为GELU

文章目录前言一、解决问题二、基本原理三、​添加方法四、总结前言 作为当前先进的深度学习目标检测算法YOLOv7&#xff0c;已经集合了大量的trick&#xff0c;但是还是有提高和改进的空间&#xff0c;针对具体应用场景下的检测难点&#xff0c;可以不同的改进方法。此后的系列…

SparkSQL - 介绍及使用 Scala、Java、Python 三种语言演示

一、SparkSQL 前面的文章中使用 RDD 进行数据的处理&#xff0c;优点是非常的灵活&#xff0c;但需要了解各个算子的场景&#xff0c;需要有一定的学习成本&#xff0c;而 SQL 语言是一个大家十分熟悉的语言&#xff0c;如果可以通过编写 SQL 而操作RDD&#xff0c;学习的成本…

ARM汇编之程序状态寄存器传输指令

ARM汇编之程序状态寄存器传输指令前言 首先&#xff0c;请问大家几个小小问题&#xff0c;你清楚&#xff1a; CLZ指令的常见使用场景&#xff1b;状态寄存器访问指令有哪些&#xff1f; 今天&#xff0c;我们来一起探索并回答这些问题。为了便于大家理解&#xff0c;以下是…

[附源码]Python计算机毕业设计SSM金牛社区疫情防控系统(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

[附源码]JAVA毕业设计老年人健康饮食管理系统(系统+LW)

[附源码]JAVA毕业设计老年人健康饮食管理系统&#xff08;系统LW&#xff09; 目运行 环境项配置&#xff1a; Jdk1.8 Tomcat8.5 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项…

LeetCode 0542. 01 矩阵

【LetMeFly】542.01 矩阵 力扣题目链接&#xff1a;https://leetcode.cn/problems/01-matrix/ 给定一个由 0 和 1 组成的矩阵 mat &#xff0c;请输出一个大小相同的矩阵&#xff0c;其中每一个格子是 mat 中对应位置元素到最近的 0 的距离。 两个相邻元素间的距离为 1 。 示…

MySQL数据库之存储引擎

MySQL数据库之存储引擎数据存储引擎介绍MyISAM数据引擎概述MyISAM的特点介绍及数据引擎对应文件MyISAM的存储格式分类MyISAM适用的生产场景举例InnoDB数据引擎概述InnoDB特点介绍及数据引擎对应文件InnoDB适用生产场景分析企业选择存储引擎的依据如何配置存储引擎查看系统支持的…

c<8>指针

目录 2&#xff0c;指针的赋值 2.1C语言允许指针赋值为0&#xff08;初始化&#xff09; 2.2指针赋值例 2.3输出指针的值 3&#xff0c;用指针引用数组 3.1利用指针输入数组 3.2优先级问题 4.多维数组 5.字符串 5.1通过指针引用字符串 4.函数中对指针的应用 4.1将指针变…

[附源码]计算机毕业设计车源后台管理系统

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Acer W700废物利用- 第一章 - 安装Linux系统Debian 11.5

前言 收拾房子时在犄角旮旯发现了一台N年前的Windows平板&#xff0c;也就是今天的主角&#xff1a;Acer W700 &#xff0c;机器配置是&#xff1a;CPU&#xff1a;I5-3337U&#xff1b;内存&#xff1a;4G&#xff1b;硬盘&#xff1a;128G固态&#xff1b; 插上充电线&…

YOLOv5图像分割--SegmentationModel类代码详解

目录 ​编辑 SegmentationModel类 DetectionModel类 推理阶段 DetectionModel--forward() BaseModel--forward() Segment类 Detect--forward SegmentationModel类 定义model将会调用models/yolo.py中的类SegmentationModel。该类是继承父类--DetectionModel类。 cl…

数学基础从高一开始1、集合的概念

数学基础从高一开始1、集合的概念 目录 数学基础从高一开始1、集合的概念 一、课程引入 解析&#xff1a;方程​编辑2是否有解&#xff1f; 解析&#xff1a;所有到定点的距离等于定长的点组成何种图形&#xff1f; 结论&#xff1a; 二、课程讲解 问题1&#xff1a; 集…

1548_AURIX_TC275_锁步比较逻辑LCL

全部学习汇总&#xff1a; GreyZhang/g_TC275: happy hacking for TC275! (github.com) 这可能是这段时间看过的最简单的一个章节了&#xff0c;所有的章节内容都可以放进这一份笔记也不显得多。 1. 首先明确LCL的意思&#xff0c;其实是锁步核比较器逻辑的意思&#xff0c;还不…

知识点1--认识Docker

IT界2014年之前&#xff0c;对于服务器虚拟化的使用&#xff0c;有过一个流派&#xff0c;基于Windows server系统VMware组成服务器集群&#xff0c;但是后期由于这样的使用方式维护成本相当高&#xff0c;比如服务器的序列、服务器台账以及服务器与服务器之间的切换等等&#…