T4周:猴痘病识别

news2024/9/20 19:49:15

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

1. 设置GPU

如果使用的是CPU可以忽略这步

from tensorflow       import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow        as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")
    
gpus

 2. 导入数据

data_dir = "./45-data/"

data_dir = pathlib.Path(data_dir)

3. 查看数据

image_count = len(list(data_dir.glob('*/*.jpg')))

print("图片总数为:",image_count)

图片总数为: 2142 

4.加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

测试集与验证集的关系:

  1. 验证集并没有参与训练过程梯度下降过程的,狭义上来讲是没有参与模型的参数训练更新的。
  2. 但是广义上来讲,验证集存在的意义确实参与了一个“人工调参”的过程,我们根据每一个epoch训练之后模型在valid data上的表现来决定是否需要训练进行early stop,或者根据这个过程模型的性能变化来调整模型的超参数,如学习率,batch_size等等。
  3. 因此,我们也可以认为,验证集也参与了训练,但是并没有使得模型去overfit验证集
batch_size = 32
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 2142 files belonging to 2 classes.
Using 1714 files for training.
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 2142 files belonging to 2 classes.
Using 428 files for validation.
class_names = train_ds.class_names
print(class_names)
['Monkeypox', 'Others']

5. 可视化数据

plt.figure(figsize=(20, 10))

for images, labels in train_ds.take(1):
    for i in range(20):
        ax = plt.subplot(5, 10, i + 1)

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

 

6. 再次检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(32, 224, 224, 3)
(32,)

7. 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

8.构建CNN网络

num_classes = 2

"""
关于卷积核的计算不懂的可以参考文章:https://blog.csdn.net/qq_38251616/article/details/114278995

layers.Dropout(0.4) 作用是防止过拟合,提高模型的泛化能力。
在上一篇文章花朵识别中,训练准确率与验证准确率相差巨大就是由于模型过拟合导致的

关于Dropout层的更多介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/115826689
"""

model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), # 卷积层1,卷积核3*3  
    layers.AveragePooling2D((2, 2)),               # 池化层1,2*2采样
    layers.Conv2D(32, (3, 3), activation='relu'),  # 卷积层2,卷积核3*3
    layers.AveragePooling2D((2, 2)),               # 池化层2,2*2采样
    layers.Dropout(0.3),  
    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层3,卷积核3*3
    layers.Dropout(0.3),  
    
    layers.Flatten(),                       # Flatten层,连接卷积层与全连接层
    layers.Dense(128, activation='relu'),   # 全连接层,特征进一步提取
    layers.Dense(num_classes)               # 输出层,输出预期结果
])

