365天深度学习训练营-第T5周:运动鞋品牌识别

news2024/9/20 14:51:17
  •  🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

 我的环境:

  • 语言环境:Python3.10.7
  • 编译器:VScode
  • 深度学习环境:TensorFlow2

 一、前期工作: 

1、导入数据集

from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
 
data_dir = "D:\T5ShoseBand"
data_dir = pathlib.Path(data_dir)

#查看图片数量
image_count = len(list(data_dir.glob("*/*/*.jpg")))
print("图片数量为: ",image_count)

.`image_count = len(list(data_dir.glob("*/*/*.jpg")))`
   - `data_dir` 是一个变量,表示图像文件所在的目录路径。
   - `glob` 是一个函数,返回与指定模式匹配的所有文件路径的迭代器。glob() 方法可以返回匹配指定模式(通配符)的文件列表,该方法的参数 “/.jpg” 表示匹配所有子文件夹下以 .jpg 结尾的文件。
   - `data_dir.glob("*/*/*.jpg")` 会返回所有符合模式 `*/*/*.jpg` 的文件路径。这里的模式表示在 `data_dir` 目录下的任意子目录中的任意文件名为 `.jpg` 后缀的文件。
   - `list(data_dir.glob("*/*/*.jpg"))` 将迭代器转换为列表,以便可以获取其长度。
   - `len(list(data_dir.glob("*/*/*.jpg")))` 返回文件路径列表的长度,即图像文件的数量。

roses = list(data_dir.glob("train/nike/*.jpg"))
image_path = str(roses[0])
 
# Open the image using PIL
image = PIL.Image.open(image_path)
 
# Display the image using matplotlib
plt.imshow(image)
plt.axis("off") 
plt.show()

 2. 数据预处理

#设施图片格式
batch_size = 32
img_height = 224
img_width = 224

#划分训练集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "D:\T5ShoseBand",
    seed = 123,
    image_size = (img_height, img_width),
    batch_size = batch_size)

#划分验证集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "D:\T5ShoseBand",
    seed = 123,
    image_size = (img_height, img_width),
    batch_size = batch_size
)

#查看标签
class_names = train_ds.class_names
print(class_names)

#数据可视化
plt.figure(figsize = (20, 10))
 
for images, labels in train_ds.take(1):
  for i in range(20):
    plt.subplot(5, 10, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")
plt.show()

plt.figure(figsize = (20, 10)) 创建一个图形对象,并指定其大小为20x10英寸

for images, labels in train_ds.take(1): 遍历train_ds数据集中的第一个批次,每个批次包含一批图和对应的标签。这里使用take(1)函数从数据集中获取一个批次。

plt.subplot(5, 10, i + 1) 在图形对象中创建一个子图,这里的子图是一个5x10的网格,并将当前子图设置为第i+1个位置。


 plt.imshow(images[i].numpy().astype("uint8")) 使用Matplotlib的imshow函数显示当前图像。images[i]是当前图像的张量表示,使用.numpy()将其转换为NumPy数组,并使用.astype("uint8")将数据类型转换为uint8以便显示。


 plt.title(class_names[labels[i]]) 为当前图像设置标题,标题内容是通过索引labels[i]从class_names列表中获取的类别名称。

plt.axis(“off”) 是 Matplotlib 库中的一个函数调用,它用于控制图像显示时的坐标轴是否可见。
具体来说,当参数为 “off” 时,图像的坐标轴会被关闭,不会显示在图像周围。这个函数通常在 plt.imshow() 函数之后调用,以便在显示图像时去掉多余的细节信息,仅仅显示图像本身。

#检验数据
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

#配置数据
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size = AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size = AUTOTUNE)

三、搭建CNN网络

