基于 Tensorflow 2.x 使用 MobileNetV2 微调模型优化训练花卉图像识别模型

news2025/1/11 12:53:55

一、模型微调

上篇文章我们通过搭建三层卷积模型,训练了花卉图像识别模型,最后经验证集验证后准确率大约为 75% ,本篇文章对该数据集进行优化,提高识别的准确度。本篇文章中对于数据集的读取强化不做过多的介绍了,大家可以参考上篇文章中的介绍,下面是上篇文章的地址:

https://blog.csdn.net/qq_43692950/article/details/128518757

深度学习模型应用于小型图像数据集场景下,一般由于数据量的局限性,导致模型提取特征有限,进而影响识别的准确度,一种常用且非常高效的优化方式便是使用预训练网络模型。将一个已在大型数据集(例如ImageNet)上训练好的模型作为基础模型。在此基础上对自己的数据集再进行训练,即使新的数据集和原始数据集完全不同的情况下也可以得到很好的特诊提取。

使用预训练网络有两种方法:特征提取(feature extraction)微调模型(fine-tuning)

  • 特征提取(feature extraction):使用预训练的优秀模型和权重来从新样本中提取特征,最后给到一个新的分类器,从头开始训练,之前的权重会随着反向传播进行修改。

  • 微调模型(fine-tuning):使用预训练的优秀模型和权重来从新样本中提取特征,最后同样给到一个新的分类器,但不同的是预训练模型的全部或某些层的权重被冻结,不会随着反向传播进行修改,只是略微调整了模型结构,这种方式不会破坏训练模型。

一般我们选取模型,都是基于ImageNet 上预训练过的用于图像分类的模型,在 keras 中可以在 keras.applications 下拿到各种网络模型结构,例如 VGG16、VGG19、Xception、ResNet、ResNetV2、 ResNeXt、InceptionV3、MobileNet、MobileNetV2、DenseNet 等。这些模型都是基于ImageNet 1000 分类的模型。

下面是 keras 中对各个模型的介绍:

https://keras.io/zh/applications/

其中各个模型在 ImageNet 上的Top1Top5 的准确率如下

模型大小Top-1 准确率Top-5 准确率参数数量深度
Xception88 MB0.7900.94522,910,480126
VGG16528 MB0.7130.901138,357,54423
VGG19549 MB0.7130.900143,667,24026
ResNet5098 MB0.7490.92125,636,712-
ResNet101171 MB0.7640.92844,707,176-
ResNet152232 MB0.7660.93160,419,944-
ResNet50V298 MB0.7600.93025,613,800-
ResNet101V2171 MB0.7720.93844,675,560-
ResNet152V2232 MB0.7800.94260,380,648-
ResNeXt5096 MB0.7770.93825,097,128-
ResNeXt101170 MB0.7870.94344,315,560-
InceptionV392 MB0.7790.93723,851,784159
InceptionResNetV2215 MB0.8030.95355,873,736572
MobileNet16 MB0.7040.8954,253,86488
MobileNetV214 MB0.7130.9013,538,98488
DenseNet12133 MB0.7500.9238,062,504121
DenseNet16957 MB0.7620.93214,307,880169
DenseNet20180 MB0.7730.93620,242,984 201
NASNetMobile23 MB0.7440.9195,326,716-
NASNetLarge343 MB0.8250.96088,949,818-

例如:查看 VGG16 模型的结构,其中如果权重文件不存在则会自动下载:

from tensorflow import keras

model = keras.applications.VGG16(weights='imagenet', include_top=True)
model.summary()

打印出下面内容,可以看出输入是 (224, 224, 3) 大小的三维图片,输出是一个 1000 分类的模型:

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
fc1 (Dense)                  (None, 4096)              102764544 
_________________________________________________________________
fc2 (Dense)                  (None, 4096)              16781312  
_________________________________________________________________
predictions (Dense)          (None, 1000)              4097000   

使用 VGG16 预测图片分类,图片如下:
请添加图片描述

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.preprocessing.image import load_img

