通过卷积神经网络(CNN)识别和预测手写数字

news2024/12/26 11:23:43

一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍

卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿人类视觉系统的工作原理来处理数据,能够从图像中自动学习和提取特征。以下是CNN的一些关键特点和组成部分:

卷积层(Convolutional Layer)

卷积层是CNN的核心,它使用滤波器(或称为卷积核)在输入图像上滑动,以提取图像的局部特征。

每个滤波器负责检测图像中的特定特征,如边缘、角点或纹理等。

卷积操作会产生一个特征图(feature map),它表示输入图像在滤波器下的特征响应。

激活函数

通常在卷积层之后使用非线性激活函数,如ReLU(Rectified Linear Unit),以增加网络的非线性表达能力。

激活函数帮助网络处理复杂的模式,并使网络能够学习更复杂的特征组合。

池化层(Pooling Layer)

池化层用于降低特征图的空间尺寸,减少参数数量和计算量,同时使特征检测更加鲁棒。

最常见的池化操作是最大池化(max pooling)和平均池化(average pooling)。

全连接层(Fully Connected Layer)

在多个卷积和池化层之后,CNN通常包含一个或多个全连接层,这些层将学习到的特征映射到最终的输出类别上。

全连接层中的每个神经元都与前一层的所有激活值相连。

softmax层

在网络的最后一层,通常使用softmax层将输出转换为概率分布,用于多分类任务中。

softmax函数确保输出层的输出值在0到1之间,并且所有输出值的总和为1。

卷积神经网络的训练

CNN通过反向传播算法和梯度下降法进行训练,以最小化损失函数(如交叉熵损失)。

在训练过程中,网络的权重通过大量图像数据进行调整,以提高分类或识别的准确性。

数据增强(Data Augmentation)

为了提高CNN的泛化能力,经常使用数据增强技术,如旋转、缩放、裁剪和翻转图像,以创建更多的训练样本。

迁移学习(Transfer Learning)

迁移学习是一种技术,它允许CNN利用在一个大型数据集(如ImageNet)上预训练的网络权重,来提高在小型或特定任务上的性能。

CNN在计算机视觉领域的应用非常广泛,包括但不限于图像分类、目标检测、语义分割、物体跟踪和面部识别等任务。由于其强大的特征提取能力,CNN已成为这些任务的主流方法之一。

MNIST数据集是一个广泛使用的手写数字识别数据集,可以通过TensorFlow库Pytorch库来获取, 也可以从官方网站下载:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集它包含四个部分:训练数据集、训练数据集标签、测试数据集和测试数据集标签。这些文件是IDX格式的二进制文件,需要特定的程序来读取。这个数据集包含了60,000张训练集图像和10,000张测试集图像,每张图像都是28x28像素的手写数字,范围从0到9。这些图像被处理为灰度值,其中黑色背景用0表示,手写数字用0到1之间的灰度值表示,数值越接近1,颜色越白。

MNIST数据集的图像通常被拉直为一个一维数组,每个数组包含784个元素(28x28像素)。数据集中的每个图像都有一个对应的标签,标签以one-hot编码的形式给出,例如数字5的标签表示为[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]。

在机器学习模型中,MNIST数据集常用于训练分类器,以识别和预测手写数字。例如,在深度学习中,可以使用卷积神经网络(CNN)来处理这些图像,学习从图像像素到数字标签的映射。

二:通过Pytorch库建立CNN模型训练MNIST数据集

使用Python的Pytorch库来完成一个卷积神经网络(CNN)来训练MNIST数据集,需要遵循以下步骤:

  1. 导入必要的库:我们需要导入Pytorch以及其它可能需要的库,如torchvision用于数据加载和变换。
  2. 加载MNIST数据集:使用torchvision库中的datasets和DataLoader来加载和预处理MNIST数据集。
  3. 定义卷积神经网络结构:设计一个简单的CNN结构,包括卷积层、池化层和全连接层。
  4. 定义损失函数和优化器:选择一个合适的损失函数,如交叉熵损失,以及一个优化器,如Adam或SGD。
  5. 训练模型:在训练集上训练模型,并保存训练过程中的损失和准确率。
  6. 测试模型:在测试集上评估模型的性能。

接下来,我们将按照这些步骤使用Python代码来完成这个任务。

Step1:导入必要的库

# 导入必要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
  • import torch: 导入了PyTorch的主库,这是进行深度学习任务的基础。
  • import torch.nn as nn: 导入了PyTorch的神经网络模块,它包含了构建神经网络所需的许多类和函数。
  • import torch.nn.functional as F: 导入了PyTorch的功能性API,它提供了不需要维护状态的神经网络操作,例如激活函数、池化等。
  • import torchvision: 导入了PyTorch的视觉库,它提供了许多视觉任务所需的工具和数据集。
  • import torchvision.transforms as transforms: 导入了对数据进行预处理的工具。
  • from torch.utils.data import DataLoader: 导入了PyTorch的数据加载器,它可以方便地迭代数据集。

