卷积神经网络(Inception-ResNet-v2)交通标志识别

news2025/4/19 16:56:57

文章目录

  • 一、前言
  • 二、前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
    • 2. 导入数据
    • 3. 查看数据
  • 二、构建一个tf.data.Dataset
    • 1.加载数据
    • 2. 配置数据集
  • 三、构建Inception-ResNet-v2网络
    • 1.自己搭建
    • 2.官方模型
  • 五、设置动态学习率
  • 六、训练模型
  • 七、模型评估
  • 八、模型的保存与加载
  • 九、预测

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

  • 卷积神经网络(CNN)实现mnist手写数字识别
  • 卷积神经网络(CNN)多种图片分类的实现
  • 卷积神经网络(CNN)衣服图像分类的实现
  • 卷积神经网络(CNN)鲜花识别
  • 卷积神经网络(CNN)天气识别
  • 卷积神经网络(VGG-16)识别海贼王草帽一伙
  • 卷积神经网络(ResNet-50)鸟类识别
  • 卷积神经网络(AlexNet)鸟类识别
  • 卷积神经网络(CNN)识别验证码

来自专栏:机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")

2. 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

import os,PIL,pathlib

# 设置随机种子尽可能使结果可以重现
import pandas as pd
import numpy  as np
np.random.seed(1)

# 设置随机种子尽可能使结果可以重现
import tensorflow as tf
tf.random.set_seed(1)

from tensorflow import keras
from tensorflow.keras import layers,models
# 导入图片数据
pictures_dir = "images"
pictures_dir = pathlib.Path("pictures_dir")

# 导入训练数据的图片路径名及标签
train = pd.read_csv("annotations.csv")

3. 查看数据

image_count = len(list(pictures_dir.glob('*.png')))
print("图片总数为:",image_count)
图片总数为: 5998
train.head()
file_namecategory
0000_0001.png0
1000_0002.png0
2000_0003.png0
3000_0010.png0
4000_0011.png0

二、构建一个tf.data.Dataset

1.加载数据

数据集中已经划分好了测试集与训练集,这次只需要进行分别加载就好了。

def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)  # 编码解码处理
    image = tf.image.resize(image, [299,299])        # 图片调整
    return image/255.0                               # 归一化处理

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
AUTOTUNE = tf.data.experimental.AUTOTUNE
common_paths = "images/"

# 训练数据的标签
train_image_label = [i for i in train["category"]]
train_label_ds = tf.data.Dataset.from_tensor_slices(train_image_label)

# 训练数据的路径
train_image_paths = [ common_paths+i for i in train["file_name"]]
# 加载图片路径
train_path_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)
# 加载图片数据
train_image_ds = train_path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
# 将图片与标签进行对应打包
image_label_ds = tf.data.Dataset.zip((train_image_ds, train_label_ds))
image_label_ds
plt.figure(figsize=(20,4))

for i in range(20):
    plt.subplot(2,10,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    
    # 显示图片
    images = plt.imread(train_image_paths[i])
    plt.imshow(images)
    # 显示标签
    plt.xlabel(train_image_label[i])

plt.show()

在这里插入图片描述

2. 配置数据集

BATCH_SIZE = 6

# 将训练数据集拆分成训练集与验证集
train_ds = image_label_ds.take(5000).shuffle(1000)  # 前1500个batch
val_ds   = image_label_ds.skip(5000).shuffle(1000)  # 跳过前1500,选取后面的

train_ds = train_ds.batch(BATCH_SIZE)
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)

