YOLOv11改进 | 注意力篇 | YOLOv11引入MSDA多尺度空洞注意力

news2024/11/23 13:17:23

1. MSDA介绍

1.1  摘要:作为事实上的解决方案,鼓励使用普通视觉变换器(ViT)对任意图像块之间的远程依赖性进行建模,而全局参与感受野会导致二次计算成本。 Vision Transformers 的另一个分支利用了受 CNN 启发的局部注意力,它只模拟小邻域中斑块之间的相互作用。 尽管这样的解决方案降低了计算成本,但它自然会受到较小的参与感受野的影响,这可能会限制性能。 在这项工作中,我们探索有效的视觉变压器,以在计算复杂性和参与感受野的大小之间寻求更好的权衡。 通过分析 ViT 中全局注意力的补丁交互,我们观察到浅层中的两个关键属性,即局部性和稀疏性,表明 ViT 浅层中全局依赖建模的冗余。 因此,我们提出多尺度扩张注意力(MSDA)来模拟滑动窗口内的局部和稀疏补丁交互。 通过金字塔架构,我们通过在低级阶段堆叠 MSDA 块和在高层阶段堆叠全局多头自注意力块来构建多尺度扩张变压器(DilateFormer)。 我们的实验结果表明,我们的 DilateFormer 在各种视觉任务上实现了最先进的性能。 在 ImageNet-1K 分类任务中,DilateFormer 实现了与现有最​​先进模型相当的性能,但 FLOP 数减少了 70%。 我们的 DilateFormer-Base 在 ImageNet-1K 分类任务上实现了 85.6% 的 top-1 准确率,在 COCO 对象检测/实例分割任务上实现了 53.5% 的框 mAP/46.1% 的掩模 mAP,在 ADE20K 语义分割任务上实现了 51.1% 的 MS mIoU。

官方论文地址:https://arxiv.org/pdf/2302.01791

官方代码地址:https://github.com/JIAOJIAYUASD/dilateformer

1.2  简单介绍:  

          Multi-Scale Dilated Attention(MSDA)模块是DilateFormer中的核心组成部分,用于处理视觉任务中的长距离依赖关系。这一模块的设计灵感来源于自然语言处理领域,借鉴了序列模型中的成功经验。MSDA通过使用滑动窗口的方式来执行自注意力机制,与传统的全局注意力相比,它能够以较低的计算成本来模拟图像块之间的交互作用。

          在DilateFormer架构中,MSDA模块被集成在浅层阶段,与深层阶段的多头自注意力(MHSA)模块形成对比。MSDA利用不同的膨胀率(dilation rates)为不同的头部设置不同的尺度,这使得该模块能够在保持特征图分辨率的同时增加感受野,并有效减少自注意力机制中的冗余。

          具体来说,MSDA模块首先将输入的特征图划分为多个头部,然后在不同的头部中采用不同膨胀率的滑动窗口进行自注意力操作。这种多尺度策略允许模型同时捕获和融合多尺度的语义特征,从而提高了特征表示的丰富性和模型对细节的捕捉能力。

           此外,MSDA模块的设计还考虑了局部性和稀疏性属性,即在浅层的注意力矩阵中,相关图像块通常在查询图像块的邻域内稀疏分布。这一点对于减少计算复杂度和内存使用具有重要意义,同时也有助于模型专注于与当前查询最相关的信息。

1.3  MSDA模块结构图

2. 核心代码

import torch
import torch.nn as nn
 
