YOLOv10改进策略【注意力机制篇】| CVPR2024 CAA上下文锚点注意力机制

news2024/10/9 9:54:53

一、本文介绍

本文记录的是基于CAA注意力模块的YOLOv10目标检测改进方法研究在远程遥感图像或其他大尺度变化的图像中目标检测任务中,为准确提取其长距离上下文信息,需要解决大目标尺度变化和多样上下文信息时的不足的问题CAA能够有效捕捉长距离依赖,并且参数量和计算量更少。

文章目录

  • 一、本文介绍
  • 二、CAA原理
    • 2.1 原理
    • 2.2 优势
  • 三、CAA的实现代码
  • 四、创新模块
    • 4.1 改进点⭐
  • 五、添加步骤
    • 5.1 修改ultralytics/nn/modules/block.py
    • 5.2 修改ultralytics/nn/modules/__init__.py
    • 5.3 修改ultralytics/nn/modules/tasks.py
  • 六、yaml模型文件
    • 6.1 模型改进版本⭐
  • 六、成功运行结果


二、CAA原理

Poly Kernel Inception Network for Remote Sensing Detection

CAA(Context Anchor Attention)注意力的设计原理和优势如下:

2.1 原理

  • 采用平均池化1×1卷积来获取局部区域特征:对输入特征进行平均池化,然后通过1×1卷积得到局部区域特征。
  • 使用深度可分离的条形卷积来近似标准大核深度可分离卷积:通过两个深度可分离的条形卷积来扩大感受野,并且这种设计基于两个考虑。首先,条形卷积是轻量级的,与传统的大核2D深度可分离卷积相比,使用几个1D深度可分离核可以达到类似的效果,同时参数减少了 k b / 2 kb/2 kb/2。其次,条形卷积有助于识别和提取细长形状物体(如桥梁)的特征。
  • 随着CAA模块所属的PKI块深度增加,增大条形卷积的核大小( k b = 11 + 2 × l kb = 11 + 2×l kb=11+2×l),以增强PKINet建立长距离像素间关系的能力,同时由于条形深度可分离设计,不会显著增加计算成本。
  • 最后,CAA模块产生一个注意力权重,用于增强PKI模块的输出特征。具体来说,通过Sigmoid函数确保注意力图在范围 ( 0 , 1 ) (0, 1) (0,1)内,然后通过元素点乘和元素求和操作来增强特征。

在这里插入图片描述

2.2 优势

  • 有效捕捉长距离依赖:通过合适的核大小设置,能够更好地捕捉长距离像素间的依赖关系,相比于较小核大小的情况,能提升模型性能,因为较小核无法有效捕获长距离依赖,而较大核可以包含更多上下文信息。
  • 轻量化:条形卷积的设计使得CAA模块具有轻量化的特点,减少了参数数量和计算量。
  • 增强特征提取:当在PKINet的任何阶段使用CAA模块时,都能带来性能提升,当在所有阶段部署CAA模块时,性能增益达到 1.03 % 1.03\% 1.03%,这表明CAA模块能够有效地增强模型对特征的提取能力。

论文:https://arxiv.org/pdf/2403.06258
源码:https://github.com/NUST-Machine-Intelligence-Laboratory/PKINet

三、CAA的实现代码

CAA模块的实现代码如下:

from mmcv.cnn import ConvModule
from mmengine.model import BaseModule

