深度学习训练营之识别宝可梦人物和角色

news2025/1/15 12:51:26

深度学习训练营之识别宝可梦人物和角色

  • 原文链接
  • 环境介绍
  • 前置工作
    • 设置GPU
    • 数据加载
    • 数据查看
  • 数据预处理
    • 加载数据
    • 可视化数据
    • 检查数据
    • 配置数据集
      • `prefetch()`功能详细介绍:
  • 调用官方的网络的模型
  • 模型训练
    • 官方模型调用
  • 设置动态学习率
  • 模型训练
  • 模型评估
  • 结果分析
  • 参考链接

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

设置GPU

如果

# K同学啊深度学习练习
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

# 打印显卡信息,确认GPU可用
print(gpus)

数据加载

import os,math
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
from sklearn.preprocessing   import MinMaxScaler
from sklearn                 import metrics
import numpy             as np
import pandas            as pd
import tensorflow        as tf
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
data = pd.read_csv('SH600519.csv')  # 读取股票文件

调整数据集所在的位置
在这里插入图片描述

将运行的代码和需要运行的数据集放在同一个文件夹,方便数据的导入
在这里插入图片描述

data_dir = "016_Pokemon"

data_dir = pathlib.Path(data_dir)

数据查看

数据集一共分为10个角色,每个角色都会有单独的自带的文件夹保存图片

image_count = len(list(data_dir.glob('*/*')))

print("图片总数为:",image_count)

图片总数为: 219

数据预处理

加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset
这三行代码主要是先对图片进行预处理,设定图片的高度和宽度还有batch_size的大小

batch_size = 8
img_height = 224
img_width = 224

这三行代码划分数据集,train_ds,val_ds

"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

在这里插入图片描述

通过class_names输出数据集的标签,标签按照字母顺序对应于目录名称

class_names = train_ds.class_names
print(class_names)

在这里插入图片描述

可视化数据

plt.figure(figsize=(20, 10))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  
        
        ax.patch.set_facecolor('yellow')
        
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

请添加图片描述

检查数据

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

在这里插入图片描述

  • Image_batch是形状的张量(8,224,224,3)。这是一批形状224x224x3的8张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(8,)的张量,这些标签对应32张图片

配置数据集

  • shuffle():打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch():预取数据,加速运行

prefetch()功能详细介绍:

CPU 正在准备数据时,加速器处于空闲状态。相反,当加速器正在训练模型时,CPU 处于空闲状态。因此,训练所用的时间是 CPU 预处理时间和加速器训练时间的总和。prefetch()将训练步骤的预处理和模型执行过程重叠到一起。当加速器正在执行第 N 个训练步时,CPU 正在准备第 N+1 步的数据。这样做不仅可以最大限度地缩短训练的单步用时(而不是总用时),而且可以缩短提取和转换数据所需的时间。如果不使用prefetch(),CPU 和 GPU/TPU 在大部分时间都处于空闲状态
使用该函数的作用就在于尽可能的提高CPU等的使用性能,提高模型训练时候的速度
在这里插入图片描述
使用该函数可以减少空闲时间
在这里插入图片描述

  • cache():将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

调用官方的网络的模型

在训练模型当中,选取不同的训练方法都会有不同的训练结果,但是要进行模型的比较的时候,难道要一个一个的训练,一个一个模型进行构建吗?这样实在是太费时间了,看到K同学所分享的内容,我看到了,
我们可以使用官方所给出的方法对模型进行直接的调用.
这样的一个神器就叫做 tf.keras.applications
简单的讲 tf.keras.applications 就是它把我们常用的一些模型,比如VGG-16等进行了封装,我假设想调用某一个模型时(例如:VGG-16),直接调用函数接口就OK了,无需再自己进行构建,当然在这里还是希望我们可以取了解一些模型具体的代码书写过程,这样对我们后期学期有很大的帮助(这话是将给我自己听的)
tf.keras.applications 目前支持的模型及该模型的性能参数如下(知道有这些内容就好):

在这里插入图片描述
TOP-5准确率可以简单的理解为在正确的标签(分类)在这5个类别里之中的概率
这是官网当中模型的相关介绍
常用的三个参数解释如下:

  • include_top:是否包括网络顶部的 3 个全连接层。
  • weights:默认不加载权重文件,"imagenet"加载官方权重文件,或者输入自己的权重文件路径。
  • classes:分类图像的类别数

这三个参数可以根据需要进行修改

