基于keras中Lenet对于mnist的处理

news2025/1/17 17:57:39

文章目录

  • MNIST
  • 导入必要的包
    • 加载数据
    • 可视化数据集
    • 查看数据集的分布
    • 开始训练
      • 画出loss图
      • 画出accuracy图
    • 使用数据外的图来测试
    • 图片可视化
    • 转化灰度图的可视化
    • 可视化卷积层的特征图
      • 第一层卷积 conv1 和 pool1
      • 第二层卷积 conv2 和 pool2

MNIST

MNIST(Modified National Institute of Standards and Technology database)是一个经典的手写数字数据集,通常用于计算机视觉和机器学习的基准测试。以下是关于MNIST数据集的介绍:

数据集内容:

MNIST数据集包含了大约70000张28x28像素的手写数字图片。
这些图片包括了从0到9的10个不同数字,每个数字都有大约7000张图片。

用途:

MNIST数据集通常用于图像分类任务,目标是将手写数字图片分为0到9的10个类别。
它被广泛用于测试和验证各种图像处理和机器学习算法,特别是深度学习模型。

数据特点:

  • 每张图片都是灰度图像,即只有一个颜色通道(黑白)。
  • 图像的大小固定为28x28像素,总共784个像素。
  • 每个像素的值在0到255之间,表示像素的亮度。

挑战:

MNIST数据集相对较小,对于现代深度学习模型来说,通常被认为是一个相对简单的任务。
然而,MNIST仍然具有一定的挑战性,因为手写数字的风格和字体会有很大的差异,有些数字可能写得非常潦草或难以识别。

应用领域:

MNIST数据集通常用于教育和研究,帮助初学者理解图像分类和深度学习概念。
它还可以作为一个基准测试数据集,用于验证新的机器学习算法或深度学习架构的性能。

在这里插入图片描述

https://classic.d2l.ai/chapter_convolutional-neural-networks/lenet.html

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam

