Tensorflow入门实战 T05-运动鞋识别

news2024/9/21 2:36:37

目录

一、完整代码

二、训练过程

(1)打印2行10列的数据。

(2)查看数据集中的一张图片

(3)训练过程(训练50个epoch)

(4)训练结果的精确度

三、遇到的问题


  • 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这篇博客的主要内容是,关于运动鞋的识别。

运动鞋数据集包含训练集和测试集,共578张。

一、完整代码

from tensorflow import keras
from keras import layers, models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]  # 如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  # 设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0], "GPU")

# 导入数据集
data_dir = "/Users/MsLiang/Documents/mySelf_project/pythonProject_pytorch/learn_demo/P_model/p05_sport/sport_data"
data_dir = pathlib.Path(data_dir)  # 打印文件夹目录
image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:",image_count)
roses = list(data_dir.glob('train/nike/*.jpg'))
result = PIL.Image.open(str(roses[0]))
# result.show()

# 数据预处理
batch_size = 32
img_height = 224
img_width = 224

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_dir = str(data_dir) + "/train"
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
test_dir = str(data_dir) + "/test/"
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir,
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)  # 打印结果: ['adidas', 'nike']

# 可视化数据
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])

        plt.axis("off")
plt.show()

# 检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)   # (32, 224, 224, 3)
    print(labels_batch.shape)  # (32,)
    break

# 配置数据集
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# 搭建神经网络
"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=(img_height, img_width, 3)),

    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),  # 卷积层1,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),  # 池化层2,2*2采样
    layers.Dropout(0.3),
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),

    layers.Flatten(),  # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),  # 全连接层,特征进一步提取
    layers.Dense(len(class_names))  # 输出层,输出预期结果
])

# model.summary()  # 打印网络结构


# 设置动态学习率
# 设置初始学习率
initial_learning_rate = 0.1

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate,
        decay_steps=10,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.92,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])


from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

epochs = 50

# 保存最佳模型参数
checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

# 设置早停
earlystopper = EarlyStopping(monitor='val_accuracy',
                             min_delta=0.001,
                             patience=20,
                             verbose=1)

# 模型训练
history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer, earlystopper])

# 模型评估图
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

# 指定图片进行预测
# 加载效果最好的模型权重
model.load_weights('best_model.h5')

from PIL import Image
import numpy as np

# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  # 这里选择你需要预测的图片
img = Image.open(str(data_dir) + "/test/nike/1.jpg")  # 这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0) # /255.0  # 记得做归一化处理(与训练集处理方式保持一致)

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])

二、训练过程

(1)打印2行10列的数据。

(2)查看数据集中的一张图片

(3)训练过程(训练50个epoch)

模型早听了。

(4)训练结果的精确度

三、遇到的问题

在进行字符串拼接的时候,出现unsupported operand type(s) for +: 'PosixPath' and 'str'

查了下相关资料,添加str( )就可以。

原因:因为前面的data_dir 是经过pathlib.Path() 处理的。

添加str( ) 完美解决。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1840225.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0618,0619 ,指针,指针与数组,字符串

目录 第八章(指针),第九章(指针和数组),第十章(字符串)思维导图 作业1 逆序打印 解答: 答案: 作业2 判断回文 解答: 答案: …

PCB设计隐藏的陷进

1、BGA芯片的开窗和过油设计。 加工工艺中,范式过孔都需要盖油设计,实心焊盘需要开窗设计,坚决不能盖油。 2、通孔设计的互联连通性 比如H3芯片的wifi设计,实际上是没有联通的,虽然四层板的中间层有焊盘,但…

【C++ | 移动构造函数】C++11的 移动构造函数 详解及例子代码

😁博客主页😁:🚀https://blog.csdn.net/wkd_007🚀 🤑博客内容🤑:🍭嵌入式开发、Linux、C语言、C、数据结构、音视频🍭 ⏰发布时间⏰:2024-06-12 2…

Mamaba3--RNN、状态方程、勒让德多项式

Mamaba3–RNN、状态方程、勒让德多项式 一、简单回顾 在Mamba1和Mamba2中分别介绍了RNN和状态方程。 下面从两个图和两个公式出发,对RNN和状态方程做简单的回顾: R N N : s t W s t − 1 U x t ; O t V s t RNN: s_t Ws_{t-1}Ux_t&…

前字节员工自爆:我原腾讯一哥们,跳槽去小公司做小领导,就签了竞业,又从小公司离职去了对手公司,结果被发现了,小公司要他赔80万

“世界那么大,我想去看看”,这句曾经火遍网络的辞职宣言,说出了多少职场人心中的渴望。然而,当我们真的迈出跳槽那一步时,才发现,现实远比想象中残酷得多。 最近,一起前字节跳动员工爆料的事件…

22种常用设计模式示例代码

文章目录 创建型模式结构型模式行为模式 仓库地址https://github.com/Xiamu-ssr/DesignPatternsPractice 参考教程 refactoringguru设计模式-目录 创建型模式 软件包复杂度流行度工厂方法factorymethod❄️⭐️⭐️⭐️抽象工厂abstractfactory❄️❄️⭐️⭐️⭐️生成器bui…

