深度学习训练营之鸟类识别

news2025/1/13 13:40:42

深度学习训练营之鸟类识别

  • 原文链接
  • 环境介绍
  • 前置工作
    • 设置GPU
    • 导入数据并进行查找
  • 数据处理
  • 可视化数据
    • 配置数据集
  • 残差网络的介绍
  • 构建残差网络
  • 模型训练
    • 开始编译
  • 结果可视化
    • 训练样本和测试样本
    • 预测

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P8周:实现鸟类识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

环境介绍

  • 语言环境:Python3.9.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

设置GPU

使用gpu可以加快运行速度

#设置gpu
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

导入数据并进行查找

# 导入数据
import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

import os,PIL

# 设置随机种子尽可能使结果可以重现
import numpy as np
np.random.seed(1)

# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)

from tensorflow import keras
from tensorflow.keras import layers,models

import pathlib

查看数据

data_dir = "D:/BaiduNetdiskDownload/第8天-没有加密版本/第8天/bird_photos"

data_dir = pathlib.Path(data_dir)
#3. 查看数据
image_count = len(list(data_dir.glob('*/*')))

print("图片总数为:",image_count)

我们可以知道一共有565张照片

数据处理

文件夹数量
Bananaquit166 张
Black Throated Bushtiti111 张
Black skimmer122 张
Cockatoo166张

对数据进行加载操作
使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset

batch_size = 8
img_height = 224
img_width = 224
"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
#使用当中的452张进行训练

使用当中的452张进行训练

"""
关于image_dataset_from_directory()的详细介绍可以参考文章:https://mtyjkh.blog.csdn.net/article/details/117018789
"""
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

使用113张进行预测操作

# 使用class_name输出数据集的标签
class_names = train_ds.class_names
print(class_names)

在这里插入图片描述

可视化数据

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("无你想你的学习训练")

for images, labels in train_ds.take(1):
    for i in range(8):
        
        ax = plt.subplot(2, 4, i + 1)  

        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        
        plt.axis("off")

在这里插入图片描述

plt.imshow(images[4].numpy().astype("uint8"))

在这里插入图片描述
再次对数据进行检查

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

在这里插入图片描述

  • Image_batch是形状的张量(8, 224, 224, 3)。这是一批形状240x240x3的8张图片(最后一维指的是彩色通道RGB)。
  • Label_batch是形状(8,)的张量,这些标签对应8张图片

配置数据集

  • shuffle() :打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

残差网络的介绍

残差网络是为了解决神经网络隐藏层过多时,而引起的网络退化问题。退化(degradation)问题是指:当网络隐藏层变多时,网络的准确度达到饱和然后急剧退化,而且这个退化不是由于过拟合引起的。

构建残差网络

from keras import layers

from keras.layers import Input,Activation,BatchNormalization,Flatten
from keras.layers import Dense,Conv2D,MaxPooling2D,ZeroPadding2D,AveragePooling2D
from keras.models import Model

def identity_block(input_tensor, kernel_size, filters, stage, block):

    filters1, filters2, filters3 = filters

    name_base = str(stage) + block + '_identity_block_'

    x = Conv2D(filters1, (1, 1), name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size,padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    x = layers.add([x, input_tensor] ,name=name_base + 'add')
    x = Activation('relu', name=name_base + 'relu4')(x)
    return x


def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    res_name_base = str(stage) + block + '_conv_block_res_'
    name_base = str(stage) + block + '_conv_block_'

    x = Conv2D(filters1, (1, 1), strides=strides, name=name_base + 'conv1')(input_tensor)
    x = BatchNormalization(name=name_base + 'bn1')(x)
    x = Activation('relu', name=name_base + 'relu1')(x)

    x = Conv2D(filters2, kernel_size, padding='same', name=name_base + 'conv2')(x)
    x = BatchNormalization(name=name_base + 'bn2')(x)
    x = Activation('relu', name=name_base + 'relu2')(x)

    x = Conv2D(filters3, (1, 1), name=name_base + 'conv3')(x)
    x = BatchNormalization(name=name_base + 'bn3')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides, name=res_name_base + 'conv')(input_tensor)
    shortcut = BatchNormalization(name=res_name_base + 'bn')(shortcut)

    x = layers.add([x, shortcut], name=name_base+'add')
    x = Activation('relu', name=name_base+'relu4')(x)
    return x

