YOLOv5、YOLOv8改进:添加ShuffleAttention注意力机制

news2025/2/26 8:13:39

 广泛应用的注意力机制主要有空间注意力机制和通道注意力机制,其目的分别是捕捉像素级的成对关系和通道依赖关系。虽然将两种机制融合在一起可以获得比单独更好的性能,但计算开销不可避免。因而,本文提出Shuffle Attetion,即SA,以解决这个问题。SA采用Shuffle Units有效地结合两种注意力机制。具体来说,SA首先将通道维度分组为多个子特征,然后并行处理。对于每个子特征,SA利用Shuffle Units描述空间和通道维度上的特征依赖关系,然后对所有子特征进行聚合,采用channel shuffle算子实现不同子特征间的信息通信。

1.1 背景
注意力机制分为两种,通道注意和空间注意,都是通过不同的聚集策略、转换和强化函数,从所有位置聚集相同的特征来增强原始特征。
GCNet和CBAM将通道注意和空间注意整合到一个模块中,在取得显著结果的同时,也存在存在收敛困难或计算负担过重的问题。
ECA-Net通过使用一维卷积简化了SE模块中信道权重的计算过程;SGE将通道的维度划分为多个子特征表示不同的语义,使用注意掩码在所有位置缩放特征向量,对每个特征组应用空间机制。然而这些网络没有充分利用空间注意和通道注意之间的相关性,致使其效率较低。

本文提出Shuffle Attention 模块,即SA。该模块按通道维度分组为子特征。对于每个子特征,SA采用Shuffle Unit同时构建通道注意和空间注意。对于每个注意模块,本文设计了一个覆盖所有位置的注意掩模,以抑制可能的噪声并突出正确的语义特征区域。

2.1 多分支架构
多分支架构背后的原理是“分裂-转换-合并”,即split-transform-merge,这降低了训练数百层网络的难度。
InceptionNet系列是成功的多分支架构,其中每个分支都精心配置了自定义的内核过滤器,以便聚合更多信息和多种功能;ResNets也可以看作是两个分支网络,其中一个分支是身份映射;SKNets和ShuffleNet系列都遵循了InceptionNets的思想,为多个分支提供了各种过滤器,但SKNets利用自适应选择机制来实现神经元的自适应感受野大小,ShuffleNets进一步将channel split和channel shuffle操作合并为单元素操作,以在速度和准确性之间进行权衡。

2.2 特征分组
特征分组学习可追溯至AlexNet,其动机是将模型分配至更多的GPU上。
MobileNets和ShuffleNets将每个通道视为一个组;CapsuleNets将每个分组神经元建模为一个capsule,其中活动capsule中的神经元代表图像中特定实体的各种属性;SGE优化了CapsuleNets,按通道维度划分多个子特征来学习不同的语义。

2.3 注意力机制
注意力机制突出信心量大的特征,而抑制不太有用的特征;自注意力机制会计算一个位置的上下文,作为图像中所有位置的加权和。
SE使用2个FC层对通道关系进行建模;ECA-Net采用一维卷积生成通道权值,降低了SE的复杂度;NL(non-local)通过计算特征图中空间点的关系矩阵生成注意力图;CBAM、GCNet和SGE将空间注意和通道注意顺序结合;DANet通过将来自不同分支的两个注意力模块相加,自适应地将局部特征与其全局依赖结合起来。


 

下面附上修改代码

有不明白的可以评论区留言

首先在common中添加一下模块