#设置Sequential模型,创建神经网络
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width,3)),
    #设置二维卷积层1,设置32个3*3卷积核,activation参数将激活函数设置为ReLU函数
    #input_shape设置图形的输入形状
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
    #池化层1,2*2采样
    layers.AveragePooling2D(2*2),
    #设置二维卷积层2,设置64个3*3卷积核,激活函数设置为ReLU函数
    layers.Conv2D(64, (3, 3), activation='relu'),
    #池化层2,2*2采样
    layers.AveragePooling2D((2, 2)),
    #设置停止工作概率,防止过拟合
    layers.Dropout(0.3),
 
    #Flatten层,用于连接卷积层与全连接层
    layers.Flatten(),
    #全连接层,特征进一步提取,64为输出空间的维数(神经元),激活函数为ReLU函数
    layers.Dense(128,activation='relu'),
    #输出层,4为输出空间的维数
    layers.Dense(4)
])
#打印网络结构
model.summary()

 四、编译 

#设置动态学习率
initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps = 30,
    decay_rate = 0.92,
    staircase = True
)

#设置优化器
opt = keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(
    #设置优化器为Adam优化器
    optimizer = opt,
    #设置损失函数为交叉熵损失函数
    #from_logits为True时,会将y_pred转化为概率
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    #设置性能指标列表,将在模型训练时对列表中的指标进行监控
    metrics = ['accuracy']
)

`initial_learning_rate = 0.001`
   - 代码定义了初始学习率为 `0.001`。学习率是在训练神经网络模型时控制权重更新步长的超参数。

 `lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate, decay_steps=30, decay_rate=0.92, staircase=True)`
   - `tf.keras.optimizers.schedules.ExponentialDecay` 是 TensorFlow Keras 中用于指数衰减学习率的调度器对象。
   - `initial_learning_rate` 是初始学习率。
   - `decay_steps` 是衰减步数,表示经过多少步之后学习率会进行衰减。
   - `decay_rate` 是衰减率,表示每个 `decay_steps` 步学习率将衰减为原来的 `decay_rate` 倍。
   - `staircase` 是一个布尔值,指定是否将衰减应用于阶梯函数。如果为 `True`,则学习率会在每个 `decay_steps` 步骤之后衰减;如果为 `False`,则学习率将根据每个训练步骤进行平滑衰减。
   - `lr_schedule` 是一个调度器对象,用于动态调整学习率。

通过指数衰减学习率,初始学习率为 `0.001`,每经过 `30` 步衰减一次,衰减率为 `0.92`,并且采用阶梯函数进行衰减。调度器对象 `lr_schedule` 可以在训练过程中根据指定的规则自动调整学习率。

学习率大与学习率小的优缺点分析:
学习率大
优点:
。1、加快学习速率。
。2、有助于跳出局部最优值。
●缺点:
。1、 导致模型训练不收敛。
0 2、单单使用大学习率容易导致模型不精确。
学习率小.
优点:
。1、有助于模型收敛、模型细化。
0 2、提高模型精度。
● 缺点:
。1、很难跳出局部最优值。
。2、收敛缓慢。

这里设置的动态学习率为:指数衰减型(ExponentialDecay) 。在每一-个epoch开始前, 学习率
(learning_ rate) 都将会重置为初始学习率(initial. _learning. rate),然后再重新开始衰减。计算公式如下: 

learning_ rate = initial. learning. _rate * decay_ rate ^ (step / decay_ steps) 

`decay_steps` 参数对训练结果有以下影响:

1. 控制学习率衰减的频率:`decay_steps` 指定了学习率衰减的步数。当训练步数达到或超过 `decay_steps` 时,学习率会进行一次衰减。较小的 `decay_steps` 值会使学习率更频繁地进行衰减,而较大的 `decay_steps` 值会使学习率衰减得更慢。衰减频率直接影响权重更新的速度和幅度。

2. 控制学习率的衰减速度:`decay_steps` 还影响学习率的衰减速度。较小的 `decay_steps` 值会使学习率更快地衰减,而较大的 `decay_steps` 值会使学习率衰减得更慢。衰减速度决定了学习率从初始值逐渐减小到较小值所经过的训练步数。

适当选择 `decay_steps` 可以帮助优化模型的训练过程。如果 `decay_steps` 设置得太小,学习率会频繁地进行衰减,可能导致模型收敛速度过慢。相反,如果 `decay_steps` 设置得太大,学习率衰减得太慢,可能导致模型在训练早期无法充分收敛。

