YOLO11改进|注意力机制篇|引入全局上下文注意力机制GCA

news2024/11/18 6:15:33

在这里插入图片描述

目录

    • 一、【】注意力机制
      • 1.1【GCA】注意力介绍
      • 1.2【GCA】核心代码
    • 二、添加【GCA】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【】注意力机制

1.1【GCA】注意力介绍

在这里插入图片描述

下图是【GCA】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 输入特征图
  • 输入的特征图大小为 𝐶×𝐻×𝑊,其中 𝐶 是通道数𝐻和 𝑊分别是特征图的高度和宽度。
  • Context Modeling(上下文建模):
  • 首先,输入特征图通过一个 1×1 卷积层(𝑊𝑘),这一步将每个像素点的特征压缩并降低维度,卷积输出的大小仍为 𝐶×𝐻×𝑊。
  • 之后,将特征图进行降维操作,应用 Softmax 函数,生成全局的上下文注意力图。这一步的结果是对特征图的全局加权,使得每个位置能够捕捉到整个图像范围内的全局上下文信息。注意力图应用于输入特征图,进行全局加权调整,得到全局上下文的增强特征图。
  • Transform(特征变换):
  • 第一步卷积操作(1×1 Conv):首先应用另一个 1×1 卷积层(𝑊1 )来重新映射全局上下文的特征,这步操作不会改变特征图的空间维度。
  • LayerNorm 和 ReLU:然后通过 Layer Normalization(层归一化)和 ReLU 激活函数来标准化并引入非线性变换,从而增强特征的表达能力。
  • 第二步卷积操作(1×1 Conv):接着,经过一个第二次的 1×1 卷积层(𝑊2)来进一步处理和压缩特征。
  • 残差连接:
  • 经过全局上下文增强的特征与原始输入特征通过 残差连接(Skip Connection) 进行叠加,这种操作保留了原始特征的同时,融入了全局上下文的增强信息,从而提高模型的鲁棒性和特征表达能力。
  • 输出特征图:
  • 输出特征图仍然是 𝐶×𝐻×𝑊,与输入保持相同的形状,但已经通过全局上下文增强和特征变换的处理。
    优势
  • 全局上下文捕捉:GC 模块的最大优势在于能够通过 Softmax 注意力机制对整个特征图进行全局建模。相比传统的局部卷积操作,该模块可以捕捉到图像全局的上下文信息,有助于识别跨越远距离的依赖关系,尤其是在处理复杂场景时,能够提高目标识别的精度。
  • 计算效率高:
    虽然该模块引入了注意力机制,但它通过 1×1 卷积来进行维度的压缩,降低了计算复杂度。相比标准的多头自注意力机制(如 Transformer),GC 模块在计算复杂度上更加友好,非常适合嵌入在卷积神经网络中。
  • 残差连接增强特征表达:
    残差连接使得原始特征得以保留,确保全局上下文信息的增强不会损失原有的特征表达能力。这种机制可以缓解梯度消失的问题,促进更深层次的网络结构训练。
  • 通用性强:
    由于 GC 模块使用的是标准的卷积和简单的注意力机制,它可以方便地嵌入到各种神经网络结构中,例如 ResNet、VGG 等,用于提升模型的全局信息感知能力。
  • 适应性好:
    该模块通过 Softmax 注意力机制学习全局上下文的加权方式,能够动态适应不同输入特征,从而使模型在不同场景下有更好的泛化能力和适应性。
    在这里插入图片描述

1.2【GCA】核心代码

import torch  # 导入 PyTorch
from torch import nn  # 从 PyTorch 导入神经网络模块