val_ds = val_ds.batch(BATCH_SIZE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
val_ds
# 查看数据 shape 进行检查
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break
(6, 299, 299, 3)
(6,)
# 再次查看数据,确认是否被打乱
plt.figure(figsize=(8,8))

for images, labels in train_ds.take(1):
    for i in range(6):
        
        ax = plt.subplot(4, 3, i + 1)  
        plt.imshow(images[i])
        plt.title(labels[i].numpy())  #使用.numpy()将张量转换为 NumPy 数组
        
        plt.axis("off")

在这里插入图片描述

三、构建Inception-ResNet-v2网络

1.自己搭建

下面是本文的重点 InceptionResNetV2 网络模型的构建,可以试着按照上面的图自己构建一下 InceptionResNetV2,这部分我主要是参考官网的构建过程,将其单独拎了出来。

from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Dropout,BatchNormalization,Activation
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Concatenate, Lambda,GlobalAveragePooling2D
from tensorflow.keras import backend as K

def conv2d_bn(x,filters,kernel_size,strides=1,padding='same',activation='relu',use_bias=False,name=None):
    
    x = Conv2D(filters,kernel_size,strides=strides,padding=padding,use_bias=use_bias,name=name)(x)
    
    if not use_bias:
        bn_axis = 1 if K.image_data_format() == 'channels_first' else 3
        bn_name = None if name is None else name + '_bn'
        x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
    if activation is not None:
        ac_name = None if name is None else name + '_ac'
        x = Activation(activation, name=ac_name)(x)
    return x

def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
    if block_type == 'block35':
        branch_0 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(branch_1, 32, 3)
        branch_2 = conv2d_bn(x, 32, 1)
        branch_2 = conv2d_bn(branch_2, 48, 3)
        branch_2 = conv2d_bn(branch_2, 64, 3)
        branches = [branch_0, branch_1, branch_2]
    elif block_type == 'block17':
        branch_0 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(x, 128, 1)
        branch_1 = conv2d_bn(branch_1, 160, [1, 7])
        branch_1 = conv2d_bn(branch_1, 192, [7, 1])
        branches = [branch_0, branch_1]
    elif block_type == 'block8':
        branch_0 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(branch_1, 224, [1, 3])
        branch_1 = conv2d_bn(branch_1, 256, [3, 1])
        branches = [branch_0, branch_1]
    else:
        raise ValueError('Unknown Inception-ResNet block type. '
                         'Expects "block35", "block17" or "block8", '
                         'but got: ' + str(block_type))

    block_name = block_type + '_' + str(block_idx)
    mixed = Concatenate(name=block_name + '_mixed')(branches)
    up = conv2d_bn(mixed,K.int_shape(x)[3],1,activation=None,use_bias=True,name=block_name + '_conv')

    x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
               output_shape=K.int_shape(x)[1:],
               arguments={'scale': scale},
               name=block_name)([x, up])
    if activation is not None:
        x = Activation(activation, name=block_name + '_ac')(x)
    return x


def InceptionResNetV2(input_shape=[299,299,3],classes=1000):

    inputs = Input(shape=input_shape)

    # Stem block
    x = conv2d_bn(inputs, 32, 3, strides=2, padding='valid')
    x = conv2d_bn(x, 32, 3, padding='valid')
    x = conv2d_bn(x, 64, 3)
    x = MaxPooling2D(3, strides=2)(x)
    x = conv2d_bn(x, 80, 1, padding='valid')
    x = conv2d_bn(x, 192, 3, padding='valid')
    x = MaxPooling2D(3, strides=2)(x)

    # Mixed 5b (Inception-A block)
    branch_0 = conv2d_bn(x, 96, 1)
    branch_1 = conv2d_bn(x, 48, 1)
    branch_1 = conv2d_bn(branch_1, 64, 5)
    branch_2 = conv2d_bn(x, 64, 1)
    branch_2 = conv2d_bn(branch_2, 96, 3)
    branch_2 = conv2d_bn(branch_2, 96, 3)
    branch_pool = AveragePooling2D(3, strides=1, padding='same')(x)
    branch_pool = conv2d_bn(branch_pool, 64, 1)
    branches = [branch_0, branch_1, branch_2, branch_pool]

    x = Concatenate(name='mixed_5b')(branches)

    # 10次 Inception-ResNet-A block
    for block_idx in range(1, 11):
        x = inception_resnet_block(x, scale=0.17, block_type='block35', block_idx=block_idx)

    # Reduction-A block
    branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid')
    branch_1 = conv2d_bn(x, 256, 1)
    branch_1 = conv2d_bn(branch_1, 256, 3)
    branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_pool]
    x = Concatenate(name='mixed_6a')(branches)

    # 20次 Inception-ResNet-B block
    for block_idx in range(1, 21):
        x = inception_resnet_block(x, scale=0.1, block_type='block17', block_idx=block_idx)


    # Reduction-B block
    branch_0 = conv2d_bn(x, 256, 1)
    branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid')
    branch_1 = conv2d_bn(x, 256, 1)
    branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid')
    branch_2 = conv2d_bn(x, 256, 1)
    branch_2 = conv2d_bn(branch_2, 288, 3)
    branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    x = Concatenate(name='mixed_7a')(branches)

    # 10次 Inception-ResNet-C block
    for block_idx in range(1, 10):
        x = inception_resnet_block(x, scale=0.2, block_type='block8', block_idx=block_idx)
    x = inception_resnet_block(x, scale=1., activation=None, block_type='block8', block_idx=10)


    x = conv2d_bn(x, 1536, 1, name='conv_7b')

    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)

    # 创建模型
    model = Model(inputs, x, name='inception_resnet_v2')

    return model

