基于EfficientNet(B0-B7)全系列不同参数量级模型开发构建中草药图像识别分析系统,实验量化对比不同模型性能

news2024/9/25 1:16:01

EfficientNet系列的模型在我们前面开发识别类项目或者是检测类项目都是比较少去使用的,一方面是技术本身迭代发展的速度是比较快的,可能新的东西还没学习更新的东西就出来了,另一方面是EfficientNet本身实际业务使用度并不高,可能真正项目开发落地过程中还需要解决额外的问题。

最近正好项目中在做一些识别相关的内容,我也陆陆续续写了一些实验性质的博文用于对自己使用过的模型进行真实数据的评测对比分析,感兴趣的话可以自行移步阅读即可:

《移动端轻量级模型开发谁更胜一筹,efficientnet、mobilenetv2、mobilenetv3、ghostnet、mnasnet、shufflenetv2驾驶危险行为识别模型对比开发测试》

《图像识别模型哪家强?19款经典CNN模型实践出真知【以眼疾识别数据为基准,对比MobileNet系列/EfficientNet系列/VGG系列/ResNet系列/i、xception系列】》

《基于轻量级卷积神经网络模型实践Fruits360果蔬识别——自主构建CNN模型、轻量化改造设计lenet、alexnet、vgg16、vgg19和mobilenet共六种CNN模型实验对比分析》

《基于轻量级模型GHoshNet开发构建眼球眼疾识别分析系统,构建全方位多层次参数对比分析实验》

本文的主要目的是想要以基准数据集【中草药图像数据集】为例,开发构建EfficientNet全系列不同参数量级的模型,之后在同样的测试数据集上进行评测对比分析。

数据集中共包含23种类别数据,清单如下:

aiye
baibiandou
baibu
baidoukou
baihe
cangzhu
cansha
dangshen
ezhu
foshou
gancao
gouqi
honghua
hongteng
huaihua
jiangcan
jingjie
jinyinhua
mudanpi
niubangzi
zhuling
zhuru
zhuye
zicao

简单看下部分数据实例:

EfficientNet是由谷歌研究团队提出的一种高效的卷积神经网络(CNN)架构,其构建原理基于网络深度、网络宽度和分辨率缩放的均衡策略:

1、网络深度:EfficientNet采用了复合系数(compound scaling)的思想,通过增加网络深度来提高其表达能力。复合系数是一个复合的缩放因子,将网络的深度、宽度和分辨率进行关联调整,以实现更好的性能。增加网络深度可以提高模型的表示能力,帮助模型更好地学习复杂的特征和模式。

2、网络宽度:在网络宽度方面,EfficientNet采用了通道缩放(channel scaling)的方法,通过调整每个卷积层的通道数来提高模型的表达能力。通道缩放可以在不增加过多参数的情况下提高模型的性能,使其更有效地利用计算资源。

3、分辨率缩放:EfficientNet通过分辨率缩放来调整输入图像的分辨率,以改善模型对不同尺度下的特征的学习能力。将输入图像的分辨率调整为适当的大小,可以使模型更好地适应不同尺度的特征,并提高在输入图像分辨率较高时的性能。

在构建EfficientNet时,研究团队通过对网络深度、宽度和分辨率的均衡性进行优化,提出了一种更高效的神经网络模型。经过复合系数的调整,EfficientNet在提高性能的同时也考虑了模型的计算效率,使其在训练速度和推断速度方面都表现出了很好的性能。总的来说,EfficientNet的构建原理可以概括为通过复合系数调整网络深度、宽度和分辨率,以实现在提高性能的同时保持计算效率。这种均衡策略使EfficientNet成为一种高效的深度学习模型,在图像识别等任务中取得了优秀的表现。

EfficientNet系列模型共构建了从B0到B7八个不同参数量级的模型,开源社区里面也有很多优秀的实现可以根据自己的实际需求选择即可,下面是我自己使用的keras实现的版本,如下所示:

