YOLO11改进|注意力机制篇| 引入SpatialGroupEnhance注意力机制

news2025/1/11 18:40:02

在这里插入图片描述

目录

    • 一、【 SpatialGroupEnhance】注意力机制
      • 1.1【 SpatialGroupEnhance】注意力介绍
      • 1.2【SpatialGroupEnhance】核心代码
    • 二、添加【SpatialGroupEnhance】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【 SpatialGroupEnhance】注意力机制

1.1【 SpatialGroupEnhance】注意力介绍

在这里插入图片描述

下图是【SpatialGroupEnhance】的结构图,让我们简单分析一下运行过程和优势

处理过程分析

  • 输入特征图:

  • 输入的特征图大小为 𝐶×𝐻×𝑊,其中 𝐶是通道数,𝐻和 𝑊
    分别是特征图的高度和宽度。图片中展示了从图像输入到特征图处理的整个过程。

  • 全局平均池化:

  • 特征图首先经过 Global Average Pooling(全局平均池化) 操作,这一步将输入特征图沿空间维度(高度和宽度)进行池化,得到一个形状为 𝐶×1×1的全局特征向量。这个向量可以看作是整个特征图的全局描述,提取了通道维度上的全局信息。

  • 特征加权生成:

  • 全局池化的特征向量 𝑔 经过一系列处理,用来生成每个通道的注意力权重。该过程包括了以下几个步骤:
    归一化:通过某种形式的归一化操作(例如 Batch Normalization 或 Layer Normalization),对特征进行标准化处理。
    逐像素点积:然后通过点积操作来计算每个通道上的注意力权重,从而使每个通道得到独立的响应权重。
    Sigmoid 激活:为了保证加权系数在 0 到 1 之间,输出通过 Sigmoid 激活函数,将权重映射到标准化范围内。

  • 通道注意力生成:

  • 通过上述步骤生成的权重 𝛼被应用于原始的特征图 𝑋上。这个过程通过对每个通道进行重新加权,增强了那些对当前任务重要的特征,抑制了噪声或不相关的信息。

  • 特征更新:

  • 特征图通过全局上下文信息的调制,得到更新后的特征图 𝑋^ 。这使得模型能够更好地捕捉到全局的上下文信息,同时保留了空间上的细节。

优势分析

  • 全局上下文信息的有效捕捉:

  • 通过全局平均池化,模型可以获得整个输入特征图的全局表示。这种全局上下文信息在特征重新加权时起到关键作用,尤其是在图像分类、检测等任务中,能够帮助模型识别出全局结构上的特征,提高对图像整体内容的理解。

  • 通道间的依赖建模:

  • 模块通过生成特征图每个通道的注意力权重,增强了通道间的依赖性。这意味着模型能够自动为重要的特征分配更多的权重,而减少对无关特征的依赖,从而使得模型更加鲁棒。

  • 计算效率高:

  • 与传统的自注意力机制(如 Transformer 中的多头自注意力)相比,这种全局注意力机制计算复杂度较低。因为池化后的向量尺寸显著减小,使得后续的计算(如归一化、点积、Sigmoid)代价很小,适合大规模特征图的处理。

  • 轻量化且易于集成:

  • 该模块结构简单,依赖基本的卷积、池化和激活操作,易于集成到现有的卷积神经网络中,如 ResNet、VGG 等。它不会大幅增加模型的参数量,因此能够在不明显增加计算资源的情况下,提升模型的表现。

  • 动态自适应特征增强:

  • 该模块能够根据输入特征的内容,自适应地调整每个通道的权重。这意味着它可以根据不同输入动态调节网络的响应能力,使得模型在各种复杂场景下都能具有较强的适应性。在这里插入图片描述

1.2【SpatialGroupEnhance】核心代码

import numpy as np
import torch
from torch import nn
from torch.nn import init
 
 
class SpatialGroupEnhance(nn.Module):
    def __init__(self, groups=8):
        super().__init__()
        self.groups = groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.weight = nn.Parameter(torch.zeros(1, groups, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, groups, 1, 1))
        self.sig = nn.Sigmoid()
        self.init_weights()
 
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
 
    def forward(self, x):
        b, c, h, w = x.shape
        x = x.view(b * self.groups, -1, h, w)  # bs*g,dim//g,h,w
        xn = x * self.avg_pool(x)  # bs*g,dim//g,h,w
        xn = xn.sum(dim=1, keepdim=True)  # bs*g,1,h,w
        t = xn.view(b * self.groups, -1)  # bs*g,h*w
 
        t = t - t.mean(dim=1, keepdim=True)  # bs*g,h*w
        std = t.std(dim=1, keepdim=True) + 1e-5
        t = t / std  # bs*g,h*w
        t = t.view(b, self.groups, h, w)  # bs,g,h*w
 
        t = t * self.weight + self.bias  # bs,g,h*w
        t = t.view(b * self.groups, 1, h, w)  # bs*g,1,h*w
        x = x * self.sig(t)
        x = x.view(b, c, h, w)
        return x
 
 
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p
 
 
class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
    default_act = nn.SiLU()  # default activation
 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
 
    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))
 
 
