传知代码-轻量注意力网络实现苹果叶片识别

news2024/9/27 21:08:25

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

引言

该系统基于EfficientNet多头自注意力机制,构建了一个高效、精准的苹果叶片识别模型,能够对不同种类的苹果叶片进行准确分类。通过结合EfficientNet的强大特征提取能力和多头注意力机制的全局信息捕捉能力,系统在处理复杂背景和不同光照条件下的叶片图像时表现出色。此外,系统还集成了一个可视化平台,用户可以直观地查看叶片分类结果,并通过简便的界面上传图像进行预测。这使得该系统在实际农业生产中具有重要的应用价值,如病害监测和农业自动化管理等。

效果展示

在这里插入图片描述
在这里插入图片描述

本文核心创新点

1. 高效网络与注意力机制的集成:
本系统创新性地将 EfficientNet 与 多头自注意力机制 融合,形成了一个集成化的模型结构。EfficientNet以其精细化的网络架构设计和卓越的参数效率著称,而多头自注意力机制则增强了模型在捕捉全局信息和复杂特征方面的表现。通过这种集成,系统显著提升了苹果叶片分类任务的精准度和稳健性,特别是在处理大规模、高维度数据时展现了卓越的性能。

2. 模块化与用户友好的界面设计:
可视化界面的设计体现了高度的用户友好性与灵活性。用户可以通过图形界面直观地选择 EfficientNet 的不同模型结构(如B0到B7),这不仅满足了不同计算资源环境下的需求,还使得复杂的模型选择过程变得简洁和易于操作。通过模块化的设计,界面能够适应多种使用场景,从研究人员到农业工作者均可受益。

3. 智能化的模型选择与调整:
本系统引入了智能化的模型选择机制,允许用户根据任务的复杂性和计算资源的可用性动态调整模型架构。这种创新性设计提升了系统的可扩展性,使得在不同的应用环境中均能达到优化的性能表现。同时,系统支持对注意力机制参数的自适应调整,进一步提高了模型的灵活性和适应性。

EfficientNet核心原理

1. 网络深度
网络深度指的是神经网络中层的数量。在EfficientNet中,网络深度是一个关键的维度,用于捕获图像中的复杂特征。通过增加网络深度,模型能够学习到更多层次的特征表示,从而提高模型的性能。然而,过深的网络也可能导致梯度消失或梯度爆炸的问题,使得模型难以训练。因此,EfficientNet在设计时通过引入残差连接(Residual Connection)等技术来缓解这些问题,同时利用复合缩放策略来平衡网络深度与其他维度的关系。如下图所示:
在这里插入图片描述
在这里插入图片描述

图中第一部分就是这个网络的baselinelayer_i就是有多少层,可以理解为深度H*W就是网络输入图像的分辨率#channels就是网络的宽度。那么图中第二部分就是增加了网络的深度,可以看到“蓝,绿,黄”三部分都增加了一个。单独对网络深度进行变换之后,作者经过消融实验得到最右边的图,这里的“d”表示网络加深的倍数,可以看到当d=6和d=8时结果已经几乎没有提升,这说明并非一直增加网络深度就可以是的网络效果变好。

2. 网络宽度
网络宽度指的是神经网络中每一层的通道数(或称为特征图的数量)。增加网络宽度可以提高模型的特征表示能力,使得模型能够捕捉到更细粒度的特征。然而,过宽的网络也会增加计算量和参数量,从而增加模型的复杂度和训练难度。EfficientNet通过引入宽度乘数(Width Multiplier)来控制每一层的通道数,从而实现网络宽度的灵活调整。较小的宽度乘数可以减少参数量和计算量,使模型更轻量;而较大的宽度乘数则可以增加模型的性能。如下图所示:

在这里插入图片描述
在这里插入图片描述

那么图中的第一部分同样也是模型的baseline,由图中的第二部分可知,这是在网络的宽度上进行改进,可以看到“蓝,绿,黄”三部分都在横向增加。网络宽度的增加也不是越大越好的,因为上面最右图可知,当宽度增大到w=2.6时几乎是最佳的结果了,因为当再继续增大,计算量会增加很多且准确率增加的并不明显。

