python与深度学习(七):CNN和fashion_mnist

news2025/1/9 1:38:52

目录

  • 1. 说明
  • 2. fashion_mnist实战
  • 2.1 导入相关库
    • 2.2 加载数据
    • 2.3 数据预处理
    • 2.4 数据处理
    • 2.5 构建网络模型
    • 2.6 模型编译
    • 2.7 模型训练
    • 2.8 模型保存
    • 2.9 模型评价
    • 2.10 模型测试
    • 2.11 模型训练结果的可视化
  • 3. fashion_mnist的CNN模型可视化结果图
  • 4. 完整代码

1. 说明

本篇文章是CNN的另外一个例子,fashion_mnist,是衣服类的数据集。
可以搭建和手写数字识别的一样的模神经网络来训练模型。

2. fashion_mnist实战

2.1 导入相关库

以下第三方库是python专门用于深度学习的库

# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# 引入绘制acc和loss曲线的库
import matplotlib.pyplot as plt
# 引入ANN的必要的类
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten
from keras.models import Sequential
from keras import optimizers, losses

2.2 加载数据

把fashion_mnist数据集进行加载

"1.加载数据"
"""
x_train是fashion_mnist训练集图片,大小的28*28的,y_train是对应的标签是数字
x_test是fashion_mnist测试集图片,大小的28*28的,y_test是对应的标签是数字
"""
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()  # 加载fashion_mnist数据集
print('mnist_data:', x_train.shape, y_train.shape, x_test.shape, y_test.shape)  # 打印训练数据和测试数据的形状

2.3 数据预处理

(1) 将输入的图片进行归一化,从0-255变换到0-1;
(2) 将输入图片的形状(60000,28,28)转换成(60000,28,28,1),相当于将图片拉直,便于输入给神经网络;
(3) 将标签y进行独热编码,因为神经网络的输出是10个概率值,而y是1个数, 计算loss时无法对应计算,因此将y进行独立编码成为10个数的行向量,然后进行loss的计算 独热编码:例如数值1的10分类的独热编码是[0 1 0 0 0 0 0 0 0 0,即1的位置为1,其余位置为0。

"2.数据预处理"


def preprocess(x, y):  # 数据预处理函数
    x = tf.cast(x, dtype=tf.float32) / 255.  # 将输入的图片进行归一化,从0-255变换到0-1
    x = tf.reshape(x, [28, 28, 1])
    """
    # 将输入图片的形状(60000,28,28)转换成(60000,28,28,1),
    相当于将图片拉直,便于输入给神经网络
    """
    y = tf.cast(y, dtype=tf.int32)  # 将输入图片的标签转换为int32类型
    y = tf.one_hot(y, depth=10)
    """
    # 将标签y进行独热编码,因为神经网络的输出是10个概率值,而y是1个数,
    计算loss时无法对应计算,因此将y进行独立编码成为10个数的行向量,然后进行loss的计算
    独热编码:例如数值1的10分类的独热编码是[0 1 0 0 0 0 0 0 0 0,即1的位置为1,其余位置为0
    """
    return x, y

2.4 数据处理

数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
通过 Dataset.from_tensor_slices 可以将训练部分的数据图片 x 和标签 y 都转换成Dataset 对象

batchsz = 128  # 每次输入给神经网络的图片数
"""
数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
通过 Dataset.from_tensor_slices 可以将训练部分的数据图片 x 和标签 y 都转换成Dataset 对象
"""
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))  # 构建训练集对象
db = db.map(preprocess).shuffle(60000).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理
ds_val = tf.data.Dataset.from_tensor_slices((x_test, y_test))  # 构建测试集对象
ds_val = ds_val.map(preprocess).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理

2.5 构建网络模型

构建了两层卷积层,两层池化层,然后是展平层(将二维特征图拉直输入给全连接层),然后是三层全连接层。

"3.构建网络模型"
model = Sequential([Conv2D(filters=6, kernel_size=(5, 5), activation='relu'),
                    MaxPool2D(pool_size=(2, 2), strides=2),
                    Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),
                    MaxPool2D(pool_size=(2, 2), strides=2),
                    Flatten(),
                    Dense(120, activation='relu'),
                    Dense(84, activation='relu'),
                    Dense(10,activation='softmax')])

model.build(input_shape=(None, 28, 28, 1))  # 模型的输入大小
model.summary()  # 打印网络结构

2.6 模型编译

