第33周:运动鞋识别(Tensorflow实战第五周)

news2024/12/26 21:01:42

目录

前言

一、前期工作

1.1 设置GPU

1.2 导入数据

1.3 查看数据

二、数据预处理

2.1 加载数据

2.2 可视化数据

2.3 再次检查数据

2.4 配置数据集

2.4.1 基本概念介绍

2.4.2 代码完成

三、构建CNN网络

四、训练模型

4.1 设置动态学习率

4.2 早停与保存最佳模型参数

4.3 模型训练

五、模型评估

5.1 Loss和Accuracy图

5.2 指定图片进行预测

六、改进

总结


前言

  • 🍨 本文为[🔗365天深度学习训练营]中的学习记录博客
  • 🍖 原作者:[K同学啊]

说在前面

1)本周任务:基于CNN模型完成对两种品牌运动鞋图片的识别

2)运行环境:Python3.6、Pycharm2020、tensorflow2.4.0


一、前期工作

1.1 设置GPU

代码如下:

# 一、前期准备
# 1.1 设置GPU
from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")
print(gpus)

1.2 导入数据

代码如下:

# 1.2 导入数据
data_dir = "./data/"
data_dir = pathlib.Path(data_dir)

1.3 查看数据

代码如下:

# 1.3 查看数据
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:", image_count)
shoses = list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(shoses[0]))

输出:

图片总数为: 578

二、数据预处理

2.1 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset,tf.keras.preprocessing.image_dataset_from_directory():是 TensorFlow 的 Keras 模块中的一个函数,用于从目录中创建一个图像数据集(dataset)。这个函数可以以更方便的方式加载图像数据,用于训练和评估神经网络模型

测试集与验证集的关系:

  • 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  • 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集
  • 因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集

代码如下:

# 二、数据预处理
# 2.1 加载数据集
batch_size = 32
img_height = 224
img_width = 224
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "./data/train/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "./data/test/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

输出如下:

Found 502 files belonging to 2 classes.

Found 76 files belonging to 2 classes.

['adidas', 'nike']

2.2 可视化数据

代码如下:

# 2.2 可视化数据
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

输出:

2.3 再次检查数据

代码如下:

# 2.3再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

输出:

(32, 224, 224, 3)
(32,)

情况说明:

  • Image_batch是形状的张量(32,2224,224,3)。这是一批形状224x224x3的32张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(32,)的张量,这些标签对应32张图片

2.4 配置数据集

2.4.1 基本概念介绍

prefetch():CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态

​​

然后使用prefetch()可显著减少空闲时间:

​​

2.4.2 代码完成

# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

代码如下:

# cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、构建CNN网络

         卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (180, 180, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

网络结构图如下:

​​

代码如下:

# 三、构建CNN网络
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Dropout(0.3),
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),
    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(len(class_names))  # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

模型结构打印如下:

四、训练模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率

4.1 设置动态学习率

ExponentialDecay函数tf.keras.optimizers.schedules.ExponentialDecay是 TensorFlow 中的一个学习率衰减策略,用于在训练神经网络时动态地降低学习率。学习率衰减是一种常用的技巧,可以帮助优化算法更有效地收敛到全局最小值,从而提高模型的性能

🔍主要参数:

  • decay_steps(衰减步数):学习率衰减的步数。在经过 decay_steps 步后,学习率将按照指数函数衰减。例如,如果 decay_steps 设置为 10,则每10步衰减一次。
  • decay_rate(衰减率):学习率的衰减率。它决定了学习率如何衰减。通常,取值在 0 到 1 之间。
  • staircase(阶梯式衰减):一个布尔值,控制学习率的衰减方式。如果设置为 True,则学习率在每个 decay_steps 步之后直接减小,形成阶梯状下降。如果设置为 False,则学习率将连续衰减

代码如下:

initial_learning_rate = 0.0001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

⚠️⚠️⚠️注:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

Tips学习率大与学习率小的优缺点分析

  • 学习率大:优点--加快学习速率,有助于跳出局部最优值;缺点--导致模型训练不收敛,单单使用大学习率容易导致模型不精确;
  • 学习率小:优点--有助于模型收敛、模型细化,提高模型精度;缺点--很难跳出局部最优值,收敛缓慢

4.2 早停与保存最佳模型参数

🔍EarlyStopping()参数说明:

  • monitor:被监测的数据
  • min_delta:在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
  • patience:没有进步的训练轮数,在这之后训练就会被停止。
  • verbose:详细信息模式。
  • mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
  • baseline:要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
  • estoe_best_weights:是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重