3. 图像分辨率
图像分辨率是指输入图像的大小或尺寸。在EfficientNet中,图像分辨率也是一个重要的维度,用于影响模型对图像细节的捕捉能力。使用更高分辨率的输入图像可以使模型捕捉到更细粒度的特征,但也会增加计算量和参数量。因此,EfficientNet通过引入分辨率乘数(Resolution Multiplier)来控制输入图像的分辨率,以实现不同计算资源条件下的灵活调整较小的分辨率乘数可以降低输入图像的分辨率,从而减少计算复杂度;而较大的分辨率乘数则可以增加模型的视觉表征能力。如下图:

在这里插入图片描述
在这里插入图片描述

该图是从图像分辨率的角度去进行改进,如右图所示,当图像分辨率增大到一定程度时,带来的效果并不明显了。这是因为随着图像分辨率的增加,模型能够捕捉到更多的细节信息。然而,这种提升并不是线性的。当分辨率增加到一定程度后,模型从中获取的新增信息量逐渐减少,导致模型性能的提升也变得不那么显著。这是因为图像中的很多细节信息在达到一定分辨率后已经足够清晰,再提高分辨率对于模型来说可能只是冗余信息的增加。
4. 复合缩放
那么上述两个方面都不是一直呈现出线性关系,那么使用复合缩放的方法来找到三者的平衡以达到最优的结果。它通过对网络深度、宽度和分辨率三个维度进行统一的缩放,来实现模型性能的最优化。具体来说,EfficientNet使用了一个复合系数(Compound Scaling Factor)来同时调整这三个维度,以确保网络各个部分之间的平衡。复合缩放策略不仅考虑了每个维度对模型性能的独立影响,还考虑了它们之间的相互作用和权衡。通过系统地研究不同维度对模型性能的影响,EfficientNet找到了一个最优的缩放比例,使得模型在保持高效的同时达到了最优的性能。
在这里插入图片描述

由图中可以看出当同时增加深度和分辨率时效果增加的非常明显。图中红色线,这是卷积层数是36层,分辨率为299*299.

多头自注意力机制

多头自注意力机制(Multi-Head Self-Attention)是Transformer模型中的一个核心组件,其原理主要通过并行地使用多个自注意力头来捕捉输入序列中的不同上下文信息。它的原理可以概括为:
在这里插入图片描述

  1. 分割输入
    输入序列的每个词向量(或更一般地,输入序列的嵌入向量)首先被分割成多个较小的部分,每个部分对应一个“头”(Head)。例如,如果输入词向量的维度是512,可以选择创建8个头那么在本文中使用了4个头,可以根据自己的需要更改头的数量。每个头的维度就是64(512/8)。
  2. 计算注意力
    对于每个头,分别执行自注意力计算。这通常涉及三个步骤:
    线性变换:首先,通过三个不同的线性变换(或全连接层)分别生成查询(Query, Q)、键(Key, K)和值(Value, V)向量。这些变换允许模型学习输入数据的不同方面。
    点积注意力:然后,计算每个查询向量与所有键向量的点积,得到一个分数矩阵,表示查询与键之间的相关性。为了防止梯度消失或爆炸,通常会对点积结果进行缩放(Scaled Dot-Product Attention)。
    Softmax归一化:接着,使用Softmax函数对分数矩阵进行归一化,得到注意力权重分布。这些权重表示序列中每个位置对当前位置的重要性。
  3. 加权求和
    使用注意力权重对值向量进行加权求和,得到每个头的输出。这一步骤聚合了序列中所有位置的信息,并根据权重赋予不同的重要性。
  4. 拼接输出
    最后,通过一个输出权重矩阵对拼接后的表示进行线性变换,得到最终的输出矩阵。这一步骤可能用于进一步整合信息,并准备输出用于后续的网络层。在本文中这里后续的网络层主要为全连接层。

本文模型结构