plt.rcParams['font.sans-serif'] = ['SimHei']

model = keras.applications.VGG16(weights='imagenet', include_top=True)

img_path = './img/cat.jpg'
image = load_img(img_path, target_size=(224, 224))
image = img_to_array(image)
# 预测
y_predictions = model.predict(tf.expand_dims(image, 0))
# 解析标签
y_lable = keras.applications.vgg16.decode_predictions(y_predictions)

plt.imshow(image / 255.0)
plt.title('预测结果:' + y_lable[0][0][1] + ',概率:' + str("%.2f" % (y_lable[0][0][2] * 100)) + '%')
plt.show()

预测结果:

在这里插入图片描述

二、使用 MobileNetV2 微调模型优化训练花卉图像识别

关于图像的预处理可以参考上篇文章,这里基于 MobileNetV2 模型进行微调,MobileNetV2 模型是一个轻量级的分类模型,由google团队在2018年提出的,相比MobileNet V1网络,准确率更高,模型更小,可用于移动端设置计算。

这里对 MobileNetV2 模型去除最后的分类器层,其余权重均被冻结,并在最后,加上两层全连接层和一层分类层,其中在 keras 中冻结参数,也非常容易,只需 Model.trainable = False 即可,例如:

mobienet = keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=(180, 180, 3))
 # 冻结权重
mobienet.trainable = False

整体模型如下所示:

在这里插入图片描述
下面使用 keras 构建模型结构:

import tensorflow as tf
from tensorflow import keras

# 定义模型类
class Model():
    # 初始化结构
    def __init__(self, checkpoint_path, log_path, model_path, num_classes, img_width, img_height):
        # checkpoint 权重保存地址
        self.checkpoint_path = checkpoint_path
        # 训练日志保存地址
        self.log_path = log_path
        # 训练模型保存地址:
        self.model_path = model_path

        # 数据统一大小并归一处理
        resize_and_rescale = tf.keras.Sequential([
            keras.layers.Resizing(img_width, img_height),
            keras.layers.Rescaling(1. / 255)
        ])
        # 数据增强
        data_augmentation = tf.keras.Sequential([
            # 翻转
            keras.layers.RandomFlip("horizontal_and_vertical"),
            # 旋转
            keras.layers.RandomRotation(0.2),
            # 对比度
            keras.layers.RandomContrast(0.3),
            # 随机裁剪
            # tf.keras.layers.RandomCrop(IMG_SIZE, IMG_SIZE),
            # 随机缩放
            keras.layers.RandomZoom(height_factor=0.3, width_factor=0.3),
        ])
        # MobileNetV2 模型结构
        mobienet = keras.applications.MobileNetV2(weights='imagenet', include_top=False, input_shape=(180, 180, 3))
        # 冻结权重
        mobienet.trainable = False
        # 初始化模型结构
        self.model = keras.Sequential([
            resize_and_rescale,
            data_augmentation,
            mobienet,
            keras.layers.Flatten(),
            keras.layers.Dense(1024,
                               kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                               kernel_regularizer=keras.regularizers.l2(0.001),
                               activation=tf.nn.relu),
            keras.layers.Dense(256,
                               kernel_initializer=keras.initializers.truncated_normal(stddev=0.05),
                               kernel_regularizer=keras.regularizers.l2(0.001),
                               activation=tf.nn.relu),
            keras.layers.Dense(num_classes, activation='softmax')
        ])

    # 编译模型
    def compile(self):
        # 输出模型摘要
        self.model.build(input_shape=(None, 180, 180, 3))
        self.model.summary()
        # 定义训练模式
        self.model.compile(optimizer='adam',
                           loss='sparse_categorical_crossentropy',
                           metrics=['accuracy'])

    # 训练模型
    def train(self, train_ds, val_ds):
        # tensorboard 训练日志收集
        tensorboard = keras.callbacks.TensorBoard(log_dir=self.log_path)

        # 训练过程保存 Checkpoint 权重,防止意外停止后可以继续训练
        model_checkpoint = keras.callbacks.ModelCheckpoint(self.checkpoint_path,  # 保存模型的路径
                                                           # monitor='val_loss',  # 被监测的数据。
                                                           verbose=0,  # 详细信息模式,0 或者 1
                                                           save_best_only=True,  # 如果 True, 被监测数据的最佳模型就不会被覆盖
                                                           save_weights_only=True,
                                                           # 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)),否则的话,整个模型会被保存,(model.save(filepath))
                                                           mode='auto',
                                                           # {auto, min, max}的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto模式中,方向会自动从被监测的数据的名字中判断出来。
                                                           period=3  # 每3个epoch保存一次权重
                                                           )
        # 填充数据,迭代训练
        self.model.fit(
            train_ds,  # 训练集
            validation_data=val_ds,  # 验证集
            epochs=30,  # 迭代周期
            verbose=2,  # 训练过程的日志信息显示,一个epoch输出一行记录
            callbacks=[tensorboard, model_checkpoint]
        )
        # 保存训练模型
        self.model.save(self.model_path)

    def evaluate(self, val_ds):
        # 评估模型
        test_loss, test_acc = self.model.evaluate(val_ds)
        return test_loss, test_acc

