卷积神经网络(CNN)识别神奇宝贝小智一伙

news2024/11/27 22:27:39

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
    • 3. 查看数据
  • 二、数据预处理
    • 1.加载数据
    • 2. 可视化数据
    • 4. 配置数据集
  • 三、调用官方网络模型
  • 四、设置动态学习率
  • 五、编译
  • 六、训练模型
  • 七、模型评估
  • 八、保存and加载模型
  • 九、预测

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码
  • 卷积神经网络(CNN)车牌识别

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

import os,PIL

# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)

# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)

import pathlib
data_dir = "Pokemon"
data_dir = pathlib.Path(data_dir)

3. 查看数据

image_count = len(list(data_dir.glob('*/*')))

print("图片总数为:",image_count)
图片总数为: 219

二、数据预处理

1.加载数据

batch_size = 8
img_height = 224
img_width = 224

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 219 files belonging to 10 classes.
Using 176 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 219 files belonging to 10 classes.
Using 43 files for validation.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

class_names = train_ds.class_names
print(class_names)
['Alcremie', 'Eevee', 'Furfrou', 'Kyurem', 'Minior', 'Pikachu', 'Rotom', 'Squirtle', 'Vivillon', 'Zygarde']

2. 可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  
        
        ax.patch.set_facecolor('yellow')
        
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

  1. 再次检查数据
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(8, 224, 224, 3)
(8,)
  • Image_batch是形状的张量(8, 224, 224, 3)。这是一批形状240x240x3的8张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(8,)的张量,这些标签对应8张图片

4. 配置数据集

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

三、调用官方网络模型

tf.keras.applications API地址:https://www.tensorflow.org/api_docs/python/tf/keras/applications

接口示例:VGG-16官方模型接口

# 这是函数原型,请勿运行该代码
tf.keras.applications.vgg16.VGG16(
    include_top=True, weights='imagenet', input_tensor=None,
    input_shape=None, pooling=None, classes=1000,
    classifier_activation='softmax'
)
<tensorflow.python.keras.engine.functional.Functional at 0x220df3e35f8>

常用的三个参数解释如下:

  • include_top:是否包括网络顶部的 3 个全连接层。
  • weights:默认不加载权重文件,"imagenet"加载官方权重文件,或者输入自己的权重文件路径。
  • classes:分类图像的类别数

其他参数暂时不建议大家了解,模型接口都是非常相似的,大家可以从上面的众多模型中选择自己想要的调用,接口函数如下:

  1. tf.keras.applications.xception.Xception()
  2. tf.keras.applications.vgg16.VGG16()
  3. tf.keras.applications.vgg19.VGG19()
  4. tf.keras.applications.resnet50.ResNet50()
  5. tf.keras.applications.inception_v3.InceptionV3()
  6. tf.keras.applications.inception_resnet_v2.InceptionResNetV2()
  7. tf.keras.applications.mobilenet.MobileNet()
  8. tf.keras.applications.mobilenet_v2.MobileNetV2()
  9. tf.keras.applications.densenet.DenseNet121()
  10. tf.keras.applications.densenet.DenseNet169()
  11. tf.keras.applications.densenet.DenseNet201()
  12. tf.keras.applications.nasnet.NASNetMobile()
  13. tf.keras.applications.nasnet.NASNetLarge()
model = tf.keras.applications.DenseNet121(weights='imagenet')
model.summary()

四、设置动态学习率

这里先罗列一下学习率大与学习率小的优缺点。

  • 学习率大
    • 优点: 1、加快学习速率。 2、有助于跳出局部最优值。
    • 缺点: 1、导致模型训练不收敛。 2、单单使用大学习率容易导致模型不精确。
  • 学习率小
    • 优点: 1、有助于模型收敛、模型细化。 2、提高模型精度。
    • 缺点: 1、很难跳出局部最优值。 2、收敛缓慢。

注意:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:

learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=5,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

五、编译

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer=optimizer,
              loss     ='sparse_categorical_crossentropy',
              metrics  =['accuracy'])

六、训练模型

epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

七、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

八、保存and加载模型

这是最简单的模型保存与加载方法哈