常见的其他接口有

  • tf.keras.applications.xception.Xception()
  • tf.keras.applications.vgg16.VGG16()
  • tf.keras.applications.vgg19.VGG19()
  • tf.keras.applications.resnet50.ResNet50()
  • tf.keras.applications.inception_v3.InceptionV3()
  • tf.keras.applications.inception_resnet_v2.InceptionResNetV2()
  • tf.keras.applications.mobilenet.MobileNet()
  • tf.keras.applications.mobilenet_v2.MobileNetV2()
  • tf.keras.applications.densenet.DenseNet121()
  • tf.keras.applications.densenet.DenseNet169()
  • tf.keras.applications.densenet.DenseNet201()
  • tf.keras.applications.nasnet.NASNetMobile()
  • tf.keras.applications.nasnet.NASNetLarge()

这方面的内容还挺有意思,直接通过调用就可以进行训练

模型训练

官方模型调用

K同学啊在调用的过程当中使用的Densenet121的方法,我这里尝试使用其他接口,使用tf.keras.applications.nasnet.NASNetMobile()进行模型的预测,其他内容不变

model = tf.keras.applications.nasnet.NASNetMobile(weights='imagenet')
model.summary()

在这里插入图片描述

设置动态学习率

设置动态学习率也就相对于是说每一次模型优化之后,我都会调整我的学习率,这样的好处就在于,使得学习率的调整更加合理,及可以避免学习率过大而导致的结果不收敛,以及避免学习率过小时进入局部最优解,常见的就有RMSProp
这里使用的是指数衰减学习率导入到优化器当中

# 设置初始学习率
initial_learning_rate = 1e-4

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate, 
        decay_steps=5,      # 敲黑板!!!这里是指 steps,不是指epochs
        decay_rate=0.96,     # lr经过一次衰减就会变成 decay_rate*lr
        staircase=True)

# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

模型训练

model.compile(optimizer=optimizer,
              loss     ='sparse_categorical_crossentropy',
              metrics  =['accuracy'])
epochs = 20

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

在这里插入图片描述

模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

结果分析

不难看出,这里的运行出来的结果很不理想,说明该方法其实并不适合对于宝可梦角色的预测当中,可能的原因在于模型动态学习率的方法并不适合,毕竟是指数倍的计算,计算量比较大,nasnet本身是一个通过堆叠 网络单元的形式将其迁移到任意分类任务,乃至任意类型的任务中,还有就是nasnet本身更加适合用于强化学习?这方面不是很了解,可以后续进行学习

参考链接

tf.keras.applications API地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/387725.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Redis】Redis 如何实现分布式锁

Redis 如何实现分布式锁1. 什么是分布式锁1.1 分布式锁的特点1.2 分布式锁的场景1.3 分布式锁的实现方式2. Redis 实现分布式锁2.1 setnx expire2.2 set ex px nx2.3 set ex px nx 校验唯一随机值,再删除2.4 Redisson 实现分布式锁1. 什么是分布式锁 分布式锁其实…

【C语言进阶:指针的进阶】回调函数

本章重点内容: 字符指针指针数组数组指针数组传参和指针传参函数指针函数指针数组指向函数指针数组的指针回调函数指针和数组面试题的解析什么是回调函数: 回调函数就是一个通过函数指针调用的函数。如果你把函数的指针(地址)作…

Lenovo 联想-IdeaPad-Y530电脑 Hackintosh 黑苹果efi引导文件

原文来源于黑果魏叔官网,转载需注明出处。硬件型号驱动情况主板联想-IdeaPad-Y530处理器Intel 酷睿2双核 T9400已驱动内存2GB已驱动硬盘2TB HP EX950 PCI-E Gen3 x4 NVMe SSD已驱动显卡NVIDIA GeForce 9300M GS无法驱动声卡Realtek ALC888无法驱动网卡RTL8168H Giga…

【Java学习笔记】3.Java 基础语法

Java 基础语法 一个 Java 程序可以认为是一系列对象的集合,而这些对象通过调用彼此的方法来协同工作。下面简要介绍下类、对象、方法和实例变量的概念。 对象:对象是类的一个实例,有状态和行为。例如,一条狗是一个对象&#xff…

【NLP相关】从零开始理解BERT模型:NLP领域的突破(BERT详解与代码实现)

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博…

Python异常处理更新,正常和不正常的都在这里

嗨害大家好鸭!我是小熊猫~ 异常处理篇嗨害大家好鸭!我是小熊猫~Python标准异常💨什么是异常?不正常异常处理💨使用except而不带任何异常类型使用except而带多种异常类型try-finally 语句异常的参数触发异常用户自定义异…

Lesson12---人工神经网络(1)

12.1 神经元与感知机 12.1.1 感知机 感知机: 1957, Fank Rosenblatt 由两层神经元组成,可以简化为右边这种,输入通常不参与计算,不计入神经网络的层数,因此感知机是一个单层神经网络 感知机 训练法则&am…

MyBatis - 13 - MyBatis逆向工程

文章目录1.准备工作1.1 建表1.2 创建Maven工程1.2.1 在pom.xml中添加依赖和插件,更新maven1.2.2 在src/main/resources下创建mybatis-config.xml1.2.3 在src/main/resources下创建jdbc.properties1.2.4 在src/main/resources下创建log4j.xml文件1.2.5 在src/main/re…