处理数据集,其中使用 80% 的图像进行训练,20% 的图像进行验证。:

import tensorflow as tf
import pathlib
from tensorflow import keras

def getData():
    # 加载数据集
    path = "F:/Tensorflow/datasets/flower/flower_photos"
    # 解析目录
    data_dir = pathlib.Path(path)

    # keras 加载数据集
    batch_size = 20
    img_height = 180
    img_width = 180

    # 使用 80% 的图像进行训练,20% 的图像进行验证。
    class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    train_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="training",
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        seed=123,
        interpolation='bilinear',
        crop_to_aspect_ratio=True,
        labels='inferred',
        class_names=class_names,
        color_mode='rgb'
    )

    val_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        seed=123,
        interpolation='bilinear',
        crop_to_aspect_ratio=True,
        labels='inferred',
        class_names=class_names,
        color_mode='rgb'
    )

    AUTOTUNE = tf.data.AUTOTUNE
    train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
    return train_ds, val_ds, len(class_names), img_width, img_height

开始训练模型:

def main():
    # 加载数据集
    train_ds, val_ds, num_classes, img_width, img_height = getData()

    checkpoint_path = './checkout/'
    log_path = './log'
    model_path = './model/model.h5'

    # 构建模型
    model = Model(checkpoint_path, log_path, model_path, num_classes, img_width, img_height)
    # 编译模型
    model.compile()
    # 训练模型
    model.train(train_ds, val_ds)
    # 评估模型
    test_loss, test_acc = model.evaluate(val_ds)
    print(test_loss, test_acc)

if __name__ == '__main__':
    main()

运行后可以看到打印的网络结构:

在这里插入图片描述

从训练日志中,可以看到 loss 一直在减小:

在这里插入图片描述

训练结束后评估模型的结果,最终在验证集上的准确率为: 85.96% ,比上篇文章整整高出了 10%

在这里插入图片描述
最后看下 tensorboard 中可视化的损失及准确率:

tensorboard --logdir=log/train

在这里插入图片描述
使用浏览器访问:http://localhost:6006/ 查看结果:

在这里插入图片描述

三、模型预测

训练后会在 model 下生成 model.h5 模型,下面直接加载该模型进行预测:

import tensorflow as tf
import pathlib
from tensorflow import keras
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']


