利用卷积神经网络进行手写数字的识别

news2025/3/1 15:58:17

数据集介绍

MNIST(Modified National Institute of Standards and Technology)数据集是一个广泛使用的手写数字识别数据集,常用于机器学习和计算机视觉领域中的分类任务。它包含了从0到9的手写数字样本,常用于训练和测试各种图像分类算法。

数据集概况

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本是一张28×28像素的灰度图像,表示一个手写数字。每个图像是一个二维矩阵,像素值范围从0(黑色)到255(白色),灰度值表示不同的颜色深度。数据集中的标签是这些图像对应的数字(0-9)。

数据集格式

  • 训练集:60,000个图像,每个图像有一个对应的标签(0到9之间的数字)。
  • 测试集:10,000个图像,也有对应的标签。

使用场景

  1. 图像分类任务:由于数据集较小且标准化,MNIST是机器学习算法(尤其是深度学习模型)测试和比较性能的一个标准数据集。
  2. 模型性能评估:MNIST被广泛用于评估各种机器学习模型的效果,尤其是在图像处理领域。
  3. 教学:由于其简单性,MNIST常作为入门学习机器学习和神经网络的教学材料。

特点

  • 图像尺寸固定:28×28像素,适合用作标准输入。
  • 图像内容简单:大多数手写数字都是规范且易于分辨的。
  • 数据集较小,适合于快速实验和初步的模型验证。

数据集获取

MNIST数据集可以通过多个平台获取,例如:

  • 通过TensorFlow、PyTorch等框架的内建API加载。
  • 从MNIST官网下载。

数据预处理及参数选择

数据处理

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# softmax归一化指数函数(https://blog.csdn.net/lz_peter/article/details/84574716),其中0.1307是mean均值和0.3081是std标准差

train_dataset = datasets.MNIST(root='./data/mnist', train=True, transform=transform)  # 本地没有就加上download=True
test_dataset = datasets.MNIST(root='./data/mnist', train=False, transform=transform)  # train=True训练集,=False测试集
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

参数的选择

batch_size = 64                #每个批次大小中有64个样本
learning_rate = 0.01           #学习率
momentum = 0.5                 #梯度下降冲量
epochs = 10                    #训练轮数
  • batch_size = 64:每次训练时使用64个样本来计算梯度并更新权重。
  • learning_rate = 0.01:每次权重更新时,步长为0.01,影响训练速度和稳定性。
  • momentum = 0.5:通过加权平均过去的梯度,帮助加速收敛并减少梯度更新的震荡。
  • epochs = 10:模型将在训练数据上进行10次完整的迭代,通常可以在这个范围内找到适合的训练状态。

网络模型

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 10, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(10, 20, kernel_size=5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(320, 50),
            torch.nn.Linear(50, 10),
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)  # 一层卷积层,一层池化层,一层激活层(图是先卷积后激活再池化,差别不大)
        x = self.conv2(x)  # 再来一次
        x = self.conv3(x)
        return x  # 最后输出的是维度为10的,也就是(对应数学符号的0~9)
  • 输入层:

    • 输入尺寸:每张输入图像是28x28像素的灰度图,单通道。输入张量的形状为 (batch_size, 1, 28, 28),其中 batch_size 是一次处理的图像数量,1 是表示单通道的灰度图像,28x28 是图像的尺寸。
  • 第一层卷积层(conv1):

    • 卷积层:使用一个大小为 5x5 的卷积核,将输入图像的1个通道(灰度)转换为10个通道。卷积核的步幅为1,填充为0(即没有边缘扩展)。这会产生一个大小为 24x24 的特征图(由于没有填充,尺寸会减少)。
    • 激活函数:ReLU(Rectified Linear Unit),它会对每个像素值进行非线性转换(ReLU(x) = max(0, x)),有效地引入了非线性特性。
    • 池化层:最大池化层使用 2x2 的池化窗口和步幅为2。池化操作减少了特征图的尺寸,将每个 2x2 的区域映射为最大值。池化操作将图像尺寸减半,从 24x24 减小为 12x12,同时减少计算量。
  • 第二层卷积层(conv2):

    • 卷积层:卷积核的大小为 5x5,将前一层输出的10个通道转换为20个通道。同样,步幅为1,没有填充。这个操作将特征图的大小从 12x12 减少到 8x8
    • 激活函数:使用ReLU激活函数。
    • 池化层:再次使用最大池化,池化窗口为 2x2,步幅为2。此操作将尺寸从 8x8 减小为 4x4
  • 全连接层(conv3):

    • 展平操作(Flatten):经过两层卷积和池化操作后,输出特征图的大小为 20x4x4。在传入全连接层之前,需要将这个多维的张量展平成一维向量。展平后的尺寸是 320(即 20 * 4 * 4)。
    • 第一个全连接层:将展平后的320个元素映射到50个神经元。该层的作用是通过加权和偏置的线性变换对输入进行处理,并通过激活函数进行非线性转换。
    • 第二个全连接层:将50个神经元映射到10个神经元,输出的每个神经元代表一个数字类别(0到9)。
  • 输出层:

    • 输出尺寸:最终输出为一个10维的向量,其中每个值表示输入图像属于每个类别的“分数”。这个分数可以通过softmax层转化为概率,用于多类分类任务。

