CNN对 MNIST 数据库中的图像进行分类

news2025/1/14 1:06:31

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图
from keras.datasets import mnist

# 使用 Keras 导入MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

 将前六个训练图像可视化

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np

# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):
    ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])
    ax.imshow(X_train[i], cmap='gray')
    ax.set_title(str(y_train[i]))

查看图像的更多细节 

def visualize_input(img, ax):
    ax.imshow(img, cmap='gray')
    width, height = img.shape
    thresh = img.max()/2.5
    for x in range(width):
        for y in range(height):
            ax.annotate(str(round(img[x][y],2)), xy=(y,x),
                        horizontalalignment='center',
                        verticalalignment='center',
                        color='white' if img[x][y]<thresh else 'black')

fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

 预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# 调整比例,使数值在 0 - 1 范围内 [0,255] --> [0,1]
X_train = X_train.astype('float32')/255
X_test = X_test.astype('float32')/255 

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

 对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import to_categorical

num_classes = 10 
# 打印前十个(整数值)训练标签
print('Integer-valued labels:')
print(y_train[:10])

# 对标签进行一次性编码
# 将类别向量转换为二进制类别矩阵
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# 打印前十个(单次)训练标签
print('One-hot labels:')
print(y_train[:10])

 重塑数据以适应我们的 CNN(和 input_shape)

# 输入图像尺寸为 28x28 像素的图像。
img_rows, img_cols = 28, 28

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

print('input_shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

您必须传递以下参数:

  • filters - 滤波器的数量。
  • kernel_size - 指定(正方形)卷积窗口高度和宽度的数值。

还有一些额外的、可选的参数需要调整:

  • strides - 卷积的步长。如果不指定任何参数,strides 将设为 1。
  • padding - "有效 "或 "相同 "之一。如果不做任何指定,padding 将设置为 "有效"。
  • activation - 通常为 "relu"。如果不指定任何内容,则不会应用激活。我们强烈建议你为网络中的每个卷积层添加 ReLU 激活函数。

 需要注意的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除网络中的最后一层外,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最终层应是具有 softmax 激活函数的密集层。最终层的节点数应等于数据集中的类总数。
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

# 创建模型对象
model = Sequential()

# CONV_1: 添加 CONV 层,采用 RELU 激活,深度 = 32 内核
model.add(Conv2D(32, kernel_size=(3, 3), padding='same',activation='relu',input_shape=(28,28,1)))
# POOL_1: 对图像进行下采样,选择最佳特征
model.add(MaxPooling2D(pool_size=(2, 2)))

# CONV_2: 在这里,我们将深度增加到 64
model.add(Conv2D(64, (3, 3),padding='same', activation='relu'))
# POOL_2: more downsampling
model.add(MaxPooling2D(pool_size=(2, 2)))

# 由于维度过多,我们只需要一个分类输出
model.add(Flatten())

# FC_1: 完全连接,获取所有相关数据
model.add(Dense(64, activation='relu'))

# FC_2: 输出软最大值,将矩阵压制成 10 个类别的输出概率
model.add(Dense(10, activation='softmax'))

model.summary()

需要注意的事项:
  • 网络以两个卷积层的序列开始,然后是最大池化层。
  • 最后一层为数据集中的每个对象类别设置了一个条目,并具有软最大激活函数,因此可以返回概率。
  • Conv2D 深度从输入层的 1 增加到 32 到 64。
  • 我们还想减少高度和宽度--这就是 maxpooling 的作用所在。请注意,在池化层之后,图像尺寸从 28 减小到 14。
  • 可以看到,每个输出形状都用 None 代替了批量大小。这是为了便于在运行时更改批次大小。
  • 最后,我们会添加一个或多个全连接层来确定图像中包含的对象。例如,如果在上一个最大池化层中发现了车轮,那么这个 FC 层将转换该信息,以更高的概率预测图像中出现了一辆汽车。如果图像中有眼睛、腿和尾巴,那么这可能意味着图像中有一只狗。

编译模型

# rmsprop 和自适应学习率 (adaDelta) 是梯度下降的流行形式,仅次于 adam 和 adagrad
# 因为我们有多个类别 (10)

# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', 
              metrics=['accuracy'])

训练模型

from keras.callbacks import ModelCheckpoint   

# 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, 
                               save_best_only=True)
hist = model.fit(X_train, y_train, batch_size=32, epochs=12,
          validation_data=(X_test, y_test), callbacks=[checkpointer], 
          verbose=2, shuffle=True)

 在验证集上加载分类准确率最高的模型

