卷积神经网络(CNN)mnist手写数字分类识别的实现

news2025/1/13 3:37:35

文章目录

  • 前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
      • 我的环境:
    • 2. 导入数据
    • 3.归一化
    • 4.可视化
    • 5.调整图片格式
  • 二、构建CNN网络模型
  • 三、编译模型
  • 四、训练模型
  • 五、预测
  • 六、知识点详解
    • 1. MNIST手写数字数据集介绍
    • 2. 神经网络程序说明
    • 3. 网络结构说明

前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

2. 导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

3.归一化

# 将像素的值标准化至0到1的区间内。
train_images, test_images = train_images / 255.0, test_images / 255.0

train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

4.可视化

plt.figure(figsize=(20,10))
for i in range(20):
    plt.subplot(5,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(train_labels[i])
plt.show()

在这里插入图片描述

5.调整图片格式

#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

# train_images, test_images = train_images / 255.0, test_images / 255.0

train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

二、构建CNN网络模型

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    
    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(10)
])

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d (Conv2D)             (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 5, 5, 64)         0         
 2D)                                                             
                                                                 
 flatten (Flatten)           (None, 1600)              0         
                                                                 
 dense (Dense)               (None, 64)                102464    
                                                                 
 dense_1 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
_________________________________________________________________

三、编译模型

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, 
                    validation_data=(test_images, test_labels))
Epoch 1/10
1875/1875 [==============================] - 15s 8ms/step - loss: 0.1429 - accuracy: 0.9562 - val_loss: 0.0550 - val_accuracy: 0.9803
Epoch 2/10
1875/1875 [==============================] - 14s 7ms/step - loss: 0.0460 - accuracy: 0.9856 - val_loss: 0.0352 - val_accuracy: 0.9883
Epoch 3/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0312 - accuracy: 0.9904 - val_loss: 0.0371 - val_accuracy: 0.9880
Epoch 4/10
1875/1875 [==============================] - 14s 7ms/step - loss: 0.0234 - accuracy: 0.9925 - val_loss: 0.0330 - val_accuracy: 0.9900
Epoch 5/10
1875/1875 [==============================] - 14s 8ms/step - loss: 0.0176 - accuracy: 0.9944 - val_loss: 0.0311 - val_accuracy: 0.9904
Epoch 6/10
1875/1875 [==============================] - 16s 9ms/step - loss: 0.0136 - accuracy: 0.9954 - val_loss: 0.0300 - val_accuracy: 0.9911
Epoch 7/10
1875/1875 [==============================] - 14s 8ms/step - loss: 0.0109 - accuracy: 0.9964 - val_loss: 0.0328 - val_accuracy: 0.9909
Epoch 8/10
1875/1875 [==============================] - 14s 7ms/step - loss: 0.0097 - accuracy: 0.9969 - val_loss: 0.0340 - val_accuracy: 0.9903
Epoch 9/10
1875/1875 [==============================] - 15s 8ms/step - loss: 0.0078 - accuracy: 0.9974 - val_loss: 0.0499 - val_accuracy: 0.9879
Epoch 10/10
1875/1875 [==============================] - 13s 7ms/step - loss: 0.0078 - accuracy: 0.9976 - val_loss: 0.0350 - val_accuracy: 0.9902

五、预测

通过下面的网络结构我们可以简单理解为,输入一张图片,将会得到一组数,这组代表这张图片上的数字为0~9中每一个数字的几率,out数字越大可能性越大。
在这里插入图片描述

plt.imshow(test_images[1])

在这里插入图片描述

输出测试集中第一张图片的预测结果

pre = model.predict(test_images)
pre[1]
313/313 [==============================] - 1s 2ms/step
array([  3.3290668 ,   0.29532072,  21.943724  ,  -7.09336   ,
       -15.3133955 , -28.765621  ,  -1.8459738 ,  -5.761892  ,
        -2.966585  , -19.222878  ], dtype=float32)

六、知识点详解

本文使用的是最简单的CNN模型- -LeNet-5,如果是第一次接触深度学习的话,可以先试着把代码跑通,然后再尝试去理解其中的代码。

1. MNIST手写数字数据集介绍

