YOLOv8 | 注意力机制 | ShuffleAttention注意力机制 提升检测精度

news2025/1/18 19:07:11

YOLOv8成功添加ShuffleAttention


⭐欢迎大家订阅我的专栏一起学习⭐

🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀
       YOLOv5涨点专栏:http://t.csdnimg.cn/1Aqzu

YOLOv8涨点专栏:http://t.csdnimg.cn/jMjHb

YOLOv7专栏:http://t.csdnimg.cn/yhXBl

💡魔改网络、复现论文、优化创新💡

 

目录

原理 

代码实现

yaml文件实现

完整代码分享

启动命令

注意事项


注意力机制使神经网络能够准确地关注输入的所有相关元素,已成为提高深度神经网络性能的重要组成部分。计算机视觉研究中广泛使用的注意力机制主要有两种:空间注意力和通道注意力,其目的分别是捕获像素级的成对关系和通道依赖性。虽然将它们融合在一起可能会比它们单独的实现获得更好的性能,但它不可避免地会增加计算开销。高效的ShuffleAttention(SA)模块可以解决这个问题,它采用ShuffleAttention单元来有效地结合两种类型的注意机制。具体来说,SA 首先将通道维度分组为多个子特征,然后并行处理它们。然后,对于每个子特征,SA 利用洗牌单元来描述空间和通道维度上的特征依赖性。之后,所有子特征被聚合,并采用“通道洗牌”算子来实现不同子特征之间的信息通信。 SA 模块高效且有效,例如,SA 针对主干 ResNet50 的参数和计算量分别为 300 vs. 25.56M 和 2.76e-3 GFLOPs vs. 4.12 GFLOPs,并且性能提升超过 1.34% Top-1 准确度方面。

使用 ResNets 作为主干的 ImageNet-1k 上最近的 SOTA 注意力模型(包括 SENet、CBAM、ECA-Net、SGE-Net 和 SA-Net)在准确性、网络参数和 GFLOP 方面的比较。圆圈的大小表示 GFLOP。显然,所提出的 SA-Net 实现了更高的精度,同时模型复杂度更低 

原理 

首先介绍构建SA模块的过程,该模块将输入特征图分组,并使用Shuffle Unit将通道注意力和空间注意力整合到每个组的一个块中。之后,所有子特征被聚合,并利用“通道洗牌”算子来实现不同子特征之间的信息通信。然后,我们展示如何在深度 CNN 中采用 SA。最后,我们可视化效果并验证所提出的 SA 的可靠性。 SA模块整体架构如图所示

Shuffle Attention结构图

它采用“通道分割”并行处理各组的子特征。对于通道注意力分支,使用 GAP 生成通道统计量,然后使用一对参数来缩放和移动通道向量。对于空间注意力分支,采用群范数生成空间统计量,然后创建类似于通道分支的紧凑特征。然后将两个分支连接起来。之后,所有子特征被聚合,最后我们利用“通道洗牌”运算符来实现不同子特征之间的信息通信。 

完全捕获通道依赖性的一个选项是利用SE块。然而,它会带来太多的参数,这不利于在速度和准确性之间进行权衡,设计更轻量级的注意力模块。此外,像 ECA 一样,不适合通过执行大小为 k 的更快一维卷积来生成通道权重,因为 k 往往会更大。为了改进,我们提供了一种替代方案,首先通过简单地使用全局平均池化(GAP)来嵌入全局信息,生成通道统计量 s ∈ RC/2G×1×1,可以通过空间维度缩小Xk1 来计算高×宽此外,还创建了一个紧凑的功能来指导精确和自适应的选择。这是通过带有 sigmoid 激活的简单门控机制来实现的。

与通道注意力不同,空间注意力关注的是“哪里”,是信息性的部分,与通道注意力是互补的。首先,我们在 Xk2 上使用群范数 (GN)来获得空间统计数据。然后,采用Fc(·)来增强^ Xk2的表示。