def main():
    # 加载数据集
    path = "F:/Tensorflow/datasets/flower/flower_photos"
    # 解析目录
    data_dir = pathlib.Path(path)

    # keras 加载数据集
    batch_size = 32
    img_height = 180
    img_width = 180

    class_names = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    class_names_cn = ['雏菊', '蒲公英', '玫瑰', '向日葵', '郁金香']
    val_ds = keras.utils.image_dataset_from_directory(
        data_dir,
        validation_split=0.2,
        subset="validation",
        image_size=(img_height, img_width),
        batch_size=batch_size,
        shuffle=True,
        seed=123,
        interpolation='bilinear',
        crop_to_aspect_ratio=True,
        labels='inferred',
        class_names=class_names,
        color_mode='rgb'
    )

    model = keras.models.load_model('./model/model.h5')

    for image_batch, labels_batch in val_ds.take(3):
        plt.figure(figsize=(10, 10))
        for i in range(9):
            plt.subplot(3, 3, i + 1)
            softmax = model.predict(tf.expand_dims(image_batch[i], 0))
            y_label = tf.argmax(softmax, axis=1).numpy()[0]
            plt.imshow(image_batch[i].numpy().astype("uint8"))
            plt.title('预测结果:' + class_names_cn[y_label] + ',概率:' + str("%.2f" % softmax[0][y_label]) + ',真实结果:' +
                      class_names_cn[labels_batch[i]])
            plt.axis('off')
        plt.show()


if __name__ == '__main__':
    main()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/134460.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

免拆机,Kindle固件版本5.10.3~5.13.3如何越狱?简单、易操作版

前言 之前有出过Kindle的越狱教程: 无需拆机,Kindle 全系列 5.12.2.2 ~ 5.14.2版本如何越狱?如何安装第三方插件 确实可以越狱,使用的漏洞也是: KindleDrip — From Your Kindle’s Email Address to Using Your C…

ubuntu18.04下mysql数据库C语言API封装