MNIST手写数字数据集来源于是美国国家标准与技术研究所,是著名的公开数据集之一。数据集中的数字图片是由250个不同职业的人纯手写绘制,数据集获取的网址为:http://yann.lecun.com/exdb/mnist/ (下载后需解压)。我们一般会采用(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()这行代码直接调用,这样就比较简单

MNIST手写数字数据集中包含了70000张图片,其中60000张为训练数据,10000为测试数据,70000张图片均是28*28,数据集样本如下:

在这里插入图片描述

如果我们把每一张图片中的像素转换为向量,则得到长度为28*28=784的向量。因此我们可以把训练集看成是一个[60000,784]的张量,第一个维度表示图片的索引,第二个维度表示每张图片中的像素点。而图片里的每个像素点的值介于0-1之间。

在这里插入图片描述

2. 神经网络程序说明

神经网络程序可以简单概括如下:
在这里插入图片描述

3. 网络结构说明

在这里插入图片描述

各层的作用

  • 输入层:用于将数据输入到训练网络
  • 卷积层:使用卷积核提取图片特征
  • 池化层:进行下采样,用更高层的抽象表示图像特征
  • Flatten层:将多维的输入一维化,常用在卷积层到全连接层的过渡
  • 全连接层:起到“特征提取器”的作用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1218863.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

6.docker运行mysql容器-理解容器数据卷

运行mysql容器-理解容器数据卷 1.什么是容器数据卷2.如何使用容器数据卷2.1 数据卷挂载命令2.2 容器数据卷的继承2.3 数据卷的读写权限2.4 容器数据卷的小实验(加深理解)2.4.1 启动挂载数据卷的centos容器2.4.2 启动后,在宿主机的data目录下会…

IDEA创建SpringBoot的多模块项目教程

