0基础深度学习项目13:基于TensorFolw实现天气识别

news2024/11/13 15:30:52
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 一、创建环境
  • 二、前期准备
    • 2.1 设置GPU
    • 2.2 导入数据
    • 2.3 数据预处理
      • 2.3.1 加载数据
      • 2.3.2 查看图像的标签
    • 2.4 数据可视化
  • 三、构建简单的CNN网络(卷积神经网络)
    • 3.1 网络结构
    • 3.2 编译模型
  • 四、训练模型
  • 五、模型预测
  • 六、总结

一、创建环境

🏡我的环境:
● 语言环境:Python3.8
● 深度学习环境:TensorFlow2
运行环境: colab

二、前期准备

2.1 设置GPU

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

2.2 导入数据

import os,PIL,pathlib
import matplotlib.pyplot as plt
import numpy             as np
from tensorflow          import keras
from tensorflow.keras    import layers,models

data_dir = "/content/drive/MyDrive/k-data/weather_photos"
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:",image_count)

输出:
在这里插入图片描述

roses = list(data_dir.glob('sunrise/*.jpg'))
PIL.Image.open(str(roses[0]))

输出:
在这里插入图片描述

2.3 数据预处理

2.3.1 加载数据

image_dataset_from_directory 函数是 TensorFlow 库中的一个工具函数,用于从文件系统中的目录创建图像数据集。这个函数可以自动读取指定目录中的图像文件,并按照一定的规则将它们组织成批次(batch)

batch_size = 32
img_height = 180
img_width = 180

# 使用 TensorFlow 库中的 `tf.keras.preprocessing.image_dataset_from_directory()` 函数来创建一个测试图像数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir, # 图像文件的目录路径
    validation_split=0.2, # 指定了从数据集中分割出20%的数据用作验证集
    subset="training", # 返回的数据集是用于训练的子集
    seed=123, # 用于确保数据集的分割是可重复的,即每次运行代码时,数据的分割方式都是相同的
    image_size=(img_height, img_width),
    batch_size=batch_size)

# 验证集
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

2.3.2 查看图像的标签

class_names = train_ds.class_names
print(class_names)

输出:
在这里插入图片描述

2.4 数据可视化

plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

输出:
在这里插入图片描述

三、构建简单的CNN网络(卷积神经网络)

3.1 网络结构

在这里插入图片描述

⭐卷积层
卷积层是卷积神经网络(CNN)中的基础组件,它通过一组可学习的过滤器(或内核)在输入数据(如图像)上进行卷积操作,以提取局部特征和模式;这些过滤器在整个输入数据上滑动,并通过元素相乘和求和生成特征图,每个特征图代表输入数据在特定过滤器下的特征响应,整个过程通过权值共享和局部连接减少参数数量,使得网络能够高效地学习图像的层次化表示。

⭐池化层
在图像处理中,由于图像中存在较多冗余信息,可用某一区域子块的统计信息(如最大值或均值等)来刻画该区域中所有像素点呈现的空间分布模式,以替代区域子块中所有像素点取值,这就是卷积神经网络中池化(pooling)操作。

池化层可对提取到的特征信息进行降维,实现下采样,同时保留了特征图中主要信息,一方面使特征图变小,简化网络计算复杂度;另一方面进行特征压缩,提取主要特征,增加平移不变性,减少过拟合风险。但其实池化更多程度上是一种计算性能的一个妥协,强硬地压缩特征的同时也损失了一部分信息。

num_classes = 4

model = models.Sequential([
    
    layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),

    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
    layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),                           # 让神经元以一定的概率停止工作,防止过拟合,提高模型的泛化能力。
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(num_classes)               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构

输出:
在这里插入图片描述

3.2 编译模型

● 损失函数(loss):用于衡量模型在训练期间的准确率。
● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
● 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

四、训练模型

  • history: 用来存储训练过程中的相关信息,比如损失(loss)和准确率(accuracy)等指标。在训练结束后,可以通过这个变量来分析模型的训练效果。

  • model.fit(...): 调用模型的fit方法,用于训练模型。fit方法接收多个参数,其中最重要的包括训练数据和标签,以及一些训练配置。

  • train_imagestrain_labels: 这两个参数分别代表训练数据和对应的标签。train_images是模型训练时使用的特征数据,而train_labels是这些特征数据对应的正确输出标签。

  • epochs=10: 这个参数指定了训练过程中要进行的迭代次数,也就是整个训练数据集要被模型遍历多少次。

  • validation_data=(test_images, test_labels): 这个参数提供了在每个epoch结束后用于验证模型性能的数据。test_imagestest_labels分别代表测试数据和标签。这样,模型在每个epoch训练结束后都会在测试集上评估一次,以监控模型的泛化能力。