model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
rescaling (Rescaling)        (None, 224, 224, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 222, 222, 16)      448       
_________________________________________________________________
average_pooling2d (AveragePo (None, 111, 111, 16)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 109, 109, 32)      4640      
_________________________________________________________________
average_pooling2d_1 (Average (None, 54, 54, 32)        0         
_________________________________________________________________
dropout (Dropout)            (None, 54, 54, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 52, 52, 64)        18496     
_________________________________________________________________
dropout_1 (Dropout)          (None, 52, 52, 64)        0         
_________________________________________________________________
flatten (Flatten)            (None, 173056)            0         
_________________________________________________________________
dense (Dense)                (None, 128)               22151296  
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 258       
=================================================================
Total params: 22,175,138
Trainable params: 22,175,138
Non-trainable params: 0
_________________________________________________________________

9.编译

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

10.训练模型

from tensorflow.keras.callbacks import ModelCheckpoint

epochs = 50

checkpointer = ModelCheckpoint('best_model.h5',
                                monitor='val_accuracy',
                                verbose=1,
                                save_best_only=True,
                                save_weights_only=True)

history = model.fit(train_ds,
                    validation_data=val_ds,
                    epochs=epochs,
                    callbacks=[checkpointer])

Epoch 1/50
54/54 [==============================] - 4s 18ms/step - loss: 0.6969 - accuracy: 0.5408 - val_loss: 0.6763 - val_accuracy: 0.6098

Epoch 00001: val_accuracy improved from -inf to 0.60981, saving model to best_model.h5
Epoch 2/50
54/54 [==============================] - 1s 12ms/step - loss: 0.6672 - accuracy: 0.5858 - val_loss: 0.6423 - val_accuracy: 0.6612

......

Epoch 00047: val_accuracy did not improve from 0.87850
Epoch 48/50
54/54 [==============================] - 1s 12ms/step - loss: 0.0953 - accuracy: 0.9691 - val_loss: 0.4090 - val_accuracy: 0.8715

Epoch 00048: val_accuracy did not improve from 0.87850
Epoch 49/50
54/54 [==============================] - 1s 12ms/step - loss: 0.0699 - accuracy: 0.9819 - val_loss: 0.3922 - val_accuracy: 0.8832

Epoch 00049: val_accuracy improved from 0.87850 to 0.88318, saving model to best_model.h5
Epoch 50/50
54/54 [==============================] - 1s 12ms/step - loss: 0.0714 - accuracy: 0.9772 - val_loss: 0.4151 - val_accuracy: 0.8785

Epoch 00050: val_accuracy did not improve from 0.88318

 11. Loss与Accuracy图

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

 

12. 指定图片进行预测

# 加载效果最好的模型权重
model.load_weights('best_model.h5')
from PIL import Image
import numpy as np

# img = Image.open("./45-data/Monkeypox/M06_01_04.jpg")  #这里选择你需要预测的图片
img = Image.open("./45-data/Others/NM15_02_11.jpg")  #这里选择你需要预测的图片
image = tf.image.resize(img, [img_height, img_width])

img_array = tf.expand_dims(image, 0) 

predictions = model.predict(img_array) # 这里选用你已经训练好的模型
print("预测结果为:",class_names[np.argmax(predictions)])
预测结果为: Others

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2130769.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Eclipse折叠if、else、try catch的{}

下载插件com.cb.eclipse.folding_1.0.6.jar。将插件放到eclipse的dropins文件夹中。修改配置,然后保存,重启Eclipse即可。

Flink快速上手

Flink快速上手 批处理Maven配置pom文件java编写wordcount代码 有界流处理无界流处理 批处理 Maven配置pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://ww…

《深度学习》深度学习 框架、流程解析、动态展示及推导

目录 一、深度学习 1、什么是深度学习 2、特点 3、神经网络构造 1&#xff09;单层神经元 • 推导 • 示例 2&#xff09;多层神经网络 3&#xff09;小结 4、感知器 神经网络的本质 5、多层感知器 6、动态图像示例 1&#xff09;一个神经元 相当于下列状态&…

通信原理:绪论

1、消息、信号与信息 消息&#xff1a; 通信系统要传输的对象&#xff0c;是具体的、物理上存在的东西。也是信息的载体。形式多种&#xff1a; 连续消息&#xff1a;语音、温度、活动图片.离散消息&#xff1a;数据、符号、文字. 信息&#xff1a; 消息中所蕴含的内容&…

proteus+51单片机+实验(LCD1620、定时器)

目录 1.LCD1602液晶显示屏 1.1基本概念 1.1.1LCD的简介 1.1.2LCD的显示原理 ​​​1.1.3LCD的硬件电路 1.1.4LCD的常见指令 1.1.5LCD的时序 ​​​​​​​1.2代码 1.2.1写命令和写数据操作 1.2.2初始化和测试代码 1. 3.3功能函数 1.3proteus代码 1.3.1器件代码 1.…

几种手段mfc140u.dll丢失的解决方法,了解mfc140u.dll

在使用Windows操作系统时&#xff0c;许多用户可能会遇到“找不到mfc140u.dll”或“mfc140u.dll未找到”的错误提示。这个错误通常是由于该文件丢失或损坏所致。本文将详细介绍mfc140u.dll文件的作用、丢失的原因及其解决方法&#xff0c;帮助您快速恢复系统的正常运行。 一、m…

无人机视角的道路损害数据集,2400张图像,包括纵向裂缝(LC)、横向裂缝(TC)、鳄鱼裂缝(AC)、斜裂(OC)、修补(RP)和坑洞(PH),共2.3GB

数据集名称 无人机视角的道路损害数据集 数据集描述 这是一个专注于道路损害检测的数据集&#xff0c;包含了从无人机视角拍摄的2400张高清图像&#xff0c;涵盖了六种典型的道路损害类型&#xff1a;纵向裂缝&#xff08;LC&#xff09;、横向裂缝&#xff08;TC&#xff0…

c++ 点云生成二维俯视图

🙋 结果预览 一、代码实现 #include <pcl/io/pcd_io.h> #include <pcl/point_types.h> #include

S7_1200配方功能快速入门

配方数据文件按照标准 CSV 格式存储在 S7-1200 CPU 装载存储器或 S7-1200 SIMATIC 存储卡“程序卡”中。分别可通过 PLC Web 服务器或对于存储卡文件操作&#xff0c;将数据文件传送到 PC 进行管理和查看。也可将修改过后的配方数据文件上传至PLC&#xff0c;再通过“RecipeImp…

【数据结构】详细介绍各种排序算法,包含希尔排序,堆排序,快排,归并,计数排序

目录 1. 排序 1.1 概念 1.2 常见排序算法 2. 插入排序 2.1 直接插入排序 2.1.1 基本思想 2.1.2 代码实现 2.1.3 特性 2.2 希尔排序(缩小增量排序) 2.2.1 基本思想 2.2.2 单个gap组的比较 2.2.3 多个gap组比较(一次预排序) 2.2.4 多次预排序 2.2.5 特性 3. 选择排…

【AcWing】869. 试除法求约数

约数&#xff1a;当前数能整除这个数。 和判断质数一样的道理&#xff0c;同样是试除法。 约数也一定是成对出现的。在枚举的时候也可以只枚举较小的那一个约数就可以了&#xff0c;较大的那个约数直接算。 #include<iostream> #include<algorithm> #include<…

无人机之处理器篇

无人机的处理器是无人机系统的核心部件之一&#xff0c;它负责控制无人机的飞行、数据处理、任务执行等多个关键功能。以下是对无人机处理器的详细解析&#xff1a; 一、处理器类型 无人机中使用的处理器主要包括以下几种类型&#xff1a; CPU处理器&#xff1a;CPU是无人机的…

JDBC API详解一

DriverManager 驱动管理类&#xff0c;作用&#xff1a;1&#xff0c;注册驱动&#xff1b;2&#xff0c;获取数据库连接 1&#xff0c;注册驱动 Class.forName("com.mysql.cj.jdbc.Driver"); 查看Driver类源码 static{try{DriverManager.registerDriver(newDrive…

中间件常见漏洞

文章目录 中间件漏洞IIS文件解析漏洞1&#xff1a;/xx.asp/xx.jpg 、/xx.asa/xx.jsp2&#xff1a;xx.asp;.jpg3&#xff1a;xx.asa、xx.cer、xx.cdx4&#xff1a;IIS.7/8 CGI配置不当解析漏洞 Apache文件解析漏洞1&#xff1a;apache2.2版本解析漏洞2&#xff1a;其余配置问题…

IMX6 L508EN 模块调试(4G)

一、概述 提起 4G 网络连接&#xff0c;大家可能会觉得是个很难的东西&#xff0c;其实对于嵌入式 Linux 而言&#xff0c;4G 网络连接恰恰相反&#xff0c;不难&#xff01;大家可以看一下其他的嵌入式 Linux 或者 Android 开发板&#xff0c;4G 模块都是 MiniPCIE 接口的&…

C++从入门到起飞之——继承上篇 全方位剖析!

&#x1f308;个人主页&#xff1a;秋风起&#xff0c;再归来~&#x1f525;系列专栏&#xff1a;C从入门到起飞 &#x1f516;克心守己&#xff0c;律己则安 目录 1、继承的概念 2、继承定义 2.1 定义格式 2.2 继承基类成员访问⽅式的变化 3、继承类模板 4、 基…

linux网络编程——UDP编程

写在前边 本文是B站up主韦东山的4_8-3.UDP编程示例_哔哩哔哩_bilibili视频的笔记&#xff0c;其中有些部分博主也没有理解&#xff0c;希望各位辩证的看。 UDP协议简介 UDP 是一个简单的面向数据报的运输层协议&#xff0c;在网络中用于处理数据包&#xff0c;是一种无连接的…

操作系统 ---- 处理机调度

一、处理机调度学习路线 二、调度要研究的问题&#xff1f; 当有一堆任务要处理&#xff0c;但由于资源有限&#xff0c;这些事情没法同时处理。这就需要确定某种规则来决定处理这些任务的顺序&#xff0c;这就是“调度”研究的问题。 三、调度的三个层次 3.1 高级调度&…

深入解读Docker核心原理:Namespace资源隔离机制详解

在容器技术中&#xff0c;资源隔离 是容器化能够实现轻量级虚拟化的关键技术之一。通过资源隔离&#xff0c;容器可以拥有自己的独立环境&#xff0c;确保容器之间互不干扰&#xff0c;从而实现应用的安全和稳定。Docker作为主流的容器平台&#xff0c;其核心的资源隔离机制依赖…

LabVIEW软件授权与分发要求

在LabVIEW开发中&#xff0c;将软件打包成安装程序并销售给其他公司&#xff08;例如对知识产权有严格要求的国外公司&#xff09;时&#xff0c;涉及授权和许可的多个关键环节。NI对LabVIEW的开发、分发、安装和使用都有明确的授权要求&#xff0c;以确保知识产权的合法性和软…