model = InceptionResNetV2([299,299,3],58)
model.summary()

2.官方模型

# import tensorflow as tf
# # 如果使用官方模型需要将图片shape调整为 [299,299,3],目前图片的shape是 [150,150,3]
# model = tf.keras.applications.inception_resnet_v2.InceptionResNetV2()
# model.summary()

五、设置动态学习率

这里先罗列一下学习率大与学习率小的优缺点。

  • 学习率大
    • 优点: 1、加快学习速率。 2、有助于跳出局部最优值。
    • 缺点: 1、导致模型训练不收敛。 2、单单使用大学习率容易导致模型不精确。
  • 学习率小
    • 优点: 1、有助于模型收敛、模型细化。 2、提高模型精度。
    • 缺点: 1、很难跳出局部最优值。 2、收敛缓慢。

注意:这里设置的动态学习率为:指数衰减型(ExponentialDecay)。在每一个epoch开始前,学习率(learning_rate)都将会重置为初始学习率(initial_learning_rate),然后再重新开始衰减。计算公式如下:

learning_rate = initial_learning_rate * decay_rate ^ (step / decay_steps)

在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:

  • 损失函数(loss):用于衡量模型在训练期间的准确率。
  • 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
  • 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

六、训练模型

Inception-ResNet-v2 模型相对之前的模型较为复杂,故而运行耗时也更长,我这边每一个epoch运行时间是130s左右。我的GPU配置是 NVIDIA GeForce RTX 3080。建议大家先将 epochs 调整为1跑通程序。

epochs = 10

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)
Epoch 1/10
834/834 [==============================] - 154s 163ms/step - loss: 2.5214 - accuracy: 0.3563 - val_loss: 1.3834 - val_accuracy: 0.6168
Epoch 2/10
834/834 [==============================] - 133s 159ms/step - loss: 0.9230 - accuracy: 0.7522 - val_loss: 0.5457 - val_accuracy: 0.8531
Epoch 3/10
834/834 [==============================] - 133s 159ms/step - loss: 0.3952 - accuracy: 0.9105 - val_loss: 0.3391 - val_accuracy: 0.9064
Epoch 4/10
834/834 [==============================] - 134s 160ms/step - loss: 0.1876 - accuracy: 0.9655 - val_loss: 0.2481 - val_accuracy: 0.9296
Epoch 5/10
834/834 [==============================] - 131s 156ms/step - loss: 0.1071 - accuracy: 0.9862 - val_loss: 0.1265 - val_accuracy: 0.9716
Epoch 6/10
834/834 [==============================] - 128s 153ms/step - loss: 0.0587 - accuracy: 0.9954 - val_loss: 0.0911 - val_accuracy: 0.9794
Epoch 7/10
834/834 [==============================] - 132s 158ms/step - loss: 0.0429 - accuracy: 0.9976 - val_loss: 0.0941 - val_accuracy: 0.9777
Epoch 8/10
834/834 [==============================] - 132s 158ms/step - loss: 0.0306 - accuracy: 0.9980 - val_loss: 0.0955 - val_accuracy: 0.9777
Epoch 9/10
834/834 [==============================] - 133s 158ms/step - loss: 0.0248 - accuracy: 0.9997 - val_loss: 0.0864 - val_accuracy: 0.9794
Epoch 10/10
834/834 [==============================] - 132s 158ms/step - loss: 0.0216 - accuracy: 0.9988 - val_loss: 0.0750 - val_accuracy: 0.9794