代码如下:

# 4.2 早停与保存最佳模型参数
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
epochs = 50
# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)
# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

4.3 模型训练

代码如下:

# 4.3 模型训练
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

打印训练过程:

五、模型评估

5.1 Loss和Accuracy图

代码如下:

# 五、模型评估
# 5、1 Loss与Accuracy图
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(len(loss))
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

训练结果可视化如下:

5.2 指定图片进行预测

代码如下:

# 5.2 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np

# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  #这里选择你需要预测的图片
img = Image.open("./data/test/nike/1.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])
img_array = tf.expand_dims(image, 0)/255.0  # 记得做归一化处理(与训练集处理方式保持一致)
predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

输出:

预测结果为: addidas

六、改进

从训练过程和指定预测的结果可以发现效果并不好,这里考虑修改initial_learning_rate,原始设置的是0.1,这里考虑将其调至0.001或0.0001


总结

本周在上周的基础上增加了动态学习率的设置,优化了训练过程,并实现了正确预测,也让我更加熟练基于Tensorflow框架下CNN模型搭建的流程,并对动态学习率初始值设置对结果影响进行了分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2251612.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Robot Screw Theory (Product of Exponentials)机器人螺旋理论(指数积)

Screw Theory 螺旋理论 Screw theory uses the fact that every rigid body transformation can be expressed by a rotational and translational movement.螺旋理论利用了每个刚体变换都可以通过旋转和平移运动来表示这一事实。 Robert Ball developed the theory in 19th ce…

【maven-5】Maven 项目构建的生命周期:深入理解与应用

1. 生命周期是什么 ​在Maven出现之前,项目构建的生命周期就已经存在,软件开发人员每天都在对项目进行清理,编译,测试及部署。虽然大家都在不停地做构建工作,但公司和公司间,项目和项目间,往往…

skywalking 配置elasticsearch持久化

下载和启动elasticsearch elasticsearch-7.17.25-linux-x86_64.tar.gz,解密文件tar -xvf elasticsearch-7.17.25-linux-x86_64.tar.gz 进入到bin目录,启动 elasticsearch -d 后台运行 下载skywalking服务包 apache-skywalking-apm-9.3.0.tar.gz&#x…

MySQL 复合查询

实际开发中往往数据来自不同的表,所以需要多表查询。本节我们用一个简单的公司管理系统,有三张表EMP,DEPT,SALGRADE 来演示如何进行多表查询。表结构的代码以及插入的数据如下: DROP database IF EXISTS scott; CREATE database IF NOT EXIST…

数据集-目标检测系列- 海边漫步锻炼人检测数据集 person >> DataBall

数据集-目标检测系列- 海边漫步锻炼人检测数据集 person >> DataBall DataBall 助力快速掌握数据集的信息和使用方式,会员享有 百种数据集,持续增加中。 需要更多数据资源和技术解决方案,知识星球: “DataBall - X 数据球…

【汇编语言】call 和 ret 指令(三) —— 深度解析汇编语言中的批量数据传递与寄存器冲突

文章目录 前言1. 批量数据的传递1.1 存在的问题1.2 如何解决这个问题1.3 示例演示1.3.1 问题说明1.3.2 程序实现 2. 寄存器冲突问题的引入2.1 问题引入2.2 分析与解决问题2.2.1 字符串定义方式2.2.2 分析子程序功能2.2.3 得到子程序代码 2.3 子程序的应用2.3.1 示例12.3.2 示例…

刷题日常(找到字符串中所有字母异位词,​ 和为 K 的子数组​,​ 滑动窗口最大值​,全排列)

找到字符串中所有字母异位词 给定两个字符串 s 和 p,找到 s 中所有 p 的 异位词的子串,返回这些子串的起始索引。不考虑答案输出的顺序。 题目分析: 1.将p里面的字符先丢进一个hash1中,只需要在S字符里面找到多少个和他相同的has…

Git 快速入门:全面了解与安装步骤

Git 快速入门:全面了解与安装步骤 一、关于Git 1.1 简介 Git 是一个开源的分布式版本控制系统,由 Linus Torvalds 于 2005 年创建,最初是为了更好地管理 Linux 内核开发而设计。 Git用于跟踪计算机文件的变化,特别是源代码文件…

私有库gitea安装

一 gitea是什么 Gitea是一款自助Git服务,简单来说,就是可以一个私有的github。 搭建很容易。 Gitea依赖于Git。 类似Gitea的还有GitHub、Gitee、GitLab等。 以下是安装步骤。 二 安装sqilite 参考: 在windows上安装sqlite 三 安装git…

[CTF/网络安全] 攻防世界 upload1 解题详析

[CTF/网络安全] 攻防世界 upload1 解题详析 考察文件上传&#xff0c;具体原理及姿势不再赘述。 姿势 在txt中写入一句话木马<?php eval($_POST[qiu]);?> 回显如下&#xff1a; 查看源代码&#xff1a; Array.prototype.contains function (obj) { var i this.…

基于HTML+CSS的房地产销售网站设计与实现

摘 要 房地产销售系统&#xff0c;在二十年来互联网时代下有着巨大的意义&#xff0c;随着互联网不断的发展扩大&#xff0c;一个方便直 观的房地产管理系统的网站开发是多么地有意义&#xff0c;不仅打破了传统的线下看房&#xff0c;线下获取资讯&#xff0c;也给房地产从业…

说说Elasticsearch拼写纠错是如何实现的?

大家好&#xff0c;我是锋哥。今天分享关于【说说Elasticsearch拼写纠错是如何实现的&#xff1f;】面试题。希望对大家有帮助&#xff1b; 说说Elasticsearch拼写纠错是如何实现的&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在 Elasticsearch 中&…

如何把产品3D模型放到网站上进行3D展示或3D互动?

要将产品3D模型放到网站上进行3D展示或3D互动&#xff0c;可以按照以下步骤进行&#xff1a; 一、准备3D模型 使用3D建模软件&#xff08;如3ds Max、Maya、Blender、C4D等&#xff09;制作好产品的3D模型。 确保3D模型的格式是网站或平台所支持的&#xff0c;常见的格式包括…

SpringBoot集成Flowable

一、工作流介绍 1、概念 通过计算机对业务流程的自动化管理。工作流是建立在业务流程的基础上&#xff0c;一个软件的系统核心根本上还是系统的业务流程&#xff0c;工作流只是协助进行业务流程管理。 解决的是&#xff1a;在多个参与者之间按照某种预定义的规则自动进行传递文…

1-1 Gerrit实用指南

注&#xff1a;学习gerrit需要拥有git相关知识&#xff0c;如果没有学习过git请先回顾git相关知识点 黑马程序员git教程 一小时学会git git参考博客 git 实操博客 1.0 定义 Gerrit 是一个基于 Web 的代码审查系统&#xff0c;它使用 Git 作为底层版本控制系统。Gerrit 的主要功…

鸿蒙开发:自定义一个任意位置弹出的Dialog

前言 鸿蒙开发中&#xff0c;一直有个问题困扰着自己&#xff0c;想必也困扰着大多数开发者&#xff0c;那就是&#xff0c;系统提供的dialog自定义弹窗&#xff0c;无法实现在任意位置进行弹出&#xff0c;仅限于CustomDialog和Component struct的成员变量&#xff0c;这就导致…

Matlab模块From Workspace使用数据类型说明

Matlab原文连接&#xff1a;Load Data Using the From Workspace Block 模型&#xff1a; 从信号来源的数据&#xff1a; timeseries 数据&#xff1a; sampleTime 0.01; numSteps 1001;time sampleTime*[0:(numSteps-1)]; time time;data sin(2*pi/3*time);simin time…

架构01-演进中的架构

零、文章目录 架构01-演进中的架构 1、原始分布式时代&#xff1a;Unix设计哲学下的服务探索 &#xff08;1&#xff09;背景 时间&#xff1a;20世纪70年代末到80年代初计算机硬件&#xff1a;16位寻址能力、不足5MHz时钟频率的处理器、128KB左右的内存转型&#xff1a;从…

社群赋能电商:小程序 AI 智能名片与 S2B2C 商城系统的整合与突破

摘要&#xff1a;本文聚焦于社群在电商领域日益凸显的关键地位&#xff0c;深入探讨在社群粉丝经济迅猛发展背景下&#xff0c;小程序 AI 智能名片与 S2B2C 商城系统如何与社群深度融合&#xff0c;助力电商突破传统运营局限&#xff0c;挖掘新增长点。通过分析社群对电商的价值…

【electron-vite】搭建electron+vue3框架基础

一、拉取项目 electron-vite 中文文档地址&#xff1a; https://cn-evite.netlify.app/guide/ 官网网址&#xff1a;https://evite.netlify.app/ 版本 vue版本&#xff1a;vue3 构建工具&#xff1a;vite 框架类型&#xff1a;Electron JS语法&#xff1a;TypeScript &…