最近在写一个多模块的SpringBoot项目,基于过程总了一些总结,故把SpringBoot多个模块的项目创建记录下来。 首先,先建立一个父工程: (1)在IDEA工具栏选择File->New->Project (2&#xff0…

热点检测/降级框架Akali的部分原理解析

发现个“轻量级本地化热点检测/降级框架 这个框架名为Akali,项目地址:https://gitee.com/bryan31/Akali主要有两个作用 1:热点检测及处理 2:降级检测及处理 从官网文档来看使用是比较简单的,一个注解就能搞定 怀着好奇的心情c…

Echarts 实现两两柱图重叠(背景和实际值柱图)

Echarts实现两两重叠柱状图_echarts 重叠柱状图_Web_阿凯的博客-CSDN博客 引用启发的博客 先来效果: option {backgroundColor: #03213D,animation: true, // 控制动画是否开启animationDuration: 1000, // 动画的时长, 它是以毫秒为单位animationDuration: func…

数据结构C语言之线性表

发现更多计算机知识,欢迎访问Cr不是铬的个人网站 1.1线性表的定义 线性表是具有相同特性的数据元素的一个有限序列 对应的逻辑结构图形: 从线性表的定义中可以看出它的特性: (1)有穷性:一个线性表中的元…

Homography详解在MVSNet中的应用

Homography(单应性变换)是计算机视觉中的一个重要概念,用于描述两个平面之间的透视关系。在图像处理和计算机视觉中,Homography通常表示两个平面之间的投影关系,这种关系可以通过一个3x3的矩阵来表示。 在数学上&…

YB4019是一款完整的单电池锂离子恒流/恒压线性充电器电池

YB4019 耐压18V 1A线性双节8.4V 锂电充电芯片 概述: YB4019是一款完整的单电池锂离子恒流/恒压线性充电器电池。底部采用热增强ESOP8封装,外部组件数量低使YB4019成为便携式应用的理想选择。此外,YB4019设计用于在USB电源规格范围内工作。Y…

洗袜子的洗衣机哪款好?内衣洗衣机测评

随着人们的生活水平的提升,越来越多小伙伴来开始追求更高的生活水平,一些智能化的小家电就被发明出来,而且内衣洗衣机是其中一个。现在通过内衣裤感染到细菌真的是越来越多,所以我们对内衣裤的清洗频次会高于普通衣服,…

File类和IO流

我是南城余!阿里云开发者平台专家博士证书获得者! 欢迎关注我的博客!一同成长! 一名从事运维开发的worker,记录分享学习。 专注于AI,运维开发,windows Linux 系统领域的分享! 本…

2024苹果笔记本清理内存清理优化工具CleanMyMac X

在使用苹果笔记本电脑的过程中,清理内存是保持电脑运行流畅的重要步骤之一。当我们使用大量的应用程序和文件时,电脑的内存可能会被占满,导致系统变慢甚至出现崩溃的情况。因此,了解如何清理苹果笔记本的内存是非常必要的。本文将…

代码随想录 Day47 动态规划15 LeetCode T583 两个字符串的删除操作 T72 编辑距离

LeetCode T583 两个字符串的删除操作 题目链接:583. 两个字符串的删除操作 - 力扣(LeetCode) 题目思路: 本题有两个思路 1.使用两个字符串的长度之和-2*最长公共子串(换汤不换药) 代码随想录Day45 动态规划13 LeetCode T1143最长公共子序列 T1135 不相交…

适用于 Mac 的 10 款最佳数据恢复工具

对于依赖计算机处理重要文件(无论是个人照片还是重要业务文档)的任何人来说,数据丢失都可能是一场噩梦。 值得庆幸的是,有多种数据恢复工具专门用于Mac用户,可以帮助您恢复丢失或意外删除的文件。 在本文中&#xff0c…

Rust图形界面:eGUI的Panel布局

文章目录 Panel布局尺寸调节源码 Panel布局 eGUI提供面板堆叠的布局方案,即Panel布局。其布局逻辑是,根据当前面板指定的方向,尽可能地填充空间。 CentralPanel 占据屏幕剩余部分的空间SidePanel 占据屏幕两侧的空间,在具体调用…

【网络】TCP协议的相关实验

TCP协议的相关实验 一、理解listen的第二个参数1、实验现象2、TCP 半连接队列和全连接队列3、关于listen的第二个参数的一些问题4、SYN洪水Ⅰ、什么是SYN洪水攻击Ⅱ、如何解决SYN洪水攻击? 二、使用Wireshark分析TCP通信流程 一、理解listen的第二个参数 在编写TCP…

【23真题】无耻!“官方”假真题!害人!

这套华侨23真题是学弟给我从考场抄出来的版本,我刚刚做完解析!后台就收到了另外一份“官方华侨23真题”的投稿。我本想对对回忆版,补充下题干。结果一对吓一跳!竟然一道题都不一样!给大家看下,真的好逼真&a…

《向量数据库指南》——TruLens + Milvus Cloud 构建RAG案例

具体案例 如前所述,RAG 配置选择可能对消除幻觉产生重大影响。下文中将基于城市百科文章构建问答 RAG 应用并展示不同的配置选择是如何影响应用性能的。在搭建过程中,我们使用 LlamaIndex 作为该应用的框架。大家可以在 Google Colab( https://colab.research.google.com/git…

Theory behind GAN

假如要生成一些人脸图,实际上就是想要找到一个分布,从这个分布内sample出来的图片像是人脸,分布之外生成的就不像人脸。而GAN要做的就是找到这个distribution。 在GAN之前用的是Maximum Likelihood Estimation。 Maximum Likelihood Estimat…

【用unity实现100个游戏之15】开发一个类保卫萝卜的Unity2D塔防游戏2(附项目源码)

文章目录 先看本次实现的最终效果前言敌人生命值扣血测试,敌人死亡控制敌人动画敌人死亡动画敌人转向问题源码完结 先看本次实现的最终效果 前言 本期紧接着上一篇,本期主要内容是实现敌人血条、动画和行为逻辑。 敌人生命值 绘制血条UI 新建 publ…

快速入门:构建您的第一个 .NET Aspire 应用程序

##前言 云原生应用程序通常需要连接到各种服务,例如数据库、存储和缓存解决方案、消息传递提供商或其他 Web 服务。.NET Aspire 旨在简化这些类型服务之间的连接和配置。在本快速入门中,您将了解如何创建 .NET Aspire Starter 应用程序模板解决方案。 …