# 加载能获得最佳验证精度的权重
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率 

# 评估测试的准确性
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]

# 打印测试精度
print('Test accuracy: %.4f%%' % accuracy)

 

注意事项:

MLP 和 CNN 通常不会产生可比较的结果。MNIST 数据集非常特别,因为它非常干净,而且经过了完美的预处理。例如,所有图像大小相同,并以 28x28 像素网格为中心。如果数字稍有偏斜或不居中,这项任务就会难得多。对于真实世界中杂乱无章的图像数据,CNN 将真正超越 MLP。

为了直观地了解为什么会出现这种情况,要将图像输入 MLP,首先必须将图像转换为矢量。然后,MLP 会将图像视为没有特殊结构的简单数字向量。它不知道这些数字原本是按空间网格排列的。

相比之下,CNN 的设计目的完全相同,即处理多维数据中的模式。与 MLP 不同的是,CNN 知道,相距较近的图像像素比相距较远的像素关系密切。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1259470.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

防火墙简介

防火墙概念 是指一种将内部网和公众访问网&#xff08;如Internet&#xff09;分开的方法&#xff0c;它实际上是一种建立在现代通信网络技术和信息安全技术基础上的应用性安全技术&#xff0c;隔离技术。 将需要保护的网络和不可信网络进行隔离&#xff0c;隐藏信息并…

【华为OD】统一考试B\C卷真题 100%通过:开源项目热榜 C/C++实现

目录 题目描述&#xff1a; 示例1 示例2 题目描述&#xff1a; 某个开源社区希望将最近热度比较高的开源项目出一个榜单&#xff0c;推荐给社区里面的开发者。对于每个开源项目&#xff0c;开发者可以进行关注(watch)、收藏(star)、fork、提issue、提交合并请求(MR)等。 数…

振南技术干货集:znFAT 硬刚日本的 FATFS 历险记(2)

注解目录 1、znFAT 的起源 1.1 源于论坛 &#xff08;那是一个论坛文化兴盛的年代。网友 DIY SDMP3 播放器激起了我的兴趣。&#xff09; 1.2 硬盘 MP3 推了我一把 &#xff08;“坤哥”的硬盘 MP3 播放器&#xff0c;让我深陷 FAT 文件系统不能自拔。&#xff09; 1.3 我…

spring Cloud在代码中如何应用,erueka 客户端配置 和 服务端配置,Feign 和 Hystrix做高可用配置

文章目录 Eureka一、erueka 客户端配置二、eureka 服务端配置 三、高可用配置FeignHystrix 通过这篇文章来看看spring Cloud在代码中的具体应用&#xff0c;以及配置和注解&#xff1b; Eureka 一、erueka 客户端配置 1、Eureka 启禁用 eureka.client.enabledtrue 2、Eurek…

在Windows上配置MySql开发java,导入JDBC的jar包后连接SQL Server数据库结合Java和MySql的一些简单实践

在Windows上配置MySql 我们先进入MySql官网 在官网中选择MySQL Installer for Windows 进入后选择第一个下载 接下来安装即可&#xff0c;在安装时&#xff0c;可以只安装MySql Server&#xff08;默认选项&#xff09;,选择Full也可&#xff0c;这样会同时安装workbench以及…

智能优化算法应用:基于蜻蜓算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于蜻蜓算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于蜻蜓算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.蜻蜓算法4.实验参数设定5.算法结果6.参考文献7.MATLAB…

SocialFi 和 GameFi 的碰撞 — Socrates 构建新的 Web3 流量入口

伴随着比特币现货 ETF 即将通过 SEC 批准的消息&#xff0c;整个加密市场在11月份达到了熊市以来的新高峰。市场普遍上涨&#xff0c;新的玩法和项目不断涌出吸引了大量老用户回归以及新用户加入。加密市场经过长期的低迷&#xff0c;终于来到了牛市的起点&#xff01; 上一轮牛…

[C++]六大默认成员函数详解

☃️个人主页&#xff1a;fighting小泽 &#x1f338;作者简介&#xff1a;目前正在学习C和Linux &#x1f33c;博客专栏&#xff1a;C入门 &#x1f3f5;️欢迎关注&#xff1a;评论&#x1f44a;&#x1f3fb;点赞&#x1f44d;&#x1f3fb;留言&#x1f4aa;&#x1f3fb; …

1980-2022年世界各国专利、商标申请数据/世界各国知识产权专利申请数据

