YOLO11改进|注意力机制篇|引入MLCA轻量级注意力机制

news2024/11/20 11:39:51

在这里插入图片描述

目录

    • 一、MLCA注意力机制
      • 1.1MLCA注意力介绍
      • 1.2MLCA核心代码
    • 五、添加MLCA注意力机制
      • 5.1STEP1
      • 5.2STEP2
      • 5.3STEP3
      • 5.4STEP4
    • 六、yaml文件与运行
      • 6.1yaml文件
      • 6.2运行成功截图

一、MLCA注意力机制

1.1MLCA注意力介绍

在这里插入图片描述

MLCA(Multi-Level Channel Attention,多级通道注意力)是一种用于提升卷积神经网络(CNN)性能的注意力机制,主要通过在多个层次上捕捉不同通道间的依赖关系,来增强网络对重要特征的关注。MLCA的工作流程:

  • 特征提取: 首先,输入图像经过卷积网络的多层卷积提取出多层次的特征图,这些特征图代表了从不同尺度和不同深度捕捉到的特征信息。
  • 通道权重计算: 在每个层次的特征图上,MLCA 分别计算出每个通道的重要性权重,使用类似 SE 模块的全局池化操作得到全局特征表示,然后通过一系列非线性操作(如全连接、激活函数)生成通道权重。
  • 多级加权: 通过通道权重对各层次的特征图进行加权操作,增强重要特征通道的响应,抑制无关或冗余的特征。
  • 融合与输出: 对不同层次的特征进行融合,形成最终的特征表达,然后将其传递给后续的网络模块(如分类器或检测头)。
    工作流程图如下所示:
    在这里插入图片描述

1.2MLCA核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
 
class MLCA(nn.Module):
    def __init__(self, in_size, local_size=5, gamma=2, b=1, local_weight=0.5):
        super(MLCA, self).__init__()
 
        # ECA 计算方法
        self.local_size = local_size
        self.gamma = gamma
        self.b = b
        t = int(abs(math.log(in_size, 2) + self.b) / self.gamma)  # eca  gamma=2
        k = t if t % 2 else t + 1
 
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.conv_local = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
 
        self.local_weight = local_weight
 
        self.local_arv_pool = nn.AdaptiveAvgPool2d(local_size)
        self.global_arv_pool = nn.AdaptiveAvgPool2d(1)
 
    def forward(self, x):
        local_arv = self.local_arv_pool(x)
        global_arv = self.global_arv_pool(local_arv)
 
        b, c, m, n = x.shape
        b_local, c_local, m_local, n_local = local_arv.shape
 
        # (b,c,local_size,local_size) -> (b,c,local_size*local_size)-> (b,local_size*local_size,c)-> (b,1,local_size*local_size*c)
        temp_local = local_arv.view(b, c_local, -1).transpose(-1, -2).reshape(b, 1, -1)
        temp_global = global_arv.view(b, c, -1).transpose(-1, -2)
 
        y_local = self.conv_local(temp_local)
        y_global = self.conv(temp_global)
 
        # (b,c,local_size,local_size) <- (b,c,local_size*local_size)<-(b,local_size*local_size,c) <- (b,1,local_size*local_size*c)
        y_local_transpose = y_local.reshape(b, self.local_size * self.local_size, c).transpose(-1, -2).view(b, c,
                                                                                                            self.local_size,
                                                                                                            self.local_size)
        # y_global_transpose = y_global.view(b, -1).transpose(-1, -2).unsqueeze(-1)
        y_global_transpose = y_global.view(b, -1).unsqueeze(-1).unsqueeze(-1)  # 代码修正
        # print(y_global_transpose.size())
        # 反池化
        att_local = y_local_transpose.sigmoid()
        att_global = F.adaptive_avg_pool2d(y_global_transpose.sigmoid(), [self.local_size, self.local_size])
        # print(att_local.size())
        # print(att_global.size())
        att_all = F.adaptive_avg_pool2d(att_global * (1 - self.local_weight) + (att_local * self.local_weight), [m, n])
        # print(att_all.size())
        x = x * att_all
        return x
 
 
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p
 
 
class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
    default_act = nn.SiLU()  # default activation
 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
 
    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))
 
 