在这里插入图片描述

本文的模型结构结合了EfficientNet和多头注意力网络共同搭建了这样一个模型,该模型在EfficientNet的最后一个卷积层后接多头自注意力网络。这样做的好处是:EfficientNet的卷积层已经对图像进行了有效的多尺度特征提取,但卷积操作本质上具有局部性,即每个卷积核只能看到一部分图像。因此,在卷积层后引入多头自注意力机制,能够捕获图像特征中的长距离依赖关系。这意味着模型不仅能关注局部的细节特征,还能够有效地整合来自图像不同区域的全局信息,进一步提升了特征表示的丰富性和表达能力。

EfficientNet b0–b7结构差异

EfficientNet-B0: 基准模型,使用了基于MobileNetV2架构的基础网络,并在此基础上进行优化。
在这里插入图片描述

EfficientNet-B1在宽度和深度上有所不同,具体不同的差异如下图:
在这里插入图片描述

EfficientNet-B2它的结构与EfficientNet-B1相同,它们之间唯一的区别是feature maps(通道)的数量不同,从而增加了参数的数量。
EfficientNet-B3到EfficientNet-B7分别在宽度和深度上均有差异,如下图:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

主要区别在于网络的深度、宽度和输入分辨率的不同,这些差异使得每个模型在计算能力和精度上有所不同,从而适应不同的应用场景和资源限制。
不同的模型呈现出来的效果和精度差异如下:
在这里插入图片描述

实现过程

1. 数据集获取
本文使用的数据集是从飞浆平台下载的:AppleLeaf9数据集包括健康的苹果叶子和8类苹果叶部病害。当然也可以从我的附件中的readme给出的百度网盘链接下载,这里下载后是划分好训练验证集的。图片如图所示:
在这里插入图片描述

2. 数据集划分
本文按照8:2的比例划分了训练和验证集

# 遍历每个类别文件夹并进行划分
for class_folder in os.listdir(data_dir):
    class_path = os.path.join(data_dir, class_folder)
 
    if os.path.isdir(class_path):
        images = os.listdir(class_path)
        random.shuffle(images)
        
        train_size = int(len(images) * split_ratio)
        train_images = images[:train_size]
        test_images = images[train_size:]
        
        # 创建类别目录
        train_class_dir = os.path.join(train_dir, class_folder)
        test_class_dir = os.path.join(test_dir, class_folder)
        os.makedirs(train_class_dir, exist_ok=True)
        os.makedirs(test_class_dir, exist_ok=True)
        ......

3. 模型训练
数据集划分好后需要在训练脚本中读取划分好的数据集。在训练时可以选择EfficientNet的类型b0–b7,会产生不同的七个模型,用于后续的预测时选择相对应的最佳模型。也可以选择多头注意力中头的数量

# 定义苹果叶片分类模型
class AppleLeafClassifier(nn.Module):
    def __init__(self, model_name='efficientnet-b3', num_classes=9, attention_heads=4):
        super(AppleLeafClassifier, self).__init__()
        self.backbone = EfficientNet.from_pretrained(model_name)
        .......

    def forward(self, x):
        x = self.backbone.extract_features(x)
        b, c, h, w = x.size()
        
        embed_dim = c
        if self.multihead_attn is None or self.multihead_attn.multihead_attn.embed_dim != embed_dim:
            self.multihead_attn = MultiHeadSelfAttention(embed_dim=embed_dim, num_heads=4).to(device)

开始模型训练

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / train_size

    model.eval()
    correct = 0
    total = 0

保存最佳模型

   # 保存最好的模型
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        torch.save(model.state_dict(), 'best_model3.pth')

print('训练完成,最佳验证准确率: {:.4f}'.format(best_acc))

模型预测与可视化平台

def predict_image(self, file_path):
        image = Image.open(file_path)
        self.display_image(image)

        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image).unsqueeze(0).to(device)

        model_name = self.model_selector.currentText()
        ........
        model.load_state_dict(torch.load(model_path, map_location=device))  # 使用 map_location 来匹配设备
        model.eval()

