T3打卡-天气识别

news2024/11/16 20:52:07
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.导入数据:

#设置GPU
import tensorflow as tf
gpus=tf.config.list_physical_devices("GPU")
if gpus:
    gpu0=gpus[0]
    tf.config.experimental.set_memory_growth(gpu0,True)
    tf.config.set_visibel_devices([gpu0],"GPU")

#导入数据
import os,PIL,pathlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tensorflow import keras
from tensorflow.keras import layers,models

data_dir="data/weather_photos"
data_dir=pathlib.Path(data_dir)

image_count=len(list(data_dir.glob('*/*.jpg')))
print("图片总数为:",image_count)

roses=list(data_dir.glob('sunrise/*.jpg'))
PIL.Image.open(str(roses[0]))

2.数据预处理:

(1)加载数据:

batch_size=32
img_height=180
img_width=180

train_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size)

val_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size)

class_names=train_ds.class_names
print(class_names)

(2)数据可视化:

plt.figure(figsize=(40,20))
for images,labels in train_ds.take(1):
    for i in range(20):
        ax=plt.subplot(5,10,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis('off')

(3)检查数据并配置数据集:

for image_batch,labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)

3.构建模型(CNN)

num_classes=4
model=models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255,input_shape=(img_height,img_width,3)),
    layers.Conv2D(16,(3,3),activation='relu',input_shape=(img_height,img_width,3)),
    layers.AveragePooling2D((2,2)),
    layers.Conv2D(32,(3,3),activation='relu'),
    layers.AveragePooling2D((2,2)),
    layers.Conv2D(64,(3,3),activation='relu'),
    layers.Dropout(0.3),
    layers.Flatten(),
    layers.Dense(128,activation='relu'),
    layers.Dense(num_classes)
])
model.summary()

4.编译并训练模型

opt=tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs=10
history=model.fit(train_ds,validation_data=val_ds,epochs=epochs)

5.模型评估:

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

总结:

1. tf.keras.preprocessing.image_dataset_from_directory():

tf.keras.preprocessing.image_dataset_from_directory(
    directory,
    labels="inferred",
    label_mode="int",
    class_names=None,
    color_mode="rgb",
    batch_size=32,
    image_size=(256, 256),
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation="bilinear",
    follow_links=False,
)

directory: 数据所在目录。如果标签是inferred(默认),则它应该包含子目录,每个目录包含一个类的图像。否则,将忽略目录结构。
labels: inferred(标签从目录结构生成),或者是整数标签的列表/元组,其大小与目录中找到的图像文件的数量相同。标签应根据图像文件路径的字母顺序排序(通过Python中的os.walk(directory)获得)。
label_mode:
int:标签将被编码成整数(使用的损失函数应为:sparse_categorical_crossentropy loss)。
categorical:标签将被编码为分类向量(使用的损失函数应为:categorical_crossentropy loss)。
binary:意味着标签(只能有2个)被编码为值为0或1的float32标量(例如:binary_crossentropy)。
None:(无标签)。
class_names: 仅当labels为inferred时有效。这是类名称的明确列表(必须与子目录的名称匹配)。用于控制类的顺序(否则使用字母数字顺序)。
color_mode: grayscale、rgb、rgba之一。默认值:rgb。图像将被转换为1、3或者4通道。
batch_size: 数据批次的大小。默认值:32
image_size: 从磁盘读取数据后将其重新调整大小。默认:(256,256)。由于管道处理的图像批次必须具有相同的大小,因此该参数必须提供。
shuffle: 是否打乱数据。默认值:True。如果设置为False,则按字母数字顺序对数据进行排序。
seed: 用于shuffle和转换的可选随机种子。
validation_split: 0和1之间的可选浮点数,可保留一部分数据用于验证。
subset: training或validation之一。仅在设置validation_split时使用。
interpolation: 字符串,当调整图像大小时使用的插值方法。默认为:bilinear。支持bilinear, nearest, bicubic, area, lanczos3, lanczos5, gaussian, mitchellcubic。
follow_links: 是否访问符号链接指向的子目录。默认:False。

