YOLOv5算法改进(4)— 添加CA注意力机制

news2024/11/23 8:47:07

前言:Hello大家好,我是小哥谈。注意力机制是近年来深度学习领域内的研究热点,可以帮助模型更好地关注重要的特征,从而提高模型的性能。在许多视觉任务中,输入数据通常由多个通道组成,例如图像中的RGB通道或视频中的时间序列帧。传统的卷积神经网络(CNN)在处理这些通道时通常是独立地对每个通道进行操作,忽略了通道之间的相互作用。CA注意力机制通过引入通道注意力来解决这个问题。它能够自动学习到不同通道之间的关联性和重要性,从而增强模型对输入数据的建模能力。具体来说,CA注意力机制通过计算每个通道的权重,使得模型能够更加关注重要的通道,并抑制不重要的通道。这样可以提高模型在处理多通道输入数据时的表达能力和性能。🌈 

前期回顾:

          YOLOv5算法改进(1)— 如何去改进YOLOv5算法

          YOLOv5算法改进(2)— 添加SE注意力机制

          YOLOv5算法改进(3)— 添加CBAM注意力机制

          目录

🚀1.论文

🚀2.CA注意力机制的原理及实现

🚀3.添加CA注意力机制的好处 

🚀4.添加CA注意力机制的方法

💥💥步骤1:在common.py中添加CA模块

💥💥步骤2:在yolo.py文件中加入类名

💥💥步骤3:创建自定义yaml文件

💥💥步骤4:修改yolov5s_CA.yaml文件 

💥💥步骤5:验证是否加入成功

💥💥步骤6:修改train.py中的'--cfg'默认参数

🚀5.添加C3_CA注意力机制的方法(在C3模块中添加)

💥💥步骤1:在common.py中添加CABottleneck和C3_CA模块

💥💥步骤2:在yolo.py文件里parse_model函数中加入类名

​💥💥步骤3:创建自定义yaml文件

​💥💥步骤4:验证是否加入成功

​💥💥步骤5:修改train.py中的'--cfg'默认参数 

🚀1.论文

目前,轻量级网络的注意力机制大都采用 SE 模块,仅考虑了通道间的信息,忽略了位置信息。尽管后来的 BAM 和 CBAM 尝试在降低通道数后通过卷积来提取位置注意力信息,但卷积只能提取局部关系,缺乏长距离关系提取的能力。为此,论文提出了新的高效注意力机制CA(coordinate attention),能够将横向和纵向的位置信息编码到 channel attention 中,使得移动网络能够关注大范围的位置信息又不会带来过多的计算量。🌴

论文题目:Coordinate Attention for Efficient Mobile Network Design

论文地址:https://arxiv.org/abs/2103.02907

代码实现:GitHub - houqb/CoordAttention: Code for our CVPR2021 paper coordinate attention 


🚀2.CA注意力机制的原理及实现

CA(Channel Attention)注意力机制是一种在深度学习中常用的注意力机制之一,用于增强模型对于不同通道(channel)之间的特征关联性。📚

其原理如下:👇

(1)输入特征经过卷积等操作得到中间特征表示。

(2)中间特征表示经过两个并行的操作:全局平均池化和全局最大池化,得到全局特征描述。

(3)全局特征描述通过两个全连接层生成注意力权重。

(4)注意力权重与中间特征表示相乘,得到加权后的特征表示。

(5)加权后的特征表示经过适当的调整(如残差连接)后,作为下一层的输入。

CA注意力的实现如图所示,可以认为分为两个并行阶段

将输入特征图分别在为宽度高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W],在经过宽方向的平均池化后,获得的特征层shape为[C, H, 1],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[C, 1, W],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C, 1, H+W],利用卷积+标准化+激活函数获得特征。

之后再次分开为两个并行阶段,再将宽高分开成为:[C, 1, H][C, 1, W],之后进行转置。获得两个特征层[C, H, 1][C, 1, W]

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况,乘上原有的特征就是CA注意力机制


🚀3.添加CA注意力机制的好处 

作者通过将位置信息嵌入到通道注意力中提出了一种新颖的移动网络注意力机制,将其称为“Coordinate Attention”。其为即插即用的注意力模块,能插入任何经典网络🍉