class Bottleneck(nn.Module):
    """Standard bottleneck."""
 
    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
        expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
 
    def forward(self, x):
        """'forward()' applies the YOLO FPN to input data."""
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
 
 
class C2f(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""
 
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
 
    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
 
    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
 
 
class C2f_SGE(nn.Module):
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, groups=8):
        """Initialize CSP bottleneck layer with two convolutions and SpatialGroupEnhance."""
        super().__init__()
        self.c2f = C2f(c1, c2, n, shortcut, g, e)
        self.sge = SpatialGroupEnhance(groups)
 
    def forward(self, x):
        """Forward pass through C2f layer and then SpatialGroupEnhance."""
        x = self.c2f(x)
        x = self.sge(x)
        return x
 
 
if __name__ == '__main__':
    input = torch.randn(50, 512, 7, 7)
    c2f_sge = C2f_SGE(c1=512, c2=512, n=1, shortcut=False, g=1, e=0.5, groups=8)
    output = c2f_sge(input)
    print(output.shape)
 
 
 

二、添加【SpatialGroupEnhance】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个SpatialGroupEnhance.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【SpatialGroupEnhance】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1,1,SpatialGroupEnhance,[]]
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)


  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【SpatialGroupEnhance】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2202011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode:反转字符串II

题目链接 string reverse(string s1) {string s2;string::reverse_iterator rit s1.rbegin();while (rit ! s1.rend()){s2 *rit;rit;}return s2; } class Solution { public:string reverseStr(string s, int k) {string s1;int i 0;//标记字符串下标int j 0;int length …

程序员在AI时代扮演着多重角色:不仅是AI技术的创造者,也是使用者,更是AIGC的贡献者

程序员在AI时代扮演着多重角色:不仅是AI技术的创造者,也是使用者,更是AIGC的贡献者 引言 大家好,我是猫头虎,在当下的AI时代,程序员不仅是AI技术的推动者,更在这个生态中扮演着多重角色&#…

ARM base instruction -- sdiv

有符号除法运算 Signed Divide divides a signed integer register value by another signed integer register value, and writes the result to the destination register. The condition flags are not affected. 将一个有符号整数寄存器值除以另一个有符号整数寄存器值&am…

YOLOv11训练自己数据集_笔记1

一、前言 官网yolov11-main yolov11代码地址 分析YOLO11的关键改进点 YOLO11 相比之前版本,带来了五大关键改进: 增强特征提取:通过改进Backbone和Neck架构,新增了C3k2和C2PSA等组件,提升了目标检测的精度。 优化效率…

python使用装饰器来统计函数被调用次数、格式化dict以及Python-smtplib邮件发送的IP name possibly forged问题解决

一、python调试:使用装饰器来统计函数被调用次数及格式化dict 喜欢调试的时候显示数据并显示一些其它的信息,比如区分是哪次调用的调试信息,比如友好的显示dict等相对复杂的数据类型,所以这里涉及到两个方面。一是统计函数被调用次…

Centos再生龙系统迁移

Centos再生龙系统迁移 1.准备工作1.1rufus镜像刻录软件1.2再生龙镜像1.3硬盘和U盘2.准备u盘启动工具2.1刻录再生龙镜像3.备份系统3.1选择U盘启动3.2选择分辨率3.3选择中文3.4选择默认键盘配置3.5使用再升龙3.6选择第一个,device-image硬盘/分区存到/来自镜像文件3.7选择local_…

运维问题0004:MM模块-操作MIGO过账报错“对象OFN_YR 2840 WE2840 的编码范围没有找到”

1、问题分析 当在SAP系统MM模块的MIGO过账时出现“对象OFN_YR 2840 WE2840的编码范围没有找到”错误,这通常是因为系统配置中缺少对应的编码范围。先来分析一下报错消息号信息:OFN_YR是后台自动凭证编号范围配置的事务代码;2840是工厂名称 ;WE2840是指接…

地理空间数据共享资源,爱好者进

地理空间数据共享资源,爱好者进 1、1986–2021年中国30m逐年耕地数据集 由于农田的空间和时间模式对食品安全和环境可持续性的供应至关重要,因此长期和准确的农田监测非常重要。研究团队开发了一种新颖、经济的年度农田映射框架,该框架集成…

【C++】适配器stack/queue/priority_queue使用和实现

目录 容器适配器 什么是容器适配器 ​编辑stack stack的了解和使用 使用举例 题目加深 模拟实现 功能实现 测试文件 ​编辑queue queue的了解和使用 使用举例 题目加深 模拟实现 功能实现 测试文件 priority_queue priority_queue的了解和使用 使用举…

基于SpringBoot项目评审系统【附源码】

基于SpringBoot项目评审系统 效果如下: 系统首页界面 学生登录界面 项目信息页面 项目申报页面 专家注册界面 管理员登录界面 管理员功能界面 项目评审界面 评审结果界面 研究背景 在当今快速发展的信息时代,项目评审作为项目管理的关键环节&#xff…

【网络】初识https协议加密过程

初识https协议加密过程 为什么不用http而要使用https常见的加密方式对称加密非对称加密数据摘要&&数据指纹 https的工作过程探究方案一:只使用对称加密方案二:只使用非对称加密方案三:双方都使用对称加密方案四:非对称加密…

毕设分享 基于协同过滤的电影推荐系统

文章目录 0 简介1 设计概要2 课题背景和目的3 协同过滤算法原理3.1 基于用户的协同过滤推荐算法实现原理3.1.1 步骤13.1.2 步骤23.1.3 步骤33.1.4 步骤4 4 系统实现4.1 开发环境4.2 系统功能描述4.3 系统数据流程4.3.1 用户端数据流程4.3.2 管理员端数据流程 4.4 系统功能设计 …

四、Java 概念知识简单了解

一、Java 的类、对象、方法和实例变量 一个 Java 程序可以认为是一系列对象的集合,而这些对象通过调用彼此的方法来协同工作。下面简要介绍下类、对象、方法和实例变量的概念。对象:对象是类的一个实例,有状态(实例变量&#xff…

嵌入式面试——FreeRTOS篇(三) 消息队列和队列集

本篇为:消息队列和队列集篇 消息队列 1、FreeRTOS中的消息队列是什么 答: 消息队列是任务到任务、任务到中断、中断到任务数据交流的一种机制(消息传递)。 2、消息队列和全局变量的区别 答: 消息队列作用有点类似于全局变量,但消…

简单的网络爬虫爬取视频

示例代码爬取一个周杰伦相关视频 import requests# 自己想下载的视频链接 video_url https://vdept3.bdstatic.com/mda-qg8cnf4bw5x6bjs5/cae_h264/1720516251158906693/mda-qg8cnf4bw5x6bjs5.mp4?v_from_shkapp-haokan-hbf&auth_key1728497433-0-0-4a32e13f751e04754e4…

谷歌上架,应用明明没问题,咋就成了“恶意软件”?看看可能的原因

作为Google Play上架应用的开发者,大家的普遍感受:比起写代码,上架的过程简直更让人心力交瘁!特别是涉及用户数据和隐私保护的时候,稍有疏忽,就可能面临应用被下架、甚至账号被封的风险。 最近听到很多开发…

6.存储过程中的游标使用(6/10)

存储过程中的游标使用 引言 在数据库编程中,游标(Cursor)是一种重要的数据库对象,它允许开发者逐行处理查询结果集。这对于需要对每一行数据进行特定处理的场景非常有用,如数据转换、数据清洗、复杂计算等。本文将详…

Qt 与 GTK:跨平台 GUI 开发利器,可用Python助力高效GUI编程

在现代软件开发中,图形用户界面 (GUI) 至关重要,它直接影响用户体验和软件的易用性。Qt 和 GTK 作为两种主流的跨平台 GUI 库,为开发者提供了构建精美且功能强大的应用程序的强大工具。本文将深入介绍 Qt 和 GTK 的特性,并探讨如何…

SwiftUI 6.0(iOS 18)新增的网格渐变色 MeshGradient 解惑

概述 在 SwiftUI 中,我们可以借助渐变色(Gradient)来实现更加灵动多彩的着色效果。从 SwiftUI 6.0 开始,苹果增加了全新的网格渐变色让我们对其有了更自由的定制度。 因为 gif 格式图片自身的显示能力有限,所以上面的…

springboot网站开发-mysql数据库字段varchar类型存储汉字的长度关系

springboot网站开发-mysql数据库字段varchar类型存储汉字的长度关系! 如果你的数据表是utf-8编码,并且采用的是mysql数据库。设计自己的业务数据。那么,如果你采用是varchar类型的字段格式,一个汉字就是一个字节。 如图所示&#…