class ShuffleAttention(nn.Module):

    def __init__(self, channel=512,reduction=16,G=8):
        super().__init__()
        self.G=G
        self.channel=channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid=nn.Sigmoid()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)


    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        #group into subfeatures
        x=x.view(b*self.G,-1,h,w) #bs*G,c//G,h,w

        #channel_split
        x_0,x_1=x.chunk(2,dim=1) #bs*G,c//(2*G),h,w

        #channel attention
        x_channel=self.avg_pool(x_0) #bs*G,c//(2*G),1,1
        x_channel=self.cweight*x_channel+self.cbias #bs*G,c//(2*G),1,1
        x_channel=x_0*self.sigmoid(x_channel)

        #spatial attention
        x_spatial=self.gn(x_1) #bs*G,c//(2*G),h,w
        x_spatial=self.sweight*x_spatial+self.sbias #bs*G,c//(2*G),h,w
        x_spatial=x_1*self.sigmoid(x_spatial) #bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out=torch.cat([x_channel,x_spatial],dim=1)  #bs*G,c//G,h,w
        out=out.contiguous().view(b,-1,h,w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out

 yolo.py中注册

  elif m in [S2Attention, SimSPPF, ACmix, CrissCrossAttention, SOCA, ShuffleAttention, SEAttention, SimAM, SKAttention]:

修改yaml文件  以yolov5s为列

# YOLOv5 🚀 by YOLOAir, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, ShuffleAttention, [1024]],

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

以上就是完整的改法

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/862879.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣:59. 螺旋矩阵 II(Python3)

题目: 给你一个正整数 n ,生成一个包含 1 到 n2 所有元素,且元素按顺时针顺序螺旋排列的 n x n 正方形矩阵 matrix 。 来源:力扣(LeetCode) 链接:力扣(LeetCode)官网 - 全…

【1572. 矩阵对角线元素的和】

来源:力扣(LeetCode) 描述: 给你一个正方形矩阵 mat,请你返回矩阵对角线元素的和。 请你返回在矩阵主对角线上的元素和副对角线上且不在主对角线上元素的和。 示例 1: 输入:mat [[1,2,3]…

AtCoder Beginner Contest 313D题题解

文章目录 [ Odd or Even](https://atcoder.jp/contests/abc313/tasks/abc313_d)问题建模问题分析1.分析每次查询的作用2.利用异或运算的性质设计查询方法 Odd or Even 问题建模 有n个数,每个数为0或者1,最多可以进行n次询问,每次询问选择k个…

第04天 Spring是如何解决循环依赖的

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏:每天一个知识点 ✨特色专栏&#xff1a…

推出稳定代码:人工智能辅助编码的新视野

推荐:使用 NSDT场景编辑器 快速助你搭建可二次编辑的3D应用场景 在不断发展的软件开发环境中,对效率和可访问性的追求导致了各种工具和平台的创建。最新的创新之一是StableCode,这是Stability AI的大型语言模型(LLM)生…

Flv格式视频怎么转MP4?视频格式转换方法分享

FLV格式的视频是一种早期的视频格式,不支持更高的分辨率和比特率,这意味着视频的清晰度和质量受限制,无法很好地保留细节和质量,这种格式的视频已经逐渐被更高质量的视频格式所替代,例如MP4格式,不仅具有很…

创新不辍,再结硕果 | 蓝奥声“无线联动监控技术”

随着无线电通信技术的迅速发展,无线远程监控系统也得到了技术上的更新,它将嵌入式产品与现代无线通信技术相结合,共同构成了一种新型的监测控制系统。物联网及其相关无线联动通信技术是智能科技快速发展的重要支撑技术之一,由此带…

主流国产GPU产品及规格概述(2023)

​ 美国对 AI 芯片出口管制,自主可控要求下国产芯片需求迫切。2022 年 10 月 7 日美国商务部工业安全局(BIS)发布《美国商务部对中华人民共和国(PRC)关于先进计算和半导体实施新的出口管制制造》细则,其中管…

复古游戏库管理器RomM

什么是 RomM ? RomM(代表 Rom Manager)是一个专注于复古游戏的游戏库管理器。通过 Web 浏览器管理和组织您的所有游戏。受 Jellyfin 的启发,允许您从现代界面管理所有游戏,同时使用 IGDB 元数据丰富它们。 RomM 支持的…

电脑自动重启是什么原因?这几个原因不可忽视!

“感觉我的电脑也没有用多久呀,怎么总是会出现自动重启的情况呢?由于我对电脑不是很熟悉,都不知道该如何解决这个问题,有没有朋友可以解释一下这是为什么呀?“ 在使用电脑时,如果电脑总是自动重启&#xff…

MySQL_索引的使用与设计

最左前缀法则 最左前缀法则适用于联合索引;查询从索引的最左列开始,不跳过其中的列,如果跳过其中的列将会导致索引失效(后面字段的索引失效)。 验证最左前缀法则 三个列的联合索引都同时使用 explain select * from u…

基于R做宏基因组的进化树ClusterTree分析

写在前面 同上一篇的PCoA分析,这个也是基于公司结果基础上的再次分析,重新挑选样本,在公司结果提供的csv结果表上进行删减,本地重新分析作图 步骤 表格预处理 在公司给的ClusterTree的原始表格数据里选取要保留的样本&#xf…

腾讯云轻量应用服务器CPU配置?主频性能

腾讯云轻量应用服务器CPU型号是什么?处理器主频多少?轻量应用服务器不支持指定CPU处理器型号,目前腾讯云服务器网账号下的轻量应用服务器,CPU采用2.5GHz主频的Intel(R) Xeon(R) Gold 6133 处理器,睿频 3.0GHz&#xff…

CTFSHOW php命令执行

目录 web29 过滤flag web30 过滤system php web31 过滤 cat|sort|shell|\. 这里有一个新姿势 可以学习一下 web32 过滤 ; . web33 web34 web35 web36 web37 data伪协议 web38 短开表达式 web39 web40 __FILE__命令的扩展 web41 web42 重定向…

对话即数据分析,网易数帆ChatBI做到了

大数据产业创新服务媒体 ——聚焦数据 改变商业 在当今数字化快速发展的时代,数据已经成为业务经营与管理决策的核心驱要素。无论是跨国大企业还是新兴创业公司,正确、迅速地洞察数据已经变得至关重要。然而,传统的BI工具往往对用户有一定的…

YOLOv5、YOLOv8改进:SEAttention 通道注意力机制

基于通道的注意力机制 源自于 CVPR2018: Squeeze-and-Excitation Networks 官方代码:GitHub - hujie-frank/SENet: Squeeze-and-Excitation Networks 如图所示,其实就是将不同的通道赋予相关的权重。Attention机制用到这里用朴素的话说就是,…

【论文阅读】基于深度学习的时序预测——FEDformer

系列文章链接 论文一:2020 Informer:长时序数据预测 论文二:2021 Autoformer:长序列数据预测 论文三:2022 FEDformer:长序列数据预测 论文地址:https://arxiv.org/abs/2201.12740 github地址&a…

Cobbler自定义yum源

再次了解下Cobbler的目录结构: 在/var/www/cobbler/ks_mirror目录下存放的是所有的镜像。 存放的是仓库镜像: 在/var/lib/cobbler/kickstarts目录下是存放的所有的kickstarts文件。 再有就是/etc/cobbler这个目录: [rootvm1 loaders]# cd /…

mysql延时问题排查

背景介绍 最近遇到一个奇怪的问题,有个业务,每天早上七点半产生主从延时,延时时间12.6K; 期间没有抽数/备份等任务;查看慢日志发现,期间有一个delete任务,在主库执行了161s delete from xxxx_…

LeetCode算法心得——故障键盘(StringBuilder)

大家好,我是晴天学长,很久都没有用StringBuilder类了,切记这个自带字符串反转的方法,会在实际比赛中节约不少的时间。 1 )故障键盘 2) .算法思路 故障键盘 1.首先把全部字母给你的了 2.只会反转前面的字符 1.字符串…