Step2:加载MNIST数据集

# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
  • transform = transforms.Compose(...): 创建了一个转换管道,用于对数据进行预处理。Compose是一个函数,它将多个转换步骤组合成一个转换。
  • transforms.ToTensor(): 将图像数据从PIL Image或NumPy ndarray格式转换为浮点张量,并且将像素值缩放到[0,1]范围内。
  • transforms.Normalize((0.5,), (0.5,)): 对图像进行归一化处理。给定均值(mean)和标准差(std),这个转换将张量的每个通道都减去均值并除以标准差。在这里,它将每个像素值从[0,1]范围转换为[-1,1]范围。
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  • 这两行代码分别加载了MNIST数据集的训练集和测试集。
  • root='./data': 指定数据集下载和存储的根目录。
  • train=True: 对于trainset,表示加载数据集的训练部分。
  • train=False: 对于testset,表示加载数据集的测试部分。
  • download=True: 表示如果数据集不在指定的root目录下,则从互联网上下载。
  • transform=transform: 应用之前定义的转换。
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
  • 这两行代码创建了两个DataLoader对象,用于在训练和测试时迭代数据集。
  • batch_size=64: 指定每个批次的样本数量。
  • shuffle=True: 对于trainloader,在每次迭代时打乱数据,这对于训练是有益的,因为它可以减少模型学习数据的顺序性。
  • shuffle=False: 对于testloader,不打乱数据,因为测试时不需要随机性。

得到了一个名为data的文件夹:

847242f10504407ca060290107d1bc8d.png

Step3:定义卷积神经网络结构

# 定义卷积神经网络结构
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
  • 这段代码定义了一个名为CNN的卷积神经网络类,它继承自nn.Module
  • __init__方法初始化了网络的结构:
    • self.conv1是一个2D卷积层,输入通道为1(MNIST图像为单通道),输出通道为32,卷积核大小为3x3,并带有1像素的填充。
    • self.pool是一个2x2的最大池化层,用于减小数据的维度。
    • self.conv2是第二个2D卷积层,输入通道为32,输出通道为64,卷积核大小为3x3,并带有1像素的填充。
    • self.fc1是一个全连接层,它将64个通道的7x7图像映射到1024个特征。
    • self.fc2是另一个全连接层,它将1024个特征映射到10个输出,对应于MNIST数据集的10个类别。
  • forward方法定义了数据通过网络的前向传播路径:
    • x首先通过conv1卷积层,然后应用ReLU激活函数,并使用pool进行池化。
    • 接着,x通过conv2卷积层,再次应用ReLU激活函数和池化。
    • x.view(-1, 64 * 7 * 7)将数据扁平化,为全连接层准备。
    • x通过fc1全连接层,并应用ReLU激活函数。
    • 最后,x通过fc2全连接层,输出结果。
# 实例化网络
net = CNN()
  • 创建了一个CNN类的实例,名为net

Step4:定义损失函数和优化器

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
  • criterion是交叉熵损失函数,常用于多分类问题。
  • optimizer是Adam优化器,用于更新网络的权重。

Step5:训练模型

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/(i+1)}")

下面是这段代码的逐行解释:

  1. epochs是一个变量,表示训练过程中模型将遍历整个训练数据集的次数。这里设置为5,意味着整个训练数据集将被遍历5次。
  2. 外层for循环,它将执行epochs次。在每次迭代中,epoch变量将代表当前的迭代次数,从0开始到epochs-1结束。
  3. 在每次epoch开始时,running_loss被重置为0.0。这个变量用于累加每个epoch中的所有批次损失,以便计算平均损失。
  4. 这是一个嵌套的for循环,它遍历trainloader返回的批次数据。enumerate函数用于遍历可迭代对象,同时跟踪当前的索引(这里是i)。
  5. trainloader是之前定义的数据加载器,它负责分批加载数据,以便于训练。
  6. 参数0指定了索引的起始值。
  7. 然后解包了data元组,其中包含输入(图像)和标签(目标值)。inputs是模型的输入数据,labels是这些输入数据的正确类别标签。
  8. 在每次迭代开始时,调用optimizer.zero_grad()来清除之前梯度计算的结果。这是必要的,因为PyTorch的梯度是累加的。
  9. 输入inputs传递给神经网络net,并得到输出outputs。这是模型的前向传播步骤。
  10. 计算了模型输出的损失。criterion是之前定义的交叉熵损失函数,它比较outputs(模型的预测)和labels(实际类别标签)来计算损失。
  11. 执行了反向传播。它计算了损失相对于模型参数的梯度。
  12. 更新了模型的权重。optimizer使用计算出的梯度来调整网络参数,以减少下一次迭代的损失。
  13. 将当前的批次损失累加到running_loss变量中,用于后续计算平均损失。
  14. 在每个epoch结束时,打印出当前epoch的编号和平均损失。epoch+1是为了从1开始计数epoch,而不是从0开始。running_loss/(i+1)计算了当前epoch的平均损失,其中i+1是当前epoch中批次的数量。