2.shuffle():

        打乱数据

3.prefetch():

        预取数据,加速运行

        CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态:

4.Dropout():   

tf.keras.layers.Dropout(
    rate, noise_shape=None, seed=None, **kwargs
)

rate:0~1之间的小数。让神经元以一定的概率rate停止工作,提高模型的泛化能力。
noise_shape:1D张量类型,int32表示将与输入相乘的二进制丢失掩码的形状;例如,如果您的输入具有形状(batch_size, timesteps, features),并且你希望所有时间步长的丢失掩码相同,则可以使用noise_shape=[batch_size, 1, features],就是哪一个是1,那么就在哪一维度按照相同的方式dropout,如果没有1就是普通的。
seed:随机种子

 作用:防止过拟合,提高模型的泛化能力。

5.卷积的计算:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2143877.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android OkHttp源码分析(一):为什么OkHttp的请求速度很快?为什么可以高扩展?为什么可以高并发

目录 一、为什么要使用OkHhttp? 在不使用OkHhttp之前,我们都是在使用什么?使用HttpURLConnection,那么我们看看HttpURLConnection发起一次请求,两次请求要花多长时间,而OkHttp花多长时间。HttpURLConnection会比okht…

【银河麒麟高级服务器操作系统实例】tcp_mem分析处理全过程内核参数调优参考

了解更多银河麒麟操作系统全新产品,请点击访问 麒麟软件产品专区:https://product.kylinos.cn 开发者专区:https://developer.kylinos.cn 文档中心:https://documentkylinos.cn 现象描述 系统中出现大量的TCP: out of memory…

Mina protocol - 体验教程

Mina protocol - 体验教程 一、零知识证明( Zero Knowledge Proof )1、零知识证明(ZKP)的基本流程工作流程: 2、zkApp 的优势:3、zkApp 每个方法的编译过程: 二、搭建第一个zkapp先决条件1、下载或者更新 zkApp CLI​2…

基于Springboot美食推荐小程序的设计与实现(源码+数据库+文档)

一.项目介绍 pc端: 支持用户、餐厅老板注册 支持管理员、餐厅老板登录 管理员: 管理员模块维护、 餐厅管理模块维护、 用户管理模块维护、 商品管…

Qt:NULL与nullptr的区别(手写nullptr)

前言 发现还是有人不知道NULL 与nullptr的区别,故写此文章。 正文 对于NULL 先看NULL的源码 我们可以看出这段代码是一个典型的预处理器宏定义块,用于处理 NULL 宏的定义。 先看开头 #if defined (_STDDEF_H) || defined (__need_NULL)这行代码检…

git报错,error: bad signature 0x00000000fatal: index file corrupt

报错 git -c diff.mnemonicprefixfalse -c core.quotepathfalse --no-optional-locks checkout daily --progress error: bad signature 0x00000000 fatal: index file corrupt 原因 git 仓库中索引文损坏 处理 1.该备份的先备份 2.删除索引并重置 rm -f .git/index git r…

医学数据分析实训 项目五 分类分析--乳腺癌数据分析与诊断

文章目录 项目六:分类分析实践目的实践平台实践内容(一)数据理解及准备(二)模型建立、预测及优化任务一:使用 KNN算法进行分类预测任务二:使用贝叶斯分类算法进行分类预测任务三:使用…

Linux基础3-基础工具4(git,冯诺依曼计算机体系结构)

上篇文章:Linux基础3-基础工具3(make,makefile,gdb详解)-CSDN博客 本章重点: 1. git简易使用 2. 冯诺依曼计算机体系结构介绍 一. git使用 1.1 什么是git? git是用于管理代码版本的一种工具,我们在如GitHub&#xf…

C++ | (二)类与对象(上)

燕子去了,有再来的时候;杨柳枯了,有再青的时候;桃花谢了,有再开的时候。但是,聪明的,你告诉我,我们的假期为什么一去不复返呢? 目录 一、初识类 1.1 类的定义 1.2 C中…

面试真题-TCP的三次握手