# 定义 ContextBlock 类,继承自 nn.Module
class ContextBlock(nn.Module):
    def __init__(self, inplanes, ratio, pooling_type='att', fusion_types=('channel_add',)):
        super(ContextBlock, self).__init__()
        
        # 验证 fusion_types 是否有效
        valid_fusion_types = ['channel_add', 'channel_mul']

        # 检查 pooling_type 是否在有效选项 ['avg', 'att'] 中
        assert pooling_type in ['avg', 'att']
        # 确认 fusion_types 是列表或元组
        assert isinstance(fusion_types, (list, tuple))
        # 确认 fusion_types 中的所有元素都在 valid_fusion_types 中
        assert all([f in valid_fusion_types for f in fusion_types])
        # 确保至少有一个 fusion 类型被指定
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes  # 输入通道数
        self.ratio = ratio  # 缩减比例
        self.planes = int(inplanes * ratio)  # 缩减后的通道数
        self.pooling_type = pooling_type  # 池化类型('avg' 或 'att')
        self.fusion_types = fusion_types  # 融合类型

        # 如果池化类型为 'att',定义注意力池化的卷积层和 softmax
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)  # 在维度 2 上做 softmax
        else:
            # 如果池化类型为 'avg',定义自适应平均池化层
            self.avg_pool = nn.AdaptiveAvgPool2d(1)

        # 定义 'channel_add' 融合类型的卷积层序列
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),  # 1x1 卷积
                nn.LayerNorm([self.planes, 1, 1]),  # 层归一化
                nn.ReLU(inplace=True),  # 激活函数 ReLU
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1)  # 1x1 卷积,恢复到原通道数
            )
        else:
            self.channel_add_conv = None

        # 定义 'channel_mul' 融合类型的卷积层序列
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),  # 1x1 卷积
                nn.LayerNorm([self.planes, 1, 1]),  # 层归一化
                nn.ReLU(inplace=True),  # 激活函数 ReLU
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1)  # 1x1 卷积,恢复到原通道数
            )
        else:
            self.channel_mul_conv = None

    # 定义空间池化方法
    def spatial_pool(self, x):
        batch, channel, height, width = x.size()  # 获取输入张量的形状
        if self.pooling_type == 'att':  # 如果池化类型为 'att'
            input_x = x.view(batch, channel, height * width)  # 展平 H 和 W
            input_x = input_x.unsqueeze(1)  # 增加一个维度
            context_mask = self.conv_mask(x)  # 应用 1x1 卷积层
            context_mask = context_mask.view(batch, 1, height * width)  # 展平 H 和 W
            context_mask = self.softmax(context_mask)  # 在 H * W 上应用 softmax
            context_mask = context_mask.unsqueeze(-1)  # 增加一个维度
            context = torch.matmul(input_x, context_mask)  # 计算加权和
            context = context.view(batch, channel, 1, 1)  # 恢复形状
        else:
            context = self.avg_pool(x)  # 如果池化类型为 'avg',直接应用平均池化

        return context  # 返回上下文张量

    # 定义前向传播方法
    def forward(self, x):
        context = self.spatial_pool(x)  # 获取上下文信息
        out = x  # 初始化输出
        if self.channel_mul_conv is not None:
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))  # 应用通道乘法融合
            out = out * channel_mul_term  # 输出乘以融合结果

        if self.channel_add_conv is not None:
            channel_add_term = self.channel_add_conv(context)  # 应用通道加法融合
            out = out + channel_add_term  # 输出加上融合结果

        return out  # 返回最终输出

# 测试代码块
if __name__ == "__main__":
    in_tensor = torch.ones((1, 64, 128, 128))  # 创建一个全1的输入张量,形状为 (1, 64, 128, 128)
    cb = ContextBlock(inplanes=64, ratio=0.25, pooling_type='att')  # 创建 ContextBlock 实例
    out_tensor = cb(in_tensor)  # 传递输入张量进行前向传播
    print(in_tensor.shape)  # 打印输入张量的形状
    print(out_tensor.shape)  # 打印输出张量的形状

二、添加【GCA】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个GCA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【GCA】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1,1,ContextBlock,[256]]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【GCA】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2200542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL优化 where谓词条件is null优化

1.创建测试表及谓词条件中包含is null模拟语句 create table t641 as select * from dba_objects; set autot trace select SUBOBJECT_NAME,OBJECT_NAME from t641 where OBJECT_NAMEWRI$_OPTSTAT_SYNOPSIS$ and SUBOBJECT_NAME is null; 2.全表扫描逻辑读1237 3.创建等值谓词条…

PE结构之导出表

导出表结构中各种值的意义 ​​​​​​ 根据函数地址表遍历函数名称RVA表,和上面的图是逆过程 //函数地址表 和当前内存中的位置DWORD AddressOfFunctionsFOA RVAToFOA(LPdosHeader, LPexprotDir->AddressOfFunctions);PDWORD LPFunctionsAddressInMemary (PDWORD)((cha…

flask发送邮件

开通邮件IMAP/SMTP服务 以网易邮箱为例 点击开启发送验证后会收到一个密钥,记得保存好 编写代码 安装flask-mail pip install flask-mail在config.py文件中配置邮件信息 MAIL_SERVER:邮件服务器 MAIL_USE_SSL:使用SSL MAIL_PORT&#…

【计算机网络】网络相关技术介绍

文章目录 NAT概述NAT的基本概念NAT的工作原理1. **基本NAT(静态NAT)**2. **动态NAT**3. **NAPT(网络地址端口转换,也称为PAT)** 底层实现原理1. **数据包处理**2. **转换表**3. **超时机制** NAT的优点NAT的缺点总结 P…

Linux:多线程中的生产消费模型

多线程 生产消费模型三种关系两个角色一个交易场所交易场所的实现(阻塞队列)pthread_cond_wait 接口判断阻塞队列的空或满时,需要使用while测试一:单消费单生产案例测试二:多生产多消费案例 生产消费模型 消费者与生产…

鸿蒙网络管理模块05——数据流量统计

如果你也对鸿蒙开发感兴趣,加入“Harmony自习室”吧!扫描下方名片,关注公众号,公众号更新更快,同时也有更多学习资料和技术讨论群。 1、概述 HarmonyOS供了基于物理网络的数据流量统计能力,支持基于网卡/U…

