T3 TensorFlow入门实战——天气识别

news2024/11/28 5:27:37
  • 🍨 本文為🔗365天深度學習訓練營 中的學習紀錄博客
  • 🍖 原作者:K同学啊 | 接輔導、項目定制

一、前期准备

1. 导入数据

# Import the required libraries
import numpy as np
import os,PIL,pathlib
import matplotlib.pyplot as plt
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,models


# load the data
data_dir = './data/weather_photos/'
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
classeNames

2. 查看数据 

image_count = len(list(data_dir.glob('*/*.jpg')))

print("图片总数为:",image_count)

roses = list(data_dir.glob('sunrise/*.jpg'))
PIL.Image.open(str(roses[0]))

二、数据预处理

1. 加载数据

# Data loading and preprocessing
batch_size = 32
img_height = 180
img_width = 180

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

 通过class_names同样可以输出数据集标签

class_names = train_ds.class_names
print(class_names)

 2. 可视化数据

# Visualize the data
plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

# Check the shape of the data
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

 3. 配置数据集

  • prefetch()功能详细介绍:CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态。
  • cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、训练模型 

1. 构建CNN网络模型

卷积神经网络(CNN)的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。不需要输入batch size。color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道(color channel)。在此示例中,我们的 CNN 输入形状是 (180, 180, 3)。我们需要在声明第一层时将形状赋值给参数input_shape

# Define the model architecture
num_classes = 4

"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
在上一篇文章花朵识别中,训练准确率与验证准确率相差巨大就是由于模型过拟合导致的

关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
    layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),                           # 让神经元以一定的概率停止工作,防止过拟合,提高模型的泛化能力。
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(num_classes)               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

 2. 编译模型

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
"""
这里设置优化器、损失函数以及metrics
"""
# model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

 3. 训练模型

"""
这里设置输入训练数据集(图片及标签)、验证数据集(图片及标签)以及迭代次数epochs
关于model.fit()函数的具体介绍可参考我的博客:
https://blog.csdn.net/qq_38251616/article/details/122321757
"""
epochs = 10

history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

四、模型评估

# Evaluate the model
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2248838.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

✨系统设计时应时刻考虑设计模式基础原则

目录 💫单一职责原则 (Single Responsibility Principle, SRP)💫开放-封闭原则 (Open-Closed Principle, OCP)💫依赖倒转原则 (Dependency Inversion Principle, DIP)💫里氏代换原则 (Liskov Substitution Principle, LSP)&#x…

fatal error in include chain (rtthread.h):rtconfig.h file not found

项目搜索这个文件 rtconfig 找到后将其复制粘贴到 你的目录\Keil\ARM\ARMCC\include 应该还有cJSON,rtthread.h和 等也复制粘贴下

【回文数组——另类递推】

题目 代码 #include <bits/stdc.h> using namespace std; using ll long long; const int N 1e510; int a[N], b[N]; int main() {int n;cin >> n;for(int i 1; i < n; i)cin >> a[i];for(int i 1; i < n / 2; i)b[i] a[i] - a[n1-i];ll ans 0;…

SQL基础入门—— 简单查询与条件筛选

在SQL中&#xff0c;查询是从数据库中获取数据的核心操作&#xff0c;而条件筛选是查询中不可或缺的一部分。通过使用条件筛选&#xff0c;我们可以精准地从大量数据中提取我们需要的信息。本节将详细讲解如何使用SQL进行简单查询与条件筛选&#xff0c;包含常见的条件运算符和…

反向代理模块

1 概念 1.1 反向代理概念 反向代理是指以代理服务器来接收客户端的请求&#xff0c;然后将请求转发给内部网络上的服务器&#xff0c;将从服务器上得到的结果返回给客户端&#xff0c;此时代理服务器对外表现为一个反向代理服务器。 对于客户端来说&#xff0c;反向代理就相当于…

英语知识网站:Spring Boot技术构建

6系统测试 6.1概念和意义 测试的定义&#xff1a;程序测试是为了发现错误而执行程序的过程。测试(Testing)的任务与目的可以描述为&#xff1a; 目的&#xff1a;发现程序的错误&#xff1b; 任务&#xff1a;通过在计算机上执行程序&#xff0c;暴露程序中潜在的错误。 另一个…

TMS FNC UI Pack 5.4.0 for Delphi 12

TMS FNC UI Pack是适用于 Delphi 和 C Builder 的多功能 UI 控件的综合集合&#xff0c;提供跨 VCL、FMX、LCL 和 TMS WEB Core 等平台的强大功能。这个统一的组件集包括基本工具&#xff0c;如网格、规划器、树视图、功能区和丰富的编辑器&#xff0c;确保兼容性和简化的开发。…