模型的优化器是Adam,另外一种优化方法,学习率是0.01,
损失函数是losses.CategoricalCrossentropy,多分类交叉熵,
性能指标是正确率accuracy。

"4.模型编译"
model.compile(optimizer='Adam',
              loss=losses.CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy']
                )
"""
模型的优化器是Adam
损失函数是losses.CategoricalCrossentropy,
性能指标是正确率accuracy
"""

2.7 模型训练

模型训练的次数是20,每1次循环进行测试

"5.模型训练"
history = model.fit(db, epochs=20, validation_data=ds_val, validation_freq=1)
"""
模型训练的次数是20,每1次循环进行测试
"""

2.8 模型保存

以.h5文件格式保存模型

"6.模型保存"
model.save('cnn_fashion.h5')  # 以.h5文件格式保存模型

2.9 模型评价

得到测试集的正确率

"7.模型评价"
model.evaluate(ds_val)  # 得到测试集的正确率

2.10 模型测试

对模型进行测试

"8.模型测试"
sample = next(iter(ds_val))  # 取一个batchsz的测试集数据
x = sample[0]  # 测试集数据
y = sample[1]  # 测试集的标签
pred = model.predict(x)  # 将一个batchsz的测试集数据输入神经网络的结果
pred = tf.argmax(pred, axis=1)  # 每个预测的结果的概率最大值的下标,也就是预测的数字
y = tf.argmax(y, axis=1)  # 每个标签的最大值对应的下标,也就是标签对应的数字
print(pred)  # 打印预测结果
print(y)  # 打印标签数字

2.11 模型训练结果的可视化

对模型的训练结果进行可视化