def ResNet50(input_shape=[224,224,3],classes=1000):

    img_input = Input(shape=input_shape)
    x = ZeroPadding2D((3, 3))(img_input)

    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2))(x)

    x =     conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x =     conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    x =     conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    x =     conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    x = AveragePooling2D((7, 7), name='avg_pool')(x)

    x = Flatten()(x)
    x = Dense(classes, activation='softmax', name='fc1000')(x)

    model = Model(img_input, x, name='resnet50')
    
    # 加载预训练模型
    model.load_weights("resnet50_weights_tf_dim_ordering_tf_kernels.h5")

    return model

model = ResNet50()
model.summary()

在这里插入图片描述

模型训练

开始编译

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率
# 设置优化器,我这里改变了学习率。
opt = tf.keras.optimizers.Adam(learning_rate=1e-7)

model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

开始训练

epochs = 10

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

在这里插入图片描述

结果可视化

训练样本和测试样本

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("无你想你的学习空间")

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

请添加图片描述

预测

# 保存模型
model.save('model/my_model.h5')
# 加载模型
new_model = keras.models.load_model('model/my_model.h5')
# 采用加载的模型(new_model)来看预测结果

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5
plt.suptitle("无你想你的学习空间")

for images, labels in val_ds.take(1):
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)  
        
        # 显示图片
        plt.imshow(images[i].numpy().astype("uint8"))
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测图片中的人物
        predictions = new_model.predict(img_array)
        plt.title(class_names[np.argmax(predictions)])

        plt.axis("off")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/145649.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习:如何解决类别不平衡问题

类别不平衡是一个常见问题,其中数据集中示例的分布是倾斜的或有偏差的。 1. 简介 类别不平衡是机器学习中的一个常见问题,尤其是在二元分类领域。当训练数据集的类分布不均时会发生这种情况,从而导致训练模型存在潜在偏差。不平衡分类问题的示…

【Unity云消散】理论基础:实现SDF的8SSEDT算法

距离元旦假期已经过去5天了(从31号算起!),接着开始学习! 游戏中的很多渲染效果都离不开SDF,那么SDF究竟是什么呢?到底是个怎么样的技术?为什么能解决那么多问题? 1 SD…

git介绍及环境搭建

git介绍及环境搭建Git介绍Git安装流程配置用户信息git工作流程与常用命令问题点总结主要工作流程git工作流程与原理总结Git介绍 1.Git是什么? Git版本控制系统是一个分布式的系统,是用来保存工程源代码历史状态(游戏存档)的命令行工具 GIT是一个命令行工具,用于版…

基于Java+Spring+vue+element社区疫情服务平台设计和实现

基于JavaSpringvueelement社区疫情服务平台设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取源…

Django+channels -> websocket

Django+channels -> websocket 学习视频: https://www.bilibili.com/video/BV1J44y1p7NX/?p=10 workon # 查看虚拟环境 mkvirtualenv web -p python3.10 # 创建虚拟环境 workon web # 进入虚拟环境pip insatll django channelsdjango-admin startproject ws_demo python …

【NI Multisim 14.0原理图环境设置——元器件库管理】

目录 序言 一、元器件库管理 🍉1.“元器件”工具栏 🍊(1)电源/信号源库 🍊(2)基本器件库 🍊(3)二极管库 🍊(4)晶体管…

seL4 背景知识

1 seL4 演变 1.1 微内核 微内核发展到目前为止经历了三代, 这里做一些归纳。参考《现代操作系统: 原理与实现》中操作系统结构一章, 关于微内核架构发展的介绍。 第一代微内核设计将许多内核态功能放到用户态, Mach 微内核是第一代微内核的代表。第二代微内核设计将对 IPC 优…

C++学习记录——일 C++入门(1)

C入门(1) 文章目录C入门(1)一、C关键字二、C第一个程序三、命名空间1、域作用限定符2、了解命名空间3、命名空间的使用四、C输入输出五、缺省参数六、函数重载七、引用1、引用符号2、引用的部分使用场景一、C关键字 关键字有98个&…