def EfficientNet(
    input_shape,
    block_args_list: List[BlockArgs],
    width_coefficient: float,
    depth_coefficient: float,
    include_top=True,
    weights=None,
    input_tensor=None,
    pooling=None,
    classes=1000,
    dropout_rate=0.0,
    drop_connect_rate=0.0,
    batch_norm_momentum=0.99,
    batch_norm_epsilon=1e-3,
    depth_divisor=8,
    min_depth=None,
    data_format=None,
    default_size=None,
    **kwargs
):
    if data_format is None:
        data_format = K.image_data_format()
    if data_format == "channels_first":
        channel_axis = 1
    else:
        channel_axis = -1
    if default_size is None:
        default_size = 224
    if block_args_list is None:
        block_args_list = get_default_block_list()
    stride_count = 1
    for block_args in block_args_list:
        if block_args.strides is not None and block_args.strides[0] > 1:
            stride_count += 1
    min_size = int(2**stride_count)
    input_shape = _obtain_input_shape(
        input_shape,
        default_size=default_size,
        min_size=min_size,
        data_format=data_format,
        require_flatten=include_top,
        weights=weights,
    )
    if input_tensor is None:
        inputs = layers.Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_tensor):
            inputs = layers.Input(tensor=input_tensor, shape=input_shape)
        else:
            inputs = input_tensor
    x = inputs
    x = layers.Conv2D(
        filters=round_filters(32, width_coefficient, depth_divisor, min_depth),
        kernel_size=[3, 3],
        strides=[2, 2],
        kernel_initializer=EfficientNetConvInitializer(),
        padding="same",
        use_bias=False,
    )(x)
    x = layers.BatchNormalization(
        axis=channel_axis, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon
    )(x)
    x = Swish()(x)
    num_blocks = sum([block_args.num_repeat for block_args in block_args_list])
    drop_connect_rate_per_block = drop_connect_rate / float(num_blocks)
    for block_idx, block_args in enumerate(block_args_list):
        assert block_args.num_repeat > 0
        block_args.input_filters = round_filters(
            block_args.input_filters, width_coefficient, depth_divisor, min_depth
        )
        block_args.output_filters = round_filters(
            block_args.output_filters, width_coefficient, depth_divisor, min_depth
        )
        block_args.num_repeat = round_repeats(block_args.num_repeat, depth_coefficient)
        x = MBConvBlock(
            block_args.input_filters,
            block_args.output_filters,
            block_args.kernel_size,
            block_args.strides,
            block_args.expand_ratio,
            block_args.se_ratio,
            block_args.identity_skip,
            drop_connect_rate_per_block * block_idx,
            batch_norm_momentum,
            batch_norm_epsilon,
            data_format,
        )(x)
        if block_args.num_repeat > 1:
            block_args.input_filters = block_args.output_filters
            block_args.strides = [1, 1]
        for _ in range(block_args.num_repeat - 1):
            x = MBConvBlock(
                block_args.input_filters,
                block_args.output_filters,
                block_args.kernel_size,
                block_args.strides,
                block_args.expand_ratio,
                block_args.se_ratio,
                block_args.identity_skip,
                drop_connect_rate_per_block * block_idx,
                batch_norm_momentum,
                batch_norm_epsilon,
                data_format,
            )(x)
    x = layers.Conv2D(
        filters=round_filters(1280, width_coefficient, depth_coefficient, min_depth),
        kernel_size=[1, 1],
        strides=[1, 1],
        kernel_initializer=EfficientNetConvInitializer(),
        padding="same",
        use_bias=False,
    )(x)
    x = layers.BatchNormalization(
        axis=channel_axis, momentum=batch_norm_momentum, epsilon=batch_norm_epsilon
    )(x)
    x = Swish()(x)
    if include_top:
        x = layers.GlobalAveragePooling2D(data_format=data_format)(x)
        if dropout_rate > 0:
            x = layers.Dropout(dropout_rate)(x)
        x = layers.Dense(classes, kernel_initializer=EfficientNetDenseInitializer())(x)
        x = layers.Activation("softmax")(x)
    else:
        if pooling == "avg":
            x = layers.GlobalAveragePooling2D()(x)
        elif pooling == "max":
            x = layers.GlobalMaxPooling2D()(x)
    outputs = x
    if input_tensor is not None:
        inputs = get_source_inputs(input_tensor)
    model = Model(inputs, outputs)
    return model