加入CA注意力机制的好处包括:

 (1)增强特征表达:CA注意力机制能够自适应地选择和调整不同通道的特征权重,从而更好地表达输入数据。它可以帮助模型发现和利用输入数据中重要的通道信息,提高特征的判别能力和区分性。

 (2)减少冗余信息:通过抑制不重要的通道,CA注意力机制可以减少输入数据中的冗余信息,提高模型对关键特征的关注度。这有助于降低模型的计算复杂度,并提高模型的泛化能力。

 (3)提升模型性能:加入CA注意力机制可以显著提高模型在多通道输入数据上的性能。它能够帮助模型更好地捕捉到通道之间的相关性和依赖关系,从而提高模型对输入数据的理解能力。

综上所述,加入CA注意力机制可以有效地增强模型对多通道输入数据的建模能力,提高模型性能和泛化能力。它在图像处理、视频分析等任务中具有重要的应用价值。🌿


🚀4.添加CA注意力机制的方法

💥💥步骤1:在common.py中添加CA模块

将下面的CA模块的代码复制粘贴到common.py文件的末尾。

# CA
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self, x):
        return self.relu(x + 3) / 6
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)
    def forward(self, x):
        return x * self.sigmoid(x)
 
class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        #c*1*W
        x_h = self.pool_h(x)
        #c*H*1
        #C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        #C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out

具体如下图所示:

💥💥步骤2:在yolo.py文件中加入类名

首先在yolo.py文件中找到parse_model函数,然后将 CoordAtt 添加到这个注册表里。

💥💥步骤3:创建自定义yaml文件

models文件夹中复制yolov5s.yaml粘贴并命名为yolov5s_CA.yaml

💥💥步骤4:修改yolov5s_CA.yaml文件 

本步骤是修改yolov5s_CA.yaml,将CA模块添加到我们想添加的位置。在这里,我将[-1,1,CoordAtt,[1024]]添加到SPPF的上一层,即下图中所示位置。

说明:♨️♨️♨️

注意力机制可以加在Backbone、Neck、Head等部分,常见的有两种:一种是在主干的SPPF前面添加一层;二是将Backbone中的C3全部替换。不同的位置效果可能不同,需要我们去反复测试。

这里需要注意一个问题,当在网络中添加新的层之后,那么该层网络后面的层的编号会发生变化。原本Detect指定的是[17,20,23]层,所以,我们在添加了CA模块之后,也要对这里进行修改,即原来的17层,变成18层,原来的20层,变成21层,原来的23层,变成24层;所以这里需要改为[18,21,24]。同样的,Concat的系数也要修改,这样才能保持原来的网络结构不会发生特别大的改变,我们刚才把CA模块加到了第9层,所以第9层之后的编号都需要加1,这里我们把后面两个Concat的系数分别由[-1,14][-1,10]改为[-1,15][-1,11]。🌻

具体如下图所示:

💥💥步骤5:验证是否加入成功

yolo.py文件里,将配置改为我们刚才自定义的yolov5s_CA.yaml

 然后运行yolo.py,得到结果。

找到了CA模块,说明我们添加成功了。🎉🎉🎉

💥💥步骤6:修改train.py中的'--cfg'默认参数

train.py文件中找到 parse_opt函数,然后将第二行'--cfg'的default改为 'models/yolov5s_CA.yaml',然后就可以开始进行训练了。🎈🎈🎈


🚀5.添加C3_CA注意力机制的方法(在C3模块中添加)

上面是单独添加注意力层,接下来的方法是在C3模块中加入注意力层。这个策略是将CA注意力机制添加到Bottleneck,替换Backbone中所有的C3模块。🌳

💥💥步骤1:在common.py中添加CABottleneck和C3_CA模块

将下面的代码复制粘贴到common.py文件的末尾。

# CA
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
 
    def forward(self, x):
        return self.relu(x + 3) / 6
 
 
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)
 
    def forward(self, x):
        return x * self.sigmoid(x)
 
 
class CABottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, ratio=32):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2
        # self.ca=CoordAtt(c1,c2,ratio)
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, c1 // ratio)
        self.conv1 = nn.Conv2d(c1, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, c2, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, c2, kernel_size=1, stride=1, padding=0)
 
    def forward(self, x):
        x1 = self.cv2(self.cv1(x))
        n, c, h, w = x.size()
        # c*1*W
        x_h = self.pool_h(x1)
        # c*H*1
        # C*1*h
        x_w = self.pool_w(x1).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        # C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = x1 * a_w * a_h
 
        # out=self.ca(x1)*x1
        return x + out if self.add else out
 
 
