YOLO11改进|注意力机制篇|引入ELA注意力机制

news2025/1/24 5:41:38

在这里插入图片描述

目录

    • 一、【ELA】注意力机制
      • 1.1【ELA】注意力介绍
      • 1.2【ELA】核心代码
    • 二、添加【ELA】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【ELA】注意力机制

1.1【ELA】注意力介绍

在这里插入图片描述

这篇论文的作者通过分析Coordinate Attention(CA) method的局限性,确定了Batch Normalization中泛化能力的缺乏、降维对通道注意力的不利影响以及注意力生成过程的复杂性。为了克服这些挑战,提出了结合一维卷积和Group Normalization特征增强技术。这种方法通过有效地编码两个一维位置特征图,无需降维即可精确定位感兴趣区域,同时允许轻量级实现。与2D卷积相比,1D卷积更适合处理序列信号,并且更轻量、更快。GN与BN相比,展现出可比较的性能和更好的泛化能力。与 CA 类似,ELA 采用strip pooling在空间维度上获取水平和垂直方向的特征向量,保持窄核形状以捕获长程依赖关系,防止不相关区域影响标签预测,从而在各自方向上产生丰富的目标位置特征。ELA 针对每个方向独立处理上述特征向量以获得注意力预测,然后使用点乘操作将其组合在一起,从而确保感兴趣区域的准确位置信息。由下图可以看出ELA相比于CA显得更简单,所以模型也就更轻
在这里插入图片描述

1.2【ELA】核心代码

import torch
import torch.nn as nn


