T4—猴痘识别

news2025/1/20 22:10:11
  • 🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊

1.导入数据

#设置gpu
from tensorflow import keras
from tensorflow.keras import layers,models
import os, PIL, pathlib
import matplotlib.pyplot as plt
import tensorflow  as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    gpu0 = gpus[0]                                       
    tf.config.experimental.set_memory_growth(gpu0, True)  
    tf.config.set_visible_devices([gpu0],"GPU")
gpus

data_dir="data/45-data/"
data_dir=pathlib.Path(data_dir)

image_count=len(list(data_dir.glob('*/*.jpg')))
print("图片的总数为:",image_count)

Monkeypox=list(data_dir.glob('Monkeypox/*.jpg'))
PIL.Image.open(str(Monkeypox[0]))

2.加载数据

batch_size=32
img_height=224
img_width=224

train_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size)

val_ds=tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height,img_width),
    batch_size=batch_size)

3.数据可视化

class_names=train_ds.class_names
print(class_names)

plt.figure(figsize=(20,10))

for images,labels in train_ds.take(1):
    for i in range(20):
        ax=plt.subplot(5,10,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

4.检查数据

for image_batch,labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

5.配置数据与构建模型

AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)

num_classes=2
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
    
    layers.Conv2D(16, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)), 
    layers.AveragePooling2D((2, 2)),               
    layers.Conv2D(32, (3, 3), activation='relu'),  
    layers.AveragePooling2D((2, 2)),               
    layers.Dropout(0.3),  
    layers.Conv2D(64, (3, 3), activation='relu'),  
    layers.Dropout(0.3),  
    
    layers.Flatten(),                       
    layers.Dense(128, activation='relu'),  
    layers.Dense(num_classes)               
])

model.summary()  

6.编译并训练模型

opt=tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt,loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

from tensorflow.keras.callbacks import ModelCheckpoint
epochs=50
checkpointer=ModelCheckpoint('best_model.h5',
                             monitor='val_accuracy',
                             verbose=1,
                             save_weights_only=True)
history=model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[checkpointer])

7.结果可视化

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

8.预测

model.load_weights('best_model.h5')
from PIL import Image
import numpy as np

img=Image.open("data/45-data/Others/NM15_02_11.jpg")
image=tf.image.resize(img,[img_height,img_width])
img_array=tf.expand_dims(image,0)

predictions=model.predict(img_array)
print("预测结果为:",class_names[np.argmax(predictions)])

总结:

1.shuffle()函数:

首先,Dataset会取所有数据的前buffer_size数据项,填充 buffer,如下图

然后,从buffer中随机选择一条数据输出,比如这里随机选中了item 7,那么bufferitem 7对应的位置就空出来了

然后,从Dataset中顺序选择最新的一条数据填充到buffer中,这里是item 10

然后在从Buffer中随机选择下一条数据输出。

需要说明的是,这里的数据项item,并不只是单单一条真实数据,如果有batch size,则一条数据项item包含了batch size条真实数据。

shuffle是防止数据过拟合的重要手段,然而不当的buffer size,会导致shuffle无意义

2.prefetch() :预取数据,加速运行

    CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2153551.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【刷题—双指针】复写0、三数之和、四数之和

目录 一、复写0二、三数之和三、四数之和 一、复写0 题目: 注意:题目要求是原数组上复写 思路: 一、确定最后一个复写的位置。定义两个变量cur等于0,dest等于-1,让cur去遍历数组。如果cur指向的元素是0,…

[SAP ABAP] 创建数据元素

我们可以使用事务码SE11创建数据元素 输入要创建的数据类型的名称,然后点击创建 选择数据元素并进行确定 输入简短描述并为数据元素分配一个域,会自动带出数据类型以及长度 创建域可参考该篇文章 创建域https://blog.csdn.net/Hudas/article/details/…

Hive企业级调优[6]——HQL语法优化之任务并行度

目录 HQL语法优化之任务并行度 优化说明 Map端并行度 Reduce端并行度 优化案例 HQL语法优化之任务并行度 优化说明 对于分布式计算任务来说,设置一个合理的并行度至关重要。Hive的计算任务依赖于MapReduce框架来完成,因此并行度的调整需要从Map端和…

YOLOv10 简介

YOLOv10,由清华大学的研究人员基于 Ultralytics Python 包构建,引入了一种全新的实时目标检测方法,该方法解决了以往 YOLO 版本中后处理和模型架构方面的不足。通过消除非极大值抑制(NMS)并优化各种模型组件&#xff0…

C# 技巧在 foreach 循环中巧妙获取索引

目录 前言 使用 LINQ 和扩展方法 直接在 LINQ 查询中使用 使用 LINQ 的 Select() 与 Enumerable.Range() 总结 最后 前言 在C#中foreach 循环是处理集合的常见方式,因其简洁性和易读性而广受青睐。 但是在某些情况下,我们需要同时获取集合中元素的…