之后,所有子特征都被聚合。最后,与ShuffleNet v2 类似,我们采用“channel shuffle”算子来实现沿通道维度的跨组信息流。 SA模块的最终输出与X的大小相同,使得SA非常容易与现代架构集成

代码实现
class ShuffleAttention(nn.Module):

    def __init__(self, channel=512, reduction=16, G=8):
        super().__init__()
        self.G = G
        self.channel = channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        # group into subfeatures
        x = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w

        # channel_split
        x_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w

        # channel attention
        x_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1
        x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1
        x_channel = x_0 * self.sigmoid(x_channel)

        # spatial attention
        x_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,w
        x_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,w
        x_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,w
        out = out.contiguous().view(b, -1, h, w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out
yaml文件实现
# Ultralytics YOLO  , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 6  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [-1, 3, ShuffleAttention, [1024]]

  - [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)
完整代码分享

链接: https://pan.baidu.com/s/1NPb6C6svuNGqyZIYUVgsVw?pwd=yjnw 提取码: yjnw

启动命令
yolo detect train model=/path/yolov8_ShuffleAttention.yaml data=/path/coco128.com
注意事项

如果报错,查看这篇文章

YOLOv8 | 添加注意力机制报错KeyError:已解决,详细步骤_yolov8 keyerroe-CSDN博客

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1538921.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣236 二叉树的最近公共祖先 Java版本

文章目录 题目描述代码 题目描述 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个节点 p、q,最近公共祖先表示为一个节点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大&…

PDF文件如何以数字进行批量重命名?以数字重命名的PDF文件

在日常生活和工作中,我们经常需要处理大量的PDF文件,如文档、报告、合同等。为了更高效地管理这些文件,一个有效的方式就是对它们进行批量命名。批量命名不仅能提高文件的组织性,还能节省大量时间。下面,我们将详细介绍…

【数据分析案列】--- 北京某平台二手房可视化数据分析

一、引言 本案列基于北京某平台的二手房数据,通过数据可视化的方式对二手房市场进行分析。通过对获取的数据进行清冼(至关重要),对房屋价格、面积、有无电梯等因素的可视化展示,我们可以深入了解北京二手房市场的特点…

docker学习笔记 三-----docker安装部署

我使用的部署环境是centos 7.9 1、安装依赖工具 yum install -y yum-utils device-mapper-persistent-data lvm2 安装完成如下图 2、添加docker的软件信息源 yum-config-manager --add-repo https://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo url地址为如…

飞桨AI应用@riscv OpenKylin

在riscv编译安装飞桨PaddlePaddle参见: 算能RISC-V通用云编译飞桨paddlepaddleopenKylin留档_在riscv下进行paddlelite源码编译-CSDN博客 安装好飞桨,就可以用飞桨进行推理了。刚开始计划用ONNX推理,但是在算能云没有装上,所以最…

C语言——程序拷贝文件

问题如下: 写一个程序拷贝文件: 使用所学文件操作,在当前目录下放一个文件data.txt,写一个程序,将data.txt文件拷贝一份,生成data_copy.txt文件。 基本思路: 打开文件data.txt,读…

服务器中了.[hpssupfast@mailfence.com].Elbie勒索病毒,数据还能恢复吗?

引言: .[hpssupfastmailfence.com].Elbie勒索病毒是一种网络攻击病毒,它会在感染用户的计算机系统中放置恶意软件,该软件会对用户的文件进行加密并要求支付赎金以解密文件。这种病毒通常通过网络钓鱼、木马植入等方式传播,利用用户…

【Ubuntu 22.04 LTS】安装vmware提示没有兼容的gcc

在ubuntu 22.04 上运行wmware时显示找不到兼容的gcc 这里要求的是12.3.0版本,我查看了自己的gcc版本是上面的11.4.0 在ask ubuntu上找到了解决方法 尝试了这一条 三条命令执行完成之后,再次运行vm,没有提示gcc的问题 点击install下载相应模…

Unity vision pro模拟器开发教程-附常见问题解决方案

前言 庄生晓梦迷蝴蝶,望帝春心托杜鹃 废话 去年苹果发布会上,推出了Vision Pro这一款XR产品。并且宣布Unity作为其主要合作伙伴,负责开发XR的开发产品。 这消息一出,当晚Unity的股价直接被熔断。产品发布之后,一直等…

深度强化学习(九)(改进策略梯度)

深度强化学习(九)(改进策略梯度) 一.带基线的策略梯度方法 Theorem: 设 b b b 是任意的函数, b b b与 A A A无关。把 b b b 作为动作价值函数 Q π ( S , A ) Q_\pi(S, A) Qπ​(S,A) 的基线, 对策略梯度没有影响: ∇ θ J …

实例:NX二次开发使用链表进行拉伸功能(链表相关功能练习)

一、概述 在进行批量操作时经常会利用链表进行存放相应特征的TAG值,以便后续操作,最常见的就是拉伸功能。这里我们以拉伸功能为例子进行说明。 二、常用链表相关函数 UF_MODL_create_list 创建一个链表,并返回链表的头指针。…

Codeforces Round #936 (Div. 2)B~D

1946B - Maximum Sum 可以想到,每次都将最大连续子序列放到该子序列的最后,也就是每一轮都能将最大连续子序列倍增一次填到数组中,最终求结果 // Problem: B. Maximum Sum // Contest: Codeforces - Codeforces Round 936 (Div. 2) // URL: …

【Flink】Flink 中的时间和窗口之窗口其他API的使用

1. 窗口的其他API简介 对于一个窗口算子而言,窗口分配器和窗口函数是必不可少的。除此之外,Flink 还提供了其他一些可选的 API,可以更加灵活地控制窗口行为。 1.1 触发器(Trigger) 触发器主要是用来控制窗口什么时候…

算法系列--动态规划--子序列(2)

💕"你可以说我贱,但你不能说我的爱贱。"💕 作者:Mylvzi 文章主要内容:算法系列–动态规划–子序列(2) 今天带来的是算法系列--动态规划--子序列(2),包含了关于子序列问题中较难的几道题目(尤其是通过二维状…

uni-app打包证书android

Android平台打包发布apk应用,需要使用数字证书(.keystore文件)进行签名,用于表明开发者身份。 Android证书的生成是自助和免费的,不需要审批或付费。 可以使用JRE环境中的keytool命令生成。 以下是windows平台生成证…

springboot实现文件上传

SpringBoot默认静态资源访问方式 首先想到的就是可以通过SpringBoot通常访问静态资源的方式,当访问:项目根路径 / 静态文件名时,SpringBoot会依次去类路径下的四个静态资源目录下查找(默认配置)。 在资源文件resour…

极大提高工作效率的 Linux 命令

作为一名软件开发人员,掌握 Linux 命令是必不可少的技能。即使你使用 Windows 或 macOS,你总会遇到需要使用 Linux 命令的场合。例如,大多数 Docker 镜像都基于 Linux 系统。要进行 DevOps 工作,你需要熟悉Linux,至少要…

Redis中的缓存穿透

缓存穿透 缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在,导致这些请求直接到了数据库上,对数据库造成了巨大的压力,可能造成数据库宕机。 常见的解决方案: 1)缓存无效 key 如果缓存和数据库中都查不到某…

【漏洞复现】WordPress Plugin NotificationX 存在sql注入CVE-2024-1698

漏洞描述 WordPress和WordPress plugin都是WordPress基金会的产品。WordPress是一套使用PHP语言开发的博客平台。该平台支持在PHP和MySQL的服务器上架设个人博客网站。WordPress plugin是一个应用插件。 WordPress Plugin NotificationX 存在安全漏洞,该漏洞源于对用户提供的…

校招免费资料大集合

通过以下资料,你可以免费获取到大量的校招资料和相关信息,帮助你更好地准备校园招聘。 学习交流群:进行计算机知识分享和交流,提供内推机会,QQ群号:325280438 夏沫Coding:致力于分享计算机干货…