1980-2022年世界各国专利、商标申请数据/世界各国知识产权专利申请数据 1、时间&#xff1a;1980-2022年 2、来源&#xff1a;WIPO数据库 3、范围&#xff1a;世界各国&#xff08;180多个国家&#xff09; 4、指标&#xff1a;国家名称、年份、代码、类型、专利申请总量、…

老鸟整理,银行测试业务+银行测试案例编写,超细汇总...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 银行的软件测试是…

2023中国SaaS大会完美收官,体验管理开辟SaaS续费增长新曲线

11月17日-19日&#xff0c;2023中国SaaS大会在苏州太湖万豪酒店完美收官。本场专属于SaaS人的行业盛会&#xff0c;设有运动会、实战闭门会、公开课、辩论赛、嘉宾对话及演讲等多项精彩活动&#xff0c;吸引了千余名To B&#xff08;SaaS&#xff09;领域创业者、投资人、企业客…

古埃及金字塔的修建

从理论上说&#xff0c;古埃及人完全有能力设计并建造出充满各种奇妙细节的胡夫金字塔&#xff0c;但后世还是不断涌现出质疑之声&#xff0c;原因倒也简单&#xff0c;那就是胡夫金字塔实在太大了。据推算&#xff0c;整座金字塔使用大约230万块巨石&#xff0c;总质量可达约5…

通俗易懂的spring Cloud;业务场景介绍 二、Spring Cloud核心组件:Eureka 、Feign、Ribbon、Hystrix、zuul

文章目录 通俗易懂的spring Cloud一、业务场景介绍二、Spring Cloud核心组件&#xff1a;Eureka三、Spring Cloud核心组件&#xff1a;Feign四、Spring Cloud核心组件&#xff1a;Ribbon五、Spring Cloud核心组件&#xff1a;Hystrix六、Spring Cloud核心组件&#xff1a;Zuul七…

深度学习中小知识点系列(五) 解读HSV模型随机增强图像

文章目录 图像HSV模型简介RGB模型转HSV模型opencv关于HSV模型实验随机增强图像HSV 图像HSV模型简介 HSV(Hue, Saturation, Value)是根据颜色的直观特性由A. R. Smith在1978年创建的一种颜色空间, 也称六角锥体模型(Hexcone Model)&#xff08;参考百度&#xff09;。在HSV模型…

Java研学-集合框架

一 关于集合框架 1 集合是Java提出的用来进行多个数据存储的"容器",数组也具备这样的功能, 2 由于数组一旦创建长度固定,且只能存放一种数据类型,不够灵活,Java提出更灵活,存放任意的数据类型的容器也就是集合 3 集合和数组的异同点 相同点&#xff1a;都是用来存…

西南科技大学数字电子技术实验一(数字信号基本参数与逻辑门电路功能测试及FPGA 实现 )预习报告

手写报告稍微认真点写,80+随便有 目录 一、计算/设计过程 1、通过虚拟示波器观察和测量信号 2、通过实际电路(电阻、开关、发光二极管)模拟逻辑门电路 二、画出并填写实验指导书上的预表

ELK----日志分析

ELK相关知识 ELK的概念与组件 ELK平台是一套完整的日志集中处理解决方案&#xff0c;将 ElasticSearch、Logstash 和 Kiabana 三个开源工具配合使用&#xff0c; 完成更强大的用户对日志的查询、排序、统计需求。 E&#xff1a;ElasticSearch &#xff08;ES&#xff09; ES是…

智能电表——电源应用

作为智能电网的重要组成部分&#xff0c;智能电表在智能电网中发挥着不可或缺的作用。智能电表是指以智能芯片为核心&#xff0c;通过运用通讯技术以及计算机技术等&#xff0c;能够进行电能计费、电功率的计量和计时&#xff0c;并且能够和上位机进行通讯、用电管理的电度表。…

STK Components 二次开发- 区域

1.创建区域 需要提供点坐标。最少三个点可以确定一个区域。 创建区域也是一样&#xff0c;创建对象然后设置点位置 &#xff0c;然后设置区域属性。 var referenceSurface m_earth.Shape; // We specify the boundary in terms of nodes connected by geodesics.var result…

PlantUML语法(全)及使用教程-时序图

目录 1. 参与者1.1、参与者说明1.2、背景色1.3、参与者顺序 2. 消息和箭头2.1、 文本对其方式2.2、响应信息显示在箭头下面2.3、箭头设置2.4、修改箭头颜色2.5、对消息排序 3. 页面标题、眉角、页脚4. 分割页面5. 生命线6. 填充区设置7. 注释8. 移除脚注9. 组合信息9.1、alt/el…