class CAA(BaseModule):
    """Context Anchor Attention"""
    def __init__(
            self,
            channels: int,
            h_kernel_size: int = 11,
            v_kernel_size: int = 11,
            norm_cfg: Optional[dict] = dict(type='BN', momentum=0.03, eps=0.001),
            act_cfg: Optional[dict] = dict(type='SiLU'),
            init_cfg: Optional[dict] = None,
    ):
        super().__init__(init_cfg)
        self.avg_pool = nn.AvgPool2d(7, 1, 3)
        self.conv1 = ConvModule(channels, channels, 1, 1, 0,
                                norm_cfg=norm_cfg, act_cfg=act_cfg)
        self.h_conv = ConvModule(channels, channels, (1, h_kernel_size), 1,
                                 (0, h_kernel_size // 2), groups=channels,
                                 norm_cfg=None, act_cfg=None)
        self.v_conv = ConvModule(channels, channels, (v_kernel_size, 1), 1,
                                 (v_kernel_size // 2, 0), groups=channels,
                                 norm_cfg=None, act_cfg=None)
        self.conv2 = ConvModule(channels, channels, 1, 1, 0,
                                norm_cfg=norm_cfg, act_cfg=act_cfg)
        self.act = nn.Sigmoid()

    def forward(self, x):
        attn_factor = self.act(self.conv2(self.v_conv(self.h_conv(self.conv1(self.avg_pool(x))))))
        return attn_factor

四、创新模块

4.1 改进点⭐

模块改进方法
1️⃣ 加入CAA模块CAA模块添加后如下:

在这里插入图片描述

注意❗:在5.2和5.3小节中需要声明的模块名称为:CAA

2️⃣:加入基于CAA模块C2f。利用CAA改进C2f模块,使模型能够更好地捕捉长距离像素间的依赖关系。

改进代码如下:

class C2f_CAA(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""

    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
        self.att = CAA(c2)

    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.att(self.cv2(torch.cat(y, 1)))

    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.att(self.cv2(torch.cat(y, 1)))

在这里插入图片描述

注意❗:在5.2和5.3小节中需要声明的模块名称为:C2f_CAA


五、添加步骤

5.1 修改ultralytics/nn/modules/block.py

此处需要修改的文件是ultralytics/nn/modules/block.py

block.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

CAAC2f_CAA模块代码添加到此文件下。

5.2 修改ultralytics/nn/modules/init.py

此处需要修改的文件是ultralytics/nn/modules/__init__.py

__init__.py文件中定义了所有模块的初始化,我们只需要将block.py中的新的模块命添加到对应的函数即可。

CAAC2f_CAAblock.py中实现,所有要添加在from .block import

from .block import (
    C1,
    C2,
    ...
    CAA,
    C2f_CAA
)

在这里插入图片描述

5.3 修改ultralytics/nn/modules/tasks.py

tasks.py文件中,需要在两处位置添加各模块类名称。

首先:在函数声明中引入CAAC2f_CAA

在这里插入图片描述

在这里插入图片描述

其次:在parse_model函数中注册CAAC2f_CAA模块

在这里插入图片描述

在这里插入图片描述


六、yaml模型文件

6.1 模型改进版本⭐

此处以ultralytics/cfg/models/v10/yolov10m.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov10m-C2f_CAA.yaml

yolov10m.yaml中的内容复制到yolov10m-C2f_CAA.yaml文件下,修改nc数量等于自己数据中目标的数量。

📌 模型的修改方法是将骨干网络中的所有C2f模块替换成C2f_CAA模块

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f_CAA, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f_CAA, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f_CAA, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2fCIB, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 1, PSA, [1024]] # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)

  - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)




六、成功运行结果

分别打印网络模型可以看到C2f_CAA已经加入到模型中,并可以进行训练了。

yolov10m-C2f_CAA

                   from  n    params  module                                       arguments                     
  0                  -1  1      1392  ultralytics.nn.modules.conv.Conv             [3, 48, 3, 2]                 
  1                  -1  1     41664  ultralytics.nn.modules.conv.Conv             [48, 96, 3, 2]                
  2                  -1  2    172416  ultralytics.nn.modules.block.C2f_CAA         [96, 96, True]                
  3                  -1  1    166272  ultralytics.nn.modules.conv.Conv             [96, 192, 3, 2]               
  4                  -1  4   1353216  ultralytics.nn.modules.block.C2f_CAA         [192, 192, True]              
  5                  -1  1     78720  ultralytics.nn.modules.block.SCDown          [192, 384, 3, 2]              
  6                  -1  4   5360640  ultralytics.nn.modules.block.C2f_CAA         [384, 384, True]              
  7                  -1  1    228672  ultralytics.nn.modules.block.SCDown          [384, 576, 3, 2]              
  8                  -1  2   1689984  ultralytics.nn.modules.block.C2fCIB          [576, 576, 2, True]           
  9                  -1  1    831168  ultralytics.nn.modules.block.SPPF            [576, 576, 5]                 
 10                  -1  1   1253088  ultralytics.nn.modules.block.PSA             [576, 576]                    
 11                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 12             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 13                  -1  2   1993728  ultralytics.nn.modules.block.C2f             [960, 384, 2]                 
 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 15             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  2    517632  ultralytics.nn.modules.block.C2f             [576, 192, 2]                 
 17                  -1  1    332160  ultralytics.nn.modules.conv.Conv             [192, 192, 3, 2]              
 18            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 19                  -1  2    831744  ultralytics.nn.modules.block.C2fCIB          [576, 384, 2, True]           
 20                  -1  1    152448  ultralytics.nn.modules.block.SCDown          [384, 384, 3, 2]              
 21            [-1, 10]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  2   1911168  ultralytics.nn.modules.block.C2fCIB          [960, 576, 2, True]           
 23        [16, 19, 22]  1   2282134  ultralytics.nn.modules.head.v10Detect        [1, [192, 384, 576]]          
YOLOv10m-C2f_CAA summary: 707 layers, 19198246 parameters, 19198230 gradients, 80.9 GFLOPs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2198714.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spark算子使用-Map,FlatMap,Filter,diatinct,groupBy,sortBy

目录 Map算子使用 FlatMap算子使用 Filter算子使用-数据过滤 Distinct算子使用-数据去重 groupBy算子使用-数据分组 sortBy算子使用-数据排序 Map算子使用 # map算子主要使用长场景,一个转化rdd中每个元素的数据类型,拼接rdd中的元素数据&#xf…

NUKE 15有哪些新的改进功能?影视后期特效合成NUKE 15 安装包分享 【Mac/win】

Nuke 15是一款由英国The Foundry公司开发的专业的合成软件,被广泛用于电影、电视和广告制作中的后期合成和特效制作。 Nuke 15拥有强大的功能和灵活性,可以帮助用户处理各种复杂的合成任务,包括图像修复、色彩校正以及粒子特效等。它具备高效…

sql注入第8关

手工注入麻烦 目录 判断闭合方式 判断注入类型 手工注入 1、获取数据库名 2、爆破数据库的名字(security) 3、爆破表的数量 4、判断表名的长度 5、判断表的列名数量 6、判断表的列名的名字 7、获取表的数据 8、判断数据的长度 9、判断数据的…

在 Hugging Face MTEB 排行榜上比较 ELSER 的检索相关性

作者:来自 Elastic Aris Papadopoulos 及 Serena Chou 本博客对 ELSER 在 Hugging Face MTEB 排行榜上的检索相关性进行了比较。 在 Hugging Face MTEB 排行榜上比较 ELSER 的检索相关性 ELSER(Elastic Learned Sparse EncodeR)是 Elastic …

WMS 智慧仓储管理系统的可视化管理_SunWMS

【大家好,我是唐Sun,唐Sun的唐,唐Sun的Sun。一站式数智工厂解决方案服务商】 WMS 智慧仓储管理系统的可视化管理主要表现在以下几个方面: 首先是库存可视化。通过系统,仓库管理人员能够以直观的图表、图形等形式清晰地…

pdf怎么加密码怎么设置密码?这几种pdf设置密码的方法简单!

pdf怎么加密码怎么设置密码?PDF格式作为现代办公和学习中频繁使用的文档类型,其身影遍布于各类场景,然而,在享受PDF带来的便利之余,不少用户对其安全性产生了疑虑,尽管PDF文件相较于其他格式更难被直接编辑…

如何查看是否是ip转发?

一、什么是ip转发 ip转发指的是路由器或者其他网络设备把接受的ip数据包从一个接口转发到另一个ip的过程。在ip转发的过程中,如果某个设备接收到某个数据包时发现该设备不是此数据包的最终目的地,它就会根据路由表中的信息将此数据包转发到下一个适合的…

10.8摩尔学习知识点

今天学习获取数据 在摩尔云平台找到要修改的主视图,然后点击操作功能,点击新增,直接输入名字获取数据,然后,显示顺序15,显示是,点击确定,然后就是自定义类上面输入创建的类名&#…

006集—— CAD锁文档的用法(CAD—C#二次开发入门)

CAD 二开中,当要在除当前文档外的其它文档的模型空间或图纸空间中添加图元时,需要先锁定其文档。用户可用要锁定的Document对象的LockDocument方法进行锁定。在调用LockDocument方法后,将返回一个DocumentLock对象。 本例创建一个新的文档然…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《面向电动汽车用户的电价套餐模块化设计 》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

数学建模算法与应用 第1章 线性规划

第1章 线性规划 线性规划是数学规划领域的重要分支,广泛应用于资源配置、生产计划、物流管理等领域。它主要用于解决如何在满足一定约束条件下,使目标函数(如成本、利润等)达到最大或最小的问题。第一章将介绍线性规划的基本概念…

点可云ERP进销存V8版本——其他支出单使用说明

其他支出单用于记录除采购内容外其支出资金,如:人工运输费、安装维修服务、差旅报销等。新增保存之后,对应资金账户将减少金额额度,并做存储记录,可在现金银行报表中体现。 新增操作 接下来我们讲解新增单据步骤。如上…

【CSS】flex配合margin实现元素均匀分布

现有代码如下&#xff0c;要求不使用网格布局&#xff0c;根据剩余空间设置margin <div className"container">{Array.from({ length: 12 }, (_, i) > i).map((item) > (<div className"box">{item}</div>))} </div>.conta…

《CUDA编程》6.CUDA的内存组织

前面几章讲了一些编写高性能CUDA程序的要点&#xff0c;但还有很多其他需要注意的&#xff0c;其中最重要的就是合理的使用设备内存 1 CUDA的内存组织简介 现代计算机中的内存存在一种组织结构(hierachy)&#xff0c;即不同类型的内存具有不同的容量和访问延迟&#xff08;可以…

从新开始,轻松搭建陪玩系统!线下线上陪玩平台搭建系统,选购线下线上陪玩小程序APP系统时,这点不能忽视!

在搭建线下线上陪玩平台系统&#xff0c;以及选购线下线上陪玩小程序APP系统时&#xff0c;以下几点是至关重要的&#xff0c;不容忽视&#xff1a; 一、明确需求与规划 目标用户定位&#xff1a; 确定陪玩系统的目标用户群体&#xff0c;如游戏玩家、技能服务需求者等。 功能…

使用C# winform 开发一个任务管理器

前言 为啥要开发这个呢 ,系统自带的关闭有些程序就关不了,它有好多线程,你关一其中一个它后台又重新开了一个,关不完,使用我这个呢 就把所有相同名称进程看作一个,一关就关 下载软件 v1 Form1.cs using System; using System.Windows.Forms;namespace TaskMaster {public pa…

learn C++ NO.21——AVL树

简单介绍一下AVL树 AVL树是一种自平衡的二叉搜索树&#xff08;Balanced Binary Search Tree, BBST&#xff09;&#xff0c;由俄罗斯数学家G. M. Adelson-Velsky和E. M. Landis在1962年发明&#xff0c;因此以其名字首字母命名。AVL树通过保持任何节点的两个子树的高度最大差…

养生健康:从日常细节中寻觅长寿之钥

养生健康&#xff1a;从日常细节中寻觅长寿之钥 在这个快节奏的时代&#xff0c;健康似乎成了一种奢侈品&#xff0c;但实则不然。养生之道&#xff0c;不在于繁复的仪式&#xff0c;而在于融入日常的点点滴滴。今天&#xff0c;就让我们一起探讨几个简单却至关重要的养生习惯…

N1从安卓盒子刷成armbian

Release Armbian_noble_save_2024.10 ophub/amlogic-s9xxx-armbian (github.com) armbian下载&#xff0c;这里要选择905d adb 下载地址 https://dl.google.com/android/repository/platform-tools-latest-windows.zip 提示信息 恩山无线论坛 使用usb image tool restet a…

Java项目实战II基于Java+Spring Boot+MySQL的高校学科竞赛平台

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 随着高等教…