七、模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

八、模型的保存与加载

# 保存模型
model.save('model/14_model.h5')
# 加载模型
new_model = keras.models.load_model('model/14_model.h5')

九、预测

# 采用加载的模型(new_model)来看预测结果

plt.figure(figsize=(10, 5))  # 图形的宽为10高为5


for images, labels in val_ds.take(1):
    for i in range(6):
        ax = plt.subplot(2, 3, i + 1)  
        
        # 显示图片
        plt.imshow(images[i])
        
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(images[i], 0) 
        
        # 使用模型预测路标
        predictions = new_model.predict(img_array)
        plt.title(np.argmax(predictions))

        plt.axis("off")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1252482.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【小沐学写作】免费在线AI辅助写作汇总

文章目录 1、简介2、文涌Effidit(腾讯)2.1 工具简介2.2 工具功能2.3 工具体验 3、PPT小助手(officeplus)3.1 工具简介3.2 使用费用3.3 工具体验 4、DeepL Write(仅英文)4.1 工具简介4.2 工具体验 5、天工AI…

数组题目:645. 错误的集合、 697. 数组的度、 448. 找到所有数组中消失的数字、442. 数组中重复的数据 、41. 缺失的第一个正数

645. 错误的集合 思路: 我们定义一个数组cnt,记录每个数出现的次数。然后我们遍历数组,从1开始,如果cnt[i] 0 那就说明这个是错误的数,如果 cnt[i] 2,那就说明是重复的数。 代码: class So…

嵌入式虚拟机原理

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab,机器人运动控制、多机器人协作,智能优化算法,滤波估计、多传感器信息融合,机器学习,人工智能等相关领域的知识和技术。关…

【数据集】全网最全的常见已公开医学影像数据集

目录 一,极市医学数据集汇总 1.CT 医学图像 ​编辑 2.恶性与良性皮肤癌 3.白内障数据集 4.胸部 X 光图像(肺炎) 5.用于图像增强的内窥镜真实合成曝光过度和曝光不足帧 6.医学家 7.乳房组织病理学图像 8.皮肤癌 MNIST:HA…

芯片安全和无线电安全底层渗透技术

和传统网络安全不同,硬件安全、芯片安全、无线电安全属于网络底层安全的重要细分领域,是网络安全的真正基石,更是国家安全的重要组成部分,“夯实网络底层安全基础,筑牢网络强国安全底座”,是底网安全重要性…

上手 Promethus - 开源监控、报警工具包

名词解释 Promethus 是什么 开源的【系统监控和警报】工具包 专注于: 1)可靠的实时监控 2)收集时间序列数据 3)提供强大的查询语言(PromQL),用于分析这些数据 功能: 1&#xff0…

红外遥控实验

本章,我们将介绍 STM32F103 对红外遥控器的信号解码。STM32 板子上标配的红外接收 头和一个小巧的红外遥控器。我们将利用 STM32 的输入捕获功能,解码开发板标配的红外遥控 器的编码信号,并将编码后的键值在 LCD 模块中显示出来。 红外遥控技…

AI换脸教程

方法一、MJ换脸大法 1.点击这个网站添加一个机器人到自己的服务器 https://discord.com/oauth2/authorize?client_id1090660574196674713&permissions274877945856&scopebot 2. /saveid 回车选择你自己的照片,并且在名字框命名身份,回车 3.…

Cesium-terrain-builder编译入坑详解

本以为编译cesium-terrian-tools编译应该没那么难,不想问题重重,不想后人重蹈覆辙,也记录下点点滴滴。 目前网上存在的cesium代码版本主要有两个分支: 原始网站【不能生成layer文件,且经久不更新,使用gdal…

Kotlin学习——kt里面的函数,高阶函数 函数式编程 扩展函数和属性