【AIGC】国内AI工具复现GPTs效果详解

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: AIGC | GPTs应用实例 文章目录 &#x1f4af;前言&#x1f4af;本文所要复现的GPTs介绍&#x1f4af;GPTs指令作为提示词在ChatGPT实现类似效果&#x1f4af;国内AI工具复现GPTs效果可能出现的问题解决方法解决后的效…

网络原理(一):应用层自定义协议的信息组织格式 HTTP 前置知识

目录 1. 应用层 2. 自定义协议 2.1 根据需求 > 明确传输信息 2.2 约定好信息组织的格式 2.2.1 行文本 2.2.2 xml 2.2.3 json 2.2.4 protobuf 3. HTTP 协议 3.1 特点 4. 抓包工具 1. 应用层 在前面的博客中, 我们了解了 TCP/IP 五层协议模型: 应用层传输层网络层…

【es6】原生js在页面上画矩形及删除的实现方法

画一个矩形&#xff0c;可以选中高亮&#xff0c;删除自己效果的实现&#xff0c;后期会丰富下细节&#xff0c;拖动及拖动调整矩形大小 实现效果 代码实现 class Draw {constructor() {this.x 0this.y 0this.disX 0this.disY 0this.startX 0this.startY 0this.mouseDo…

12 —— Webpack中向前端注入环境变量

环境变量的作用&#xff1a;根据不同环境&#xff0c;执行不同的配置 需求&#xff1a;开发模式下打印语句生效&#xff0c;生产模式下打印语句失效 —— 使用Webpack内置的DefinePlugin插件 const webpack require(webpack) module.exports { plugins: [ new webpack.Def…

Learn Git Branching 学习笔记

网址&#xff1a;Learn Git Branching 一、基础篇 1.1 git commit 1.1.1 示例&#xff08;git commit&#xff09; git commit 1.1.2 题目&#xff08;两次提交记录&#xff09; git commit git commit 前 后 1.2 git branch 1.2.1 示例&#xff08;git branch <>、git …

Unity类银河战士恶魔城学习总结(P145 Save Skill Tree 保存技能树)

【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili 教程源地址&#xff1a;https://www.udemy.com/course/2d-rpg-alexdev/ 本章节实现了技能树的保存 警告&#xff01;&#xff01;&#xff01; 如果有LoadData&#xff08;&#xff09;和SaveData(&#xff09;…

从 App Search 到 Elasticsearch — 挖掘搜索的未来

作者&#xff1a;来自 Elastic Nick Chow App Search 将在 9.0 版本中停用&#xff0c;但 Elasticsearch 拥有你构建强大的 AI 搜索体验所需的一切。以下是你需要了解的内容。 生成式人工智能的最新进展正在改变用户行为&#xff0c;激励开发人员创造更具活力、更直观、更引人入…

社团管理新工具:SpringBoot框架

3系统分析 3.1可行性分析 通过对本社团管理系统实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本社团管理系统采用SSM框架&#xff0c;JAVA作为开发语言&#…

Vue3 源码解析(三):静态提升

什么是静态提升 Vue3 尚未发布正式版本前&#xff0c;尤大在一次关于 Vue3 的分享中提及了静态提升&#xff0c;当时笔者就对这个亮点产生了好奇&#xff0c;所以在源码阅读时&#xff0c;静态提升也是笔者的一个重点阅读点。 那么什么是静态提升呢&#xff1f;当 Vue 的编译器…

8款Pytest插件助力Python自动化测试

当测试用例变得复杂&#xff0c;或者需要处理大量测试数据时&#xff0c;插件通过使测试更加简洁和结构化而变得非常有用。Python凭借其简洁性和多功能性&#xff0c;成为自动化测试的热门选择&#xff0c;而pytest是最广泛使用的测试框架之一。虽然pytest本身功能强大&#xf…

【spark-spring boot】学习笔记

目录 说明RDD学习RDD介绍RDD案例基于集合创建RDDRDD存入外部文件中 转换算子 操作map 操作说明案例 flatMap操作说明案例 filter 操作说明案例 groupBy 操作说明案例 distinct 操作说明案例 sortBy 操作说明案例 mapToPair 操作说明案例 mapValues操作说明案例 groupByKey操作说…

Spring Boot 3 集成 Spring Security(2)授权

文章目录 授权配置 SecurityFilterChain基于注解的授权控制自定义权限决策 在《Spring Boot 3 集成 Spring Security&#xff08;1&#xff09;》中&#xff0c;我们简单实现了 Spring Security 的认证功能&#xff0c;通过实现用户身份验证来确保系统的安全性。Spring Securit…

Apache OFBiz xmlrpc XXE漏洞(CVE-2018-8033)

目录 1、漏洞描述 2、EXP下载地址 3、EXP利用 1、漏洞描述 Apache OFBiz是一套企业资源计划&#xff08;ERP&#xff09;系统。它提供了广泛的功能&#xff0c;包括销售、采购、库存、财务、CRM等。 Apache OFBiz还具有灵活的架构和可扩展性&#xff0c;允许用户根据业务需求…