def initUI(self):
        self.setWindowTitle(self.title)

        layout = QVBoxLayout()

        self.model_selector = QComboBox(self)
        for model_name in ['efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7']:
            self.model_selector.addItem(model_name)
        layout.addWidget(self.model_selector)

        self.button = QPushButton('选择图像', self)
        self.button.clicked.connect(self.open_file_dialog)
        layout.addWidget(self.button)

        self.label = QLabel(self)
        layout.addWidget(self.label)

        self.result_label = QLabel('分类结果: ', self)
        self.result_label.setStyleSheet("font-size: 18px;")  
        layout.addWidget(self.result_label)

        self.setLayout(layout)
        self.setGeometry(300, 300, 400, 300)

训练过程
在这里插入图片描述

实验结果展示
1.基线模型结果
在这里插入图片描述

2.加多头注意力结果
在这里插入图片描述

可以看到在添加多头注意力后,我们的模型结果是要优于基线模型的。

参考文献

1.Mingxing Tan and Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019. Arxiv link: https://arxiv.org/abs/1905.11946.
2.Vaswani A. Attention is all you need[J]. arxiv preprint arxiv:1706.03762, 2017.

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2171058.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ks渲染做汽车动画吗?汽车本地渲染与云渲染成本分析

Keyshot是一款强大的实时光线追踪和全域光渲染软件,它确实可以用于制作汽车动画,包括汽车模型的渲染和动画展示。Keyshot的动画功能允许用户创建相机移动、物体变化等动态效果,非常适合用于汽车动画的制作。 至于汽车动画的渲染成本&#xff…

Power Platform开发小技巧,一天一个APP, 如何快速搭建二维码识别器

之前,给大家分享了微软Power Platform开发课程——手把手教你搭建二维码生成器,很多小伙伴反馈真好用。这期我们继续为大家分享Power Platform的开发能力与技巧。 今天介绍如何开发⼀个⼆维码识别器。 该应用包含如下功能: 1.⼆维码图片的…

尾矿库安全监测系统:守护矿山安全的关键技术

尾矿库是矿山企业用于存放尾矿的重要设施,其安全状况直接关系到周边环境和人民生命财产安全。近年来,随着技术的不断进步,尾矿库安全监测系统应运而生,为尾矿库的安全管理提供了强有力的技术支持。本文将详细介绍尾矿库安全监测系…

基于spi机制构造的webshell

前言 最近在翻阅yzddmr6师傅博客的时候,发现师傅还有个github的地址 https://github.com/yzddmr6/MyPresentations 里面发现师傅去补天白帽子大会上讲解了一些webshell的攻防,特此进行了学习,然后发现了一个很有意思的webshell&#xff0c…

YOLOv9改进,YOLOv9主干网络替换为PP-LCNetV2(百度飞浆视觉团队自研,轻量化架构),全网独发

摘要 PP-LCNetV2 是在图像分类任务中提出的一种轻量级卷积神经网络,用于在边缘设备上实现高效的推理。PP-LCNet 系列模型的设计旨在提高移动和边缘设备上的推理性能,同时保持较高的准确率。PP-LCNetV2 是在 PP-LCNetV1 基础上改进的。 理论介绍 PP-LCNetV2模型结构如下: …

数据库存储加密技术有哪些 TDE透明加密和列表级加密