训练集占比75%,测试集占比25%,所有模型按照相同的数据集配比进行实验对比分析,计算准确率、精确率、召回率和F1值四种指标。结果详情如下所示:

{
	"EfficientNetB0": {
		"accuracy": 0.7323568575233023,
		"precision": 0.7342840951689299,
		"recall": 0.7401447749844765,
		"f1": 0.721134606264584
	},
	"EfficientNetB1": {
		"accuracy": 0.6711051930758988,
		"precision": 0.674815403573669,
		"recall": 0.6814607103125642,
		"f1": 0.654513293295677
	},
	"EfficientNetB2": {
		"accuracy": 0.6804260985352862,
		"precision": 0.6887820209389782,
		"recall": 0.693270915541022,
		"f1": 0.6667803713033284
	},
	"EfficientNetB3": {
		"accuracy": 0.6577896138482025,
		"precision": 0.6575815350378532,
		"recall": 0.6697388006766206,
		"f1": 0.6394242217002369
	},
	"EfficientNetB4": {
		"accuracy": 0.607190412782956,
		"precision": 0.6009338789436689,
		"recall": 0.6201496013702704,
		"f1": 0.575676812729165
	},
	"EfficientNetB5": {
		"accuracy": 0.3754993342210386,
		"precision": 0.4126647746126044,
		"recall": 0.3840962856358272,
		"f1": 0.33776556111562969
	},
	"EfficientNetB6": {
		"accuracy": 0.29427430093209058,
		"precision": 0.3007763995135063,
		"recall": 0.3057837376695735,
		"f1": 0.23015536725242517
	},
	"EfficientNetB7": {
		"accuracy": 0.19573901464713715,
		"precision": 0.1291924114976619,
		"recall": 0.1938956570492365,
		"f1": 0.10902900005697842
	}
}

简单介绍下上述使用的四种指标:

准确率(Accuracy):即分类器正确分类的样本数占总样本数的比例,通常用于评估分类模型的整体预测能力。计算公式为:准确率 = (TP + TN) / (TP + TN + FP + FN),其中 TP 表示真正例(分类器将正例正确分类的样本数)、TN 表示真负例(分类器将负例正确分类的样本数)、FP 表示假正例(分类器将负例错误分类为正例的样本数)、FN 表示假负例(分类器将正例错误分类为负例的样本数)。

精确率(Precision):即分类器预测为正例中实际为正例的样本数占预测为正例的样本数的比例。精确率评估分类器在预测为正例时的准确程度,可以避免过多地预测假正例。计算公式为:精确率 = TP / (TP + FP)。

召回率(Recall):即分类器正确预测为正例的样本数占实际为正例的样本数的比例。召回率评估分类器在实际为正例时的识别能力,可以避免漏掉过多的真正例。计算公式为:召回率 = TP / (TP + FN)。

F1 值(F1-score):综合考虑精确率和召回率,是精确率和召回率的调和平均数。F1 值在评估分类器综合表现时很有用,因为它同时关注了分类器的预测准确性和识别能力。计算公式为:F1 值 = 2 * (精确率 * 召回率) / (精确率 + 召回率)。 F1 值的取值范围在 0 到 1 之间,值越大表示分类器的综合表现越好。

为了能够直观清晰地对比不同模型的评测结果,这里对其进行可视化分析,如下所示:

这个结果着实是没有预想到的,参数量更大的B7模型反而得到的效果是最差的,这个可能也跟我的显存太小跑B7的时候调小了很多Batch_size,但是感觉这个也不应该会差这么多,总之就是结果反映出来的问题很奇怪,后面有时间选择别的数据集再去尝试一下看看是不是都是这个情况。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1408783.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