自学网络安全 or Web安全,一般人我还是劝你算了吧

由于我之前写了不少网络安全技术相关的文章,不少读者朋友知道我是从事网络安全相关的工作,于是经常有人私信问我: 我刚入门网络安全,该怎么学? 要学哪些东西? 有哪些方向? 怎么选?…

福昕PDF编辑器快速去除PDF水印方法

在福昕PDF编辑器软件中打开一个带有水印的PDF文件,点击如图下所示的页面管理->水印,点击全部移除 点击 是 水印消除(注:部分类型的水印可以消除,但是有些类型的水印无法通过此方法消除)

Day28:回溯法 491.递增子序列 46.全排列 47.全排列 II 332.重新安排行程 51. N皇后 37. 解数独 蓝桥杯 与或异或

491. 非递减子序列 给你一个整数数组 nums ,找出并返回所有该数组中不同的递增子序列,递增子序列中 至少有两个元素 。你可以按 任意顺序 返回答案。 数组中可能含有重复元素,如出现两个整数相等,也可以视作递增序列的一种特殊情…

Flask新手入门(一)

前言 Flask是一个用Python编写的轻量级Web应用框架。它最初由Armin Ronacher作为Werkzeug的一个子项目在2010年开发出来。Werkzeug是一个综合工具包,提供了各种用于Web应用开发的工具和函数。自发布以来,Flask因其简洁和灵活性而迅速受到开发者的欢迎。…

昇思大模型学习·第一天

mindspore快速入门回顾 导入mindspore包 处理数据集 下载mnist数据集进行数据集预处理 MnistDataset()方法train_dataset.get_col_names() 打印列名信息使用create_tuple_iterator 或create_dict_iterator对数据集进行迭代访问 网络构建 mindspore.nn: 构建所有网络的基类用…

Docker(四)-Docker镜像

1.概念 镜像是一种轻量级的、可执行的独立软件包,它包含运行某个软件所需的所有内容,我们把应用程序和配置依赖 打包好形成一个可交付的运行环境(包括代码,运行时需要的库,环境变量和配置文件等),这个打包好的运行环境…

docker pull xxx拉取超时time out

文章目录 前言总结 前言 换了镜像源,改配置的都不行,弄了一个下午,最后运行一下最高指令就可以了 sudo docker_OPTS"--dns 8.8.8.8"总结 作者:加辣椒了吗? 简介:憨批大学生一枚,喜欢…

STM32自己从零开始实操06:无线电路原理图

一、WIFI 模块电路设计 1.1指路 延续使用 ESP-12S 芯片,封装 SMD 16x24mm。 实物图 原理图与PCB图 2.2电路图 电路较为简单,如下图: 2.2.1引脚说明 序号引脚名称描述1RST复位复位引脚,低电平有效3EN使能芯片使能端&#xff0c…

【深度学习】sdwebui A1111 加速方案对比,xformers vs Flash Attention 2

文章目录 资料支撑资料结论sdwebui A1111 速度对比测试sdxlxformers 用contorlnet sdxlsdpa(--opt-sdp-no-mem-attention) 用contorlnet sdxlsdpa(--opt-sdp-attention) 用contorlnet sdxl不用xformers或者sdpa ,用contorlnet sdxl不用xformers或者sdpa …

Windows安装配置jdk和maven(仅做记录)

他妈的远程连接不上公司电脑,只能在家重新配置一遍,在此记录一下后端环境全部配置 Windows安装配置JDK 1.8一、下载 JDK 1.8二、配置环境变量三、验证安装 Windows安装配置Maven 3.8.8一、下载安装 Maven并配置环境变量二、设置仓库镜像及本地仓库三、测…

BUU CODE REVIEW 11 代码审计之反序列化知识

打开靶场&#xff0c;得到的是一段代码。 通过分析上面代码可以构造下面代码&#xff0c;获取到序列化之后的obj。 <?php class BUU {public $correct "";public $input "";public function __destruct() {try {$this->correct base64_encode(u…

如何解决input输入时存在浏览器缓存问题?

浏览器有时会在你输入表单过后缓存你的输入&#xff0c;有时候能提供方便。 但是在某些新建或新页面情况下出现历史的输入信息&#xff0c;用户体验很差。 解决方案 设置 autocomplete关闭 &#xff1a;<input type"text" autocomplete"off">增加…

[AIGC] MyBatis-Plus中如何使用XML进行CRUD操作?

在MyBatis-Plus中&#xff0c;我们可以非常方便地使用XML进行CRUD&#xff08;创建、读取、更新、删除&#xff09;操作。以下是一些基本步骤和示例&#xff0c;希望能帮助到还在初学阶段的您。 文章目录 1. 创建Mapper接口2. 创建Mapper XML 文件3. 调用Mapper方法 1. 创建Ma…

【golang学习之旅】Go程序快速开始 Go程序开发的基本注意事项

系列文章 【golang学习之旅】使用VScode安装配置Go开发环境 【golang学习之旅】报错&#xff1a;a declared but not used 【golang学习之旅】Go 的基本数据类型 【golang学习之旅】深入理解字符串string数据类型 【golang学习之旅】go mod tidy 【golang学习之旅】记录一次 p…