透明数据加密(TDE)和列级加密是数据库加密中两种常见的加密方式,它们在加密范围、实现方式以及对应用程序的影响等方面存在明显的区别。 透明数据加密(TDE) 定义: 透明数据加密(Transparent …

稀土阻燃协效剂-氢氧化镁(氢氧化铝)的应用

稀土阻燃协效剂凭借独特的稀土4f电子层结构,在聚合物材料燃烧时可催化酯化成炭,迅速在高分子表面形成致密连续的碳层,隔绝聚合物材料内部的可燃性气体与氧气的接触,从而达到阻燃抑烟的效果,且燃烧时不产生有毒有害气体。 金士镧系列稀土阻燃剂是一种基于稀土协效阻燃的复合阻燃…

CTF竞赛介绍以及刷题网址(超详细)零基础入门到精通,收藏这一篇就够了

CTF(Capture The Flag)中文一般译作夺旗赛,在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。CTF起源于1996年DEFCON全球黑客大会,以代替之前黑客们通过互相发起真实攻击进行技术比拼的方式。发展至今&…

安全防护装备检测系统源码分享

安全防护装备检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer…

D18【python接口自动化学习】-python基础之内置数据类型

day18 综合练习:实现手机通讯录(下) 学习日期:20240925 学习目标:内置数据类型--27 小试牛刀:如何使用类型转换实现手机通讯录(下) 学习笔记: 实现手机通讯录 案例文…

以题为例浅谈反序列化漏洞

什么是反序列化漏洞 反序列化漏洞是基于序列化和反序列化的操作,在反序列化——unserialize()时存在用户可控参数,而反序列化会自动调用一些魔术方法,如果魔术方法内存在一些敏感操作例如eval()函数,而且参数是通过反序列化产生的…

占领矩阵-第15届蓝桥省赛Scratch中级组真题第5题

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第190讲。 如果想持续关注Scratch蓝桥真题解读,可以点击《Scratch蓝桥杯历年真题》并订阅合集,…

【图像处理】多幅不同焦距的同一个物体的平面图象,合成一幅具有立体效果的单幅图像原理(一)

合成一幅具有立体效果的单幅图像,通常是利用多个不同焦距的同一物体的平面图像,通过图像处理技术实现的。以下是该过程的基本原理: 1. 立体视觉原理 人眼的立体视觉是通过双眼观察物体的不同视角而获得的。两只眼睛的位置不同,使…

【学习笔记】MIPI

MIPI介绍 MIPI是由ARM、Nokia、ST、IT等公司成立的一个联盟,旨在把手机内部的接口如存储接口,显示接口,射频/基带接口等标准化,减少兼容性问题并简化设计。 MIPI联盟通过不同的工作组,分别定义一系列手机内部的接口标…

猜拳数据集-石头-剪刀-布数据集

“石头-剪刀-布”计算机视觉项目是一个利用摄像头捕捉手势并识别出手势是石头、剪刀还是布的项目。这类项目通常用于学习和展示计算机视觉技术,如图像处理、特征提取以及机器学习或深度学习模型的应用。 数据介绍 rock-paper-scissors Computer Vision Project数…

基于状态机的流程编排架构设计

背景 xx产品侧规划了全新的能力升级, 主要思路为:改变之前通过xx等手工生成xx的方式,通过标准化流程尽可能的减少人工介入,提升产出效率。xx入库、xx生成链路存在链路长、链路不稳定问题,由于目前缺乏比较好的监控、检…

一文多图,彻底弄懂LSM-Tree

一文弄懂LSM-Tree LSM-Tree是什么? LSM-Tree(Log Structured Merge Tree)是一种数据结构,它被设计用于处理大量写入操作的场景,常见于许多NoSQL数据库中,如BigTable、Cassandra、RocksDB和LevelDB等。 L…

废品回收小程序:回收更加便捷!

在日常生活中,废品回收已经成为了一种常见事,随着电商的快速发展,居民难免会产生大量的废纸盒等可回收物,以及在日常生活中产生的其他回收物, 目前,废品回收市场也发生了改革,传统的“叫卖”方…

MySQL高阶1990-统计实验的数量

目录 题目 准备数据 分析数据 总结 题目 写一个 SQL 查询语句,以报告在给定三个实验平台中每种实验完成的次数。请注意,每一对(实验平台、实验名称)都应包含在输出中,包括平台上实验次数是零的。 结果可以以任意…

C++之STL—常用查找算法

- find //查找元素 - find_if //按条件查找元素 - adjacent_find //查找相邻重复元素 - binary_search //二分查找法 - count //统计元素个数 - count_if //按条件统计元素个数 find (iterator begin, …