class C2f_MLCA(nn.Module):
    """Faster Implementation of CSP Bottleneck with 2 convolutions."""
 
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
        """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
        expansion.
        """
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
 
    def forward(self, x):
        """Forward pass through C2f layer."""
        y = list(self.cv1(x).chunk(2, 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
 
    def forward_split(self, x):
        """Forward pass using split() instead of chunk()."""
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))
 
 
class Bottleneck(nn.Module):
    """Standard bottleneck."""
 
    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
        expansion.
        """
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2
        self.MLCA = MLCA(c2)
 
    def forward(self, x):
        """'forward()' applies the YOLO FPN to input data."""
        return x + self.MLCA(self.cv2(self.cv1(x))) if self.add else self.MLCA(self.cv2(self.cv1(x)))
 
 
if __name__ == "__main__":
    attention = MLCA(in_size=64)
    inputs = torch.randn((2, 55, 16, 16))
    result = attention(inputs)
    print(result.shape)

五、添加MLCA注意力机制

5.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个MLCA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

5.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

5.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

5.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

六、yaml文件与运行

6.1yaml文件

以下是添加MLCA注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 1, MLCA, []]
  - [-1, 2, C2PSA, [1024]] # 11

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 14

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 17 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 20 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 23 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

6.2运行成功截图

在这里插入图片描述

OK 以上就是添加MLCA注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2183352.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

简单的微信小程序登录 注册 页面及逻辑

一、示例 二、示例代码 1.wxml <!--pages/login.wxml--> <!-- 登录注册文字 --> <view class"title">{{TitleText}}</view> <!-- 登录框 --> <view class"inputBox"><input type"text" placeholder&qu…

Nature Machine Intelligence 基于强化学习的扑翼无人机机翼应变飞行控制

尽管无人机技术发展迅速&#xff0c;但复制生物飞行的动态控制和风力感应能力&#xff0c;仍然遥不可及。生物学研究表明&#xff0c;昆虫翅膀上有机械感受器&#xff0c;即钟形感受器campaniform sensilla&#xff0c;探测飞行敏捷性至关重要的复杂气动载荷。 近日&#xff0…

国庆普及模拟赛-1 赛后总结

题目链接&#xff1a; file:///D:/C/%E9%9B%86%E8%AE%AD%E6%B5%8B%E8%AF%95/1001/2022%20-%20J2.pdf T1&#xff1a;隔离 题意如图。需要求所有时间的最短。 思路&#xff1a; 不需要进行一次次枚举&#xff0c;先算出总共要办事的总时间sum&#xff0c;如果某一次时间超过2…

Mysql数据库~~条件查询、分页查询、修改操作

目录 1.表的其他操作 1.1创建一个表 1.2对于表的排序 1.3修改某一列的名字 1.4使用表达式 1.5删除列的重复项 1.6多个列进行排序 2.条件查询 2.1条件查询语句 2.2比较运算符 2.3条件查询展示 2.4条件查询的先后问题 2.5逻辑运算符使用 2.6模糊查询匹配 2.7对于nu…

【2022工业3D异常检测文献】BTF: 结合手工制作的3D描述和颜色特征的异常检测方法

BACK TO THE FEATURE: CLASSICAL 3D FEATURES ARE (ALMOST) ALL YOU NEED FOR 3D ANOMALY DETECTION 1、Background BTF(Back to the Feature)&#xff0c;一种 结合手工制作的3D表示&#xff08;FPFH&#xff09;和基于深度颜色特征提取&#xff08;PatchCore&#xff09; 的…

关于未知物检测设备和方法(测未知物成分含量)

未知物检测是一项涉及多个学科和技术的复杂工作&#xff0c;它对于新材料的研究、开发、生产以及质量控制具有重要意义。以下是一些常用的未知物检测方法和设备&#xff1a; 光谱分析&#xff1a;包括红外光谱&#xff08;IR&#xff09;、核磁共振&#xff08;NMR&#xff09;…

【Android 13源码分析】Activity生命周期之onCreate,onStart,onResume-2

忽然有一天&#xff0c;我想要做一件事&#xff1a;去代码中去验证那些曾经被“灌输”的理论。                                                                                  – 服装…

无源码实现免登录功能

因项目要求需要对一个没有源代码的老旧系统实现免登录功能&#xff0c;系统采用前后端分离的方式部署&#xff0c;登录时前端调用后台的认证接口&#xff0c;认证接口返回token信息&#xff0c;然后将token以json的方式存储到cookie中&#xff0c;格式如下&#xff1a; 这里有…

10月1日星期二今日早报简报微语报早读

10月1日星期二&#xff0c;国庆节&#x1f1e8;&#x1f1f3;&#xff0c;农历八月廿九&#xff0c;早报#微语早读。 1、A股暴涨刷新多项历史纪录&#xff1a;两市成交总额近2.6万亿元&#xff0c;创指涨逾15%&#xff1b; 2、文旅部&#xff1a;常年不超过最高承载量的旅游景…

Docker 安装 Citus 单节点集群:全面指南与详细操作

Docker 安装 Citus 单节点集群&#xff1a;全面指南与详细操作 文章目录 Docker 安装 Citus 单节点集群&#xff1a;全面指南与详细操作一 服务器资源二 部署图三 安装部署1 创建网络2 运行脚本1&#xff09;docker-compose.cituscd1.yml2&#xff09;docker-compose.cituswk1.…

zi2zi-chain: 中国书法字体图片生成和字体制作的一站式开发

在zi2zi-pytorch的基础上&#xff0c;做了进一步的修复和完善。本项目github对应网址为https://github.com/not-bald-owl/zi2zi-chain/tree/master。 修复部分为&#xff1a;针对预处理部分的函数弃用、生僻字无法生成、训练和推理部分单卡支持改为多卡并行、以及扩展从本地的…

过去8年,编程语言的流行度发生了哪些变化?PHP下降,Objective-C已过时

前天有一个汇总9个不同排名数据的“地表最强”编程语言排行榜&#xff0c;为了更好地理解语言流行度的变化&#xff0c;作者将2016年的类似调查结果与2024年的数据进行了比较。 虽然2016年的调查只包含6个排名&#xff0c;但它仍然提供了宝贵的参考数据。 我们来看看详细的情…

C++之String类(下)

片头 嗨喽~ 我们又见面啦&#xff0c;在上一篇C之String类&#xff08;上&#xff09;中&#xff0c;我们对string类的函数有了一个初步的认识&#xff0c;这一篇中&#xff0c;我们将继续学习string类的相关知识。准备好了吗&#xff1f;咱们开始咯~ 二、标准库中的string类 …

业务封装与映射 -- AMP BMP GMP

概述 不同单板支持不同的封装模式&#xff0c;主要包括: AMP (Asynchronous Mapping Procedure&#xff0c;异步映射规程)BMP (Bit-synchronous Mapping Procedure&#xff0c;比特同步映射规程)GMP (Generic Mapping Procedure&#xff0c;通用映射规程) AMP/BMP&#xff1a…

Qt_绘图

目录 1、绘图核心类 2、QPainter类的使用 2.1 绘制线段 2.2 绘制矩形 2.3 绘制圆形 2.4 绘制文本 3、QPen类的使用 3.1 使用画笔 4、QBrush类的使用 4.1 使用画刷 5、绘制图片 5.1 测试QPixmap 5.1.1 图片移动 5.1.2 图标缩小 5.1.3 旋转图片 5.1.4 将…

【逐行注释】MATLAB下的粒子滤波代码(三维状态与观测,可直接复制粘贴到MATLAB上面运行)

文章目录 程序设计1. 介绍2. 系统模型3. 算法步骤源代码(直接复制到MATLAB上面可以运行)运行结果程序设计 1. 介绍 粒子滤波(Particle Filter, PF)是一种基于贝叶斯理论的递归估计方法,广泛用于动态系统状态的估计和跟踪。该方法通过一组粒子(即假设的状态)及其权重来…

【Android 13源码分析】Activity生命周期之onCreate,onStart,onResume-1

忽然有一天&#xff0c;我想要做一件事&#xff1a;去代码中去验证那些曾经被“灌输”的理论。                                                                                  – 服装…

5款惊艳全网的AI写作论文神器!从此告别写作烦恼!

在当今的学术研究和写作领域&#xff0c;撰写高质量的论文是一项挑战性的任务。幸运的是&#xff0c;随着人工智能技术的发展&#xff0c;AI论文写作工具逐渐成为帮助学者和学生提高写作效率的重要工具。这些工具不仅能够提高写作效率&#xff0c;还能帮助研究者生成高质量的论…

ECharts 快速使用

最终效果 使用介绍 echarts图表的绘制&#xff0c;大体分为三步&#xff1a; 根据 DOM实例&#xff0c;通过 echarts.init方法&#xff0c;生成 echarts实例构建 options配置对象&#xff0c;整个echarts的样式&#xff0c;皆有该对象决定最后通过实例.setOption方法&#xf…

【测试-BUG篇】软件测试的BUG知识你了解多少呢?

文章目录 1. 软件测试的生命周期2. BUG3. BUG的生命周期4. 与开发人员起争执怎么办 1. 软件测试的生命周期 &#x1f34e;软件测试 贯穿整个软件的生命周期&#xff1b; &#x1f34e;软件测试的生命周期是指测试流程&#xff1b; ①需求分析 用户角度&#xff1a;软件需求是…