通常,对于较大的数据集或复杂的模型,较大的 `decay_steps` 值可能更合适,因为模型需要更多的训练步骤来进行权重调整。而对于较小的数据集或简单的模型,较小的 `decay_steps` 值可能更合适,因为模型可以更快地收敛到最优解。

需要根据具体问题和实验来调整和选择合适的 `decay_steps` 值,以获得更好的训练结果。

 decay_steps取值不同时训练效果:

decay_steps = 5:

decay_steps = 20: 

decay_steps = 70: 

 

五、训练模型 

#模型训练
from tensorflow.keras.callbacks import ModelCheckpoint
 
epochs = 50
 
checkpointer = ModelCheckpoint(
    'best_model.h5',
    monitor = 'val_accuracy',
    verbose = 1,
    save_best_only = True,
    save_weights_only = True
)
 
history = model.fit(
    train_ds,
    validation_data = val_ds,
    epochs = epochs,
    callbacks = [checkpointer]
)

 六、模型评估

6.1Loss和Acc图

#模型评估
loss = history.history['loss']
val_loss = history.history['val_loss']
 
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
 
epochs_range = range(len(loss))
 
plt.figure(figsize = (12, 4))
 
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label = "Training Acc")
plt.plot(epochs_range, val_acc, label = "Validation Acc")
plt.legend(loc = 'lower right')
plt.title("Training And Validation Acc")
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label = "Training Loss")
plt.plot(epochs_range, val_loss, label = "Validation Loss")
plt.legend(loc = 'upper right')
plt.title("Training And Validation Loss")
 
plt.show()

6.2指定结果进行预测

model.load_weights('best_model.h5')
 
from PIL import Image
import numpy as np
 
img = Image.open("D:/T4ShoseBand/train/nike/1 (12).jpg")
 
image = tf.image.resize(img, [img_height, img_width]) 
 
img_array = tf.expand_dims(image, 0)
 
 
predictions = model.predict(img_array)
#这个函数用于对输入图像进行分类预测。它使用已经训练好的模型来对输入数据进行推断,并输出每个类别的概率分布。
print("预测结果为:", class_names[np.argmax(predictions)])
 

七、完整代码

from tensorflow import keras
from tensorflow.keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image
import numpy as np
 
data_dir = "D:\T5ShoseBand"
data_dir = pathlib.Path(data_dir)

#设施图片格式
batch_size = 32
img_height = 224
img_width = 224

#划分训练集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "D:\T5ShoseBand",
    seed = 123,
    image_size = (img_height, img_width),
    batch_size = batch_size)

#划分验证集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "D:\T5ShoseBand",
    seed = 123,
    image_size = (img_height, img_width),
    batch_size = batch_size
)

#查看标签
class_names = train_ds.class_names
print(class_names)

#检验数据
for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

#配置数据
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size = AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size = AUTOTUNE)

#设置Sequential模型,创建神经网络
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width,3)),
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
    layers.AveragePooling2D(2*2),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.AveragePooling2D((2, 2)),
    layers.Dropout(0.3),
    layers.Flatten(),
    layers.Dense(128,activation='relu'),
    layers.Dense(4)
])

#设置动态学习率
initial_learning_rate = 0.001
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps = 30,
    decay_rate = 0.92,
    staircase = True
)

#设置优化器
opt = keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(
    optimizer = opt,
    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['accuracy']
)

#模型训练
from tensorflow.keras.callbacks import ModelCheckpoint
 
epochs = 50
 
checkpointer = ModelCheckpoint(
    'best_model.h5',
    monitor = 'val_accuracy',
    verbose = 1,
    save_best_only = True,
    save_weights_only = True
)
 
history = model.fit(
    train_ds,
    validation_data = val_ds,
    epochs = epochs,
    callbacks = [checkpointer]
)

#模型评估
loss = history.history['loss']
val_loss = history.history['val_loss']
 
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
 
epochs_range = range(len(loss))
 