贪心,CF 865B - Ordering Pizza

目录 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 二、解题报告 1、思路分析 2、复杂度 3、代码详解 一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 865B - Ordering Pizza 二、解题报告 1、思路分析 如果我们不考虑披萨数…

读懂MySQL事务隔离

什么是事务 事务就是一组原子性的SQL查询,或者说一个独立的工作单元。事务内的语句,要么全部执行成功,要么全部执行失败。 关于事务银行系统的应用是解释事务必要性的一个经典例子。 假设一个银行的数据库有两张表:支票表&#x…

MySql数据库---存储过程

存储过程概念 MySQL 5.0 版本开始支持存储过程。 简单的说,存储过程就是一组SQL语句集,功能强大,可以实现一些比较复杂的逻辑功能,类似于JAVA语言中的方法,类似Python中的函数; 存储过就是数据库 SQL 语言…

【数据结构】红黑树相关知识详细梳理

1. 红黑树的概念 红黑树,是一种二叉搜索树,但在每个结点上增加一个存储位表示结点的颜色,可以是Red或 Black。 通过对任何一条从根到叶子的路径上各个结点着色方式的限制,红黑树确保没有一条路 径会比其他路径长出俩倍&#xff0c…

大数据行业应用实训室建设方案

摘要: 本文旨在探讨唯众针对当前大数据行业的人才需求,提出的《大数据行业应用实训室建设方案》。该方案旨在构建一个集理论教学、实践操作、技术创新与行业应用于一体的综合实训平台,以培养具备实战能力的大数据专业人才。 一、大数据课程体…

【AI知识点】机器学习中的常用优化算法(梯度下降、SGD、Adam等)

更多AI知识点总结见我的专栏:【AI知识点】 AI论文精读、项目和一些个人思考见我另一专栏:【AI修炼之路】 有什么问题、批评和建议都非常欢迎交流,三人行必有我师焉😁 1. 什么是优化算法? 在机器学习中优化算法&#x…

决策树随机森林-笔记

决策树 1. 什么是决策树? 决策树是一种基于树结构的监督学习算法,适用于分类和回归任务。 根据数据集构建一棵树(二叉树或多叉树)。 先选哪个属性作为向下分裂的依据(越接近根节点越关键)?…

【动态规划-最长递增子序列(LIS)】【hard】力扣1671. 得到山形数组的最少删除次数

我们定义 arr 是 山形数组 当且仅当它满足&#xff1a; arr.length > 3 存在某个下标 i &#xff08;从 0 开始&#xff09; 满足 0 < i < arr.length - 1 且&#xff1a; arr[0] < arr[1] < … < arr[i - 1] < arr[i] arr[i] > arr[i 1] > … &g…

【hot100-java】二叉搜索树中第 K 小的元素

二叉树 二叉搜索树的中序遍历是递增序列。 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode lef…

【C++】面向对象之继承

不要否定过去&#xff0c;也不要用过去牵扯未来。不是因为有希望才去努力&#xff0c;而是努力了&#xff0c;才能看到希望。&#x1f493;&#x1f493;&#x1f493; 目录 ✨说在前面 &#x1f34b;知识点一&#xff1a;继承的概念及定义 •&#x1f330;1.继承的概念 •&…

ECCV24高分论文:MVSplat稀疏视图下的高效的前馈三维重建模型

目录 一、概述 二、相关工作 1、稀疏视角场景重建 2、前馈NeRF 3、前馈3DGS 4、多视角立体视觉 三、MVSplat 1、多视角Transformer 一、概述 本文提出了一个MVSplat高效的前馈三维重建模型&#xff0c;可以从稀疏的多视图图像中预测3D高斯分布&#xff0c;并且相较于p…

三角形面积 python

题目&#xff1a; 计算三角形面积 代码&#xff1a; a int(input("请输入三角形的第一个边长&#xff1a;")) b int(input("请输入三角形的第二个边长&#xff1a;")) c int(input("请输入三角形的第三个边长&#xff1a;")) s (abc) / 2 #…

我谈均值平滑模板——给均值平滑模板上升理论高度

均值平滑&#xff08;Mean Smoothing&#xff09;&#xff0c;也称为盒状滤波&#xff08;Box Filter&#xff09;&#xff0c;通过计算一个像素及其周围像素的平均值来替换该像素的原始值&#xff0c;从而达到平滑图像的效果。 均值平滑通常使用一个模板&#xff08;或称为卷…

ISCC认证是什么?ISCC认证的申请流程有哪些注意事项?

ISCC认证&#xff0c;即国际可持续发展与碳认证&#xff08;International Sustainability & Carbon Certification&#xff09;&#xff0c;是一个全球通用的可持续发展认证体系。以下是对ISCC认证的详细介绍&#xff1a; 一、起源与背景 ISCC认证体系起源于德国&#x…