maptalks 右键删除多边形 电子围栏

<!-- 地图组件 --> <template><div :id"id" class"container"></div> </template><script> import _ from "lodash"; import "maptalks/dist/maptalks.css"; import * as maptalks from "ma…

RDMA vs InfiniBand 网卡接口如何区分?

(该架构图来源于参考文献) 高性能计算网络&#xff0c;RoCE vs. InfiniBand该怎么选&#xff1f; 新 RoCEv2 标准可实现 RDMA 路由在第三层以太网网络中的传输。RoCEv2 规范将用以太网链路层上的 IP 报头和 UDP 报头替代 InfiniBand 网络层。这样&#xff0c;就可以在基于 IP…

虹科分享丨AR与AI融合加速,医疗护理更便捷!

来源&#xff1a;虹科数字化与AR 虹科分享丨AR与AI融合加速&#xff0c;医疗护理更便捷&#xff01; 原文链接&#xff1a;https://mp.weixin.qq.com/s/Fi0wNfk_TDXRo_1-6cSRNQ 欢迎关注虹科&#xff0c;为您提供最新资讯&#xff01; #AR眼镜 #医疗护理 根据Reports and Da…

【动态规划】【map】【C++算法】1289. 下降路径最小和 II

作者推荐 视频算法专题 本文涉及知识点 动态规划汇总 map LeetCode1289. 下降路径最小和 II 给你一个 n x n 整数矩阵 grid &#xff0c;请你返回 非零偏移下降路径 数字和的最小值。 非零偏移下降路径 定义为&#xff1a;从 grid 数组中的每一行选择一个数字&#xff0c;…

云计算项目五:部署数据库服务mysql |部署共享存储服务NFS | 配置网站服务

部署数据库服务mysql |部署共享存储服务NFS | 配置网站服务 案例1:配置逻辑卷步骤一:创建LV步骤二:格式化案例2:配置数据库服务器步骤一:安装软件MySQL服务软件(2台数据库服务器都要安装)步骤二:挂载lv设备步骤三:启动服务步骤四:管理员登录案例3:配置主从同步步骤一…

【自然语言处理】【深度学习】文本向量化、one-hot、word embedding编码

因为文本不能够直接被模型计算&#xff0c;所以需要将其转化为向量 把文本转化为向量有两种方式&#xff1a; 转化为one-hot编码转化为word embedding 一、one-hot 编码 在one-hot编码中&#xff0c;每一个token使用一个长度为N的向量表示&#xff0c;N表示词典的数量。 即&…

谷歌地球引擎Google Earth Engine针对不同地表类型分别自动生成随机采样点的方法

本文介绍在谷歌地球引擎&#xff08;Google Earth Engine&#xff0c;GEE&#xff09;中&#xff0c;按照给定的地表分类数据&#xff0c;对每一种不同的地物类型&#xff0c;分别加以全球范围内随机抽样点自动批量选取的方法。 本文是谷歌地球引擎&#xff08;Google Earth En…

05.Elasticsearch应用(五)

Elasticsearch应用&#xff08;五&#xff09; 1.目标 咱们这一章主要学习Mapping&#xff08;映射&#xff09; 2.介绍 Mapping是对索引库中文档的约束&#xff0c;类似于数据表结构&#xff0c;作用如下&#xff1a; 定义索引中的字段的名称定义字段的数据类型&#xff…

0124-2-算法题解析与总结(四)

5.5 如何去除有序数组的重复元素 本文对应的力扣题目&#xff1a; 26.删除排序数组中的重复项 83.删除排序链表中的重复元素 26.删除排序数组中的重复项&#xff1a; int removeDuplicates(int[] nums) {int n nums.length;if (n 0) return 0;int slow 0, fast 1;while…

Spring基于AbstractRoutingDataSource实现MySQL多数据源

目录 多数据源实现 yml配置文件 配置类 业务代码 案例演示 多数据源实现 yml配置文件 spring:datasource:type: com.alibaba.druid.pool.DruidDataSourcedatasource1:url: jdbc:mysql://127.0.0.1:3306/datasource1?serverTimezoneUTC&useUnicodetrue&characte…

