基于 Tensorflow 2.x 实现多层卷积神经网络,实践 Fashion MNIST 服装图像识别

news2024/11/24 17:53:47

一、 Fashion MNIST 服装数据集

Fashion MNIST 数据集,该数据集包含 10 个类别的 70000 个灰度图像。大小统一是 28x28的长宽,其中 60000 张作为训练数据,10000张作为测试数据,该数据集已被封装在了 tf.keras.datasets 工具包下,数据如图所示:

在这里插入图片描述
Fashion MNIST 数据集更多样化,比常规 MNIST 更具挑战性。标签是整数数组,介于 09 之间。这些标签对应于图像所代表的服装类:

标签分类
0T恤/上衣
1裤子
2套头衫
3连衣裙
4外套
5凉鞋
6衬衫
7运动鞋
8
9短靴

可以通过下面程序对该数据进行可视化预览:

import tensorflow as tf
import matplotlib.pyplot as plt

keras = tf.keras
fashion_mnist = tf.keras.datasets.fashion_mnist
plt.rcParams['font.sans-serif'] = ['SimHei']

classify = {
    0: 'T恤/上衣',
    1: '裤子',
    2: '套头衫',
    3: '连衣裙',
    4: '外套',
    5: '凉鞋',
    6: '衬衫',
    7: '运动鞋',
    8: '包',
    9: '短靴'
}

# 加载 fashion_mnist 数据集
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

print(x_train.shape)
print(y_train.shape)

for i in range(10):
    image = x_test[i]/255.0
    label = y_test[i]
    plt.imshow(image, cmap=plt.cm.gray)
    plt.title(label=('标签值: ' + str(classify[label])))
    plt.show()

在这里插入图片描述

二、搭建多层卷积神经网络模型

本文基于 Tensorflow 2.x 构建多层卷积神经网络,在 Tensorflow 2.x 中官方更推荐的上层API工具为 Keras ,本文也是使用 Keras 进行实验测试。

设计模型结构如下所示:

在这里插入图片描述

通过 Keras 建立模型结构,具体的解释都写在了注释中:

import tensorflow as tf

keras = tf.keras
fashion_mnist = tf.keras.datasets.fashion_mnist

# 定义模型类
class mnistModel():
    # 初始化结构
    def __init__(self, checkpoint_path, log_path, model_path):
        # checkpoint 权重保存地址
        self.checkpoint_path = checkpoint_path
        # 训练日志保存地址
        self.log_path = log_path
        # 训练模型保存地址:
        self.model_path = model_path
        # 初始化模型结构
        self.model = tf.keras.models.Sequential([
            # 输入层,第一层卷积 ,卷积核 3x3 ,输出 (None, 28, 28, 32),卷积模式 same
            keras.layers.Conv2D(32, (3, 3),
                                kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                activation=tf.nn.relu,
                                kernel_regularizer=keras.regularizers.l2(0.001),
                                padding='same',
                                input_shape=(28, 28, 1)),
            # 输入层,第二层卷积 ,卷积核 3x3 ,输出 (None, 28, 28, 64),卷积模式 same
            keras.layers.Conv2D(32, (3, 3),
                                kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                activation=tf.nn.relu,
                                kernel_regularizer=keras.regularizers.l2(0.001),
                                padding='same',
                                input_shape=(28, 28, 1)),
            # 池化层,2x2 MaxPool,输出 (None, 14, 14, 64)
            keras.layers.MaxPooling2D(2, 2),

            # 第三层卷积 ,卷积核 3x3 ,输出 (None, 14, 14, 128) ,卷积模式 same
            keras.layers.Conv2D(64, (3, 3),
                                kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                activation=tf.nn.relu,
                                kernel_regularizer=keras.regularizers.l2(0.001),
                                padding='same'),
            # 第四层卷积 ,卷积核 3x3 ,输出 (None, 14, 14, 256) ,卷积模式 same
            keras.layers.Conv2D(64, (3, 3),
                                kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                                activation=tf.nn.relu,
                                kernel_regularizer=keras.regularizers.l2(0.001),
                                padding='same'),
            # 池化层,2x2 MaxPool,输出 (None, 7, 7, 256)
            keras.layers.MaxPooling2D(2, 2),
            # Dropout 随机失活,防止过拟合,输出 (None, 7, 7, 256)
            keras.layers.Dropout(0.2),
            # 转为全链接层,输出 (None, 12544)
            keras.layers.Flatten(),
            # 第一层全链接层,输出 (None, 512)
            keras.layers.Dense(512,
                               kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                               kernel_regularizer=keras.regularizers.l2(0.001),
                               activation=tf.nn.relu),
            # 第二层全链接层,输出 (None, 256)
            keras.layers.Dense(256,
                               kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                               kernel_regularizer=keras.regularizers.l2(0.001),
                               activation=tf.nn.relu),
            # softmax 层,输出 (None, 10)
            keras.layers.Dense(10, activation=tf.nn.softmax)
        ])

    # 编译模型
    def compile(self):
        # 输出模型摘要
        self.model.summary()
        # 定义训练模式
        self.model.compile(optimizer='adam',
                           loss='sparse_categorical_crossentropy',
                           metrics=['accuracy'])

    # 训练模型
    def train(self, x_train, y_train):
        # tensorboard 训练日志收集
        tensorboard = keras.callbacks.TensorBoard(log_dir=self.log_path)

        # 训练过程保存 Checkpoint 权重,防止意外停止后可以继续训练
        model_checkpoint = keras.callbacks.ModelCheckpoint(self.checkpoint_path,  # 保存模型的路径
                                                           monitor='val_loss',  # 被监测的数据。
                                                           verbose=0,  # 详细信息模式,0 或者 1
                                                           save_best_only=True,  # 如果 True, 被监测数据的最佳模型就不会被覆盖
                                                           save_weights_only=True,
                                                           # 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)),否则的话,整个模型会被保存,(model.save(filepath))
                                                           mode='auto',
                                                           # {auto, min, max}的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto模式中,方向会自动从被监测的数据的名字中判断出来。
                                                           period=3  # 每3个epoch保存一次权重
                                                           )
        # 填充数据,迭代训练
        self.model.fit(
            x_train,  # 训练集
            y_train,  # 训练集的标签
            validation_split=0.2,  # 验证集的比例
            epochs=30,  # 迭代周期
            batch_size=30,  # 一批次输入的大小
            verbose=2,  # 训练过程的日志信息显示,一个epoch输出一行记录
            callbacks=[tensorboard, model_checkpoint]
        )
        # 保存训练模型
        self.model.save(self.model_path)

    def evaluate(self, x_test, y_test):
        # 评估模型
        test_loss, test_acc = self.model.evaluate(x_test, y_test)
        return test_loss, test_acc

