TensorFlow系列:第四讲:MobileNetV2实战

news2025/1/11 2:41:53

一. 加载数据集

编写工具类,实现数据集的加载

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/808da38d6ad74628b869c28e937b02d9.png


import keras

"""
加载数据集工具类
"""


class DatasetLoader:
    def __init__(self, path_url, image_size=(224, 224), batch_size=32, class_mode='categorical'):
        self.path_url = path_url
        self.image_size = image_size
        self.batch_size = batch_size
        self.class_mode = class_mode

    # 不使用图像增强
    def load_data(self):
        # 加载训练数据集
        train_data = keras.preprocessing.image_dataset_from_directory(
            self.path_url + '/train',  # 训练数据集的目录路径
            image_size=self.image_size,  # 调整图像大小
            batch_size=self.batch_size,  # 每批次的样本数量
            label_mode=self.class_mode,  # 类别模式:返回one-hot编码的标签
        )

        # 加载验证数据集
        val_data = keras.preprocessing.image_dataset_from_directory(
            self.path_url + '/validation',  # 验证数据集的目录路径
            image_size=self.image_size,  # 调整图像大小
            batch_size=self.batch_size,  # 每批次的样本数量
            label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签
        )
        # 加载测试数据集
        test_data = keras.preprocessing.image_dataset_from_directory(
            self.path_url + '/test',  # 验证数据集的目录路径
            image_size=self.image_size,  # 调整图像大小
            batch_size=self.batch_size,  # 每批次的样本数量
            label_mode=self.class_mode  # 类别模式:返回one-hot编码的标签
        )
        class_names = train_data.class_names
        return train_data, val_data, test_data, class_names

二. 训练模型完整代码

import keras
from keras import layers

from utils.dataset_loader import DatasetLoader

"""
使用MobileNetV2,实现图像多分类
"""

# 模型训练地址
PATH_URL = '../data/fruits'
# 训练曲线图
RESULT_URL = '../results/fruits'
# 模型保存地址
SAVED_MODEL_DIR = '../saved_model/fruits'

#  图片大小
IMG_SIZE = (224, 224)
# 定义图像的输入形状
IMG_SHAPE = IMG_SIZE + (3,)
# 数据加载批次,训练轮数
BATCH_SIZE, EPOCH = 32, 16


# 训练模型
def train():
    # 实例化数据集加载工具类
    dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)
    train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()

    # 构建 MobileNet 模型
    base_model = keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False)
    # 将模型的主干参数进行冻结
    base_model.trainable = False
    model = keras.Sequential([
        layers.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
        # 设置主干模型
        base_model,
        # 对主干模型的输出进行全局平均池化
        layers.GlobalAveragePooling2D(),
        # 通过全连接层映射到最后的分类数目上
        layers.Dense(len(class_total), activation='softmax')
    ])
    # 编译模型
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # 模型结构
    model.summary()
    # 指明训练的轮数epoch,开始训练
    model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)
    # 测试
    loss, accuracy = model.evaluate(test_ds)
    # 输出结果
    print('Mobilenet test accuracy :', accuracy, ',loss :', loss)
    # 保存模型 savedModel格式
    model.export(filepath=SAVED_MODEL_DIR)


if __name__ == '__main__':
    train()

训练模型输出如下:

模型结构:

在这里插入图片描述
训练进度:主要看最下边一行输出,一轮训练完成会显示训练集和验证集的正确率。
在这里插入图片描述
验证正确率:

在这里插入图片描述
保存的模型:

在这里插入图片描述

三. 函数式调用方式

以后的所有讲解,都基于函数式方式进行,因为函数式调用比较灵活。

# 函数式调用方式
def train1():
    # 实例化数据集加载工具类
    dataset_loader = DatasetLoader(PATH_URL, IMG_SIZE, BATCH_SIZE)
    train_ds, val_ds, test_ds, class_total = dataset_loader.load_data()

    inputs = keras.Input(shape=IMG_SHAPE)
    # 加载预训练的 MobileNetV2 模型,不包括顶层分类器,并在 Rescaling 层之后连接
    base_model = keras.applications.MobileNetV3Large(weights='imagenet', include_top=False, input_tensor=inputs)

    # 冻结 MobileNetV2 的所有层,以防止在初始阶段进行权重更新
    for layer in base_model.layers:
        layer.trainable = False
    # 在 MobileNetV2 之后添加自定义的顶层分类器
    x = layers.GlobalAveragePooling2D()(base_model.output)
    predictions = layers.Dense(len(class_total), activation='softmax')(x)
    # 构建最终模型
    model = keras.Model(inputs=base_model.input, outputs=predictions)
    # 编译模型
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    # 查看模型结构
    model.summary()
    model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)
    # 测试
    loss, accuracy = model.evaluate(test_ds)
    # 输出结果
    print('Mobilenet test accuracy :', accuracy, ',loss :', loss)
    # 保存模型 savedModel格式
    model.export(filepath=SAVED_MODEL_DIR)