Kotlin 是一门现代但已成熟的编程语言,旨在让开发人员更幸福快乐。 它简洁、安全、可与 Java 及其他语言互操作,并提供了多种方式在多个平台间复用代码,以实现高效编程。 https://play.kotlinlang.org/byExample/01_introduction/02_Functio…

【STM32】GPIO输出

1 GPIO简介 (1)GPIO(General Purpose Input Output)通用输入输出口 (2)可配置为8种输入输出模式 (3)引脚电平:0V~3.3V,部分引脚可容忍5V(可以输…

mysql8下载与安装教程

文章目录 1. MySQL下载2. MySQL安装3. 添加环境变量4. 登录mysql 1. MySQL下载 以下两个网址二选一 官网:https://downloads.mysql.com/archives/community/阿里云镜像:https://mirrors.aliyun.com/mysql/?spma2c6h.13651104.d-5173.5.2e535dc8shSjIl…

centos7搭建ftp服务

一、安装 yum -y install vsftpd vi /etc/vsftpd/vsftpd.conf二、编辑配置文件 /etc/vsftpd/vsftpd.conf 内容如下 #是否允许匿名,默认no anonymous_enableNO#这个设定值必须要为YES 时,在/etc/passwd内的账号才能以实体用户的方式登入我们的vsftpd主机…

【Java程序员面试专栏 专业技能篇 】Java SE核心面试指引(四):Java新特性

关于Java SE部分的核心知识进行一网打尽,包括四部分:基础知识考察、面向对象思想、核心机制策略、Java新特性,通过一篇文章串联面试重点,并且帮助加强日常基础知识的理解,全局思维导图如下所示 本篇Blog为第四部分:Java新特性,子节点表示追问或同级提问 Java8新特性…

PyInstaller打包python程序为exe可执行文件

教程千千万,貌似我的window电脑就是打包不了,而且不同电脑的表现都不一致,很是奇怪。 文章目录 1 极简版1.1 生成文件spec详解1.2 是否变成一个exe主文件 2 虚拟环境打包3 其他打包需求3.1 加密打包3.2 Pyinstaller打包多个py文件为一个exe文…

代码随想录算法训练营第四十七天|198. 打家劫舍、213. 打家劫舍II、337. 打家劫舍III

LeetCode 198. 打家劫舍 题目链接:198. 打家劫舍 - 力扣(LeetCode) 第一次打家劫舍,来个简单一些的,无非就是偷了当前这家偷不了下一家,因此dp[n]代表,偷前n家的时候所能偷到的最高金额&#x…

区间预测 | Matlab实现BP-KDE的BP神经网络结合核密度估计多变量时序区间预测

区间预测 | Matlab实现BP-KDE的BP神经网络结合核密度估计多变量时序区间预测 目录 区间预测 | Matlab实现BP-KDE的BP神经网络结合核密度估计多变量时序区间预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.BP-KDE多变量时间序列区间预测,基于BP神经网络多…

rtsp点播异常出现‘circluar_buffer_size‘ option was set but it is xx

先说现象: 我使用potplay播放器来点播rtsp码流的时候可以点播成功,同事使用vlc和FFplay来点播rtsp码流的时候异常。 排查思路: 1.开始怀疑是oss账号问题,因为ts切片数据是保存在oss中的,我使用的是自己的oss账号,同事使用的是公司…

Kafka 如何实现顺序消息

版本说明 本文所有的讨论均在如下版本进行,其他版本可能会有所不同。 Kafka: 3.6.0Pulsar: 2.9.0RabbitMQ 3.7.8RocketMQ 5.0Go1.21github.com/segmentio/kafka-go v0.4.45 结论先行 Kafka 只能保证单一分区内的顺序消息,无法保证多分区间的顺序消息…

【数据结构】用C语言实现链队列(附完整运行代码)

🦄个人主页:修修修也 🎏所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 一.了解项目功能 在本次项目中我们的目标是实现一个链队列: 该链队列使用动态内存分配空间,可以用来存储任意数量的同类型数据. 队列结点(QNode)需要包含两个要素:数据域data,…