上面优化器使用的 adamlosssparse_categorical_crossentropy,一共训练 30 个周期,每个 batch 30 张图片,验证集的比例为 20%,并且每三个周期保存一次权重,防止意外停止后继续训练,最后保存了 h5 的训练模型,方便后面进行测试预测效果。

下面开始训练模型:

import tensorflow as tf

keras = tf.keras
fashion_mnist = tf.keras.datasets.fashion_mnist

def main():
    # 加载数据集
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    # 修改shape 数据归一化
    x_train, x_test = x_train.reshape(60000, 28, 28, 1) / 255.0, \
                      x_test.reshape(10000, 28, 28, 1) / 255.0

    checkpoint_path = './checkout/'
    log_path = './log'
    model_path = './model/model.h5'

    # 构建模型
    model = mnistModel(checkpoint_path, log_path, model_path)
    # 编译模型
    model.compile()
    # 训练模型
    model.train(x_train, y_train)
    # 评估模型
    test_loss, test_acc = model.evaluate(x_test, y_test)
    print(test_loss, test_acc)

if __name__ == '__main__':
    main()

运行后可以看到打印的网络结构:

在这里插入图片描述

从训练日志中,可以看到 loss 一直在减小:

在这里插入图片描述

等待训练结束后看下评估模型的结果:

在这里插入图片描述

下面看下 tensorboard 中可视化的损失及准确率:

tensorboard --logdir=log/train

在这里插入图片描述
使用浏览器访问:http://localhost:6006/ 查看结果:

在这里插入图片描述

三、模型预测

上面搭建的模型,训练后会在 model 下生成 model.h5 模型,下面直接加载该模型进行预测:

import tensorflow as tf
import matplotlib.pyplot as plt

keras = tf.keras
fashion_mnist  = tf.keras.datasets.fashion_mnist
plt.rcParams['font.sans-serif'] = ['SimHei']

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

classify = {
    0: 'T恤/上衣',
    1: '裤子',
    2: '套头衫',
    3: '连衣裙',
    4: '外套',
    5: '凉鞋',
    6: '衬衫',
    7: '运动鞋',
    8: '包',
    9: '短靴'
}

model = keras.models.load_model('./model/model.h5')

for i in range(10):
    image = x_test[i]
    label = y_test[i]
    softmax = model.predict(image.reshape([1, 28, 28, 1]))
    y_label = tf.argmax(softmax, axis=1).numpy()[0]
    plt.imshow(image, cmap=plt.cm.gray)
    plt.title(label = ('预测结果: '+ classify[y_label] + ',  真实结果: '+ classify[label]))
    plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/99446.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

move functions with VS without noexcept

本文所讲对移动函数使用noexcept修饰时带来的效率提升只针对std::vector。而对std::deque来说没有功效。 1. 针对std::vector 1.1 move functions with noexcept 当移动构造函数有noexcept修饰时,在对std::vector进行push_back扩充致使vector的size等于capacity时…

26. GPU以及 没有gpu的情况下使用colab

在PyTorch中,CPU和GPU可以用torch.device(‘cpu’) 和torch.device(‘cuda’)表示。 应该注意的是,cpu设备意味着所有物理CPU和内存, 这意味着PyTorch的计算将尝试使用所有CPU核心。 然而,gpu设备只代表一个卡和相应的显存。 如果…

【大数据技术Hadoop+Spark】Spark SQL、DataFrame、Dataset的讲解及操作演示(图文解释)