第二百八十六回

文章目录 概念介绍实现方法示例代码 我们在上一章回中介绍了如何拦截路由相关的内容&#xff0c;本章回中将介绍页面转场动画.闲话休提&#xff0c;让我们一起Talk Flutter吧。 概念介绍 我们在上一章回中介绍了路由拦截相关的内容&#xff0c;本章回中将使用路由拦截实现转场…

《动手学深度学习(PyTorch版)》笔记3

注&#xff1a;书中对代码的讲解并不详细&#xff0c;本文对很多细节做了详细注释。另外&#xff0c;本书源代码是在Jupyter Notebook上运行的&#xff0c;较为分散&#xff0c;本文将代码集中起来&#xff0c;并加以完善&#xff0c;全部用vscode测试通过。 Chapter3 Linear …

长城资产信息技术岗24届校招面试面经

本文介绍2024届秋招中&#xff0c;中国长城资产管理股份有限公司的信息技术岗岗位一面的面试基本情况、提问问题等。 10月投递了中国长城资产管理股份有限公司的信息技术岗岗位&#xff0c;所在部门为长城新盛信托有限责任公司。目前完成了一面&#xff0c;在这里记录一下一面经…

Puppeteer结合Jest对网页进行测试

之前我们使用Puppeteer进行网页爬虫&#xff08;以及自动化操作&#xff09;&#xff0c;这篇文章主要验证一下Puppeteer测试的可实现性。 项目设置 让我们从设置一个基本的React应用程序开始。 我们将安装其他依赖项,如Puppeteer和Faker。 为了这篇文章的目的,我创建了一个…

ASP.NET Core WebAPI从HTTPS调整为HTTP启动

使用VS2022创建WebAPI项目时&#xff0c;默认勾选“配置HTTPS(H)”&#xff0c;这样启动WebAPI时以https方式启动。   如果要从HTTPS调整为HTTP启动&#xff0c;需要修改项目中以下几处&#xff0c;首先是Program.cs中删除app.UseHttpsRedirection()语句&#xff0c;删除后…

gitlab runner 安装、注册、配置、使用

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

GitLab升级版本(任意用户密码重置漏洞CVE-2023-7028)

目录 前言漏洞分析影响范围查看自己的GitLab版本升级路程 升级过程13.1.1113.8.8 - 14.0.1214.3.614.9.5 - 16.1.6 前言 最近GitLab发了个紧急漏洞需要修复&#xff0c;ok接到命令立刻着手开始修复&#xff0c;在修复之前先大概了解一下这个漏洞是什么东西 漏洞分析 1、组件…

免费的 UI 设计资源网站 Top 8

今日与大家分享8个优秀的免费 UI 设计资源网站。这些网站的资源包括免费设计材料站、设计工具、字体和其他网站&#xff0c;尤其是一些材料站。它们是免费下载的&#xff0c;材料的风格目前很流行&#xff0c;适合不同的项目。非常适合平面设计WEB/UI设计师收藏&#xff0c;接下…

.git 文件夹结构解析

.git 文件夹结构解析 在这篇文章就让我们来看看这个 Git 仓库里的文件分别都是用来干什么的&#xff0c;以及在执行了相关的 Git 命令后这些文件会如何响应。 hooks&#xff08;钩&#xff09;&#xff1a;存放一些shell脚本info&#xff1a;存放仓库的一些信息logs&#xff…

1分钟部署幻兽帕鲁联机服务,PalWorld服务器搭建教程(阿里云)

1分钟部署幻兽帕鲁联机服务&#xff0c;PalWorld服务器搭建教程 最近这游戏挺火&#xff0c;很多人想跟朋友联机&#xff0c;如果有专用服务器&#xff0c;就不需要房主一直开着电脑&#xff0c;稳定性也好得多。 概述 幻兽帕鲁是Pocketpair开发的一款开放世界生存制作游戏&…