filebeat采集nginx日志

背景我们公司项目组用的是elastic的一整套技术栈,es,kibana,filebeat和apm,目前已经可以采集网关各个微服务的日志。架构图现在需要在原来的基础上把nginx这的日志也采集上来,方便做链路跟踪问题与思路原先traceId是在…

数字经济时代,“8K+”开拓行业新格局

2023深圳国际8K超高清视频产业发展大会召开,大会以“超清互联 数智创新”为主题,汇聚两院院士、产业领袖、领军企业共同深入探讨超高清产业发展现状、关键问题和未来趋势,并集中发布《深圳市超高清视频显示产业白皮书(2023版&…

「数据密集型系统搭建」开卷篇|什么是数据密集型系统

在我们开发的诸多系统,基本都可以视为“数据密集型系统”,数据是一切物质的载体,我们依靠数据做存储记录,通过数据进行信息传递交换,最终还要数据来呈现和展示等,从一定视角而言,系统中最核心、…

临时用网搞不定?别着急,5G网络“急救车”来啦

如何在1天时间内,用不超过5名装维人员,完成超过200间宿舍的网络覆盖,让即将踏上考场的高三学子们尽快用上网络? 近期,这个问题一直困扰着重庆电信客户经理周睿。原来,由于疫情原因,重庆市某中学…

WINDOWS安装Oracle11.2.0.4

(一)Oracle服务器端安装 1.运行Oracle11g服务器端安装程序setup.exe,弹出如下界面: 2.如上界面中,把默认打上的勾去掉,然后点击【下一步】,弹出如下界面: 3.如上界面中,选择跳过软件更新,然后点击【下一步…

指针进阶(三)再谈数组与串函数

🌞欢迎来到C语言的世界 🌈博客主页:卿云阁 💌欢迎关注🎉点赞👍收藏⭐️留言📝 🌟本文由卿云阁原创! 🌠本阶段属于练气阶段,希望各位仙友顺利完成…

【阶段二】Python数据分析数据可视化工具使用01篇:数据可视化工具介绍、数据可视化工具安装、折线图与柱形图

本篇的思维导图: 数据可视化工具介绍 Matplotlib是最著名的绘图库,主要用于二维绘图,当然也可以进行简单的三维绘图。它提供了一整套丰富的命令,让我们可以非常快捷地用Python可视化数据,而且允许输出达到出版质量的多种图像格式。 Seaborn是在matplo…

国内电容市场份额达七成,松下如何抢占高地?

01 电容市场发展 电容器是三大电子被动元器件之一,是电子线路中不可缺少的基础元件,约占全部电子元件用量的40%,产值的66%。中国电容器行业规模增速持续高于全球规模增速,中国市场的快速增长成为拉动全球电容器行业规模增长的主要…

【Python从入门到进阶】2、Python环境的安装

接上篇《1、初识Python》 上一篇我们对Python这门编程语言进行了一个基本的了解,本篇我们来学习如何下载安装Python编程环境,以及如何使用pip管理Python包。 本篇讲解的是Windows环境下安装Python编程环境的步骤。 一、Python安装包下载 想要使用Pyth…

vue框架、element-ui组件库、font awesome图表库

一、vue 创建一个新vue项目。 vue create ProjectName 然后cd到该目录下,npm run serve启动服务器,即可打开。 二、组件库 element-ui是饿了么的,ArcoDesign是字节的,有很多。 install见官方文档:组件 | Element 导入…

黑马学SpringAMQP

目录: (1)SpringAMQP的基本介绍 (2)SpringAMQP-入门案例的消息发送 (3) SpringAMQP-入门案例的消息接收 (4)SpringAMQP-WorkQueue模型 (5)Sp…

408数据结构考点总结

第一章 绪论 考点 1:时间复杂度与空间复杂度 时间复杂度 定义:将算法中基本运算的执行次数的数量级作为时间复杂度,记为O(n)O(n)O(n)。 计算原则 加法法则:T(n)T1(n)T2(n)O(f(n))O(g(n))O(max⁡(f(n),g(n)))T(n)T_{1}(n)T_{2…