epochs = 10

history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

输出:
在这里插入图片描述

五、模型预测

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

输出:
在这里插入图片描述

六、总结

本文将采用CNN实现多云、下雨、晴、日出四种天气状态的识别。较上篇文章,本文为了增加模型的泛化能力,新增了Dropout层并且将最大池化层调整成了平均池化层。

Dropout是一种正则化技术,用于通过在训练过程中随机丢弃(即暂时移除)网络中的一些神经元来减少过拟合。

Dropout的工作原理

  • 随机丢弃:在每次训练迭代中,Dropout以一定的概率(通常是一个超参数,如0.5)随机选择网络中的神经元,并将其暂时从网络中移除。这意味着这些神经元在当前迭代中不会对损失函数的计算和梯度更新产生影响。
  • 保留激活:被保留的神经元将正常参与到前向传播和反向传播过程中。
  • 权重重缩放:由于Dropout在训练过程中减少了神经元的数量,因此通常需要对剩余神经元的权重进行重缩放,以保持网络的表达能力。这通常通过在前向传播过程中对激活值进行缩放来实现。

Dropout的优点

  • 减少过拟合:通过随机丢弃神经元,模型不能过度依赖任何单一的神经元,这迫使网络学习更加鲁棒的特征表示。
  • 模型平均:Dropout可以看作是训练多个不同架构的神经网络并进行集成的一种形式,因为每次迭代中被丢弃的神经元都不同,从而相当于训练了多个不同的模型。
  • 计算效率:尽管Dropout增加了一些计算复杂性,但它不需要额外的参数或存储空间,因此在计算资源有限的情况下仍然是一种实用的技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2069143.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NVDLA专题12:具体模块介绍——RUBIK

概述 RUBIK类似于BDMA,它无需任何数据计算对数据映射格式进行转换。RUBIK有3种工作模式,分别是: 合并(Contract)数据立方体将特征数据立方体分割为多平面(multi-planar)格式将多平面(multi-planar)格式合并到数据立方体 由于该…

第三十八篇-TeslaP40-SenseVoice部署,速速杠杠的

环境 系统:CentOS-7 CPU: 14C28T 内存:32G 显卡:Tesla P40 24G 驱动: 535 CUDA: 12.2创建环境 conda create -n sv python3.11 -y conda activate sv克隆 git clone https://github.com/FunAudioLLM/SenseVoice.git cd SenseVoice配置镜像…

React antd Table表格动态合并单元格

