注意力机制篇 | YOLO11改进 | 即插即用的高效多尺度注意力模块EMA

news2025/1/6 19:53:41

前言:Hello大家好,我是小哥谈。与传统的注意力机制相比,多尺度注意力机制引入了多个尺度的注意力权重,让模型能够更好地理解和处理复杂数据。这种机制通过在不同尺度上捕捉输入数据的特征,让模型同时关注局部细节和全局结构,以提高对细节和上下文信息的理解,达到提升模型的表达能力、泛化性、鲁棒性和定位精度,优化资源使用效率的效果。🌈

     目录

🚀1.基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:创建EMA.py文件

🍀🍀步骤2:修改tasks.py文件

🍀🍀步骤3:创建自定义yaml文件

🍀🍀步骤4:新建train.py文件

🚀1.基础概念

这篇论文的作者旨在通过提出一种高效多尺度注意力模块 (Efficient Multi-Scale Attention Module, EMA) 来改进现有的注意力机制。他们的核心思想是:在不进行通道降维的情况下,通过分组和多尺度并行子网络来有效地捕捉全局和局部的空间依赖关系。此外,他们通过跨空间学习方法,将全局与局部特征进行融合,以提高像素级的配对关系捕捉能力,增强模型对复杂视觉任务(如图像分类和目标检测)的表现,同时保持较低的计算开销。这种设计不仅提高了模型的准确性,还提升了其计算效率。

创新点:

  • 多尺度并行子网络设计论文提出了多尺度的并行子网络结构,用于捕捉图像中的短程和长程依赖关系。这个设计使得模型可以在多个尺度上学习更丰富的特征表示。

  • 通道维度的重组与传统的通道降维方式不同,该方法通过将部分通道重组到批次维度中,避免了通道降维带来的信息损失,从而保持每个通道的完整信息。

  • 跨空间学习方法创新性地提出了跨空间学习方法,用来融合并行子网络的输出。该方法通过捕捉像素级的成对关系,强化了全局上下文信息的表达,有助于提升特征的聚合效果。

  • 高效注意力机制与CBAM、坐标注意力等其他注意力机制相比,EMA模块在使用较少参数的情况下,显著提高了图像分类和目标检测任务的性能,并降低了计算开销。

整体结构:

EMA模型通过将输入特征按通道维度分组,并采用两个并行分支:一个分支使用1D全局池化和1×1卷积处理全局信息,另一个分支使用3×3卷积捕捉局部空间信息。两分支的输出通过矩阵乘法融合,生成注意力图,并与输入特征结合,最终提升模型对全局和局部信息的捕捉能力,同时降低计算复杂度。

论文题目:《Efficient Multi-Scale Attention Module with Cross-Spatial Learning》 

论文地址:  https://arxiv.org/abs/2305.13563v2

代码实现:  https://github.com/YOLOonMe/EMA-attention-module


🚀2.网络结构

本文的改进是基于YOLO11,关于其网络结构具体如下图所示:

本文所做的改进是在YOLO11的网络结构中加入EMA注意力机制关于改进后的网络结构图具体如下图所示:


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:创建EMA.py新文件

步骤2:修改tasks.py文件

步骤3:创建自定义yaml文件

步骤4:新建train.py文件


🚀4.改进方法

🍀🍀步骤1:创建EMA.py文件

在目录:ultralytics/nn/modules文件下创建EMA.py文件,该文件代码如下:

import torch
from torch import nn
# By CSDN 小哥谈

class EMA_attention(nn.Module):
    def __init__(self, channels, factor=8):
        super(EMA_attention, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)
🍀🍀步骤2:修改tasks.py文件

首先,找到parse_model函数(935行左右),在下图所示位置加入EMA_attention

关于所加位置如下图所示:

然后,在该文件头部导入代码:

from ultralytics.nn.modules.EMA import EMA_attention
🍀🍀步骤3:创建自定义yaml文件