mysql C语言API操作数据库比较繁琐,可以将其封装起来,这样使用比较方便,下面是一种封装方式。 目录 1.连接封装 2.连接池封装 3.测试代码 1.连接封装 将数据库连接进行封装,主要提供如下接口: (1&…

L2-030 冰岛人

2018年世界杯,冰岛队因1:1平了强大的阿根廷队而一战成名。好事者发现冰岛人的名字后面似乎都有个“松”(son),于是有网友科普如下: 冰岛人沿用的是维京人古老的父系姓制,孩子的姓等于父亲的名加后缀&#x…

torchnet.meter使用教程

前言 最近项目开发过程中遇到了torchnet.metertorchnet.metertorchnet.meter来记录模型信息,搜了好多篇博客潦潦草草,没有一点干货,于是根据官方代码和官方文档,基于自己的理解,制定了使用教程: torchnet简介 torch…

一句话实现报表生成PDF同时通过outlook发送

元旦节快乐 哈喽,大家2023年好呀! 今天,元旦最后一天,给大家分享什么好玩的示例呢? 让我来想想,嗯?这样可以吗?一句话就实现将报表生成PDF,同时可以编辑一些信息并通过…

【源码分享】java多用户B2B2C商城源码带WAP手机端源码

分享一款非常不错的java多用户B2B2C商城源码,带WAP手机端源码,源码地址在文末。 需要源码学习,可私信我获取。 一、技术构架: 开发语言: Java1.7 数 据 库 : MySQL5.5 数据库持久层:阿里巴巴…

车载诊断协议UDS——会话模式状态机Session

UDS之Session服务 会话模式管控是汽车电子诊断范畴很重要的两个状态机之一(另一个是安全访问),不同的会话模式是用来区分诊断服务执行权限。 一位非常尊敬的业内前辈曾举如下例子来形容这个状态机:不同的场景,喝对应的酒! 公司商务场合下,对应的酒是红酒;长辈酒桌上,对…

Redis 哨兵模式

哨兵是一个分布式系统,你可以在一个架构中运行多个哨兵进程,这些进程使用流言协议来接收关于Master主服务器是否下线的信息,并使用投票协议来决定是否执行自动故障迁移,以及选择哪个Slave作为新的Master。 一、哨兵模式概述 1.1…

ubuntu做系统常见出错处理方法1

1.不能分区解决办法(安装ubuntu没有出现安装选项,也就是找不到硬盘分区怎么办?-爱码网) 解决办法是进入bios模式(一般都是重启时反复按f12,不同电脑型号可自行查阅)把硬盘模式从raid调整为ahci(System configuration–&#xff…

方差和标准差的意义

文章目录案例:箭靶案例:身高案例:身高体重在此前一篇文章 《算法效果评估:均方根误差(RMSE)/ 标准误差》中,我们介绍了方差/标准差的计算方法,也点出了它们是用来“度量数据离散程度…

linux系统中wifi驱动的配置与编译实现方法

大家好,今天主要和大家聊一聊,如何使用linux系统中的WIFI驱动完成相应的实验。 目录 第一:WIFI驱动添加与编译方法 第二:将驱动代码添加到linux内核中 第三:配置Linux内核 第四:编译WIFI驱动 第一&…

YOLOv5更换骨干网络之 MobileNetV3

论文地址:https://arxiv.org/abs/1905.02244 代码地址:https://github.com/xiaolai-sqlai/mobilenetv3 我们展示了基于互补搜索技术和新颖架构设计相结合的下一代 MobileNets。MobileNetV3通过结合硬件感知网络架构搜索(NAS)和 N…

MySQL基础篇

MySQL数据库笔记 第一部分 MySQL基础篇 第01章 数据库概述 1. 为什么要使用数据库 持久化(persistence):把数据保存到可掉电式存储设备中以供之后使用。大多数情况下,特别是企业级应用,数据持久化意味着将内存中的数据保存到硬盘上加以“…

网络类型实验

1.先配ip [Huawei]sysname R1 [R1]interface GigabitEthernet 0/0/1 [R1-GigabitEthernet0/0/1]ip add 192.168.1.1 24 [R1-GigabitEthernet0/0/1]int s 4/0/0 [R1-Serial4/0/0]ip add 12.1.1.1 24 其他同理 2.写三条缺省指向R2来使网络通 [R1]ip route-static 0.0.0.0 0 12…

【王道操作系统】3.1.1 什么是内存?进程的基本原理,深入指令理解其过程

什么是内存?进程的基本原理,深入指令理解其过程 文章目录什么是内存?进程的基本原理,深入指令理解其过程1.什么是内存?有何作用?2.进程运行的基本原理2.1 指令的工作原理---操作码若干参数2.2 逻辑地址(相对…

C++类和对象3:关于类内部的更多细节

目录 初始化列表: explicit关键字 ​编辑 static成员 友元 内部类 匿名对象 拷贝对象时的一些编译器优化 我们已经接触过了构造函数,其功能可以很方便的帮助我们为变量赋值,但是在这里并不是初始化,因为一个构造函数可以为几…

02 Hadoop概述

Hadoop概述1、Hadoop是什么2、Hadoop版本3、HDFS、YARN、MapReduce(1) HDFS(2)YARN(3)MapReduce(3)Hadoop模块之间的关系1、Hadoop是什么 是一个由Apache基金会开发的分布式系统基础…

动态规划是个好东西:编辑距离

力扣:72. 编辑距离 这道题目让我狠狠的了解了动态规划,这玩意是真强。 题目描述很简单: 这道题正常来说,我们要考虑这个字符怎么换,长度不一怎么找…等等问题,但是这样做会发现很困难,显然这是…

Vert.x 核心概念及事件模型

Vert.x是基于事件的,提供一个事件驱动编程模型 使用Vert.x作为服务器时,程序员只要编写事件处理器event handler即可。(当TCP socket有数据时,event handler被创建调用) 另外它还可以在以下几种情况激活: …

反向迭代器

文章目录1. list的反向迭代器2. list的rbegin和rend3. 反向迭代器的实现3.1 复用vector反向迭代器3.2 反向迭代器的变化1. list的反向迭代器 我们先来看一看库里面的list的迭代器是如何写的: 这是list的正向迭代器。 这是list的反向迭代器。 其实大佬们是把正向迭…