"9.模型训练时的可视化"
# 显示训练集和验证集的acc和loss曲线
acc = history.history['accuracy']  # 获取模型训练中的accuracy
val_acc = history.history['val_accuracy']  # 获取模型训练中的val_accuracy
loss = history.history['loss']  # 获取模型训练中的loss
val_loss = history.history['val_loss']  # 获取模型训练中的val_loss
# 绘值acc曲线
plt.figure(1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
# 绘制loss曲线
plt.figure(2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()  # 将结果显示出来

3. fashion_mnist的CNN模型可视化结果图

Epoch 1/20
469/469 [==============================] - 13s 21ms/step - loss: 0.6610 - accuracy: 0.7600 - val_loss: 0.5067 - val_accuracy: 0.8097
Epoch 2/20
469/469 [==============================] - 12s 24ms/step - loss: 0.4408 - accuracy: 0.8375 - val_loss: 0.4202 - val_accuracy: 0.8491
Epoch 3/20
469/469 [==============================] - 12s 24ms/step - loss: 0.3844 - accuracy: 0.8595 - val_loss: 0.3868 - val_accuracy: 0.8605
Epoch 4/20
469/469 [==============================] - 13s 27ms/step - loss: 0.3507 - accuracy: 0.8707 - val_loss: 0.3924 - val_accuracy: 0.8533
Epoch 5/20
469/469 [==============================] - 13s 27ms/step - loss: 0.3272 - accuracy: 0.8805 - val_loss: 0.3621 - val_accuracy: 0.8674
Epoch 6/20
469/469 [==============================] - 14s 27ms/step - loss: 0.3106 - accuracy: 0.8859 - val_loss: 0.3436 - val_accuracy: 0.8728
Epoch 7/20
469/469 [==============================] - 14s 28ms/step - loss: 0.2923 - accuracy: 0.8923 - val_loss: 0.3429 - val_accuracy: 0.8732
Epoch 8/20
469/469 [==============================] - 14s 29ms/step - loss: 0.2821 - accuracy: 0.8962 - val_loss: 0.3268 - val_accuracy: 0.8802
Epoch 9/20
469/469 [==============================] - 14s 28ms/step - loss: 0.2714 - accuracy: 0.8994 - val_loss: 0.3208 - val_accuracy: 0.8832
Epoch 10/20
469/469 [==============================] - 14s 28ms/step - loss: 0.2621 - accuracy: 0.9031 - val_loss: 0.3187 - val_accuracy: 0.8822
Epoch 11/20
469/469 [==============================] - 12s 24ms/step - loss: 0.2517 - accuracy: 0.9059 - val_loss: 0.3154 - val_accuracy: 0.8870
Epoch 12/20
469/469 [==============================] - 12s 25ms/step - loss: 0.2418 - accuracy: 0.9098 - val_loss: 0.3058 - val_accuracy: 0.8928
Epoch 13/20
469/469 [==============================] - 12s 25ms/step - loss: 0.2344 - accuracy: 0.9129 - val_loss: 0.3182 - val_accuracy: 0.8885
Epoch 14/20
469/469 [==============================] - 12s 24ms/step - loss: 0.2277 - accuracy: 0.9150 - val_loss: 0.3212 - val_accuracy: 0.8824
Epoch 15/20
469/469 [==============================] - 12s 24ms/step - loss: 0.2190 - accuracy: 0.9177 - val_loss: 0.2903 - val_accuracy: 0.8981
Epoch 16/20
469/469 [==============================] - 12s 24ms/step - loss: 0.2141 - accuracy: 0.9198 - val_loss: 0.3071 - val_accuracy: 0.8895
Epoch 17/20
469/469 [==============================] - 12s 25ms/step - loss: 0.2091 - accuracy: 0.9211 - val_loss: 0.3042 - val_accuracy: 0.8897
Epoch 18/20
469/469 [==============================] - 13s 26ms/step - loss: 0.2018 - accuracy: 0.9239 - val_loss: 0.2985 - val_accuracy: 0.8973
Epoch 19/20
469/469 [==============================] - 13s 26ms/step - loss: 0.1942 - accuracy: 0.9275 - val_loss: 0.2867 - val_accuracy: 0.9026
Epoch 20/20
469/469 [==============================] - 12s 25ms/step - loss: 0.1887 - accuracy: 0.9286 - val_loss: 0.3019 - val_accuracy: 0.9005
79/79 [==============================] - 1s 9ms/step - loss: 0.3019 - accuracy: 0.9005

在这里插入图片描述
在这里插入图片描述
从以上结果可知,模型的准确率达到了90%

4. 完整代码

# python练习
# 重新学习时间:2023/5/2 14:46
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# 引入绘制acc和loss曲线的库
import matplotlib.pyplot as plt
# 引入ANN的必要的类
from keras.layers import Dense, Conv2D, MaxPool2D, Flatten
from keras.models import Sequential
from keras import optimizers, losses

"1.加载数据"
"""
x_train是fashion_mnist训练集图片,大小的28*28的,y_train是对应的标签是数字
x_test是fashion_mnist测试集图片,大小的28*28的,y_test是对应的标签是数字
"""
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()  # 加载fashion_mnist数据集
print('mnist_data:', x_train.shape, y_train.shape, x_test.shape, y_test.shape)  # 打印训练数据和测试数据的形状

"2.数据预处理"


def preprocess(x, y):  # 数据预处理函数
    x = tf.cast(x, dtype=tf.float32) / 255.  # 将输入的图片进行归一化,从0-255变换到0-1
    x = tf.reshape(x, [28, 28, 1])
    """
    # 将输入图片的形状(60000,28,28)转换成(60000,28,28,1),
    相当于将图片拉直,便于输入给神经网络
    """
    y = tf.cast(y, dtype=tf.int32)  # 将输入图片的标签转换为int32类型
    y = tf.one_hot(y, depth=10)
    """
    # 将标签y进行独热编码,因为神经网络的输出是10个概率值,而y是1个数,
    计算loss时无法对应计算,因此将y进行独立编码成为10个数的行向量,然后进行loss的计算
    独热编码:例如数值1的10分类的独热编码是[0 1 0 0 0 0 0 0 0 0,即1的位置为1,其余位置为0
    """
    return x, y


batchsz = 128  # 每次输入给神经网络的图片数
"""
数据加载进入内存后,需要转换成 Dataset 对象,才能利用 TensorFlow 提供的各种便捷功能。
通过 Dataset.from_tensor_slices 可以将训练部分的数据图片 x 和标签 y 都转换成Dataset 对象
"""
db = tf.data.Dataset.from_tensor_slices((x_train, y_train))  # 构建训练集对象
db = db.map(preprocess).shuffle(60000).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理
ds_val = tf.data.Dataset.from_tensor_slices((x_test, y_test))  # 构建测试集对象
ds_val = ds_val.map(preprocess).batch(batchsz)  # 将数据进行预处理,随机打散和批量处理

"3.构建网络模型"
model = Sequential([Conv2D(filters=6, kernel_size=(5, 5), activation='relu'),
                    MaxPool2D(pool_size=(2, 2), strides=2),
                    Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),
                    MaxPool2D(pool_size=(2, 2), strides=2),
                    Flatten(),
                    Dense(120, activation='relu'),
                    Dense(84, activation='relu'),
                    Dense(10,activation='softmax')])

model.build(input_shape=(None, 28, 28, 1))  # 模型的输入大小
model.summary()  # 打印网络结构

"4.模型编译"
model.compile(optimizer='Adam',
              loss=losses.CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy']
                )
"""
模型的优化器是Adam
损失函数是losses.CategoricalCrossentropy,
性能指标是正确率accuracy
"""

"5.模型训练"
history = model.fit(db, epochs=20, validation_data=ds_val, validation_freq=1)
"""
模型训练的次数是20,每1次循环进行测试
"""
"6.模型保存"
model.save('cnn_fashion.h5')  # 以.h5文件格式保存模型

"7.模型评价"
model.evaluate(ds_val)  # 得到测试集的正确率

"8.模型测试"
sample = next(iter(ds_val))  # 取一个batchsz的测试集数据
x = sample[0]  # 测试集数据
y = sample[1]  # 测试集的标签
pred = model.predict(x)  # 将一个batchsz的测试集数据输入神经网络的结果
pred = tf.argmax(pred, axis=1)  # 每个预测的结果的概率最大值的下标,也就是预测的数字
y = tf.argmax(y, axis=1)  # 每个标签的最大值对应的下标,也就是标签对应的数字
print(pred)  # 打印预测结果
print(y)  # 打印标签数字

"9.模型训练时的可视化"
# 显示训练集和验证集的acc和loss曲线
acc = history.history['accuracy']  # 获取模型训练中的accuracy
val_acc = history.history['val_accuracy']  # 获取模型训练中的val_accuracy
loss = history.history['loss']  # 获取模型训练中的loss
val_loss = history.history['val_loss']  # 获取模型训练中的val_loss
# 绘值acc曲线
plt.figure(1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
# 绘制loss曲线
plt.figure(2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()  # 将结果显示出来

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/787570.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Install the Chinese input method on Linux

Open terminal and input: sudo -i apt install fcitx fcitx-googlepinyinWait for it to finish. Search fcitx: "设置"-->"输入法": Finally, we get the following result: Ctrl Space:Switch the input method. The test …

Redis追本溯源(三)内核:线程模型、网络IO模型、过期策略与淘汰机制、持久化

文章目录 一、Redis线程模型演化1.Redis4.0之前2.Redis4.0之后单线程、多线程对比3.redis 6.0之后 二、Redis的网络IO模型1.基于事件驱动的Reactor模型2.什么是事件驱动,事件驱动的Reactor模型和Java中的AIO有什么区别3.异步非阻塞底层实现原理 三、Redis过期策略1.…

印刷和数字设计的页面布局软件 QuarkXPress 2023 Crack

QuarkXPress 2023 用于印刷 和数字设计的页面布局软件,使用 QuarkXPress 释放您的创造力并最大限度地提高生产力 图形设计和桌面出版流程早就应该进行创新和颠覆,所以 QuarkXPress 就来了。自 1987 年首次亮相市场以来,成千上万的创意专业人士…

RocketMQ教程-(5)-功能特性-事务消息

事务消息为 Apache RocketMQ 中的高级特性消息,本文为您介绍事务消息的应用场景、功能原理、使用限制、使用方法和使用建议。 事务消息为 Apache RocketMQ 中的高级特性消息,本文为您介绍事务消息的应用场景、功能原理、使用限制、使用方法和使用建议。…

FFmpeg AVFilter的原理(三)- filter是如何被驱动的

1、下面是一个avfilter的graph 上图是ffmpeg中doc/examples中filtering_video.c案例的示意图。 本章节主要查看avfilter中的数据是怎么进入的,然后又是怎么出来的。 主要考察两个函数: av_buffersrc_add_frame_flags()av_buffers…

gcc编译的时候出现错误,可以用core查看错误信息

比如说我们有文件main.c,threadpool.c,threadpool.h main.c和threadpool.c都用了threadpool.h,也就是#include "threadpool.h" (1)如果我们直接使用gcc main.c -o a.out -lpthread会报如下的错 我们需要进行动态库链接 gcc -c threadpool.c -…

驱动开发 day3 (模块化驱动启动led,蜂鸣器,风扇,震动马达)

模块化驱动启动led,蜂鸣器,风扇,震动马达并加上Makefile 封装模块化驱动,可自由安装卸载驱动,便于驱动更新(附图) 1.安装模块驱动同时初始化各个设备并使能 2.该驱动会自动创建驱动节点. 3.通过c函数程序输入控制各个设备 4.卸载模块驱动 //编译驱动…

裂缝二维检测:裂缝类型判断

裂缝类型选择 裂缝类型有很多种,这里我们判断类型的目的是要搞明白是否有必要检测裂缝的长度。在本文中,需要判断的裂缝类型共有四种:横向裂缝、纵向裂缝、斜裂缝、网状裂缝。 环境搭建 上一节骨架图提取部分,我们已经安装了sk…

Linux centos安装openoffice在线预览

前言:由于项目里需要用到word、excel等文件的在线预览,所有选择了openoffice 1、下载openoffice Apache OpenOffice - Official Download 大家自行选择需要安装的版本,楼主由于之前在其他服务器安装过,选择了之前用过的版本&am…

4.Docker数据管理和容器互联

文章目录 Docker数据管理数据卷(容器与宿主机之间数据共享)数据卷容器(容器与容器之间数据共享)容器互联 Docker数据管理 数据卷(容器与宿主机之间数据共享) 数据卷是一个供容器使用的特殊目录&#xff0…

【UE5 多人联机教程】03-创建游戏

效果 步骤 打开“UMG_MainMenu”,增加创建房间按钮的点击事件 添加如下节点 其中,“FUNL Fast Create Widget”是插件自带的函数节点,内容如下: “创建会话”节点指游戏成功创建一个会话后,游戏的其他实例即可发现&am…

mysql -速成

目录 1.概述 1.3SQL的优点 1.4 SQL 语言的分类 2. 软件的安装与启动 2.1 安装 2.2 MySQL服务的启动和停止 2.3 MySQL服务的登录和退出 ​编辑 2.4 mysql常用命令 2.5 图形化用户结构Sqlyong 3.DQL 语言 3.1 基础查询 3.1.1、语法 3.1.2 特点 3.2 条件查询 3.2.1 …

数据库的聚合函数和窗口函数

1. 聚合函数 数据库的聚合函数是用于对数据集执行聚合计算的函数。它们将一组值作为输入,并生成单个聚合值作为输出。聚合函数通常与GROUP BY子句结合使用,以便在数据分组的基础上执行聚合操作。 1.1. 常用的聚合函数 COUNT():计算指定列或…

(五)springboot实战——springboot自定义事件的发布和订阅

前言 本节内容我们主要介绍一下springboot自定义事件的发布与订阅功能,一些特定应用场景下使用自定义事件发布功能,能大大降低我们代码的耦合性,使得我们应用程序的扩展更加方便。就本身而言,springboot的事件机制是通过观察者设…

Python(三十九)for-in循环

❤️ 专栏简介:本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中,我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 :本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

JAVA设计模式——模板设计模式(heima)

JAVA设计模式——模板设计模式(heima) 文章目录 JAVA设计模式——模板设计模式(heima)一、模板类二、子类2.1 Tom类2.2 Tony类 三、测试类 一、模板类 package _01模板设计模式;public abstract class TextTemplate{public final…

利用FME实现批量提取图斑特征点、关键界址点提取、图斑拐点抽稀,解决出界址点成果表时点数过多问题的方法

目录 一、实现效果 二、实现过程 1.提取图斑界址点 2.计算各界址点的角度 3.筛选提取关键界址点 三、总结 对于范围较大的图斑,界址点数目较大,在出界址点成果表前,往往需要对界址点进行处理,提取出关键特征点作为出界址点成…

数据库集群方案详解

本期直播我们邀请 KaiwuDB 资深解决方案专家周幸骏,为大家分享数据库集群方案详解。周老师毕业于复旦大学数学系,从业 20 余年,曾在 IBM 公司任资深技术专家,并为多家国有大型商业银行提供技术咨询和数据库业务连续方案设计等服务…

IBM:2023 年数据泄露的平均成本将达到 445 万美元

IBM 发布年度《数据泄露成本报告》,显示 2023 年全球数据泄露平均成本达到 445 万美元,比过去 3 年增加了 15%。创下该报告的历史新高。 报告显示,企业在计划如何应对日益增长的数据泄露频率和成本方面存在分歧。研究发现,虽然 95…

Linux学成之路(基础篇)(二十三)MySQL服务(上)

目录 一、概述 一、什么是MySQL 二、数据库能干什么 三、为什么要用数据库,优势、特性? 二、数据库类型 一、关系型数据库 RDBMS 一、概述 二、特点 三、代表产品 二、非关系型数据库 一、概述 二、特点 三、代表产品 三、数据库模型 一、…