TCP的基础知识 TCP头部 面试题:TCP的头部是多大? TCP(传输控制协议)的头部通常是固定的20个字节长,但是根据TCP选项(Options)的不同,这个长度可以扩展。TCP头部包含了许多关键的字…

depcheck 检查项目中依赖的使用情况 避免幽灵依赖的产生

depcheck 检查项目中依赖的使用情况 避免幽灵依赖的产生 什么是幽灵依赖 (幻影依赖) 形成原因 幽灵依赖是指node_modules中存在 而package.json中没有声明过的依赖 但却能够在项目的依赖树中找到并使用的模块 Node.js 的模块解析规则: Node.js 采用了一种非传统的模…

C++速通LeetCode简单第20题-多数元素

方法一&#xff1a;暴力解法&#xff0c;放multiset中排序&#xff0c;然后依次count统计&#xff0c;不满足条件的值erase清除。 class Solution { public:int majorityElement(vector<int>& nums) {int ans 0;multiset<int> s;for(int i 0;i < nums.s…

「iOS」viewController的生命周期

iOS学习 ViewController生命周期有关方法案例注意 ViewController生命周期有关方法 init - 初始化程序&#xff1b;loadView - 在UIViewController对象的view被访问且为空的时候调用&#xff1b;viewDidLoad - 视图加载完成后调用&#xff1b;viewWillAppear - UIViewControll…

给大模型技术从业者的建议,入门转行必看!!

01—大模型技术学习建议‍‍‍ 这个关于学习大模型技术的建议&#xff0c;也可以说是一个学习技术的方法论。 首先大家要明白一点——(任何)技术都是一个更偏向于实践的东西&#xff0c;具体来说就是学习技术实践要大于理论&#xff0c;要以实践为主理论为辅&#xff0c;而不…

换个手机IP地址是不是不一样?

在当今这个信息爆炸的时代&#xff0c;手机已经成为我们生活中不可或缺的一部分。而IP地址&#xff0c;作为手机连接网络的桥梁&#xff0c;也时常引起我们的关注。你是否曾经好奇&#xff0c;换个手机&#xff0c;IP地址会不会也跟着变呢&#xff1f;本文将深入探讨这个问题&a…

Android15之编译Cuttlefish模拟器(二百三十一)

简介&#xff1a; CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a; 多媒体系统工程师系列【…

直流斩波电路

目录 1. 降压斩波电路&#xff08;Buck Converter&#xff09; 2. 升压斩波电路&#xff08;Boost Converter&#xff09; 3. 升降压斩波电路&#xff08;Buck-Boost Converter&#xff09; 4. Cuk斩波电路&#xff08;Cuk Converter&#xff09; 直流斩波电路是一种将直流…

Unity3D下如何播放RTSP流?

技术背景 在Unity3D中直接播放RTSP&#xff08;Real Time Streaming Protocol&#xff09;流并不直接支持&#xff0c;因为Unity的内置多媒体组件&#xff08;如AudioSource和VideoPlayer&#xff09;主要设计用于处理本地文件或HTTP流&#xff0c;而不直接支持RTSP。所以&…

上海人工智能实验室开源视频生成模型Vchitect 2.0 可生成20秒高清视频

上海人工智能实验室日前推出的Vchitect2.0视频生成模型正在悄然改变视频创作的游戏规则。这款尖端AI工具不仅简化了视频制作流程&#xff0c;还为创作者提供了前所未有的灵活性和高质量输出。 Vchitect2.0的核心优势在于其强大的生成能力和高度的可定制性。用户只需输入文字描…

用Matlab求解绘制2D散点(x y)数据的最小外接圆、沿轴外接矩形

用Matlab求解绘制2D散点&#xff08;x y&#xff09;数据的最小外接圆、沿轴外接矩形 0 引言1 原理概述即代码实现1.1 最小外接圆1.2 沿轴外接矩形 2 完整代码3 结语 0 引言 本篇简单介绍下散点数据最小外接圆、沿轴外接矩形的简单原理和matlab实现过程。 1 原理概述即代码实现…