T5打卡-运动鞋识别

news2024/11/6 7:22:11
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.导入及查看数据:

from tensorflow import keras
from tensorflow.keras import layers,models
import os,PIL,pathlib
import matplotlib.pyplot as plt
import tensorflow as tf

gpus=tf.config.list_physical_devices("GPU")
if gpus:
    gpu0=gpus[0]
    tf.conifg.experimental.set_memory_growth(gpu0,True)
    tf.config.set_visible_devices([gpu0],"GPU")
    
gpus

data_dir="data/46-data"
data_dir=pathlib.Path(data_dir)
#查看图片数量
image_count=len(list(data_dir.glob("*/*/*.jpg")))
print("图片总数为:",image_count)
#查看图片
roses=list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(roses[0]))

2.加载数据:

batch_size = 32
img_height = 224
img_width = 224

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "data/46-data/train/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "data/46-data/test/",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
#查看种类
class_names = train_ds.class_names
print(class_names)

3.数据可视化

plt.figure(figsize=(20,10))
for images,labels in train_ds.take(1):
    for i in range(20):
        ax=plt.subplot(5,10,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

4.检查数据与配置数据:

for image_batch,labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

5.构建模型:

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), 
    layers.AveragePooling2D((2, 2)),              
    layers.Conv2D(32, (3, 3), activation='relu'),  
    layers.AveragePooling2D((2, 2)),               
    layers.Dropout(0.3),  
    layers.Conv2D(64, (3, 3), activation='relu'), 
    layers.Dropout(0.3),  
    
    layers.Flatten(),                       
    layers.Dense(128, activation='relu'),  
    layers.Dense(len(class_names))               
])

model.summary() 

6.编译并训练模型

initial_learning_rate =0.0001

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=10,      
        decay_rate=0.92,     
        staircase=True)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history=model.fit(train_ds,
                  validation_data=val_ds,
                  epochs=epochs,
                  callbacks=[checkpointer,earlystopper])

7.结果可视化:

acc=history.history['accuracy']
val_acc=val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(len(loss))

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

8.预测数据:

model.load_weights('best_model.h5')

from PIL import Image
import numpy as np

img = Image.open("data/46-data/test/adidas/1.jpg")  
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0)

predictions = model.predict(img_array)
print("预测结果为:",class_names[np.argmax(predictions)])

总结:

1.设置动态学习率——ExponentailDecay函数:

        tf.keras.optimizers.schedules.ExponentialDecay是 TensorFlow 中的一个学习率衰减策略,用于在训练神经网络时动态地降低学习率。学习率衰减是一种常用的技巧,可以帮助优化算法更有效地收敛到全局最小值,从而提高模型的性能。

        主要参数:

        initial_learning_rate(初始学习率):初始学习率大小。

        decay_steps(衰减步数):学习率衰减的步数。在经过 decay_steps 步后,学习率将按照指数函数衰减。例如,如果 decay_steps 设置为 10,则每10步衰减一次。

        decay_rate(衰减率):学习率的衰减率。它决定了学习率如何衰减。通常,取值在 0 到 1 之间。

        staircase(阶梯式衰减):一个布尔值,控制学习率的衰减方式。如果设置为 True,则学习率在每个 decay_steps 步之后直接减小,形成阶梯状下降。如果设置为 False,则学习率将连续衰减。

2.早停设置——EarlyStopping函数:

参数:

  monitor: 被监测的数据。

  min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。

  patience: 没有进步的训练轮数,在这之后训练就会被停止。

  verbose: 详细信息模式。

  mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。

  baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。

  estore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。

3.保存最佳模型参数:

checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2180850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI与大数据的结合:如何从海量数据中提取价值

引言 在当今数字化时代,数据如同新石油,成为推动社会与商业进步的重要资源。随着物联网、社交媒体和企业运营中数据生成的激增,我们正处在一个数据爆炸的时代。然而,面对海量且复杂的数据信息,仅依靠传统的分析方法已经…