最终得到每个epoch的平均损失如下:

49592faa38b84b699f4458f2cf76a433.png

Step6:测试模型

# 测试模型
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct / total}%")
  1. correcttotal是两个变量,分别用于跟踪模型在测试数据集上正确预测的样本数量和总的样本数量。
  2. with torch.no_grad()是一个上下文管理器,用于在测试阶段禁用梯度计算。因为测试阶段不需要计算梯度,这样可以节省内存并加快计算速度。
  3. for循环,遍历testloader返回的测试数据集的批次数据。
  4. 这行代码解包了data元组,其中包含测试图像images和它们对应的真实标签labels
  5. 这行代码将测试图像images输入到训练好的神经网络net中,并得到输出outputs
  6. torch.max(outputs.data, 1)返回两个值:第一个是每个批次中最大值的元素,第二个是这些最大值的索引。在这里,最大值代表模型对每个图像的预测类别,而索引则代表预测的类别标签。
  7. predicted是模型预测的类别标签的向量。
  8. 这行代码累加测试集中总的样本数量。labels.size(0)给出了当前批次中样本的数量。
  9. (predicted == labels)是一个布尔表达式,它比较模型的预测predicted和真实标签labels,并返回一个布尔张量,其中正确预测的位置为True,否则为False。
  10. .sum()计算布尔张量中True的数量,即正确预测的样本数量。
  11. .item()将计算得到的张量(只有一个元素)转换为Python的标量值。
  12. 这行代码计算并打印出模型在测试数据集上的准确率。准确率是通过将正确预测的样本数量correct除以总样本数量total,然后乘以100来得到的百分比。这里假设测试数据集包含10000个样本。

得到准确率如下:

9eaaa375532f47f496aa265cb2d0d615.png

使用这个建立好的卷积神经网络(CNN)模型,主要用于训练分类器。具体来说,这个模型能够识别手写数字图像,并将它们分类为0到9中的一个类别。它适用于MNIST数据集。这个示例能够帮助更好的了解卷积神经网络(CNN)的原理。

 

想要探索更多元化的数据分析视角,可以关注之前发布的相关内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2106152.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

快排的深入学习

目录 交换类排序 一、冒泡排序 1. 算法介绍 2.算法流程 3. 算法性能分析 (1)时间复杂度分析 (2) 空间复杂度分析 冒泡排序的特性总结: 二、快速排序 1.算法介绍 2. 执行流程 1). hoare版本 2). 挖坑法 3)…

5.9灰度直方图

目录 实验原理 实验代码 运行结果 实验原理 calcHist 函数通常是指在计算机视觉和图像处理中用于计算图像直方图的一个函数。 cv:calcHist () 用于计算一个或多个数组的直方图。它可以处理图像数据并返回一个表示像素强度分布的向量(对于灰度图像)或…

Java:集合的相关汇总介绍

主要包含Set(集)、 List(列表包含 Queue)和 Map(映射)。 1、Collection: Collection 是集合 List、 Set、 Queue 的最基本的接口。 2、Iterator:迭代器,可以通过迭代器遍历集合中的数据。 3、Map:是映射表的…

VTK随笔十三:QT与VTK的交互

一、基于 Ot的 VTK 应用程序 以 VTK 读入一幅 JPG 图像,然后在 Qt 界面上使用 VTK 显示该图像为例,演示QT与VTK的交互。 1、创建QT项目QT_VTK_Demo 2、配置VTK库 在CMakeLists.txt中添加如下代码: 配置完成后重新打开工程加载VTK库。 3、编…

制裁下的转型:俄罗斯加密货币战略布局与人民币挂钩BRICS稳定币的崛起

在国际制裁重压下,俄罗斯正在积极推进加密货币政策改革,通过设立加密货币交易所和推动与人民币挂钩的BRICS稳定币,试图在全球金融体系中谋求新的生存与发展路径。这一系列举措标志着俄罗斯在数字经济领域的重大转向,既是对当前经济…

Linux【5】远程管理

目录 shutdown关机 ifconfig输出网卡信息 ping ip地址——检测连接正常 ssh 【-p port】 userip scp不同主机之间的文件copy 当前文件复制到远程 远程文件复制到本地 复制文件夹 -r shutdown关机 shutdown -r 重启 ifconfig输出网卡信息 ping ip地址——检测连接正常…

集成电路学习:什么是PCB印刷电路板