plt.figure(figsize = (12, 4))
 
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label = "Training Acc")
plt.plot(epochs_range, val_acc, label = "Validation Acc")
plt.legend(loc = 'lower right')
plt.title("Training And Validation Acc")
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label = "Training Loss")
plt.plot(epochs_range, val_loss, label = "Validation Loss")
plt.legend(loc = 'upper right')
plt.title("Training And Validation Loss")
 
plt.show()

model.load_weights('best_model.h5')
 
 #预测
img = Image.open("D:/T4ShoseBand/train/nike/1 (12).jpg") 
image = tf.image.resize(img, [img_height, img_width])  
img_array = tf.expand_dims(image, 0)
predictions = model.predict(img_array)
#这个函数用于对输入图像进行分类预测。它使用已经训练好的模型来对输入数据进行推断,并输出每个类别的概率分布。
print("预测结果为:", class_names[np.argmax(predictions)])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/729773.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【活动】如何在工作中管理情绪

写在前面 近期发生的新闻热点再度引发公众对稳定情绪和心理健康的关注。有时候我们遇到的最大的敌人,不是运气也不是能力,而是失控的情绪和口无遮拦的自己。如何在工作中保持稳定的情绪?谈谈我的看法。 愤怒的危害 说到愤怒这种情绪&#xf…

基于深度学习的高精度鸟类目标检测识别系统(PyTorch+Pyside6+YOLOv5模型)