注意: ① 采用的是React antDsign 4.x版本 ② 需重新处理data数据 实现效果 代码实现 import React from react; import { Table } from antd;const data [{key: 0,name: 张三,age: 22,sex: 男,},{key: 1,name: 李四,age: 42,sex: 男,},{key: 2,name: 小丽,age: …

CAN的协议层介绍

一,CAN帧种类介绍 1. 数据帧(Data Frame):数据帧是CAN总线上用于传输用户数据的帧,包括必要的帧头、标识符、控制位、数据长度代码、数据域、CRC校验码和应答域等部分,是CAN通信中最基本和最重要的帧类型。…

Android Room DataBase

Room数据库是在Sqlite的基础上,进行了封装和优化。这让我们可以摆脱,繁琐的数据库操作 在module的gradle里面,加入: dependencies {annotationProcessor "androidx.room:room-compiler:2.3.0"implementation androidx.room:room-…

Selenium自动化测试 常见API的使用

本篇文章内容是关于 Selenium 自动化测试工具的常见 API 的使用 Selenium版本:4.23.1 编程语言:Java JDK22 编译器:IDEA 2024.2.0.2 浏览器版本:谷歌浏览器128.0.6613.36(正式版本) (64 位&…

【Hexo】hexo-butterfly主题添加装备展示页面

本文首发于 ❄️慕雪的寒舍 在翻开往的时候看到了一位老哥的博客里面正好有这个教程,整了一下发现效果还不错! Hexo的Butterfly魔改教程:我的装备,分享你在用的设备 | 张洪HeoHexo博客添加自定义css和js文件 | Leonus 注&#x…

Python个人收入影响因素模型构建:回归、决策树、梯度提升、岭回归|数据分享...

全文链接:https://tecdat.cn/?p37423 分析师:Greata Xie “你的命运早在出生那一刻起便被决定了。”这样无力的话语,无数次在年轻人的脑海中回响,尤其是在那些因地域差异而面临教育资源匮乏的年轻人中更为普遍。在中国&#xff0…

NRC-SIM:基于Node-RED的多级多核缓存模拟器

整理自: 《NRC-SIM: A NODE-RED Based Multi-Level, Many-Core Cache Simulator》,由 Ezequiel Trevio 撰写,作为他在德克萨斯大学里奥格兰德河谷分校攻读电气工程硕士学位的部分成果。以下是论文的详细主要内容: 摘要(Abstract…

全网最适合入门的面向对象编程教程:37 Python常用复合数据类型-列表和列表推导式

全网最适合入门的面向对象编程教程:37 Python 常用复合数据类型-列表和列表推导式 摘要: 在 Python 中,列表是一个非常灵活且常用的复合数据类型。它允许存储多个项,这些项可以是任意的数据类型,包括其他列表。列表推…

大话MoE混合专家模型

MoE(Mixture of Experts),专家混合,就像是人工智能界的超级团队。想象一下,每个专家都有自己的拿手好戏,比如医疗问题找医生,汽车故障找机械师,做饭找大厨。MoE也是这样,…

【前端面试】操作系统

进程与线程 进程线程定义是计算机中的程序关于某数据集合上的一次运行活动,是系统进行资源分配和调度的基本单位是进程中的一个实体,是CPU调度和分派的基本单位,共享进程的资源资源分配拥有独立的内存空间和系统资源共享进程的内存和资源开销…

【Harmony OS 4.0】像素单位 - px、vp、fp

1. px 物理像素,以像素个数来定义图像尺寸。弊端是,在不同像素密度的屏幕上,相同的像素个数对应的物理尺寸是不同的。就会导致我们的应用在不同设备上显示的尺寸可能不同。如下图: 2. vp(Virtual Pixel) 虚拟像素是一种可根据屏幕…

L-Eval:一个60k左右长文评测数据集

前言 L-Eval是复旦大学邱锡鹏老师团队在 2023 年 7 月左右发布的一个标准化的长文本语言模型(LCLMs)评估数据集,包含20个子任务、411篇长文档、平均长度为7217个单词,超过2000个人工标记的QA对。它分为封闭型任务和开放型任务&am…

Niushop商城第三方插件cps联盟_同城配送_上门预约上手教程配置方法适合单商户和多商户以及V6哈

Niushop商城第三方插件cps联盟_同城配送_上门预约上手教程配置方法 序言:Niushop里面插件比较多可以说有上百种, 不过大多数都是官方自研默认自带50余种剩余的是收费的价格在80-299不等,另外的插件就是和第三方合作,简单的说就是…

25届应届网安面试,默认页面信息泄露

吉祥知识星球http://mp.weixin.qq.com/s?__bizMzkwNjY1Mzc0Nw&mid2247485367&idx1&sn837891059c360ad60db7e9ac980a3321&chksmc0e47eebf793f7fdb8fcd7eed8ce29160cf79ba303b59858ba3a6660c6dac536774afb2a6330#rd 《网安面试指南》http://mp.weixin.qq.com/s?…

linux系统使用yum安装mysql5.6版本的流程

1.下载安装包及依赖包 MySQL :: Download MySQL Community Server (Archived Versions) [rootlocalhost localrepo]# ls MySQL-client-5.6.47-1.el7.x86_64.rpm MySQL-server-5.6.47-1.el7.x86_64.rpm MySQL-test-5.6.47-1.el7.x86_64.rpm MySQL-devel-5.6.47-1.…

如何关闭谷歌浏览器后台运行

当谷歌浏览器不再需要时仍处于后台运行的状态,这不仅消耗宝贵的系统资源,还会影响到多任务的处理效率。本文将为大家详细介绍关闭谷歌浏览器后台还在运行的原因,并提供详细步骤帮助大家禁用后台运行。(本文由https://www.liulanqi…

【FESCO福利专区-注册安全分析报告-无验证方式导致安全隐患】

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 1. 暴力破解密码,造成用户信息泄露 2. 短信盗刷的安全问题,影响业务及导致用户投诉 3. 带来经济损失,尤其是后付费客户,风险巨大,造…

无线液位变送器的特点优势

无线液位变送器集成了多种先进功能,广泛应用于消防水车、水厂、污水处理厂、城市供水、高楼水池、水井、水塔、地热井、矿井等领域的液位监测,具有以下几个显著特点: 4G远程通信能力:无线液位变送器通过内置的4G模块,能…