YOLOv5/v7 添加注意力机制,30多种模块分析④,CA模块,ECA模块

news2025/1/9 12:05:14

目录

    • 一、注意力机制介绍
      • 1、什么是注意力机制?
      • 2、注意力机制的分类
      • 3、注意力机制的核心
    • 二、CA模块
      • 1、CA模块的原理
      • 2、实验结果
      • 3、应用示例
    • 三、ECA模块
      • 1、ECA模块的原理
      • 2、实验结果
      • 3、应用示例

大家好,我是哪吒。

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。


在机器学习和自然语言处理领域,随着数据的不断增长和任务的复杂性提高,传统的模型在处理长序列或大型输入时面临一些困难。传统模型无法有效地区分每个输入的重要性,导致模型难以捕捉到与当前任务相关的关键信息。为了解决这个问题,注意力机制(Attention Mechanism)应运而生。

一、注意力机制介绍

1、什么是注意力机制?

注意力机制(Attention Mechanism)是一种在机器学习和自然语言处理领域中广泛应用的重要概念。它的出现解决了模型在处理长序列或大型输入时的困难,使得模型能够更加关注与当前任务相关的信息,从而提高模型的性能和效果。

本文将详细介绍注意力机制的原理、应用示例以及应用示例。

2、注意力机制的分类

类别描述
全局注意力机制(Global Attention)在计算注意力权重时,考虑输入序列中的所有位置或元素,适用于需要全局信息的任务。
局部注意力机制(Local Attention)在计算注意力权重时,只考虑输入序列中的局部区域或邻近元素,适用于需要关注局部信息的任务。
自注意力机制(Self Attention)在计算注意力权重时,根据输入序列内部的关系来决定每个位置的注意力权重,适用于序列中元素之间存在依赖关系的任务。
Bahdanau 注意力机制全局注意力机制的一种变体,通过引入可学习的对齐模型,对输入序列的每个位置计算注意力权重。
Luong 注意力机制全局注意力机制的另一种变体,通过引入不同的计算方式,对输入序列的每个位置计算注意力权重。
Transformer 注意力机制自注意力机制在Transformer模型中的具体实现,用于对输入序列中的元素进行关联建模和特征提取。

3、注意力机制的核心

注意力机制的核心思想是根据输入的上下文信息来动态地计算每个输入的权重。这个过程可以分为三个关键步骤:计算注意力权重、对输入进行加权和输出。首先,计算注意力权重是通过将输入与模型的当前状态进行比较,从而得到每个输入的注意力分数。这些注意力分数反映了每个输入对当前任务的重要性。对输入进行加权是将每个输入乘以其对应的注意力分数,从而根据其重要性对输入进行加权。最后,将加权后的输入进行求和或者拼接,得到最终的输出。注意力机制的关键之处在于它允许模型在不同的时间步或位置上关注不同的输入,从而捕捉到与任务相关的信息。

🏆YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

🏆YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块

🏆YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块

二、CA模块

1、CA模块的原理

CA(Coordinate Attention)模块是一种基于位置坐标的注意力机制,它可以在不同空间尺度上对特征图进行自适应的调整。

CA模块通过计算每个像素点的空间坐标信息,将其转换为一个与输入特征维度相同的向量,并利用这些向量来计算空间注意力权重。

CA模块首先使用一个可学习的线性变换将每个像素点的坐标映射到一个低维空间中,通过对该空间中所有坐标向量的点积操作,得到每个像素点的空间注意力权重,利用这些权重对输入特征进行加权求和,实现了自适应的特征图调整。

在这里插入图片描述

论文中提出的Coordinate Attention模块(c)与经典的SE通道注意力模块 [18] (a)和CBAM模块 [44] (b)进行了比较。其中,“GAP”和“GMP”分别是全局平均池化和全局最大池化,‘X Avg Pool’和‘Y Avg Pool’分别指1D水平全局池化和1D垂直全局池化。

2、实验结果

在这里插入图片描述

不同注意力方法在三个经典视觉任务中的性能表现。从左到右的y轴标签分别是top-1准确度、平均IoU和AP。显然,我们的方法不仅在ImageNet分类任务中优于SE模块和CBAM ,而且在下游任务如语义分割和COCO物体检测中表现更佳。结果基于MobileNetV2 。

在这里插入图片描述

通过可视化工具,我们展示了使用不同注意力方法的模型在最后一个构建块中生成的特征图。我们可视化了每个注意力模块之前和之后的特征图。显然,我们提出的Coordinate Attention(CA)可以比其他注意力方法更精确地定位感兴趣的对象。