一、PCB:印刷电路板 PCB,全称为Printed Circuit Board,即印刷电路板,是现代电子设备中不可或缺的基础构件。它作为电子元器件的载体和连接体,在电子设备中发挥着至关重要的作用。以下是对PCB的详细解析: 二…

【C++初阶】一、C++入门(万字总结)

「前言」 「专栏」C详细版专栏 🌈个人主页: 代码探秘者 🌈C语言专栏:C语言 🌈C专栏: C 🌈喜欢的诗句:无人扶我青云志 我自踏雪至山巅 目录 一、关于C 1.1 什么是C 1.2 C 发展史 二、C关键字(C…

5.8幂律变换

目录 示例代码1 运行结果1 示例代码2 运行结果2 补充示例原理 示例:使用cv::pow进行图像处理 代码 运行结果 ​编辑 补充 实验代码3 运行结果3​编辑 在OpenCV中,幂律变换(Power Law Transformations)是一种常用的图像…

集成电路学习:什么是MOSFET(MOS管)

一、MOSFET:MOS管 MOSFET,全称Metal-Oxide-Semiconductor Field-Effect Transistor,即金属-氧化物半导体场效应晶体管,也常被称为MOS管或金氧半场效晶体管。它是一种可以广泛使用在模拟电路与数字电路的场效应晶体管(f…

【游戏安全】CheatEngine基础使用——如何对不同类型的数值进行搜索?如何破解数值加密找到想修改的数值?

游戏安全 不同数值类型的搜索破解简单数值加密 不同数值类型的搜索 可以在游戏中看到很精确的物品数量,但是在CE中却什么都扫不到。 这是因为他的数值类型可能并不是四字节的,在游戏中这个数值的机制是一个慢慢增长的数值,所以他很有可能是…

【重学 MySQL】八、MySQL 的演示使用和编码设置

【重学 MySQL】八、MySQL 的演示使用和编码设置 MySQL 的使用演示登录 MySQL查看所有数据库创建数据库使用数据库创建表插入数据查询数据删除表或数据库注意事项 MySQL 的编码设置查看 MySQL 支持的字符集和排序规则服务器级别的编码设置数据库级别的编码设置表级别的编码设置列…

Python3.8绿色便携版嵌入式版制作

Python 的绿色便携版有两种:官方 Embeddable 版本(嵌入式版);安装版制作的绿色版。Embeddable 版适用于需要将 Python 集成到其他应用程序或项目中的情况,它不包含图形界面的安装程序,只提供了 Python 解释器和必要的库文件。安装版包含了 Python 解释器、标准库和其他一些…

基于ssm+vue+uniapp的高校课堂教学管理系统小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

8. GIS数据分析师岗位职责、技术要求和常见面试题

本系列文章目录: 1. GIS开发工程师岗位职责、技术要求和常见面试题 2. GIS数据工程师岗位职责、技术要求和常见面试题 3. GIS后端工程师岗位职责、技术要求和常见面试题 4. GIS前端工程师岗位职责、技术要求和常见面试题 5. GIS工程师岗位职责、技术要求和常见面试…

Ferrari求解四次方程

参考: 1) https://proofwiki.org/wiki/Ferrari’s_Method#google_vignette 2)https://blog.csdn.net/qq_25777815/article/details/85206702

【精彩回顾·成都】COC 成都阿里云 CMeet:AIGC 创新应用技术实践!

文章目录 前言一、活动介绍二、精彩分享内容及活动议程2.1、《COC 成都社区情况和活动介绍》2.2、《浅谈 AIGC 商业化》2.3、《通义大模型与 AI 技术在各行业领域的实践与探索》2.4、《话题一:AIGC 在内容创作领域的革新》2.5、《话题二:AIGC 在技术与工…

【ubuntu使用笔记】nvme磁盘挂载失败问题记录

no object for d-bus interface 问题 no object for d-bus interface 解决方法 systemctl --user restart gvfs-udisks2-volume-monitor文件格式错误 问题 解决方法 sudo ntfsfix /dev/nvme4 sudo mount /dev/nvme4 soft/

828华为云征文|基于Flexus X实例云服务器的实际场景-等保三级服务器设置

🔴大家好,我是雄雄,欢迎关注微信公众号:雄雄的小课堂 先看这里 写在前面3️⃣mysql创建安全管理员、审计管理员✅解决方法增加安全管理员增加审计管理员账户❌问题描述✅解决方法 3️⃣linux登录失败问题❌问题描述✅解决方法 3️…

ARM N2微架构介绍

简介 之前在“ARM V2处理器微架构介绍”一文中介绍了面向服务器、云计算等应用的ARM V2处理器微架构,V系列具有更强性能,N系列强调性能和功耗等方向的平衡,本文就将介绍一下ARM N2处理器微架构相比较前代的一些提升。尽管ARM还具备一代N1/V1…