# 保存模型
model.save('model/16_model.h5')
# 加载模型
new_model = tf.keras.models.load_model('model/16_model.h5')

九、预测

九、预测
# 采用加载的模型(new_model)来看预测结果

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("预测结果展示")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy().astype("uint8"))
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = new_model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1259435.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux网络——数据链路层

目录 一.认识以太网 二.以太网帧格式 三.认识MAC地址 四.认识MTU 五.以太局域网的通信原理 六.其他重要协议 1.DNS协议 2.域名简介 3.ICMP协议 4.NAT技术 5.NAT技术的缺陷 6.NAT和代理服务器 一.认识以太网 "以太网" 不是一种具体的网络, 而是一种技术标…

C语言入门---位操作

目录 1. 两个数不同的二进制位个数 2.原码、反码、补码 3.不创建临时变量实现两个数的交换 4.求一个整数存储在内存中的二进制中1的个数 5. 特例-1 6.将指定的位置置1 7.将指定位置置1 8.a与a 9.||与&& 10.逗号表达式 11.srand与rand 12.sizeof 13.结构体初始…

一文了解低代码平台

随着数字化转型的加速&#xff0c;企业需要更快速地开发和交付应用程序&#xff0c;以适应市场需求和客户需求的变化。在这种情况下&#xff0c;低代码平台成为了企业的首选方案之一。 想象一下&#xff0c;你可以用一个可视化工具构建自己的应用程序&#xff0c;而无需编写繁琐…

保护IP地址不被窃取的几种方法

随着互联网的普及和信息技术的不断发展&#xff0c;网络安全问题日益凸显。其中&#xff0c;保护个人IP地址不被窃取成为了一个重要的问题。IP地址是我们在互联网上的身份标识&#xff0c;如果被他人获取&#xff0c;就可能导致个人隐私泄露、计算机受到攻击等一系列问题。因此…

你“瞧不起”的拼多多,原来还有这样的一面

有人说&#xff0c;自私是天性&#xff0c;刻印在基因里的本能。也有人持不同意见。 人类学家玛格丽特米德在授课中问学生&#xff0c;文明最早的标志是什么&#xff1f;有人说是陶罐&#xff0c;石器&#xff0c;或者武器&#xff0c;米德告诉他们&#xff0c;是一根愈合的股…

python scoket 多人聊天室 带界面

前言 本来是为了局域网内能够复制段儿代码方便远程调试用的&#xff0c;ssh当然也可以&#xff0c;当然还是我头脑风暴散发&#xff0c;想到这里了。于是从网上拉了一个&#xff0c;改通之后&#xff0c;留一个备份。 期望还是很好的&#xff0c;以后用来支持ubuntu聊天之类的…

新能源钠离子电池污废水如何处理

钠离子电池作为一种新能源电池&#xff0c;已经展示出了广阔的应用前景。然而&#xff0c;随着其生产和使用规模的不断扩大&#xff0c;对其产生的污废水问题也变得越来越重要。如何处理新能源钠离子电池的污废水&#xff0c;已经成为一个必须解决的问题。 首先&#xff0c;我…

第二十五章 解析cfg文件及读取获得网络结构

网络结构 以YOLOv3_SPP为例 cfg文件 部分&#xff0c;只是用来展示&#xff0c;全部的代码在文章最后 [net] # Testing # batch1 # subdivisions1 # Training batch64 subdivisions16 width608 height608 channels3 momentum0.9 de…

基于STM32 +(NVIC)中断概念应用和控制方案

前言 本次我们学习一下STM32的中断控制器—— NVIC&#xff0c;控制着整个STM32芯片中断相关的功能&#xff0c;它跟Cortex-M3 内核紧密联系&#xff0c;是内核里面的一个外设。 本篇博客大部分是自己收集和整理&#xff0c;借鉴了很多大佬的图片和知识点整理&#xff0c;如有侵…

中科大蒋彬课题组开发 FIREANN,分析原子对外界场的响应

内容一览&#xff1a; 使用传统方法分析化学系统与外场的相互作用&#xff0c;具有效率低、成本高等劣势。中国科学技术大学的蒋彬课题组&#xff0c;在原子环境的描述中引入了场相关特征&#xff0c;开发了 FIREANN&#xff0c;借助机器学习对系统的场相关性进行了很好的描述。…