Python入门:asyncio异步编程结果处理

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 文章内容 📒📝 处理异步任务的基本概念📝 获取第一个结果📝 添加回调函数📝 使用`return_exceptions`处理异常📝 判断任务完成情况📝 获取结果详情⚓️ 相关链接 ⚓️📖 介绍 📖 在Python编程中,异步并发可能…

钉钉H5微应用Springboot+Vue开发分享

文章目录 说明技术路线注意操作步骤思路图 一、创建钉钉应用二、创建java项目三、创建vue项目(或uniapp项目),npm引入sdk的依赖四、拥有公网域名端口。开发环境可以使用(贝锐花生壳等工具)五、打开钉钉开发者平台&…

【YashanDB知识库】客户端字符集与数据库字符集兼容问题

本文转自YashanDB官网,具体内容请见https://www.yashandb.com/newsinfo/7352675.html?templateId1718516 问题现象 客户端yasql配置字符集为GBK,服务端yasdb配置字符集为UTF8,之后执行语句: 会发现: 期望是两个都…

【LeetCode】每日一题 2024_9_29 买票需要的时间(模拟)

前言 每天和你一起刷 LeetCode 每日一题~ LeetCode 启动! 昨天的每日一题是线段树二分,题目难度远超我的能力范围,所以更不出来了 题目:买票需要的时间 代码与解题思路 func timeRequiredToBuy(tickets []int, k int) (sum in…

【Kubernetes知识点】 解读 Service 和 EndpointSlice 之间的关系

【Kubernetes知识点】 解读 Service 和 EndpointSlice 之间的关系 目录 1 概念 1.1 Service的概念1.2 Endpoint 的概念1.3 EndpointSlice 的引入 1.3.1 EndpointSlice支持的地址1.3.2 EndpointSlice的状态1.3.3 EndpointSlice的拓扑信息 1.4 Service 、Endpoint和 EndpointSl…

Beyond Compare 比较CRC值、二进制比较、关联规则比较,有何区别?(CRC比较、CRC值比较)

文章目录 Beyond Compare文件比较方法深入分析CRC值比较定义及工作原理应用场景优点和缺点 二进制比较定义及工作原理应用场景优点和缺点 关联规则比较定义及工作原理应用场景优点和缺点 比较示例 性能差异CRC值比较的性能影响优点缺点 二进制比较的性能影响优点缺点 关联规则比…

C项目--带权限的图书管理系统(1000多行代码,代码数据可下载,极其适合初学练手)

本专栏目的 更新C/C的相关的项目 前言 C语言的图书权限管理系统完结(进阶的一点后面更新),1000多行代码(核心代码5、600行);本设计是一个比较综合的练习,用到数据结构(顺序表、链表、静态链表)、文件、排…

发布-订阅模式演示示例

<!DOCTYPE html> <html lang="en"><head><meta charset="UTF-8" /><meta name="viewport" content="width=device-width, initial-scale=1.0" /><title>发布-订阅模式示例</title><styl…

LC记录一:寻找旋转数组最小值、判断旋转数组是否存在给定元素

文章目录 33.搜索旋转排序数组81.搜索旋转排序数组||153.寻找旋转排序数组中的最小值154.寻找旋转排序数组中的最小值||参考链接 33.搜索旋转排序数组 https://leetcode.cn/problems/search-in-rotated-sorted-array/description/ 下面这张图片是LC154题官方题解提供的一个图…

重磅!25年3月起,PMP®考试将启用新教材!

近期&#xff0c;PMI对各科目教材进行了调整。9月27日宣布了2025年3月将会更新ACP的考试内容&#xff0c;新版考试仍将围绕敏捷思维和产品交付&#xff0c;但考试内容大纲(ECO)将整合为四大领域&#xff08;思维、领导力、产品、交付&#xff09;&#xff0c;融合现代新的项目类…

DAY18||530.二叉搜索树的最小绝对值差 |501.二叉搜索树中的众数| 236.二叉树的最近公共祖先

530.二叉搜索树的最小绝对值差 题目&#xff1a;530. 二叉搜索树的最小绝对差 - 力扣&#xff08;LeetCode&#xff09; 给你一个二叉搜索树的根节点 root &#xff0c;返回 树中任意两不同节点值之间的最小差值 。 差值是一个正数&#xff0c;其数值等于两值之差的绝对值。 …

MongoDB 快速入门+单机部署(附带脚本)

目录 介绍 体系结构 数据模型 BSON BSON 数据类型 特点 高性能 高可用 高扩展 丰富的查询支持 其他特点 部署 单机部署 普通安装 脚本安装 Docker Compose 安装 卸载 停止 MongoDB 删除包 删除数据目录 参考&#xff1a; https://docs.mongoing.com/ 介绍…

Ping到底干了啥?ICMP 协议详解

什么是 ICMP&#xff1f; ICMP&#xff08;Internet Control Message Protocol&#xff0c;互联网控制消息协议&#xff09;是一种网络层协议&#xff0c;主要用于在网络设备之间传递控制信息和错误消息。它是 IP 协议族的一部分&#xff0c;通常与 IP 协议一起使用。ICMP 的主…

cheese自动化平台开发环境搭建【图文版】

目的 cheese自动化平台是一款可以模拟鼠标和键盘操作的自动化工具。它可以帮助用户自动完成一些重复的、繁琐的任务&#xff0c;节省大量人工操作的时间。可以采用Vscode、IDEA编写&#xff0c;支持Java、Python、nodejs、GO、Rust、Lua。官方提供了视频版教程&#xff0c;对于…

leetcode每日一题day20(24.9.30)——座位预约管理系统

思路&#xff1a;由于一直是出最小的编号&#xff0c;并且除此之外只有添加元素的操作&#xff0c;如果使用数组存储&#xff0c;并记录&#xff0c;这样出最小编号时间是O(n)复杂度,释放一个座位则是O(1)在操作出线机会均等的情况下&#xff0c;平均是O(n/2), 但对于这种重复 …

CANoe_trace介绍以及如何使用C#仿制trace方案介绍

C#UI界面仿制trace界面介绍--初次制作&#xff0c;后续待完善 在汽车电子开发中&#xff0c;DBC&#xff08;Database Container&#xff09;文件对于定义和描述CAN&#xff08;Controller Area Network&#xff09;通信协议至关重要。随着项目的迭代和功能的扩展&#xff0c;手…

Elasticsearch 开放推理 API 增加了对 Google AI Studio 的支持

作者&#xff1a;来自 Elastic Jeff Vestal 我们很高兴地宣布 Elasticsearch 的开放推理 API 支持 Gemini 开发者 API。使用 Google AI Studio 时&#xff0c;开发者现在可以与 Elasticsearch 索引中的数据进行聊天、运行实验并使用 Google Cloud 的模型&#xff08;例如 Gemin…

0基础学前端 day7

大家好&#xff0c;欢迎来到无限大的频道 今天继续带领大家来了解前端的知识&#xff0c;深入了解互联网和浏览器背后的技术。 历史背景 互联网的起源可以追溯到20世纪60年代的ARPANET项目&#xff0c;作为研究机构之间的通信网络。最初的网页浏览器由Tim Berners-Lee于1990…

【工程测试技术】第3章 测试装置的基本特性,静态特性和动态特性,一阶二阶系统的特性,负载效应,抗干扰性

目录 3.1 概述 1测量装置的静态特性 2.标准和标准传递 3.测量装置的动态特性 4.测量装置的负载特性 5.测量装置的抗干扰性 1.线性度 2.灵敏度 3.回程误差 4.分辨力 5.零点漂移和灵敏度漂移 3.3.1 动态特性的数学描述 1.传递函数 2.频率响应函数 3.脉冲响应函数 …