mysql中的json查询

首先来构造数据 查询department里面name等于研发部的数据 查询语句跟普通的sql语句差不多,也就是字段名要用到path表达式 select * from user u where u.department->$.name 研发部 模糊查询 select * from user u where u.department->$.name like %研发%…

论文阅读笔记:Sapiens: Foundation for Human Vision Models

Sapiens: Foundation for Human Vision Models 1 背景1.1 问题1.2 目标 2 方法3 创新点4 模块4.1 Humans-300M数据集4.2 预训练4.3 2D位姿估计4.4 身体部位分割4.5 深度估计4.6 表面法线估计 5 实验5.1 实现细节5.2 2D位姿估计5.3 身体部位分割5.4 深度估计5.5 表面法线估计5.6…

如何在 Ubuntu 上安装 OpenSSH Server ?

OpenSSH 是一组工具集合,它允许您使用 SSH 在网络上进行安全、加密的通信。它包括安全远程登录、文件传输和应用程序隧道的特性。OpenSSH 通常用于 Linux 系统上的安全远程访问和文件传输。由于其强大的安全措施,受到世界各地许多用户的信任。 本教程将…

【Rust练习】16.模式

文章题目来自:https://practice-zh.course.rs/pattern-match/patterns.html 1 🌟🌟 使用 | 可以匹配多个值, 而使用 … 可以匹配一个闭区间的数值序列 fn main() {} fn match_number(n: i32) {match n {// 匹配一个单独的值1 > println!(…

2024 研究生数学建模竞赛(E题)建模秘籍|高速公路应急车道紧急启用模型|文章代码思路大全

铛铛!小秘籍来咯! 小秘籍团队独辟蹊径,运用聚类分析,逻辑回归模型,决策树/规则,ARIMA等强大工具,构建了这一题的详细解答哦! 为大家量身打造创新解决方案。小秘籍团队,始…

Python基于TensorFlow实现Transformer分类模型(Transformer分类算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 随着深度学习技术的发展,自然语言处理(NLP)领域取得了显著的进步。…

人工智能开发实战推荐算法应用解析

内容导读 个性化推荐思路推荐算法分类推荐效果评估 一、个性化推荐思路 推荐系统能为你提供个性化的智能服务,是基于以下事实认知:人们倾向于喜欢那些与自己喜欢的东西相似的其它物品,或倾向于与自己趣味相投的人有相似的爱好,…

9.安卓逆向-安卓开发基础-安卓四大组件2

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 内容参考于:图灵Python学院 本人写的内容纯属胡编乱造,全都是合成造假,仅仅只是为了娱乐,请不要盲目相信。 工…

mysql为什么建议创建字段的时候not null

相信大家在建表或者给表新加字段的时候,一些老司机们都会建议我们,字段要定义为not null,原因呢是一是占用存储空间,另一个是避免出现一些意料之外的错误。当然针对这个问题,大家可能也会在网上去搜下,不过…

C语言中的一些小知识(三)

一、你了解printf()吗? 你知道下面代码的输出结果吗? int a123; printf("%2d \n",a); printf() 函数是 C 语言中用于格式化输出的标准函数,它允许你将数据以特定的格式输出到标准输出设备(通常是屏幕)。p…

20240921全国计算机二级Python考试(大头博士计算二级)

一、背景需求: 20240921我在上海应用技术大学44号楼考场参加2024年9月的全国计算机二级(Python语言程序设计)考试。 时隔多年,再次来到大学校园,恍若隔世 扫码找考场在哪里 考场须知 1、进考场,先刷身份证…

有关elementui form验证问题,有值却仍然显示不通过

参考链接 有关elementui form验证问题,有值却仍然显示不通过 - 一棵写代码的柳树 - 博客园 需要保证表单上的 :model" "和prop" "对应的属性相同 el-form 绑定数据:model 和 规则:rules input 绑定 数据表单里的数据 其父组件提供校验所绑定的…

Mybatis的XML实现方法

Mybatis的开发有两种方式: 1、注解 2、XML 使用Mybatis的注解方式,主要是来完成一些简单的增删改查功能。如果需要实现复杂的SQL功能,建议使用XML来配置映射语句,也就是将SQL语句写在XML配置文件中。 Mybatis的XML的实现需要以下…

leetcode练习 二叉树的最大深度

给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:3提示: 树中节点的数量在 [0, 104] 区间内。-100 …

java使用ByteBuffer进行多文件合并和拆分

1.背景 因为验证证书的需要,需要把证书文件和公钥给到客户,考虑到多个文件交互的不便性,所以决定将2个文件合并成一个文件交互给客户。刚开始采用字符串拼接2个文件内容,但是由于是加密文件,采用字符串形式合并后&…