模型训练

# Construct loss and optimizer ------------------------------------------------------------------------------
loss_f = torch.nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)  # lr学习率,momentum冲量


# Train and Test CLASS --------------------------------------------------------------------------------------
# 把单独的一轮一环封装在函数类里
def train(epoch):
    running_loss = 0.0  # 这整个epoch的loss清零
    running_total = 0
    running_correct = 0
    for batch_idx, data in enumerate(train_loader, 0):  #第一个代表训练的批次,data中包括数据和标签,第一个数据代表输入即inputs,第二个数据代表标签labels

        inputs, target = data
        optimizer.zero_grad()   #将之前的梯度清零

        # forward + backward + update
        outputs = model(inputs)
        loss = loss_f(outputs, target)
        #反向传播
        loss.backward()
        #参数更新
        optimizer.step()

        # 把运行中的loss累加起来,为了下面300次一除
        running_loss += loss.item()
        # 把运行中的准确率acc算出来
        _, predicted = torch.max(outputs.data, dim=1)
        running_total += inputs.shape[0]
        running_correct += (predicted == target).sum().item()
        if batch_idx % 300 == 299:  # 不想要每一次都出loss,浪费时间,选择每300次出一个平均损失,和准确率
            print('[%d, %5d]: loss: %.3f , acc: %.2f %%'
                  % (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))
            running_loss = 0.0  # 这小批300的loss清零
            running_total = 0
            running_correct = 0  #


            # 这小批300的acc清零

        # torch.save(model.state_dict(), './model_Mnist.pth')
        # torch.save(optimizer.state_dict(), './optimizer_Mnist.pth')


def test():
    correct = 0
    total = 0
    with torch.no_grad():  # 测试集不用算梯度
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度,沿着行(第1个维度)去找1.最大值和2.最大值的下标
            total += labels.size(0)  # 张量之间的比较运算
            correct += (predicted == labels).sum().item()
    acc = correct / total
    print('[%d / %d]: Accuracy on test set: %.1f %% ' % (epoch + 1, epochs, 100 * acc))  # 求测试的准确率,正确数/总数
    return acc


# Start train and Test --------------------------------------------------------------------------------------
if __name__ == '__main__':
    acc_list_test = []
    for epoch in range(epochs):
        train(epoch)
        # if epoch % 10 == 9:  #每训练10轮 测试1次
        acc_test = test()
        acc_list_test.append(acc_test)

    plt.plot(acc_list_test)
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy On TestSet')
    plt.show()

训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2259517.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

题解 - 取数排列

题目描述 取1到N共N个连续的数字(1≤N≤9),组成每位数不重复的所有可能的N位数,按从小到大的顺序进行编号。当输入一个编号M时,就能打印出与该编号对应的那个N位数。例如,当N=3时,可…

如何在 ASP.NET Core 3.1 应用程序中使用 Log4Net

介绍 日志记录是应用程序的核心。它对于调试和故障排除以及应用程序的流畅性非常重要。 借助日志记录,我们可以对本地系统进行端到端的可视性,而对于基于云的系统,我们只能提供一小部分可视性。您可以将日志写入磁盘或数据库中的文件&#xf…

监控易监测对象及指标之:宝兰德中间件JMX监控指标解读

监控易作为一款全面的IT监控软件,能够为企业提供深入、细致的监控服务,确保企业IT系统的稳定运行。在本文中,我们将详细解读监控易针对宝兰德中间件JMX的监控指标,以帮助用户更好地理解和应用这些监控数据。 监测指标概览&#x…

Ubuntu 安装 Samba Server

在 Mac 上如何能够与Ubuntu 服务器共享文件夹,需要在 Ubuntu 上安装 Samba 文件服务器。本文将介绍如何在 Ubuntu 上安装 Samba 服务器从而达到以下目的: Mac 与 Ubuntu 共享文件通过用户名密码访问 安装 Samba 服务 sudo apt install samba修改配置文…

数字化招聘系统如何帮助企业实现招聘效率翻倍提升?

众所周知,传统的招聘方式已经难以满足现代企业对人才的需求,而数字化招聘系统的出现,为企业提供了全新的解决方案。通过数字化招聘系统,企业可以自动化处理繁琐的招聘流程,快速筛选合适的候选人,从而大幅提…

C语言数组和字符串笔记

C语言数组和字符串笔记 1. 数组及其相关概念 1.1 为什么需要使用数组? 数组是一个有序的、类型相同的数据集合。这些数据被称为数组的元素。每个数组都有一个名字,数组名代表数组的起始地址。数组的元素通过索引或下标访问,索引从0开始。 …

u-boot移植、配置、编译学习笔记【刚开始就中止了】