3、应用示例

在YOLOv5中,CA模块指的是Channel Attention模块,用于增强卷积神经网络的特征提取能力。这里提供一个使用PyTorch实现的示例代码,演示如何将CA模块添加到YOLOv5的骨干网络中。

首先,在YOLOv5的骨干网络中找到需要添加CA模块的Conv层,例如C3层。然后,在该Conv层之后添加一个CA模块。

以下是应用示例:

import torch
import torch.nn as nn

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // reduction_ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // reduction_ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class YOLOv5Backbone(nn.Module):
    def __init__(self):
        super(YOLOv5Backbone, self).__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1),

            nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),

            nn.Conv2d(64, 64, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),

            nn.Conv2d(64, 128, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),
            ChannelAttention(128),
        )

        self.layer1 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),

            nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),

            nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),

            nn.Conv2d(256, 128, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),

            nn.Conv2d(128, 256, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),

            nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 256, 1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),

            nn.Conv2d(256, 512, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),
        )
          self.layer3 = nn.Sequential(
        nn.Conv2d(512, 1024, 3, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.1),

        nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.1),

        nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.1),

        nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.1),

        nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.1),

        nn.Conv2d(1024, 512, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.1),

        nn.Conv2d(512, 1024, 3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(1024),
        nn.LeakyReLU(0.1),
    )

        def forward(self, x):
            x = self.stem(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            return x

在YOLOv5Backbone类的初始化方法中,首先定义了一个名为“stem”的Sequential对象,包含3个卷积层和BatchNorm2d和LeakyReLU激活函数;然后在第4个卷积层后添加了一个ChannelAttention模块。接下来定义了3个Sequential对象,分别代表着YOLOv5的C3、C4和C5阶段,在其中也都添加了若干个卷积层和BatchNorm2d和LeakyReLU激活函数。最后在forward方法中将输入特征图依次传入各个Sequential对象中,并返回输出特征图。

三、ECA模块

1、ECA模块的原理

ECA模块的主要思想是利用通道间的关联性来调整各通道的权重,从而增强网络的语义表达能力。具体而言,给定通过全局平均池化(GAP)获得的聚合特征,ECA模块通过执行大小为k的快速1D卷积生成通道权重,其中k通过通道维度C的映射自适应确定。如图所示,ECA模块的结构相对简单,计算效率高,且参数量较小,因此较适合嵌入到深层网络结构中使用。

在这里插入图片描述

2、实验结果

在这里插入图片描述

上述内容是一篇关于不同注意力机制的比较研究的论文中的部分结果展示。该研究分别使用了SENet、CBAM、A2-Nets和ECA-Net等四种不同的注意力模块,以ResNets作为骨干网络,在分类准确率、网络参数和FLOPs等方面进行了比较,结果如图所示。

可以看到,ECA-Net在具有更少的模型复杂性的同时,获得了更高的分类准确率,这表明它是一种非常有效的注意力模块,可以用于提高深度神经网络的性能。同时,与其他注意力模块相比,ECA-Net具有更低的网络参数和FLOPs,这意味着它可以在保持性能的同时,更加高效地运行。

在这里插入图片描述

上述内容是一篇关于ECA模块在使用ResNet-50和ResNet-101作为骨干网络时,不同k值的结果展示。该研究还比较了自适应选择核大小的ECA模块与基线SENet的结果。

具体来说,研究人员通过对比不同的k值(包括1、2、4、8和16),发现当k=16时,ECA模块可以在性能和计算效率之间取得最佳平衡。此外,他们还使用了自适应选择核大小的ECA模块,并通过与SENet进行比较,发现ECA模块在准确率和参数数量方面都表现更好。

因此,这些结果表明,在使用ResNet-50和ResNet-101作为骨干网络时,ECA模块可以提高分类准确率,并且通过使用自适应选择核大小,可以进一步提高模型性能。

3、应用示例

以下是使用ECA模块的YOLOv5应用示例:

import torch
import torch.nn as nn

class ECAModule(nn.Module):
    def __init__(self, channels, k=16):
        super(ECAModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

class YOLOv5(nn.Module):
    def __init__(self):
        super(YOLOv5, self).__init__()
        
        self.backbone = ...
        
        self.neck = nn.ModuleList([
            ECAModule(256),
            ECAModule(512),
            ECAModule(1024)
        ])
        
        self.head = ...
        
    def forward(self, x):
        # backbone
        ...
        
        # neck
        for i in range(len(self.neck)):
            x[i] = self.neck[i](x[i])
        
        # head
        ...
        
        return output

在上述代码中,我们定义了一个ECAModule类来实现ECA模块。该模块首先将输入特征图进行全局平均池化,并通过一个1D卷积层来对通道维度进行处理。然后,将得到的权重值通过Sigmoid函数进行归一化,并与输入特征图相乘,得到加权后的特征图。

在YOLOv5中,我们将ECAModule模块应用于neck部分的特征图融合中,以提高目标检测性能。具体来说,对于每个尺度的特征图,我们都使用一个ECAModule模块来进行权值计算。最后,我们将所有特征图进行通道上的拼接,并将其输入到head中进行预测。

参考论文:

  1. https://arxiv.org/abs/2103.02907
  2. https://arxiv.org/pdf/1910.03151.pdf

在这里插入图片描述

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,🚀内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。

🏆华为OD机试(JAVA)真题(A卷+B卷)

每一题都有详细的答题思路、详细的代码注释、样例测试,订阅后,专栏内的文章都可看,可加入华为OD刷题群(私信即可),发现新题目,随时更新,全天CSDN在线答疑。

🏆哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师。

🏆往期回顾:

YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块

YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/644539.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

阿里云企业邮箱设置教程(新手指南)

阿里云企业邮箱怎么使用?企业邮箱快速入门教程,从购买、设置管理员账号密码、创建组织架构账号邮件组、邮箱迁移、切换解析、钉邮绑定与同步,最后启用邮箱,阿里云百科分享阿里云企业邮箱使用教程快速入门: 目录 阿里…

2--Gradle入门 - Groovy简介、基本语法

2--Gradle入门 - Groovy简介、基本语法 Gradle 需要 Groovy 语言的支持,所以本章节主要来介绍 Groovy 的基本语法。 1.Groovy 简介 在某种程度上,Groovy 可以被视为Java 的一种脚本化改良版,Groovy 也是运行在 JVM 上,它可以很好地与 Java 代…

【OpenCV DNN】Flask 视频监控目标检测教程 08

欢迎关注『OpenCV DNN Youcans』系列,持续更新中 【OpenCV DNN】Flask 视频监控目标检测教程 08 3.8 OpenCVFlask实时监控人脸识别控制按钮新建 Flask 项目 cvFlask08cPython程序文件视频流的网页模板程序运行 本系列从零开始,详细讲解使用 Flask 框架构…

Nature子刊:生物合成硝基化蛋白,助力解决药物免疫耐受!

氨基酸是蛋白质的单个构建模块,对生物系统的正常运转至关重要。所有生物系统中的蛋白质都是由20种标准氨基酸组成的,自然界中还发现了超过500种不同类型的其他氨基酸,以及大量的人造氨基酸。其中一些替代氨基酸有助于创造新类型的药物和治疗方…

阿里企业邮箱注册流程(新手指南)

阿里云企业邮箱购买流程,企业邮箱分为免费版、标准版、集团版和尊享版,阿里云百科分享企业邮箱版本区别,企业邮箱收费标准价格表,以及阿里企业邮箱详细购买流程: 目录 阿里云企业邮箱购买流程 一、阿里云账号注册及…

驱动开发:内核ShellCode线程注入

还记得《驱动开发:内核LoadLibrary实现DLL注入》中所使用的注入技术吗,我们通过RtlCreateUserThread函数调用实现了注入DLL到应用层并执行,本章将继续探索一个简单的问题,如何注入ShellCode代码实现反弹Shell,这里需要…

ChatGPT 背后的技术重点:RLHF、IFT、CoT、红蓝对抗

近段时间,ChatGPT 横空出世并获得巨大成功,使得 RLHF、SFT、IFT、CoT 等这些晦涩的缩写开始出现在普罗大众的讨论中。这些晦涩的首字母缩略词究竟是什么意思?为什么它们如此重要?我们调查了相关的所有重要论文,以对这些…

Go1.21 速览:go.mod 的 Go 版本号将会约束 Go 程序构建,要特别注意了!

大家好,我是煎鱼。 之前 Go 核心团队的负责人 Russ Cox 针对 Go 的向前兼容(指的是旧版本的 Go 编译新的 Go 代码),进行了进一步的设计。 重点内容如下: 新增 GOTOOLCHAIN 环境变量的设置。改变在工作模块(…

阿里云弹性公网EIP收费价格表

阿里云弹性公网EIP怎么收费?EIP地域不同价格不同,EIP计费模式分为包年包月和按量付费,弹性公网IP可以按带宽收费也可以按使用流量收费,阿里云百科分享阿里云弹性公网IP不同地域、不同计费模式、按带宽和按使用流量详细收费价格表&…

cpp新小点1

这里写目录标题 argc argv继承虚继承多态override不加override overload纯虚函数和抽象类虚析构和纯虚析构 static和 constexternself前置 后置默认构造 析构继承构造函数不能是虚函数派⽣类的override虚函数定义必须和⽗类完全⼀致。 有特列何时共享虚函数地址表 智能指针arrm…

【数据库必备知识】上手表设计

目录 📖前言 1. 基本步骤 1.1 梳理清楚需求中的实体 1.2 梳理清楚实体间的关系 2. 实体间的三种关系 2.1 一对一 2.2 一对多 2.3 多对多 🎉小结ending 📖前言 本文讲解的是基本的表设计, 设计一般只有在有一定实际项目经验后, 才能…

MAVEN - 使用maven-dependency-plugin的应用场景是什么?

简述 maven-dependency-plugin是MAVEN的一个插件。 作用 该插件主要用于管理项目中的依赖,使用该插件可以方便地查看、下载、复制和解压缩依赖,还支持生成依赖树和依赖报告。 功能 该插件有很多可用的GOAL,大部分与依赖构建、依赖分析和依…

《面试1v1》Map

我是 javapub,一名 Markdown 程序员从👨‍💻,八股文种子选手。 《面试1v1》 连载中… 面试官: 小伙子,又来挑战你了。听说你对Java集合中的Map也很在行? 候选人: 谢谢夸奖,Map这个接口的确非常重要且强大…

SpringMVC原理分析 | JSON、Jackson、FastJson

💗wei_shuo的个人主页 💫wei_shuo的学习社区 🌐Hello World ! JSON JSON(JavaScriptObject Notation,JS对象简谱)是一种轻量级的数据交换格式。它基于 ECMAScript(European Computer…

无自注意力照样高效!RIFormer开启无需token mixer的Transformer结构新篇章

©PaperWeekly 原创 作者 | 岳廷 研究方向 | 计算机视觉 引言 论文地址: https://openaccess.thecvf.com/content/CVPR2023/papers/Wang_RIFormer_Keep_Your_Vision_Backbone_Effective_but_Removing_Token_Mixer_CVPR_2023_paper.pdf 问题:Vision …

如何将代码中的相关调试信息输出到对应的日志文件中

一、将调试信息输出到屏幕中 1.1 一般写法 我们平常在写代码时&#xff0c;肯定会有一些调试信息的输出&#xff1a; #include <stdio.h> #include <stdlib.h>int main() {char szFileName[] "test.txt";FILE *fp fopen(szFileName, "r")…

R语言 tidyverse系列学习笔记(系列5)dplyr 数据分析之across

成绩单 score install.packages("dplyr") library(dplyr)install.packages("tibble") library(tibble)install.packages("stringr") library(stringr)score tibble(IDc("1222-1","2001-0","3321-1","4898-…

MySQL(八):排序与分页

排序与分页 前言一、排序数据1、排序规则2、单列排序3、多列排序 二、分页1、背景2、实现规则3、拓展 前言 本博主将用CSDN记录软件开发求学之路上亲身所得与所学的心得与知识&#xff0c;有兴趣的小伙伴可以关注博主&#xff01;也许一个人独行&#xff0c;可以走的很快&…

从零开始Vue项目中使用MapboxGL开发三维地图教程(三)添加全屏,缩放旋转和比例控制面板以及自定义图标、标记点击弹窗、地图平移等功能

文章目录 1、添加各种控制面板1.1、添加全屏1.2、缩放旋转控制1.3、比例尺 2、获取并显示鼠标移动位置的经纬度坐标3、添加图标3.1、添加图片图层的图标3.2、添加带有标记的自定义图标3.3、悬停时显示弹出窗口 1、添加各种控制面板 1.1、添加全屏 //添加全屏控制this.map.addC…

管理类联考——逻辑——知识篇——第一章 性质命题

第一章 性质命题&#xff08;最基础&#xff0c;最难*****&#xff09; 一、性质命题定义&#xff08;必考&#xff09; 判断事物具有或不具有某种性质的命题。 二、性质命题的四种基本形式 全称肯定&#xff1a;①所有的A都是B 全称否定&#xff1a;②所有的A不是B 特称肯…