摘要:基于深度学习的高精度鸟类目标(鹦鹉(Crested Myna)、麻雀(Eurasian Tree Sparrow)、黑头文鸟(Chestnut Munia)、白领翡翠(Collared Kingfisher)、太阳鸟…

【C语言】进阶指针(二)—>函数指针与回调函数

目录 前言: 一、函数指针 代码1分析: 代码2分析: 二、函数指针数组 三、指向函数指针数组的指针 四、回调函数(模拟实现库函数qsort) (一)void*类型指针的作用 (二&#xf…

Spark—Shell命令对WordCount案例的基本操作(统计、去重、排序、求平均值及join)

一、统计、去重 1、案例数据介绍 WordCount统计:某电商网站记录了大量的用户对商品的收藏数据,并将数据存储在名为buyer_favorite的文本文件中。文本数据格式如下: 2、启动spark-shell 配置好spark环境,若还没有环境可以参考…

windows下环境问题总结

nacos 启动后在spring 项目中无法加载yml配置文件 spring.datasource.platform mysql 注意一定要放开这行,不放的话,可能会导致服务可以成功注册,但是,我们无法使用局部的 nacos里yml配置文件的属性

Linux:项目自动化构建工具——make/Makefile

文章目录 一.make与Makefile的关系1.Makefile2.make 二.项目清理1.clean2. .PHONY 前言: 本章主要内容有认识与学习Linux环境下如何使用项目自动化构建工具——make/makefile。 一.make与Makefile的关系 当我们编写一个较大的软件项目时,通常需要将多个…

js实现图片压缩

创建一个type"file"的input标签&#xff0c;用于文件上传。 <input type"file" name"" id"upload" value"" />通过js实现图片压缩 window.onload function () {const upload document.getElementById("upload…

9.10UEC++生成、销毁actor

BeginPlay&#xff1a; 1.SpawnActor&#xff1a;<模板类>&#xff08;模板::staticclass&#xff08;&#xff09;&#xff0c;FVector const class&#xff0c;FRotation const class&#xff09; 生成一个actor 2.Destory&#xff08;&#xff09;从世界中销毁一个a…

SSM学习笔记-------Spring(一)

SSM学习笔记-------Spring&#xff08;一&#xff09; Spring_day011、课程介绍1.1 为什么要学?1.2 学什么?1.3 怎么学? 2、Spring相关概念2.1 初识Spring2.1.1 Spring家族2.1.2 了解Spring发展史 2.2 Spring系统架构2.2.1 系统架构图2.2.2 课程学习路线 2.3 Spring核心概念…

【zabbix 代理服务器】

目录 一、部署 zabbix 代理服务器1、设置 zabbix 的下载源&#xff0c;安装 zabbix-proxy2、初始化数据库1、创建数据库并指定字符集2、创建 zabbix 数据库用户并授权 3、导入数据库信息4、修改 zabbix-proxy 配置文件5、启动 zabbix-proxy6、在所有主机上配置 hosts 解析7、在…

Maven高级(四)--私服

一.作用 我们所拆分的模块是可以在同一个公司各个项目组之间的项目组之间进行资源共享的&#xff0c;这就需要Maven的私服来实现。 二.场景 两个项目组之间如何基于私服进行资源的共享的呢&#xff1f; 例如A开发了一个模块tlias-utils,B团队进行项目开发&#xff0c;要想使用…

CentOS7 主机万兆网卡配置端口聚合后如何与华为交换机连接并上网

环境: 组装机测试服务器 CentOS 7 CentOS Linux release 7.7.1908 (Core) 网卡1:瑞昱普通千兆板载网卡 网卡2:EB-LINK intel 82599芯片PCI-E X8 10G 光模块千兆单模1310 交换机 HW-S1730S-S48T4S-A Version 5.170 (S1730 V200R021C01SPC200) 光模块千兆单模1310 …

Python学习笔记(十七)————模块相关

目录 &#xff08;1&#xff09;模块 &#xff08;2&#xff09;模块的导入方式 ①import 模块名 ②from 模块名 import 功能名 ③from 模块名 import * ④as定义别名 &#xff08;3&#xff09;自定义模块 &#xff08;4&#xff09;测试模块 &#xff08;1&#xff09…

List移除元素的四种方式

List 移除某个元素 四种方式&#xff1a; 方式一&#xff0c;使用 Iterator &#xff0c;顺序向下&#xff0c;如果找到元素&#xff0c;则使用 remove 方法进行移除。方式二&#xff0c;倒序遍历 List &#xff0c;如果找到元素&#xff0c;则使用 remove 方法进行移除。方式…

使用TypeScript实现贪吃蛇小游戏(网页版)

本项目使用webpackts所编写 下边是项目的文件目录 /src下边的index.html页面是入口文件 index.ts是引入所有的ts文件 /modules文件夹是用来存放所有类的 index.html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"…

VoxaMech 的机甲装备 — NFT 系列

一套新的机甲装备即将诞生&#xff1a;帝国甲虫套装是无与伦比的权威和威望的象征&#xff0c;体现了古代帝国甲虫的雄伟壮观。其华丽的铠甲散发着帝王般的优雅气息&#xff0c;其威严的外观在战场上令人肃然起敬。 每套装备由手臂部件、胸甲、头盔、腿部件和剑组成。每件装备单…

SpringBoot——业务层测试事务回滚

事务回滚 关于事务回滚的概念我们之前在学习数据库的时候已经提到过了&#xff0c;这里我们再次强化一下记忆。所谓的事务回滚就是在执行多条SQL语句的时候&#xff0c;如果其中一条SQL出现了异常导致执行失败&#xff0c;则数据库的状态回滚到执行多条SQL语句之前的状态&…

第六章:YOLO v1网络详解(统一的实时目标检测)

(目标检测篇&#xff09;系列文章目录 第一章:R-CNN网络详解 第二章:Fast R-CNN网络详解 第三章:Faster R-CNN网络详解 第四章:SSD网络详解 第五章:Mask R-CNN网络详解 第六章:YOLO v1网络详解 第七章:YOLO v2网络详解 第八章:YOLO v3网络详解 文章目录 系列文章目录技…

一起学SF框架系列5.7-模块Beans-BeanDefinition定义

在SF下&#xff0c;开发人员用xml或注解模式定义bean&#xff0c;框架把这些定义转化为内部BeanDefinition类&#xff0c;然后通过BeanDefinition类实现Bean的管理&#xff08;包括初始化、依赖注入及生命周期管理&#xff09;&#xff0c;因此了解Bean的定义、解析、使用过程非…

[kafka] windows下安装kafka(含安装包)

[kafka] windows下安装kafka&#xff08;含安装包&#xff09; 目录 前言 一、下载kafka安装包 1&#xff09;下载安装包 2&#xff09;解压安装包 二、运行zookeeper 1.运行zookeeper&#xff08;因为kafka必须要和zookeeper一起运行&#xff09; 三、运行kafka 四、使用fafka…