def leNet_model():
    model = Sequential()
    model.add(Conv2D(30, (5, 5), input_shape=(28, 28, 1), activation='relu'))  # 24*24*30
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(15, (3, 3), activation='relu'))  # 15*30*3*3
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(500, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    
    # 使用Adam优化器,学习率为0.01
    optimizer = Adam(learning_rate=0.01)
    
    # 编译模型,使用交叉熵损失函数和准确率作为指标
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    
    return model

model = leNet_model()
print(model.summary())
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_6 (Conv2D)           (None, 24, 24, 30)        780       
                                                                 
 max_pooling2d_6 (MaxPoolin  (None, 12, 12, 30)        0         
 g2D)                                                            
                                                                 
 conv2d_7 (Conv2D)           (None, 10, 10, 15)        4065      
                                                                 
 max_pooling2d_7 (MaxPoolin  (None, 5, 5, 15)          0         
 g2D)                                                            
                                                                 
 flatten_3 (Flatten)         (None, 375)               0         
                                                                 
 dense_5 (Dense)             (None, 500)               188000    
                                                                 
 dropout_3 (Dropout)         (None, 500)               0         
                                                                 
 dense_6 (Dense)             (None, 10)                5010      
                                                                 
=================================================================
Total params: 197855 (772.87 KB)
Trainable params: 197855 (772.87 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

导入必要的包

!pip install Keras==2.0.6
!pip install np_utils
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.datasets import mnist
from keras.utils.np_utils import to_categorical
import random
from keras.models import Model

加载数据

# 设置随机种子,以确保结果可重复
np.random.seed(0)

# 使用Keras的mnist.load_data()加载MNIST数据集,将数据集分为训练集和测试集
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# 打印训练集和测试集的形状(维度)
print(X_train.shape) #(60000, 28, 28)
print(X_test.shape) #(10000, 28, 28)

可视化数据集

# 创建一个空列表用于存储每个数字类别的样本数量
num_of_samples = []

# 定义图形中的列数和数字类别总数
cols = 5
num_classes = 10

# 创建一个图形子图,包含num_classes行和cols列
fig, axs = plt.subplots(nrows=num_classes, ncols=cols, figsize=(5, 10))
fig.tight_layout()

# 遍历每列和每个数字类别
for i in range(cols):
    for j in range(num_classes):
        # 从训练集中选择特定类别的图像
        x_selected = X_train[y_train == j]
        # 随机选择一张图像并显示在子图中
        axs[j][i].imshow(x_selected[random.randint(0, (len(x_selected) - 1)), :, :], cmap=plt.get_cmap('gray'))
        axs[j][i].axis("off")
        if i == 2:
            axs[j][i].set_title(str(j))  # 在中间列添加标题
            num_of_samples.append(len(x_selected))  # 记录每个类别的样本数量

# 打印每个数字类别的样本数量
print(num_of_samples)

# 创建一个新的图形,用于显示样本数量
plt.figure(figsize=(12, 4))

生成5示例和10个种类

查看数据集的分布

# 使用Matplotlib绘制柱状图,展示训练数据集中各个类别的样本数量分布
plt.bar(range(0, num_classes), num_of_samples)  # 创建柱状图,x轴为类别编号,y轴为样本数量
plt.title("Distribution of the train dataset")  # 设置图表标题
plt.xlabel("Class number")  # 设置x轴标签
plt.ylabel("Number of images")  # 设置y轴标签
plt.show()  # 显示柱状图

# 重新调整图像数据的形状,以匹配卷积神经网络(CNN)的输入要求
X_train = X_train.reshape(60000, 28, 28, 1)  # 将训练数据集的形状重塑为(样本数量, 28, 28, 1)
X_test = X_test.reshape(10000, 28, 28, 1)  # 将测试数据集的形状重塑为(样本数量, 28, 28, 1)

# 对训练和测试标签进行独热编码,以适应多类别分类任务
y_train = to_categorical(y_train, 10)  # 将训练标签独热编码为(样本数量, 10)
y_test = to_categorical(y_test, 10)    # 将测试标签独热编码为(样本数量, 10)

# 数据归一化:将图像像素值从0到255的范围缩放到0到1的范围,有助于模型训练
X_train = X_train / 255  # 训练数据集归一化处理
X_test = X_test / 255    # 测试数据集归一化处理

在这里插入图片描述

开始训练

history  = model.fit(X_train, y_train,\
epochs=10,  validation_split = 0.1, batch_size = 400,\
 verbose = 1, shuffle = True)
参数描述
X_train训练数据
y_train训练标签
epochs训练的轮数
validation_split将训练数据的一部分用于验证的比例
batch_size每个小批量的样本数量
verbose控制训练过程中的输出信息级别
shuffle是否在每轮训练前随机打乱训练数据

画出loss图

# 使用Matplotlib绘制训练过程中的损失曲线
plt.plot(history.history['loss'])       # 绘制训练集上的损失值曲线
plt.plot(history.history['val_loss'])   # 绘制验证集上的损失值曲线
plt.legend(['loss', 'val_loss'])        # 添加图例,标记曲线含义
plt.title('Loss')                       # 设置图表标题为"Loss"
plt.xlabel('epoch')                     # 设置x轴标签为"epoch",表示训练轮数

画出accuracy图

# 使用Matplotlib绘制训练过程中的准确率曲线
plt.plot(history.history['accuracy'])         # 绘制训练集上的准确率曲线
plt.plot(history.history['val_accuracy'])     # 绘制验证集上的准确率曲线
plt.legend(['accuracy', 'val_accuracy'])      # 添加图例,标记曲线含义
plt.title('Accuracy')                        # 设置图表标题为"Accuracy",表示准确率曲线的含义
plt.xlabel('epoch')                          # 设置x轴标签为"epoch",表示训练轮数

使用数据外的图来测试

给图片加点noise,看看测试结果

图片可视化

import requests  # 导入 requests 库,用于发送 HTTP 请求
from PIL import Image  # 导入 PIL 库,用于图像处理
import numpy as np  # 导入 NumPy 库,用于数组操作
import matplotlib.pyplot as plt  # 导入 Matplotlib 库,用于图像显示

# 指定要下载的图像的URL
url = "https://colah.github.io/posts/2014-10-Visualizing-MNIST/img/mnist_pca/MNIST-p1815-4.png"

# 发送HTTP GET请求,获取图像数据,stream=True 表示以流的方式获取数据
response = requests.get(url, stream=True)

# 打印HTTP响应状态码,以确认是否成功获取图像数据
print(response)

# 使用PIL库打开图像,并将数据存储在img对象中
img = Image.open(response.raw)

# 使用Matplotlib显示图像,指定灰度色彩映射
plt.imshow(img, cmap=plt.get_cmap('gray'))

# 将图像转换为NumPy数组
img_array = np.asarray(img)

# 打印图像数组的形状,以了解图像的尺寸和通道数
print(img_array.shape)

在这里插入图片描述

转化灰度图的可视化

import cv2
resized = cv2.resize(img_array,(28,28)) #通过opencv像素
gray_scale = cv2.cvtColor(resized,cv2.COLOR_BGR2GRAY)#取灰度值
image = cv2.bitwise_not(gray_scale)# 对灰度图像进行按位取反操作,即将图像中的白色像素变为黑色,黑色像素变为白色
plt.imshow(gray_scale,cmap=plt.get_cmap("gray")) #将图片显示位黑白

在这里插入图片描述

image = image/255
image = image.reshape(1,28,28,1) #让他和我们训练得时候一样
prediction = np.argmax(model.predict(image))#多类数据集得预测模型
print("预测结果: ",str(prediction))
1/1 [==============================] - 0s 18ms/step
预测结果数字:  4
#看测试结果
score = model.evaluate(X_test,y_test,verbose=0)
print(type(score))
print('Test socre', score[0])
print('Test accuracy',score[1])

可视化卷积层的特征图

# 利用 Model API 获取模型中间层的输出
# 创建两个新的模型,layer1 和 layer2,分别将输入和输出连接到模型的第一个和第三个层
layer1 = Model(inputs=model.layers[0].input, outputs=model.layers[0].output)
layer2 = Model(inputs=model.layers[0].input, outputs=model.layers[2].output)

# 使用 layer1 和 layer2 对输入图像进行预测,获取中间层的输出
visual_layer1, visual_layer2 = layer1.predict(image), layer2.predict(image)

# 打印中间层的输出形状
print(visual_layer1.shape)
print(visual_layer2.shape)

第一层卷积 conv1 和 pool1

# 创建一个 10x6 的大图,用于显示多个特征图
plt.figure(figsize=(10, 6))

# 循环遍历每个卷积核的特征图
for i in range(30):
    # 在大图中创建子图,6行5列,i+1 表示子图的位置
    plt.subplot(6, 5, i+1)
    
    # 显示特征图,使用 'jet' 色彩映射以增强可视化效果
    plt.imshow(visual_layer1[0, :, :, i], cmap=plt.get_cmap('jet'))
    
    # 关闭坐标轴
    plt.axis('off')

# 显示整个图像
plt.show()

这段代码通过循环遍历第一层卷积层的每个卷积核(共30个),并将其特征图可视化显示出来。每个特征图都以不同的颜色显示,通过色彩映射(‘jet’)可以增强特征图的可视化效果。这有助于理解卷积层在图像中检测到的不同特征或模式。

通过在 plt.subplot() 中设置合适的行列数和位置,可以将多个特征图显示在同一图像中,以便一次性查看多个特征。这种可视化方法有助于深入了解神经网络的特征提取过程。

在这里插入图片描述

第二层卷积 conv2 和 pool2

plt.figure(figsize=(10,6))
for i in range(15):
  plt.subplot(3,5,i+1)
  plt.imshow(visual_layer2[0,:,:,i],cmap=plt.get_cmap('jet'))
  plt.axis('off')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/997901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零开始的PICO教程(4)--- UI界面绘制与响应事件

从零开始的PICO教程(4)— UI界面绘制与响应事件 文章目录 从零开始的PICO教程(4)--- UI界面绘制与响应事件一、前言1、大纲2、教程示例 二、具体步骤1、PICO VR环境配置2、XR的UI Canvas画布创建与调整(1)C…

【雷达原理】雷达信号级建模与仿真

目录 前言一、LFMCW信号概述1.1 优点1.2 缺点 二、LFMCW信号模型2.1 发射信号模型2.2 接收信号模型2.3 信号混频 三、MATLAB仿真3.1 仿真结果3.2 代码 四、参考文献 前言 雷达信号形式多种多样,按照雷达的体制进行分类,有脉冲雷达和连续波雷达。脉冲雷达…

C#程序到底从哪里开始看,从Main函数开始,那么Main函数是什么?

视觉人机器视觉粉丝问我,拿到自己公司得架构,问我,C#程序到底从哪里看,从Main函数开始,那么Main函数是什么? Main()函数 Main()是C#应用程序的入口点,执行这个函数就是执行应用程序。也就是说,在执行过程开始时,会执行Main()函数,在Main()函数执行完毕时,执行过…

微信小程序上拉触底事件

一、什么是上拉触底事件 上拉触底是移动端的专有名词,通过手指在屏幕上的上拉滑动操作,从而加载更多数据的行为。 二、监听上拉触底事件 在页面的.js文件中,通过onReachBottom()函数即可监听当前页面的上拉触底事件。 三、配置上拉触底距…

vue组件库开发,webpack打包,发布npm

做一个像elment-ui一样的vue组件库 那多好啊!这是我前几年就想做的 但webpack真的太难用,也许是我功力不够 今天看到一个视频,早上6-13点,终于实现了,呜呜 感谢视频的分享-来龙去脉-大家可以看这个视频:htt…

【List篇】ArrayList 的线程不安全介绍

ArrayList 为什么线程不安全? 主要原因是ArrayList是非同步的,没有同步机制,并且其底层实现是基于数组,而数组的长度是固定的。当对 ArrayList 进行增删操作时,需要改变数组的长度,这就会导致多个线程可能同时操作同一个数组&…

Unlikely argument type for equals(): int seems to be unrelated to String

前面字符串 后面数值 if (new Integer(2).equals(loginUser.getStatus())) 或者另外定义一个吧

JAVASE 窗口

本文目录 1、前言2、JFrame、JButton3、JLabl4、ImageIcon 1、前言 java提供了很多已经写好了的类供我们使用,而我们没必要去细腻研究它的构成原理,就好比我们让我们编程让机器人动起来,没必要细腻研究机器人每个器件是怎么做出来的一样&…

Qt Designer UI设计布局小结

目录 前言1 居中布局2 左右布局3 上下布局4 复杂页面布局总结 前言 本文总结了在开发Qt应用程序时使用 Designer 进行UI布局的一些心得体会。Qt Designer是Qt提供的一个可视化界面设计工具,旨在帮助开发人员快速创建和布局用户界面。它提供了丰富的布局管理器和控件…

系统架构设计专业技能 · 计算机组成与结构

现在的一切都是为将来的梦想编织翅膀,让梦想在现实中展翅高飞。 Now everything is for the future of dream weaving wings, let the dream fly in reality. 点击进入系列文章目录 系统架构设计高级技能 计算机组成与结构 一、计算机结构1.1 CPU 组成1.2 冯诺依曼…

【数据结构-队列】阻塞队列

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

vue学习之 v-for key

v-for key Vue使用 v-for渲染的元素列表时&#xff0c;它默认使用“就地更新”的策略。如果数据项的顺序被改变&#xff0c;Vue 将不会移动 DOM元素来匹配数据项的顺序&#xff0c;而是就地更新每个元素。创建 demo9.html,内容如下 <!DOCTYPE html> <html lang"…

60、RESTful 的高级配置---HttpMessageConverter

★ HttpMessageConverter的作用 RequestBody修饰处理方法的参数&#xff0c;如获取json格式的数据&#xff0c;将json格式的数据转换成我们需要的java对象&#xff0c; ResponseBody 这些把对象转成json格式响应给前端&#xff0c; 底层都是由这个HttpMessageConverter类实现的…

【Redis专题】大厂生产级Redis高并发分布式锁实战

目录 前言课程内容一、一个案例引发的思考二、Redis分布式锁的演进2.1 单纯使用Redis的setnx实现分布式锁2.2 setnx 过期时间3.3 Redisson实现分布式锁&#xff1a;setnx 过期时间 锁续命 三、Redisson客户端实现的分布式锁及源码分析 学习总结 前言 Redis中间件&#xff0…

文件上传之图片码混淆绕过(upload的16,17关)

目录 1.upload16关 1.上传gif loadup17关&#xff08;文件内容检查&#xff0c;图片二次渲染&#xff09; 1.上传gif&#xff08;同上面步骤相同&#xff09; 2.条件竞争 1.upload16关 1.上传gif imagecreatefromxxxx函数把图片内容打散&#xff0c;&#xff0c;但是不会…

Selenium - Tracy 小笔记2

selenium本身是一个自动化测试工具。 它可以让python代码调用浏览器。并获取到浏览器中加们可以利用selenium提供的各项功能。帮助我们完成数据的抓取。它容易被网站识别到&#xff0c;所以有些网站爬不到。 它没有逻辑&#xff0c;只有相应的函数&#xff0c;直接搜索即可 …

dubbo 服务注册使用了内网IP,而服务调用需要使用公网IP进行调用

一、问题描述&#xff1a; 使用dubbo时&#xff0c;提供者注册时显示服务地址ip为[内网IP:20880]&#xff0c;导致其他消费者在外部连接的情况下时&#xff0c;调用dubbo服务失败 二、解决办法 方法一、修改hosts文件 &#xff08;1&#xff09;. 先查询一下服务器的hostna…

【动态规划刷题 13】最长递增子序列 摆动序列

300. 最长递增子序列 链接: 300. 最长递增子序列 1.状态表示* dp[i] 表⽰&#xff1a;以 i 位置元素为结尾的「所有⼦序列」中&#xff0c;最⻓递增⼦序列的⻓度。 2.状态转移方程 对于 dp[i] &#xff0c;我们可以根据「⼦序列的构成⽅式」&#xff0c;进⾏分类讨论&#…

RabbitMQ管控台使用

安装成功RabbitMQ后&#xff0c;进入到管理控制台界面 拷贝配置文件到指定目录当中然后重启RabbitMQ。

FIR滤波器简述及FPGA仿真验证

数字滤波器的设计&#xff0c;本项目做的数字滤波器准确来说是FIR滤波器。 FIR滤波器&#xff08;有限冲激响应滤波器&#xff09;&#xff0c;与另一种基本类型的数字滤波器——IIR滤波器&#xff08;无限冲击响应滤波器&#xff09;相对应&#xff0c;其实就是将所输入的信号…