__all__ = ['MultiDilatelocalAttention']
 
 
class DilateAttention(nn.Module):
    "Implementation of Dilate-attention"
 
    def __init__(self, head_dim, qk_scale=None, attn_drop=0, kernel_size=3, dilation=1):
        super().__init__()
        self.head_dim = head_dim
        self.scale = qk_scale or head_dim ** -0.5
        self.kernel_size = kernel_size
        self.unfold = nn.Unfold(kernel_size, dilation, dilation * (kernel_size - 1) // 2, 1)
        self.attn_drop = nn.Dropout(attn_drop)
 
    def forward(self, q, k, v):
        # B, C//3, H, W
        B, d, H, W = q.shape
        q = q.reshape([B, d // self.head_dim, self.head_dim, 1, H * W]).permute(0, 1, 4, 3, 2)  # B,h,N,1,d
        k = self.unfold(k).reshape(
            [B, d // self.head_dim, self.head_dim, self.kernel_size * self.kernel_size, H * W]).permute(0, 1, 4, 2,
                                                                                                        3)  # B,h,N,d,k*k
        attn = (q @ k) * self.scale  # B,h,N,1,k*k
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        v = self.unfold(v).reshape(
            [B, d // self.head_dim, self.head_dim, self.kernel_size * self.kernel_size, H * W]).permute(0, 1, 4, 3,
                                                                                                        2)  # B,h,N,k*k,d
        x = (attn @ v).transpose(1, 2).reshape(B, H, W, d)
        return x
 
 
class MultiDilatelocalAttention(nn.Module):
    "Implementation of Dilate-attention"
 
    def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None,
                 attn_drop=0., proj_drop=0., kernel_size=3, dilation=[1, 2, 3, 4]):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.dilation = dilation
        self.kernel_size = kernel_size
        self.scale = qk_scale or head_dim ** -0.5
        self.num_dilation = len(dilation)
        assert num_heads % self.num_dilation == 0, f"num_heads{num_heads} must be the times of num_dilation{self.num_dilation}!!"
        self.qkv = nn.Conv2d(dim, dim * 3, 1, bias=qkv_bias)
        self.dilate_attention = nn.ModuleList(
            [DilateAttention(head_dim, qk_scale, attn_drop, kernel_size, dilation[i])
             for i in range(self.num_dilation)])
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
 
    def forward(self, x):
        B, C, H, W = x.shape
        # x = x.permute(0, 3, 1, 2)# B, C, H, W
        y = x.clone()
        qkv = self.qkv(x).reshape(B, 3, self.num_dilation, C // self.num_dilation, H, W).permute(2, 1, 0, 3, 4, 5)
        # num_dilation,3,B,C//num_dilation,H,W
        y1 = y.reshape(B, self.num_dilation, C // self.num_dilation, H, W).permute(1, 0, 3, 4, 2)
        # num_dilation, B, H, W, C//num_dilation
        for i in range(self.num_dilation):
            y1[i] = self.dilate_attention[i](qkv[i][0], qkv[i][1], qkv[i][2])  # B, H, W,C//num_dilation
        y2 = y1.permute(1, 2, 3, 0, 4).reshape(B, H, W, C)
        y3 = self.proj(y2)
        y4 = self.proj_drop(y3).permute(0, 3, 1, 2)
        return y4
 

3. YOLOv11中添加MSDA

3.1 在ultralytics/nn下新建Extramodule

 3.2 在Extramodule里创建MSDA

在MSDA.py文件里添加给出的MSDA代码

添加完MSDA代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在tasks.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {MultiDilatelocalAttention}:
            c2 = ch[f]
            args = [c2, *args]

4. 新建一个yolo11MSDA.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13
  - [-1, 1, MultiDilatelocalAttention, []]

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, MultiDilatelocalAttention, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  - [-1, 1, MultiDilatelocalAttention, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1, 1, MultiDilatelocalAttention, []]

  - [[17, 21, 26], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5. 模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11MSDA.yaml')
    model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

模型结构打印,成功运行:

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv11有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186226.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Qt】控件概述 (1)

控件概述 1. QWidget核心属性1.1核心属性概述1.2 enable1.3 geometry——窗口坐标1.4 window frame的影响1.4 windowTitle——窗口标题1.5 windowIcon——窗口图标1.6 windowOpacity——透明度设置1.7 cursor——光标设置1.8 font——字体设置1.9 toolTip——鼠标悬停提示设置1…

征程6 工具链常用工具和 API 整理(含新手示例)

1.引言 征程6 工具链目前已经提供了比较丰富的集成化工具和接口来支持模型的移植和量化部署,本帖将整理常用的工具/接口以及使用示例来供大家参考,相信这篇文章会提升大家对 征程6 工具链的使用理解以及效率。 干货满满,欢迎访问 2.hb_con…

Windows 10再次成为Steam上最受欢迎的操作系统 Linux用户比例略有下降

上个月,Valve平台上使用Windows 11的玩家自三年前推出以来首次取代了 Windows 10。 然而,这一变化是短暂的–仅仅一个月后,Windows 10 又回到了第一的位置,但持续时间也可能不会太长。 根据 Valve 的数据,64 位 Window…

【动态规划】完全背包问题

完全背包问题 1. 完全背包 点赞👍👍收藏🌟🌟关注💖💖 你的支持是对我最大的鼓励,我们一起努力吧!😃😃 1. 完全背包 题目链接: DP42 【模板】完全背包 题目分…

微积分-反函数6.4(对数函数的导数)

在本节中,我们将找到对数函数 y log ⁡ b x y \log_b x ylogb​x 和指数函数 y b x y b^x ybx 的导数。我们从自然对数函数 y ln ⁡ x y \ln x ylnx 开始。我们知道它是可导的,因为它是可导函数 y e x y e^x yex 的反函数。 d d x ( ln ⁡ x…

计算机毕业设计 农场投入品运营管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

【JAVA开源】基于Vue和SpringBoot的医院管理系统

本文项目编号 T 062 ,文末自助获取源码 \color{red}{T062,文末自助获取源码} T062,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

Stream流的中间方法

一.Stream流的中间方法 注意1:中间方法,返回新的Stream流,原来的Stream流只能使用一次,建议使用链式编程 注意2:修改Stream流中的数据,不会影响原来集合或者数组中的数据 二.filter filter的主要用法是…

Win10/11电脑怎么折腾都进不去Bios?看这!

前言 这段时间有小伙伴反馈,按照很早之前小白写的进Bios教程,死活进不去Bios。 所以这个教程今天又更新了! 咱们有一说一:有些机器的Bios是真的不知道快捷键是什么的。但按照今天的这篇,我想应该没有人进不去电脑的…

60 序列到序列学习(seq2seq)_by《李沐:动手学深度学习v2》pytorch版

系列文章目录 文章目录 系列文章目录一、理论知识比喻机器翻译Seq2seq编码器-解码器细节训练衡量生成序列的好坏的BLEU(值越大越好)总结 二、代码编码器解码器损失函数训练预测预测序列的评估小结练习 一、理论知识 比喻 seq2seq就像RNN的转录工作一样,非常形象的比…

YOLOv11改进 | 注意力篇 | YOLOv11引入Polarized Self-Attention注意力机制

1. Polarized Self-Attention介绍 1.1 摘要:像素级回归可能是细粒度计算机视觉任务中最常见的问题,例如估计关键点热图和分割掩模。 这些回归问题非常具有挑战性,特别是因为它们需要在低计算开销的情况下对高分辨率输入/输出的长期依赖性进行…

最新BurpSuite2024.9专业中英文开箱即用版下载

1、工具介绍 本版本更新介绍 此版本对 Burp Intruder 进行了重大改进,包括自定义 Bambda HTTP 匹配和替换规则以及对扫描 SOAP 端点的支持。我们还进行了其他改进和错误修复。 Burp Intruder 的精简布局我们对 Burp Intruder 进行了重大升级。现在,您可…

0基础学习CSS(十四)填充

CSS padding(填充) CSS padding(填充)是一个简写属性,定义元素边框与元素内容之间的空间,即上下左右的内边距。 padding(填充) 当元素的 padding(填充)内边距…

深入理解 Solidity 中的支付与转账:安全高效的资金管理攻略

在 Solidity 中,支付和转账是非常常见的操作,尤其是在涉及资金的合约中,比如拍卖、众筹、托管等。Solidity 提供了几种不同的方式来处理 Ether 转账,包括 transfer、send 和 call,每种方式的安全性、灵活性和复杂度各有…

【通配符】粗浅学习

1 背景说明 首先要注意,通配符中的符号和正则表达式中的特殊符号具备不同的匹配意义,例如:*在正则表达式中表示里面是指匹配前面的子表达式0次或者多次,而在通配符领域则是表示代表0个到无穷个任意字符。 此外,要注意…

大学城就餐推荐系统小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,餐厅信息管理,美食类型管理,餐厅美食管理,评价信息管理,系统管理 微信端账号功能包括:系统首页,餐厅信息&a…

JavaScript for循环语句

for循环 循环语句用于重复执行某个操作,for语句就是循环命令,可以指定循环的起点、终点和终止条件。它的格式如下 for(初始化表达式;条件;迭代因子){语句} for语句后面的括号里面,有三个表达式 初始化表达式(initialize):确定循环变量的初始…

OpenAI开发者大会派礼包:大幅降低模型成本 AI语音加持App

美东时间10月1日周二,OpenAI举行了年度开发者大会DevDay,今年的大会并没有任何重大的产品发布,相比去年大会显得更低调,但OpenAI也为开发者派发了几个大“礼包”,对现有的人工智能(AI)工具和API…

Spring(学习笔记)

<context:annotation-config/>是 Spring 配置文件中的一个标签&#xff0c;用于开启注解配置功能。这个标签可以让 Spring 容器识别并处理使用注解定义的 bean。例如&#xff0c;可以使用 Autowired 注解自动装配 bean&#xff0c;或者使用 Component 注解将类标记为 bea…

四.网络层(上)

目录 4.1网络层功能概述 4.2 SDN基本概念 4.3 路由算法与路由协议 4.3.1什么是路由协议&#xff1f; 4.3.2什么是路由算法&#xff1f; 4.3.3路由算法分类 (1)静态路由算法 (2)动态路由算法 ①全局性 OSPF协议与链路状态算法 ②分散性 RIP协议与距离向量算法 4.3.…