class ELA(nn.Module):
    def __init__(self, in_channels, phi='T'):
        super(ELA, self).__init__()
        '''
        ELA-T 和 ELA-B 设计为轻量级,非常适合网络层数较少或轻量级网络的 CNN 架构
        ELA-B 和 ELA-S 在具有更深结构的网络上表现最佳
        ELA-L 特别适合大型网络。
        '''
        Kernel_size = {'T': 5, 'B': 7, 'S': 5, 'L': 7}[phi]
        groups = {'T': in_channels, 'B': in_channels, 'S': in_channels // 8, 'L': in_channels // 8}[phi]
        num_groups = {'T': 32, 'B': 16, 'S': 16, 'L': 16}[phi]
        pad = Kernel_size // 2
        self.con1 = nn.Conv1d(in_channels, in_channels, kernel_size=Kernel_size, padding=pad, groups=groups, bias=False)
        self.GN = nn.GroupNorm(num_groups, in_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        b, c, h, w = input.size()
        x_h = torch.mean(input, dim=3, keepdim=True).view(b, c, h)
        x_w = torch.mean(input, dim=2, keepdim=True).view(b, c, w)
        x_h = self.con1(x_h)  # [b,c,h]
        x_w = self.con1(x_w)  # [b,c,w]
        x_h = self.sigmoid(self.GN(x_h)).view(b, c, h, 1)  # [b, c, h, 1]
        x_w = self.sigmoid(self.GN(x_w)).view(b, c, 1, w)  # [b, c, 1, w]
        return x_h * x_w * input


if __name__ == "__main__":
    # 创建一个形状为 [batch_size, channels, height, width] 的虚拟输入张量
    input = torch.randn(2, 256, 40, 40)
    ela = ELA(in_channels=256, phi='T')
    output = ela(input)
    print(output.size())

同时大家可以根据下面代码替换自己想要的型号,只需要将phi=''中的字母替换即可
在这里插入图片描述

二、添加【ELA】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个ELA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示

在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【ELA】注意力机制在大目标检测层中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1,1,ELA,[]]


  - [[16, 19, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【ELA】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2192577.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python中的数据可视化:从入门到进阶

数据可视化是数据分析和科学计算中的重要环节,它通过图形化的方式呈现数据,使复杂的统计信息变得直观易懂。Python提供了多种强大的库来支持数据可视化,如Matplotlib、Seaborn、Plotly等。本文将从基础到进阶,详细介绍如何使用这些…

如何构建LSTM神经网络模型

一、了解LSTM 1. 核心思想 首先,LSTM 是 RNN(循环神经网络)的变体。它通过引入细胞状态 C(t) 贯穿于整个网络模型,达到长久记忆的效果,进而解决了 RNN 的长期依赖问题。 2. 思维导图 每个LSTM层次都有三个重要的门结构…

贝尔曼公式

为什么return 非常重要 在选择哪个策略更好的时候,此时需要使用到return,比如下面三个策略的返回值。 策略1: 策略2:策略3:涉及到两个policys path How to calculate return 定义 上图定义了不同的起点下的return value 递推…

优化销售漏斗建立高效潜在客户生成策略的技巧

如何建立有效的潜在客户生成策略?建立有效潜在客户生成策略需要准确定义目标受众,利用内容营销、SEO、社交媒体、邮件营销和定向广告吸引客户,参加行业会议并跟踪分析数据。借助Zoho CRM系统,企业能够更加高效地管理客户信息&…

Windows上 minGW64 编译 libssh2库

下载libssh2库:https://libssh2.org/download/libssh2-1.11.0.zip 继续下载OpenSSL库: https://codeload.github.com/openssl/openssl/zip/refs/heads/OpenSSL_1_0_2-stable

算法讲解—最小生成树(Kruskal 算法)

算法讲解—最小生成树(Kruskal 算法) 简介 根据度娘的解释我们可以知道,最小生成树(Minimum Spanning Tree, MST)就是:一个有 n n n 个结点的连通图的生成树是原图的极小连通子图,且包含原图中的所有 n n n 个结点…

【Diffusion分割】CTS:基于一致性的医学图像分割模型

CTS: A Consistency-Based Medical Image Segmentation Model 摘要: 在医学图像分割任务中,扩散模型已显示出巨大的潜力。然而,主流的扩散模型存在采样次数多、预测结果慢等缺点。最近,作为独立生成网络的一致性模型解决了这一问…

【Python】数据可视化之聚类图

目录 clustermap 主要参数 参考实现 clustermap sns.clustermap是Seaborn库中用于创建聚类热图的函数,该函数能够将数据集中的样本按照相似性进行聚类,并将聚类结果以矩阵的形式展示出来。 sns.clustermap主要用于绘制聚类热图,该热图通…

云计算第四阶段 CLOUD2周目 01-03

国庆假期前,给小伙伴们更行完了云计算CLOUD第一周目的内容,现在为大家更行云计算CLOUD二周目内容,内容涉及K8S组件的添加与使用,K8S集群的搭建。最重要的主体还是资源文件的编写。 (*^▽^*) 环境准备: 主机清单 主机…

CUDNN下载配置

目录 简介 下载 配置 简介 cuDNN(CUDA Deep Neural Network library)是NVIDIA开发的一个深度学习GPU加速库,旨在提供高效、标准化的原语(基本操作)来加速深度学习框架(如TensorFlow、PyTorch等&#xf…

Rust 快速入门(一)

Rust安装信息解释 cargo:Rust的编译管理器、包管理器、通用工具。可以用Cargo启动新的项目,构建和运行程序,并管理代码所依赖的所有外部库。 Rustc:Rust的编译器。通常Cargo会替我们调用此编译器。 Rustdoc:是Rust的…

Java 面向对象设计一口气讲完![]~( ̄▽ ̄)~*(上)

目录 Java 类实例 Java面向对象设计 - Java类实例 null引用类型 访问类的字段的点表示法 字段的默认初始化 Java 访问级别 Java面向对象设计 - Java访问级别 Java 导入 Java面向对象设计 - Java导入 单类型导入声明 按需导入声明 静态导入声明 例子 Java 方法 J…

decltype推导规则

decltype推导规则 当用decltype(e)来获取类型时,编译器将依序判断以下四规则: 1.如果e是一个没有带括号的标记符表达式(id-expression)或者类成员访问表达式,那么decltype(e)就是e所命名的实体的类型。此外,如果e是一个被重载的函…

k8s 之安装metrics-server

作者:程序那点事儿 日期:2024/01/29 18:25 metrics-server可帮助我们查看pod的cpu和内存占用情况 kubectl top po nginx-deploy-56696fbb5-mzsgg # 报错,需要Metrics API 下载 Metrics 解决 wget https://github.com/kubernetes-sigs/metri…

基于auth2的单点登录原理理解

创作背景:基于auth2实现企业门户与业务系统的单点登录跳转。 架构组成:4A统一认证中心,门户系统,业务系统,用户; 实现目标:用户登录门户系统后,可通过点击业务系统菜单&#xff0c…

字符串数学专题

粗心的小可 题目描述 小可非常粗心,打字的时候将手放到了比正确位置偏右的一个位置,因此,Q打成了W,E打成了R,H打成了J等等。键盘如下所示 现在给你若干行小可打字的结果,请你还原成正确的文本。 输入描述…

嵌入式面试八股文(五)·一文带你详细了解程序内存分区中的堆与栈的区别

目录 1. 栈的工作原理 1.1 内存分配 1.2 地址生长方向 1.3 生命周期 2. 堆的工作原理 2.1 动态内存分配 2.1.1 malloc函数 2.1.2 calloc函数 2.1.3 realloc函数 2.1.4 free函数 2.2 生命周期管理 2.3 地址生长方向 3. 堆与栈区别 3.1 管理方式不同…

海南聚广众达电子商务咨询有限公司助力商家业绩飙升

在这个短视频与直播风靡的时代,抖音电商无疑成为了众多商家竞相追逐的新风口。作为电商服务领域的佼佼者,海南聚广众达电子商务咨询有限公司凭借其专业的团队、创新的策略与丰富的实战经验,正引领着一批又一批商家在抖音平台上破浪前行&#…

顺序表及其代码实现

目录 前言1.顺序表1.1 顺序表介绍1.2 顺序表基本操作代码实现 总结 前言 顺序表一般不会用来单独存储数据,但自身的优势,很多时候不得不使用顺序表。 1.顺序表 1.1 顺序表介绍 顺序表是物理结构连续的线性表,支持随机存取(底层…

Leetcode—139. 单词拆分【中等】

2024每日刷题&#xff08;173&#xff09; Leetcode—139. 单词拆分 dp实现代码 class Solution { public:bool wordBreak(string s, vector<string>& wordDict) {int n s.size();unordered_set<string> ust(wordDict.begin(), wordDict.end());vector<b…