四. 保存训练过程曲线图

在训练模型时,我们不可能时时盯着训练数据结果,如果把训练过程曲线保存成图片,这样就比较方便查看。

在项目中编写一个工具类如下:
在这里插入图片描述
上边代码简单改造:

    # 训练模型
    history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCH)
    # 保存曲线图
    Utils.trainResult(history, RESULT_URL)

曲线图如下:训练集和验证集准确率上升,损失率下降,这是完美的表现。

在这里插入图片描述

五. 模型可视化批量测试

在这里插入图片描述
编写可视化批量测试工具类:

import keras
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import FancyBboxPatch

from utils.dataset_loader import DatasetLoader

"""
模型工具类
"""


class ModelUtil:
    def __init__(self, saved_model_dir, path_url):
        self.save_model_dir = saved_model_dir  # savedModel 模型保存地址
        self.path_url = path_url  # 模型训练数据地址

    # 批量识别 进行可视化显示
    def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):
        dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)
        train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()
        # 加载savedModel模型
        tfs_layer = keras.layers.TFSMLayer(self.save_model_dir)
        # 创建一个新的 Keras 模型,包含 TFSMLayer
        model = keras.Sequential([
            keras.Input(shape=image_size + (3,)),  # 根据你的模型的输入形状
            tfs_layer
        ])

        plt.figure(figsize=(10, 10))
        for images, labels in test_ds.take(1):
            # 使用模型进行预测
            outputs = model.predict(images)
            for i in range(num_images):
                plt.subplot(5, 5, i + 1)
                image = np.array(images[i]).astype("uint8")
                plt.imshow(image)
                index = int(np.argmax(outputs[i]))
                prediction = outputs[i][index]
                percentage_str = "{:.2f}%".format(prediction * 100)
                plt.title(f"{class_names[index]}: {percentage_str}")
                plt.axis("off")
        plt.subplots_adjust(hspace=0.5, wspace=0.5)
        plt.show()

使用工具类:

if __name__ == '__main__':
    # train()
    model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)
    model_util.batch_evaluation()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925277.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PostgreSQL日志文件配置,记录所有操作记录

为了更详细的记录PostgreSQL 的运行日志,我们一般需要修改PostgreSQL 默认的配置文件,这里整理了一些常用的配置 修改配置文件 打开 PostgreSQL 配置文件 postgresql.conf。该文件通常位于 PostgreSQL 安装目录下的 data 文件夹中。 找到并修改以下配…

1.10-改进CBOW模型的学习

文章目录 0引言1 cupy包的安装2解决VScode中matplotlib绘图不显示的问题3 CBOW模型学习的实现4 CBOW模型对更复杂模式的捕捉5单词向量的评价方法6总结 0引言 本节将前面实现的改进的CBOW模型在PTB数据集上跑一遍由于希望跟书上一样调用GPU,因此需要安装cupy包&…

前端Canvas入门——一些注意事项

创建渐变的三种方法: createLinearGradient() - 线性渐变 createRadialGradient() - 径向渐变(放射性渐变) createConicGradient() - 锥形渐变 这三种的核心观点都是: 创建一个gradient对象,然后调用addColorStop()方法…

【软件测试】自动化测试常用函数 -- 详解

一、WebDriver API 一个简单自动化脚本的构成: 脚本解析 # coding utf-8 from selenium import webdriver import time browser webdriver.Firefox() time.sleep(3) browser.get("http://www.baidu.com") time.sleep(3) browser.find_element_by_id(…

Photoshop

彩色转灰度:ctrlshiftu 背景转黑色: 魔术棒容差10 shift连选 shiftF5(填充)钢笔选择 路径 工作路径 将路径作为选区载入 点回图层 按ctrlx删除选区 待更新

[C++]——同步异步日志系统(5)

同步异步日志系统 一、日志消息格式化设计1.1 格式化子项类的定义和实现1.2 格式化类的定义和实现 二、日志落地类设计2.1 日志落地模块功能实现与测试2.2 日志落地模块功能功能扩展 一、日志消息格式化设计 日志格式化模块的作用:对日志消息进行格式化&#xff0c…

Windows 子系统WSL2 Ubuntu使用事项