class C3_CA(C3):
    # C3 module with CABottleneck()
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        super().__init__(c1, c2, n, shortcut, g, e)
        c_ = int(c2 * e)  # hidden channels
        self.m = nn.Sequential(*(CABottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

💥💥步骤2:在yolo.py文件里parse_model函数中加入类名

yolo.py文件parse_model函数中,加入CABottleneckC3_CA这两个模块。

​💥💥步骤3:创建自定义yaml文件

按照上面的步骤创建yolov5s_C3_CA.yaml文件,替换4个C3模块。

​💥💥步骤4:验证是否加入成功

yolo.py文件里配置刚才我们自定义的yolov5s_C3_CA.yaml,然后运行。 

​💥💥步骤5:修改train.py中的'--cfg'默认参数 

train.py文件中找到parse_opt函数,然后将第二行'--cfg'的default改为 'models/yolov5s_C3_CA.yaml',然后就可以开始进行训练了。🎈🎈🎈


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/923302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

村口的人家排放污水,污水浸染了整个村子,怎么办

从前有一个很不错的村子里,村子里有很多户人家,随着生活水平越来越好,房子也修起来了,柏油马路也宽敞了,大家进出村子,都要走那条马路,要不就出不去。 目录 1. 修厕所 2. 村口的日家 3. 告诉…

商城的TPS与并发用户数是如何换算的?

商城的TPS与并发用户数的换算关系可以通过以下公式计算: TPS 并发用户数 / 平均事务响应时间 其中,平均事务响应时间是指系统处理一个事务所需的平均时间。 下面是商城性能测试的一些用例示例: 用户登录: 目标:测…

4.物联网LWIP之C/S编程,stm32作为服务器,stm32作为客户端,代码的优化

LWIP配置 服务器端实现 客户端实现 错误分析 一。LWIP配置(FREERTOS配置,ETH配置,LWIP配置) 1.FREERTOS配置 为什么要修改定时源为Tim1?不用systick? 原因:HAL库与FREERTOS都需要使用systi…

【Python原创毕设|课设】基于Python Flask的上海美食信息与可视化宣传网站项目-文末附下载方式以及往届优秀论文,原创项目其他均为抄袭

基于Python Flask的上海美食信息与可视化宣传网站(获取方式访问文末官网) 一、项目简介二、开发环境三、项目技术四、功能结构五、运行截图六、功能实现七、数据库设计八、源码获取 一、项目简介 随着大数据和人工智能技术的迅速发展,我们设…

【JavaEE进阶】MyBatis表查询

文章目录 一. 使用MyBatis完成数据库的操作1. MyBatis程序中sql语句的即时执行和预编译1.1 即时执行(${})1.2 预编译(#{})1.3 即时执行和预编译的优缺点 2. 单表的增删改等操作2.1 增加操作2.2 修改操作2.3 删除操作2.4 like(模糊…

星际争霸之小霸王之小蜜蜂(六)--让子弹飞

目录 前言 一、添加子弹设置 二、创建子弹 三、创建绘制和移动子弹函数 四、让子弹飞 五、效果 总结 前言 小蜜蜂的基本操作已经完成了,现在开始编写子弹的代码了。 一、添加子弹设置 在我的预想里,我们的小蜜蜂既然是一只猫,那么放出的子弹…

基于小波神经网络的短时交通流量预测Matlab代码

1案例背景 1.1小波理论 小波分析是针对傅里叶变换的不足发展而来的。傅里叶变换是信号处理领域中应用最广泛的一种分析手段,然而它有一个严重不足,就是变换时抛弃了时间信息,通过变换结果无法判断某个信号发生的时间,即傅里叶变换在时域中没有分辨能力。小波是长度有限、平均为…

分布式与微服务相关知识

分布式与微服务 1.zookeeper是什么2.zookeeper保证数据一致性3.zookeeper的快速领导者选举是怎么实现的4.CAP理论5.BASE理论6.分布式id生成方案(1)UUID(2)数据库自增序列(3)Leaf-segment(4&…

Linux下的系统编程——vim/gcc编辑(二)

前言: 在Linux操作系统之中有很多使用的工具,我们可以用vim来进行程序的编写,然后用gcc来生成可执行文件,最终运行程序。下面就让我们一起了解一下vim和gcc吧 目录 一、vim编辑 1.vim的三种工作模式 2.基本操作之跳转字符 &a…

实现外网访问本地服务

最近开发需要其他项目组的人访问我本地服务测试,但又不在同一个地方,不能使用内网访问,所以需要外网访问本地服务功能. 条件: 1.需要一台具备公网IP的服务器 我用的服务器是windows,电脑也是Windows系统 2.下载frp 软件,只需要下载一份就可以了,分别放到服务器上和本地目录既…

2011-2021年全国各省绿色创新效率数据(原始数据+测算结果)

2011-2021年全国各省绿色创新效率数据(原始数据测算结果) 2011-2021年全国各省绿色创新效率 1、时间:2011-2021年 2、范围:全国31省市 3、来源:各省年鉴、科技年鉴、环境年鉴 4、指标:地区、编号、年份、R&D人…

设计模式大白话——命令模式

命令模式 一、概述二、经典举例三、代码示例(Go)四、总结 一、概述 ​ 顾名思义,命令模式其实和现实生活中直接下命令的动作类似,怎么理解这个命令是理解命令模式的关键!!!直接说结论是很不负责…

树形结构的快速生成

背景 相信大家都遇到过树形结构,像是文件列表、多级菜单、评论区的设计等等,我们都发现它有很多层级,第一级可以有多个,下边的每一个层级也可以有多个;有的可以设计成无限层级的,有的只能设计成两级。那么…

工程师使用IT服务台软件可以解决哪些问题?

现如今企业数字化建设已初具规模,业务系统基本已告一段落,而下一步关注的重点则从技术转向管理,如何能让这些系统更好运行起来,如何提高管理效率已是重中之重。在此向您推荐一款高效的IT服务管理工具——ServiceDesk Plus&#xf…

elementui的el-tabs标签页样式修改

一、官网样式: 二、修改样式 1.去掉下划线 效果: 代码: /* 去掉tabs标签栏下的下划线 */ ::v-deep .el-tabs__nav-wrap::after {position: static !important;/* background-color: #fff; */ } 2.改变下划线颜色 效果: 代码:…

使用VisualStudio制作上位机(三)

文章目录 使用VisualStudio制作上位机(三)第三部分:GUI内部函数设计使用VisualStudio制作上位机(三) Author:YAL 第三部分:GUI内部函数设计 这一部分,主要实现CAN设备的打开 将CAN厂家的二次开发文件添加到工程里调用相关函数打开或关闭CAN首先,添加“类文件”,类主…

死锁的典型情况、产生的必要条件和解决方案

前言 死锁:多个线程同时被阻塞,他们中的一个或全部都在等待某个资源被释放。由于线程被无限期地阻塞,因此程序不可能正常终止。 目录 前言 一、死锁的三种典型情况 (一)一个线程一把锁 (二)…

聊一聊a_bogus

前言 可以关注我哟,一起学习,主页有更多练习例子 如果哪个练习我没有写清楚,可以留言我会补充 如果有加密的网站可以留言发给我,一起学习共享学习路程 如侵权,联系我删除 此文仅用于学习交流,请勿于商用&a…

保护隐私为先的话,最好是不登录用ChatGPT,6种方法助你轻松接入-纯分享

ChatGPT是OpenAI研发的强大AI语言模型,用户可以通过它进行有意义的对话,并获取问题解答。但是,一些用户可能更倾向于在不需要创建账号或不登录的情况下使用ChatGPT。在这篇指南中,我们将探讨各种无需账号即可访问ChatGPT的方法。无…

续二:《你的医书是假的!批评付施威的《DDD诊所——聚合过大综合症》

DDD领域驱动设计批评文集 “软件方法建模师”不再考查基础题 《软件方法》各章合集 我写了一篇文章,批评付施威的《DDD诊所——聚合过大综合症》(以下简称《DDD诊所》),文章是《你的医书是假的!批评付施威的《DDD诊…