教程视频地址 https://www.bilibili.com/video/BV1L24y187cK 【这个视频中途停更了…原因是实际中需要去改u-boot的情况比较少】 使用的u-boot的源码 视频中使用的是 u-boot-2017.03 学习到这里,暂停u-boot的移植、配置、编译学习,原因是经过与老师…

回归任务与分类任务应用及评价指标

能源系统中的回归任务与分类任务应用及评价指标 一、回归任务应用1.1 能源系统中的回归任务应用1.1.1 能源消耗预测1.1.2 负荷预测1.1.3 电池健康状态估计(SOH预测)1.1.4 太阳能发电量预测1.1.5 风能发电量预测 1.2 回归任务中的评价指标1.2.1 RMSE&…

【树莓派4B】MindSpore lite 部署demo

一个demo,mindspore lite 部署在树莓派4B ubuntu22.04中,为后续操作开个门! 环境 开发环境:wsl-ubuntu22.04分发版部署环境:树莓派4B,操作系统为ubuntu22.04mindspore lite版本:mindspore-li…

AI监控赋能健身馆与游泳馆全方位守护,提升安全效率

一、AI视频监控技术的崛起 随着人工智能技术的不断发展,AI视频监控正成为各行业保障安全、提升效率的关键工具。相比传统监控系统,AI技术赋予监控系统实时分析、智能识别和精准预警的能力,让“被动监视”转变为“主动防控”。 二、AI监控应用…

M|林中小屋

title: 林中小屋 The Cabin in the Woods time: 2024-12-13 周五 rating: 7 豆瓣: 7.6 上映时间: “2012” 类型: M恐怖 导演: 德鲁戈达德 Drew Goddard 主演: 克里斯汀康奈利 Kristen Connolly弗兰克朗茨 Fran Kranz 国家/地区: 美国 片长/分钟: 95分钟 M&#xff5…

Mysql中的sql语句怎么执行的?

1.连接MySQL 通过客户端使用TCP(数据传输协议)连接MySQL连接器,连接器接到请求后对它进行检验是否有权限,有就进行分配资源。(这个过程不能超过8小时) 2.成功连接(校验效验) 客户端发送sql语句&#xff…

流网络复习笔记

所以这里的19是118-019 <s , w> 1/3就是容量是3&#xff0c;流量是1 残留网络就是两个相对箭头上都是剩余对应方向还能同行的流量 所以s->w 3-1 2, w->s 1

Redis - 实战之 全局 ID 生成器 RedisIdWorker

概述 定义&#xff1a;一种分布式系统下用来生成全局唯一 ID 的工具 特点 唯一性&#xff0c;满足优惠券需要唯一的 ID 标识用于核销高可用&#xff0c;随时能够生成正确的 ID高性能&#xff0c;生成 ID 的速度很快递增性&#xff0c;生成的 ID 是逐渐变大的&#xff0c;有利于…

arXiv-2024 | VLM-GroNav: 基于物理对齐映射视觉语言模型的户外环境机器人导航

作者&#xff1a; Mohamed Elnoor, Kasun Weerakoon, Gershom Seneviratne, Ruiqi Xian, Tianrui Guan, Mohamed Khalid M Jaffar, Vignesh Rajagopal, and Dinesh Manocha单位&#xff1a;马里兰大学学院公园分校原文链接&#xff1a;VLM-GroNav: Robot Navigation Using Phys…

华为无线AC、AP模式与上线解析(Huawei Wireless AC, AP Mode and Online Analysis)

华为无线AC、AP模式与上线解析 为了实现fit 瘦AP的集中式管理&#xff0c;我们需要统一把局域网内的所有AP上线到AC&#xff0c;由AC做集中式管理部署。这里我们需要理解CAPWAP协议&#xff0c;该协议分为两种报文&#xff1a;1、管理报文 2、数据报文。管理报文实际在抓包过程…

简单vue3前端打包部署到服务器,动态配置http请求头后端ip方法教程

vue3若依框架前端打包部署到服务器&#xff0c;需要部署到多个服务器上&#xff0c;每次打包会很麻烦&#xff0c;今天教大家一个动态配置请求头api的方法&#xff0c;部署后能动态获取(修改)对应服务器的请求ip 介绍两种方法&#xff0c;如有需要可以直接尝试步骤一&#xff…

Java-DataX 插件机制示例

示例代码 DataXPluginExample: DataX 项目的plugin 机制学习https://gitee.com/wendgit/data-xplugin-example/ 摘要 DataXPluginExample 是一个我编写的专门解读DataX插件机制的示例项目&#xff0c;旨在深入解析和掌握DataX的插件机制。本示例通过简洁明了的实现方式&#…

基于AI网关的风电系统在线监测

风力发电是典型的清洁能源之一&#xff0c;也是我国能源结构转型的重要组成。近年来我国大力发展风能、水能、光伏等清洁能源&#xff0c;加快创造人与生态友好和谐的人居社会。由于风电机组通常部署于偏远的野外&#xff0c;经常面临狂风、暴雨、日晒等严苛工作形势&#xff0…