在目录:ultralytics/cfg/models/11下创建yolo11_EMA.yaml文件,该文件代码如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# By CSDN 小哥谈

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [16, 1, EMA_attention, [256]] # 23
  - [19, 1, EMA_attention, [512]] # 24
  - [22, 1, EMA_attention, [1024]] # 25

  - [[23, 24, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)
🍀🍀步骤4:新建train.py文件

在根目录下,新建train.py文件,该文件代码如下:

# -*- coding: utf-8 -*-
# By CSDN 小哥谈
import warnings
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    # model.load('yolo11n.pt') 
    model = YOLO(model=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\11\yolo11_EMA.yaml')
    model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml',
                imgsz=640,
                epochs=50,
                batch=4,
                workers=0,
                device='',
                optimizer='SGD',
                close_mosaic=10,
                resume=False,
                project='runs/train',
                name='exp',
                single_cls=False,
                cache=False,
                )

点击“运行”,代码可以正常运行。

关于其他添加位置:

添加位置1:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# By CSDN 小哥谈

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10
  - [-1, 1, EMA_attention, [1024]] # 11

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

 添加位置2:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# By CSDN 小哥谈

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, EMA_attention, [256]] # 17

  - [-1, 1, Conv, [256, 3, 2]]  # 18
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)
  - [-1, 1, EMA_attention, [512]] # 21

  - [-1, 1, Conv, [512, 3, 2]] #22
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 24 (P5/32-large)
  - [-1, 1, EMA_attention, [1024]] # 25

  - [[17, 21, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2238043.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

dell服务器安装ESXI8

1.下载镜像在官网 2.打开ipmi(idrac),将esxi镜像挂载,然后服务器开机 3.进入bios设置cpu虚拟化开启,进入boot设置启动选项为映像方式 4..进入安装引导界面3.加载完配置进入安装 系统提示点击继 5.选择安装磁盘进行…

Linux -- 操作系统(软件)

目录 什么是操作系统? 计算机的层状结构 为什么要有操作系统 操作系统到底层硬件 驱动程序 操作系统如何管理硬件? 操作系统到用户 系统调用接口 库函数 回到问题 什么是操作系统? 操作系统(Operating System&#xf…

【大数据算法】MapReduce算法概述之:MapReduce基础模型

MapReduce基础模型 1、引言2、MapReduce基础模型2.1 定义2.2 核心原理2.3 优点2.4 缺点2.5 局限性2.6 实例 3、总结 1、引言 小屌丝:鱼哥,鱼哥, 不得了啊 小鱼:啥事情这么慌慌张张的 小屌丝:这不是慌张啊 小鱼&#x…

深入解析 Transformers 框架(四):Qwen2.5/GPT 分词流程与 BPE 分词算法技术细节详解

前面我们已经通过三篇文章,详细介绍了 Qwen2.5 大语言模型在 Transformers 框架中的技术细节,包括包和对象加载、模型初始化和分词器技术细节: 深入解析 Transformers 框架(一):包和对象加载中的设计巧思与…

商品详情 API 接口的返回结果通常包含哪些信息?

商品详情 API 接口的返回结果通常包含以下几类信息: 一、商品基本信息: 商品 ID:唯一标识商品的编号,在电商平台的数据库中具有唯一性,用于区分不同的商品。商品标题:对商品的简要描述,通常包…

探索 Seata 分布式事务

Seata(Simple Extensible Autonomous Transaction Architecture)是阿里巴巴开源的一款分布式事务解决方案,旨在帮助开发者解决微服务架构下的分布式事务问题。它提供了高效且易于使用的分布式事务管理能力,支持多种事务模式&#…

AI写作(七)的核心技术探秘:情感分析与观点挖掘

一、AI 写作中的关键技术概述 情感分析与观点挖掘在 AI 写作中起着至关重要的作用。情感分析能够帮助 AI 理解文本中的情感倾向,无论是正面、负面还是中性。在当今信息时代,准确把握用户情绪对于提供个性化体验和做出明智决策至关重要。例如,…

容器化技术入门:Docker详解

💓 博客主页:瑕疵的CSDN主页 📝 Gitee主页:瑕疵的gitee主页 ⏩ 文章专栏:《热点资讯》 容器化技术入门:Docker详解 容器化技术入门:Docker详解 容器化技术入门:Docker详解 引言 Doc…

Flutter运行App时出现“Running Gradle task ‘assembleDebug“问题解决

在参考了众多解决办法中最有用并且最快的方法 Gradle Distributions 在这个地方下载对应你这个文件中的gradle版本 然后将 最后一行本来不是这样的,我们把下载好的zip包保存到本地,然后用这个代替网址,最后成功运行

Spark中的shuffle

Shuffle的本质基于磁盘划分来解决分布式大数据量的全局分组、全局排序、重新分区【增大】的问题。 1、Spark的Shuffle设计 Spark Shuffle过程也叫作宽依赖过程,Spark不完全依赖于内存计算,面临以上问题时,也需要Shuffle过程。 2、Spark中哪…

window11安装elasticsearch+Kibana

1、下载elasticsearch与elasticsearch 下载elasticsearch 查看elasticsearch对应的Kibana版本 下载elasticsearch解压后文件目录如下 可执行脚本文件,包括启动elasticsearch服务、插件管理、函数命令等 bin配置文件目录,如elasticsearch配置、角色配置、jvm配置等 conf 默认…

linux rocky 9.4部署和管理docker harbor私有源

文章目录 Harbor简介安装Harbor技术细节1.安装系统(略),设置主机名和IP2.安装docker3.安装docker-compose4.安装Harbor私有源仓库5 测试登录1.本机登录2.客户端登录Harbor服务器配置docker源1. 下载镜像2.把镜像上传到Harbor私有仓库源3.客户端下载镜像,并且启动容器linux …

【计算机网络五】HTTP协议!网站运行的奥秘!

目录 HTTP协议 1.HTTP是什么? 2.Fiddler抓包查看HTTP协议格式 3.HTTP请求 4.HTTP响应 HTTP协议 1.HTTP是什么? HTTP ( 全称为 " 超文本传输协议 ") 诞生与 1991 年 . 目前已经发展为最主流使用的一种应用层协议 . HTTP 的前几个版本…

嵌入式硬件实战基础篇(一)-STM32+DAC0832 可调信号发生器-产生方波-三角波-正弦波

引言:本内容主要用作于学习巩固嵌入式硬件内容知识,用于想提升下述能力,针对学习STM32与DAC0832产生波形以及波形转换,对于硬件的降压和对于前面硬件篇的实际运用,针对仿真的使用,具体如下: 设…

ubuntu 24.04运行chattts时cuda安装错误原因分析

使用ubuntu 24.04,按照2noise/ChatTTS官方流程安装依赖时报错。ChatTTShttps://github.com/2noise/ChatTTS 这是因为cuda版本不对,ChatTTS目前的版本,要求支持cuda 12.4及以上,但是如果nvidia显卡驱动版本较老,无法支…

【动态规划】斐波那契数列模型总结

一、第 N 个泰波那契数 题目链接: 第 N 个泰波那契数 题目描述: 题目分析: 1、状态表示: dp[i] 表示:第 i 个斐波那契数的值 2、状态转移方程: 由题意可知第 i 个数等于其前三个数之和 dp[i] dp[i-…

2024 第五次周赛

A: 直接遍历即可 #include<bits/stdc.h> using namespace std;typedef long long ll; typedef pair<ll, ll>PII; const int N 2e6 10; const int MOD 998244353; const int INF 0X3F3F3F3F;int n, m; int main() {cin >> n;int cnt 0;for(int i 0; i …

【数据库系列】postgresql链接详解

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Uniapp底部导航栏设置(附带PS填充图标教程)

首先需要注册和登录ifconfont官网&#xff0c;然后创建项目添加需要的图标 创建和添加图标库请参考&#xff1a;Uniapp在Vue环境中引入iconfont图标库&#xff08;详细教程&#xff09; 打开iconfont官网&#xff0c;找到之前添加的图标库&#xff0c;下载png图片 如果需要的…

提取神经网络数学表达式

来自《老饼讲解神经网络》 www..bbbdata.com 当我们在matlab训练好网络后&#xff0c;可以使用神经网络工具箱的sim(net,x)函数进行预测输出。但往往想提取出它的数学表达式&#xff0c;该怎么提取呢&#xff1f; 下面以《一个简单的神经网络例子》中的模型为例&#xff0c;提取…