一、Spark SQL简介 park SQL是spark的一个模块,主要用于进行结构化数据的SQL查询引擎,开发人员能够通过使用SQL语句,实现对结构化数据的处理,开发人员可以不了解Scala语言和Spark常用API,通过spark SQL,可…

数据挖掘Java——Kmeans算法的实现

一、K-means算法的前置知识 k-means算法,也被称为k-平均或k-均值,是一种得到最广泛使用的聚类算法。相似度的计算根据一个簇中对象的平均值来进行。算法首先随机地选择k个对象,每个对象初始地代表了一个簇的平均值或中心。对剩余的每个对象根…

给 VitePress 添加 algolia 搜索

大家好,我是 Chocolate。 最近在折腾 VitePress,搭建了一个文档项目:ChoDocs,不过文档还不支持搜索功能,虽然目前内容不多,但待我同步完之后,搜索就很有必要了。 之前看 VitePress 官网发现没有…

pikachu靶场暴力破解绕过token防护详解

今天继续给大家介绍渗透测试相关知识,本文主要内容是pikachu靶场暴力破解绕过token防护详解。 免责声明: 本文所介绍的内容仅做学习交流使用,严禁利用文中技术进行非法行为,否则造成一切严重后果自负! 再次强调&#x…

基于改进的多目标粒子群算法的微电网多目标调度(三个目标函数)(matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

中央重磅文件明确互联网医疗服务可用医保支付!

文章目录中央重磅文件明确互联网医疗服务可用医保支付!中央重磅文件明确互联网医疗服务可用医保支付! 当下,互联网医疗机构已加入到新冠防治的“主战场”,在分流线下诊疗发挥了很大作用。国家层面也在进一步鼓励互联网医疗行业发…

基于多尺度形态学梯度进行边缘检测(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

C++中的继承

把握住自己能把握住的点滴,把它做到极致,加油! 本节目标1.继承的概念及定义1.1继承的概念1.2 继承定义1.2.1 定义格式1.2.2 继承方式和访问限定符1.2.3 继承基类成员访问方式的变化2.继承中的作用域练习3.基类和派生类对象赋值转换4.派生类的…

Java+SSM网上订餐系统点餐餐厅系统(含源码+论文+答辩PPT等)

项目功能简介: 该项目采用的技术实现如下 后台框架:Spring、SpringMVC、MyBatis UI界面:BootStrap、H-ui 、JSP 数据库:MySQL 系统功能 系统分为前台订餐和后台管理: 1.前台订餐 用户注册、用户登录、我的购物车、我的订单 商品列…

Linux 常用的命令

前言 Linux 的学习对于一个程序员的重要性是不言而喻的。前端开发相比后端开发,接触 Linux 机会相对较少,因此往往容易忽视它。但是学好它却是程序员必备修养之一。 作者使用的是阿里云服务器 ECS (最便宜的那种) CentOS 7.7 64…

快速了解JSON及JSON的使用

文章目录JSON简介JSON语法JSON 名称/值对JSON对象数组JSON的简单使用JSON简介 JSON(JavaScriptObjectNotation,JS对象简谱)是一种轻量级的数据交换格式 JS对象简谱,那么JSON如何转换为JS对象: JSON文本格式在语法上与…

多弹协同攻击时的无源定位

题目 采用被动接收方式的无源探测定位技术具有作用距离远、隐蔽接 收、不易被敌方发觉等优点,能有效提高探测系统在电子战环境下的 生存能力和作战能力。 在无源定位的研究中,测向定位技术(Direction of Arrival,DOA) …

SpringBoot操作Mongo

文章目录引入依赖yaml实体类集合操作创建删除相关注解文档操作添加实验 数据查询添加更新删除引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId> </dependency><de…

Jmeter配置不同业务请求比例,应对综合场景压测

背景 在进行综合场景压测时&#xff0c;遇到了如何实现不同的请求所占比例不同的问题。 有人说将这些请求分别放到单独的线程组下&#xff0c;然后将线程组的线程数按照比例进行配置。 这种方法不是很好&#xff0c;因为服务器对不同的请求处理能力不同&#xff0c;有的处理快…

C规范编辑笔记(八)

往期文章&#xff1a; C规范编辑笔记(一) C规范编辑笔记(二) C规范编辑笔记(三) C规范编辑笔记(四) C规范编辑笔记(五) C规范编辑笔记(六) C规范编辑笔记(七) 正文&#xff1a; 今天来给大家分享我们的第八篇C规范编辑笔记&#xff0c;话不多说&#xff0c;我们直接来看&…

计算机毕设Python+Vue新闻类网站(程序+LW+部署)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

基于微信小程序的灯具商城系统-计算机毕业设计

项目介绍 开发语言&#xff1a;Java 框架&#xff1a;ssm JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&a…

Java中IO体系

File File类 File类 : 表示计算机中所有的文件和文件夹; [计算机硬盘上除了文件就是文件夹]如何创建File对象 :File(String pathname) : 传入文件路径[String],创建File对象并指向这个路径的文件/文件夹File(String parent, String child) :传入文件路径[String],创建File对象…