一文读懂:IOPS、延迟和吞吐量等存储性能指标

各位ICT的小伙伴们大家好呀&#xff0c; 在我们谈存储性能的时候&#xff0c;总会听到IOPS、延迟&#xff08;Latency&#xff09;、带宽&#xff08;Bandwidth&#xff09;、吞吐量&#xff08;Throughput&#xff09;以及响应时间&#xff08;Response Time&#xff09;等技…

Lighthouse(灯塔)—— Chrome浏览器强大的性能测试工具

本文浏览器版本参考如下&#xff1a; 一、认识Lighthouse Lighthouse 是 Google 开发的一款工具&#xff0c;用于分析网络应用和网页&#xff0c;收集现代性能指标并提供对开发人员最佳实践的意见。 为 Lighthouse 提供一个需要审查的网址&#xff0c;它将针对此页面运行一连…

Typora+PicGo+Minio搭建博客图床

文章目录 TyporaPicGoMinio搭建博客图床前言什么是图床?为什么需要图床?准备工作一、Typora二、Picgo1. 下载Picgo2. 下载node.js3. 下载minio插件 三、服务器端配置1. 添加端口到安全组2. 使用Docker安装minio3. 配置minio image-20231127175530696四、minio插件配置五、Typ…

Python入门04字符串

目录 1 字符串的定义2 转义字符3 字符串的常见方法4 分割字符串5 字符串反转6 字符串的链式调用7 格式化字符串8 多行字符串总结 1 字符串的定义 在Python中&#xff0c;字符串表示一个字符的序列&#xff0c;比如 str "hello,world"这里我们定义了一个字符串&…

SpringBoot 入门学习

开发环境配置 JDK 1.8、Maven 3.8.8、 IDEA CE 2023.2 框架介绍 Spring Boot 是由 Pivotal 团队提供的全新框架&#xff0c;其设计目的是用来简化 Spring 应用的初始搭建以及开发过程。该框架使用了特定的方式来进行配置&#xff0c;从而使开发人员不再需要定义样板化的配置…

STM32F103C8T6——4路PWM

//main()函数前面的extern TIM_HandleTypeDef htim2;extern TIM_HandleTypeDef htim3;//main()函数内部额外添加的HAL_TIM_Base_Start_IT(&htim2);HAL_TIM_PWM_Start(&htim2,TIM_CHANNEL_1);HAL_TIM_PWM_Start(&htim2,TIM_CHANNEL_2);HAL_TIM_PWM_Start(&htim2…

深度学习中小知识点系列(三) 解读Mosaic 数据增强

前言 Mosaic数据增强&#xff0c;这种数据增强方式简单来说就是把4张图片&#xff0c;通过随机缩放、随机裁减、随机排布的方式进行拼接。Mosaic有如下优点&#xff1a; &#xff08;1&#xff09;丰富数据集&#xff1a;随机使用4张图片&#xff0c;随机缩放&#xff0c;再随…

[ CSS ] 内容超出容器后 以...省略

内容超出容器后 以…省略 当前效果 代码 <template><div class"box">有志者&#xff0c;事竟成&#xff0c;破釜沉舟&#xff0c;百二秦关终属楚; 有心人&#xff0c;天不负&#xff0c;卧薪尝胆&#xff0c;三千越甲可吞吴</div> </templa…

【Proteus仿真】【Arduino单片机】蔬菜大棚温湿度控制系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器&#xff0c;使用PCF8574、LCD1602液晶、DHT11温湿度传感器、按键、继电器、蜂鸣器、加热、水泵电机等。 主要功能&#xff1a; 系统运行后&#xff0c;LCD160…

TikTok新媒体战略:数字时代的社交营销

引言 随着数字时代的来临&#xff0c;社交媒体已成为企业推广和品牌建设的关键平台之一。而在众多社交媒体中&#xff0c;TikTok以其独特的短视频形式和庞大的用户基数吸引了无数企业和个人创作者。本文将深入探讨TikTok新媒体战略&#xff0c;探讨在数字时代如何利用这一平台进…