搭建zabbix4.0监控服务实例

一.Zabbix服务介绍 1.1服务介绍 Zabbix是基于WEB界面的分布式系统监控的开源解决方案,Zabbix能够监控各种网络参数,保证服务器系统安全稳定的运行,并提供灵活的通知机制让SA快速定位并解决存在的各种问题。 1.2 Zabbix优点 Zabbix分布式监…

python用openpyxl包操作xlsx文件,统计表中合作电影数目最多的两个演员

题目🎉🎉🎉:编程完成下面任务:已知excel文件“电影导演演员信息表.xlsx”如下图所示:🍳🍳🍳要求:使用 openpyxl 包操作打开此文件,编写程序统计在…

sqlli-labs基本使用

1.安装hackbar插件 链接:https://pan.baidu.com/s/1-QIYmAU-BV_DEONfxovizQ 提取码:dc66 2.SQL注入表信息解析(案例使用的sqlli-labs自带的数据库security) 2.1 通过order by 判断表有多少列 分析表有多少列(通过…

【Storm】【六】Storm 集成 Redis 详解

Storm 集成 Redis 详解 一、简介二、集成案例三、storm-redis 实现原理四、自定义RedisBolt实现词频统计一、简介 Storm-Redis 提供了 Storm 与 Redis 的集成支持&#xff0c;你只需要引入对应的依赖即可使用&#xff1a; <dependency><groupId>org.apache.storm…

红日(vulnstack)2 内网渗透ATTCK实战

环境配置 链接&#xff1a;百度网盘 请输入提取码 提取码&#xff1a;wmsi 攻击机&#xff1a;kali2022.03 web 192.168.111.80 10.10.10.80 自定义网卡8&#xff0c;自定义网卡18 PC 192.168.111.201 10.10.10.201 自定义网卡8&#xff0c;自定义网卡18 DC 192.168.52.1…

【Word/word2007】将标题第1章改成第一章

问题&#xff1a;设置多级列表没有其他格式选的解决办法和带来的插入图注解的问题&#xff0c;将标题第1章改成第一章的问题其他方案。 按照百度搜索的方法设置第一章&#xff0c;可以是没有相应的样式可以选。 那就换到编号选项 设置新的编号值 先选是 然就是变得很丑 这时打开…

数据结构(一)(嵌入式学习)

数据结构干货总结&#xff08;一&#xff09;基础线性表的顺序表示线性表的链式表示单链表双链表循环链表循环单链表循环双链表栈顺序存储链式存储队列队列的定义队列的常见基本操作队列的顺序存储结构顺序队列循环队列队列的链式存储结构树概念二叉树二叉树的创建基础 数据&a…

项目实战典型案例14——代码结构混乱 逻辑边界不清晰 页面美观设计不足

代码结构混乱 逻辑边界不清晰 页面美观设计不足一&#xff1a;背景介绍问题1 代码可读性差&#xff0c;代码结构混乱问题2 逻辑边界不清晰&#xff0c;封装意识缺乏示例3.展示效果上的美观设计二&#xff1a;思路&方案问题一&#xff0c;代码可读性差&#xff0c;代码结构混…

tun驱动之ioctl

struct ifreq ifr; ifr.ifr_flags | IFF_TAP | IFF_NO_PI; ioctl(fd, TUNSETIFF, (void *)&ifr); 上面的代码的意思是设置网卡信息&#xff0c;并将tun驱动设置为TAP模式。在TAP模式下&#xff0c;在用户空间下调用open打开/dev/net/tun驱动文件&#xff0c;发送(调用send函…

C语言不踩坑: 自动类型转换规则

先看一个例程&#xff1a; # include <stdio.h> int main(void) {int a -10;unsigned b 5;if ((ab) > 0){printf("(ab) > 0\n");printf("(ab) %d\n",ab);}else{printf("(ab) < 0\n");}return 0; }运行的结果是&#xff1a; …

svn 分支(branch)和标签(tag)管理

版本控制的一大功能是可以隔离变化在某个开发线上&#xff0c;这个开发线就是分支&#xff08;branch&#xff09;。分支通常用于开发新功能&#xff0c;而不会影响主干的开发。也就是说分支上的代码的编译错误、bug不会对主干&#xff08;trunk&#xff09;产生影响。然后等分…

实现echarts主题随项目主题切换

前言 项目中很多时候都带有dark/light两中主题类型&#xff0c;通过switch标签控制&#xff0c;但是echarts图形是通过canvas标签绘制&#xff0c;其背景颜色和字体样式并不会随着项目主题类型的切换而切换。所以需要额外设置监听主题事件&#xff0c;主要实现思路如下&#x…