Windows 子系统WSL2 Ubuntu使用事项 要使外部设备能够访问运行在 Windows 上的 WSL2 实例,你可以端口转发的方法。由于 WSL2 是在虚拟化环境中运行,直接访问比 WSL1 更为复杂. 1 如何实现子系统可以被外部系统SSH 1.1 端口转发: 通过windows代理WSL2的…

微信视频号的视频怎么下载到本地?快速教你下载视频号视频

天来说说市面上常见的微信视频号视频下载工具,教大家快速下载视频号视频! 方法一:缓存方法 该方法来源早期视频技术,因早期无法将大量视频通过网络存储,故而会有缓存视频文件到手机,其目的为了提高用户体验…

stm32入门-----初识stm32

目录 前言 ARM stm32 1.stm32家族 2.stm32的外设资源 3.命名规则 4.系统结构 5.引脚定义 6.启动配置 7.STM32F103C8T6芯片 8.STM32F103C8T6芯片原理图与最小系统电路 前言 已经很久没跟新了,上次发文的时候是好几个月之前了,现在我是想去学习st…

C++继承和多态

目录 继承 继承的意义 访问限定符、继承方式 赋值兼容规则(切片) 子类的默认成员函数 多继承 继承is a和组合has a 多态 什么是多态 形成多态的条件 函数重载,隐藏,重写的区别 override和final 多态原理 继承 继承的…

FinalShell介绍,安装与应用

目录 一、什么是finalshell 二、finalshell功能 三、为什么要用finalshell 四、安装finalshell 五、finalshell使用 1.添加连接 获取虚拟ip地址 2.启动连接 一、什么是finalshell FinalShell是一体化的的服务器,网络管理软件,不仅是ssh客户端,还是功能强大的开发,运维工…

在RHEL9.4上启用SFTP服务

FTP存在的不足: 明文传输 FTP传输的数据(包括用户名、密码和文件内容)都是明文的,这意味着数据可以被网络上的任何人截获并读取。没有内置的加密机制,容易受到中间人攻击。 被动模式下的端口问题 FTP的被动模式需要…

server nat表和会话表的作用及NAT地址转换详细

本章节主要讲nat技术的基础 -会话表的建立也是看5元组 -状态检测技术的回包一样也看5元组,但是状态检测技术会看的除开5元组还有更多东西 老哥,你真的应该好好注意一个东西:我们的会话表只是为了后续包的转发,会话表是记录的首…

C++:哈希表

哈希表概念 哈希表可以简单理解为:把数据转化为数组的下标,然后用数组的下标对应的值来表示这个数据。如果我们想要搜索这个数据,直接计算出这个数据的下标,然后就可以直接访问数组对应的位置,所以可以用O(1)的复杂度…

澳门建筑插画:成都亚恒丰创教育科技有限公司

澳门建筑插画:绘就东方之珠的斑斓画卷 在浩瀚的中华大地上,澳门以其独特的地理位置和丰富的历史文化,如同一颗璀璨的明珠镶嵌在南国海疆。这座城市,不仅是东西方文化交融的典范,更是建筑艺术的宝库。当画笔轻触纸面&a…

能源园区可视化管理系统

利用图扑 HT 可视化打造能源园区管理系统,实时监控和优化能源分配,提升园区运行效率,增强安全管理,推动绿色和可持续发展。

信立方大模型 | 以AI之钥,开拓智能守护新疆界

在当前网络安全形势日益复杂的背景下,技术的进步不仅带来了便利,也使得网络攻击手段更加多样化和隐蔽化。据悉,国外某研究团队已成功利用GPT技术开发出一种黑客智能体框架,该框架能够深入研读CVE(通用漏洞披露&#xf…

MATLAB激光通信和-积消息传递算法(Python图形模型算法)模拟调制

🎯要点 🎯概率论和图论数学形式和图结构 | 🎯数学形式、图结构和代码验证贝叶斯分类器算法:🖊多类型:朴素贝叶斯,求和朴素贝叶斯、高斯朴素贝叶斯、树增强贝叶斯、贝叶斯网络增强贝叶斯和半朴素…

Android12 MultiMedia框架之GenericSource extractor

前面两节学习到了各种Source的创建和extractor service的启动,本节将以本地播放为例记录下GenericSource是如何创建一个extractor的。extractor是在PrepareAsync()方法中被创建出来的,为了不过多赘述,我们直接从GenericSource的onPrepareAsyn…

LeetCode刷题笔记第3011题:判断一个数组是否可以变为有序

LeetCode刷题笔记第3011题:判断一个数组是否可以变为有序 题目: 想法: 使用冒泡排序进行排序,在判断大小条件时加入判断二进制下数位为1的数目是否相同,相同则可以进行互换。最后遍历数组,相邻两两之间是…