YOLO11改进|注意力机制篇|引入Mamba注意力机制MLLAttention

news2024/10/17 13:27:37

在这里插入图片描述

目录

    • 一、【MLLAttention】注意力机制
      • 1.1【MLLAttention】注意力介绍
      • 1.2【MLLAttention】核心代码
    • 二、添加【MLLAttention】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
      • 2.5STEP5
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【MLLAttention】注意力机制

1.1【MLLAttention】注意力介绍

在这里插入图片描述

下图是【MLLAttention】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • MLLA Block:
  • MLLA 模块的核心是其线性注意力机制,输入首先通过一个 线性变换(Linear),对输入特征进行降维或转化。
  • 接着,线性注意力(Linear Attention) 被应用在变换后的特征上。线性注意力的主要特点是其计算复杂度低,能够快速处理长序列特征。这一层通过简单的线性变换替代了复杂的自注意力机制,提升了效率。
  • 之后,特征会进一步通过 卷积(Conv) 和其他线性层进行细粒度的特征提取。这些操作确保特征在局部范围内也能够被有效捕捉和处理。
  • 前馈网络(MLP):
  • 在经过 MLLA Block 的处理后,特征会经过一个多层感知器(MLP)网络。MLP 网络通过非线性激活函数(例如 ReLU 或 GELU)来增强模型的非线性表达能力。
  • 归一化(Norm) 层则用于对每一层输出的特征进行标准化处理,确保特征的稳定性,有助于模型更快收敛。
  • 残差连接:
  • 每个模块之间存在残差连接(Residual Connection),保证了在深层网络中,输入信息不会因为层数过深而消失。这种设计使得梯度更容易传播,提升了训练的稳定性和效率。
    优势
  • 计算效率高:
  • 线性注意力机制的引入极大地降低了计算复杂度。相比传统的自注意力(如 Transformer 中的全局自注意力)依赖矩阵乘法和大规模计算,线性注意力能够以更低的计算成本完成类似的操作,非常适合处理大规模数据或长序列输入。
  • 结合局部和全局特征:
  • 通过将卷积层和线性层结合,MLLA 模块能够同时处理局部和全局特征。卷积层擅长捕捉局部空间关系,而线性层则负责处理全局特征的交互,从而提升了模型对多维度信息的捕捉能力。
  • 增强的非线性表达能力:
  • MLP 部分通过多层非线性变换,增强了模型的表达能力。这使得模型不仅能够处理线性关系,还能够处理更复杂的特征模式,从而提升模型的预测精度。
  • 稳定的梯度传播:
  • 残差连接的使用确保了模型在深度网络中的训练稳定性,有助于缓解梯度消失问题。对于深层网络来说,这种设计可以有效提高收敛速度和模型性能。
    在这里插入图片描述

1.2【MLLAttention】核心代码

import torch
import torch.nn as nn
 
 
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
 
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
 
 
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
                 bias=True, dropout=0, norm=nn.BatchNorm2d, act_func=nn.ReLU):
        super(ConvLayer, self).__init__()
        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=(padding, padding),
            dilation=(dilation, dilation),
            groups=groups,
            bias=bias,
        )
        self.norm = norm(num_features=out_channels) if norm else None
        self.act = act_func() if act_func else None
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x
 
 
class RoPE(torch.nn.Module):
 
    def __init__(self, base=10000):
        super(RoPE, self).__init__()
        self.base = base
 
    def generate_rotations(self, x):
        # 获取输入张量的形状
        *channel_dims, feature_dim = x.shape[1:-1][0], x.shape[-1]
        k_max = feature_dim // (2 * len(channel_dims))
 
        assert feature_dim % k_max == 0, "Feature dimension must be divisible by 2 * k_max"
 
        theta_ks = 1 / (self.base ** (torch.arange(k_max, dtype=x.dtype, device=x.device) / k_max))
        angles = torch.cat([t.unsqueeze(-1) * theta_ks for t in
                            torch.meshgrid([torch.arange(d, dtype=x.dtype, device=x.device) for d in channel_dims],
                                           indexing='ij')], dim=-1)
 
        rotations_re = torch.cos(angles).unsqueeze(dim=-1)
        rotations_im = torch.sin(angles).unsqueeze(dim=-1)
        rotations = torch.cat([rotations_re, rotations_im], dim=-1)
 
        return rotations
 
    def forward(self, x):
        rotations = self.generate_rotations(x)
 
        x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
 
        pe_x = torch.view_as_complex(rotations) * x_complex
 
        return torch.view_as_real(pe_x).flatten(-2)
 
 
class MLLAttention(nn.Module):
 
    def __init__(self, dim=3, input_resolution=[160, 160], num_heads=4, qkv_bias=True, **kwargs):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.elu = nn.ELU()
        self.lepe = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.rope = RoPE()
 
    def forward(self, x):
        x = x.reshape((x.size(0), x.size(2) * x.size(3), x.size(1)))
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        # self.rope = RoPE(shape=(h, w, self.dim))
        num_heads = self.num_heads
        head_dim = c // num_heads
 
        qk = self.qk(x).reshape(b, n, 2, c).permute(2, 0, 1, 3)
        q, k, v = qk[0], qk[1], x
        # q, k, v: b, n, c
 
        q = self.elu(q) + 1.0
        k = self.elu(k) + 1.0
        q_rope = self.rope(q.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k_rope = self.rope(k.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
 
        z = 1 / (q @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
        kv = (k_rope.transpose(-2, -1) * (n ** -0.5)) @ (v * (n ** -0.5))
        x = q_rope @ kv * z
 
        x = x.transpose(1, 2).reshape(b, n, c)
        v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = x + self.lepe(v).permute(0, 2, 3, 1).reshape(b, n, c)
        x = x.transpose(2, 1).reshape((b, c, h, w))
        return x
 
    def extra_repr(self) -> str:
        return f'dim={self.dim}, num_heads={self.num_heads}'
 
 
if __name__ == "__main__":
    # Generating Sample image
    image_size = (1, 64, 160, 160)
    image = torch.rand(*image_size)
 
    # Model
    model = MLLAttention(64)
 
    out = model(image)
    print(out.size())

二、添加【MLLAttention】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个MLLAttention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

2.5STEP5

这个模块有点特殊,需要在moels/yolo/detect/val.py文件下做修改,不然验证的时候会出错
在这里插入图片描述

        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=self.stride)

注意替换后运行完这个模块后记得还原

三、yaml文件与运行

3.1yaml文件

以下是添加【MLLAttention】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128,3,2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256,3,2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512,3,2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024,3,2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1,1,MLLAttention,[]]
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【MLLAttention】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2211825.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[Linux#66][TCP->IP] 面向字节流 | TCP异常 | filesocket | 网络层IP

目录 1. 面向字节流 思考:对于UDP协议来说,是否也存在“粘包问题”呢? 2.TCP 异常情况 3.知识 1.UDP实现可靠传输(经典面试题) 2. 网络抓包 | 爬虫 3.打通文件和 socket 的关系 4.网络层:IP 前置知识 1. 面向字节流 udp…

Java+vue部署版本反编译

🏆本文收录于《全栈Bug调优(实战版)》专栏,主要记录项目实战过程中所遇到的Bug或因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&am…

C++STL(2)

queue(队列) queue是一种先进先出的数据结构。 queue提供了一组函数来操作和访问元素,但它的功能相对较简单。 push(x):在队尾插入元素 x pop():弹出队首元素 front():返回队首元素 back():返回队尾元素 empty():检查队列是否为空 size0:返回队列中元素的个数 pri…

Android ViewModel

一问:ViewModel如何保证应用配置变化后能够自动继续存在,其原理是什么,ViewModel的生命周期和谁绑定的? ViewModel 的确能够在应用配置发生变化(例如屏幕旋转)后继续存在,这得益于 Android 系统的 ViewMod…

模拟电子电路基础(常见半导体+multisim学习1)

目录 1.半导体的基础 1.1.半导体基础知识 1.1.1本征半导体 1.1.2杂质半导体 1.1.3PN结 1.2半导体二极管 1.2.1半导体二极管的几种常见结构 1.2.2二极管的伏安特性曲线 1.2.3二极管的主要参数 1.2.4二级管的等效电路 1.2.5稳压二极管 1.2.其他类型二极管 2.multisim的…

双目视觉搭配YOLO实现3D测量

一、简介 双目(Stereo Vision)技术是一种利用两个相机来模拟人眼视觉的技术。通过对两个相机获取到的图像进行分析和匹配,可以计算出物体的深度信息。双目技术可以实现物体的三维重建、距离测量、运动分析等应用。 双目技术的原理是通过两…

Docker-nginx数据卷挂载

数据卷(volume)是一个虚拟目录,是容器内目录与宿主机目录之间映射的桥梁。 以Nginx为例,我们知道Nginx中有两个关键的目录: html:放置一些静态资源conf:放置配置文件 如果我们要让Nginx代理我们…

java项目之厨艺交流平台设计与实现(源码+文档)

项目简介 厨艺交流平台设计与实现实现了以下功能: 厨艺交流平台设计与实现的主要使用者管理员管理用户信息,可以添加,修改,删除用户信息信息。 💕💕作者:落落 💕💕个人…

分享一个从图片中提取色卡的实现

概述 最近在做“在线地图样式配置”的功能的时候,发现百度地图有个功能时上传一张图片,从图片中提取颜色并进行配图。本文就简单实现一下如何从图片中提取色卡。 效果 实现 实现思路 通过canvasdrawImage绘制图片,并通过getImageData获取…

主数据系统管理、运维的实践经验与建议

公司在预研一个新的主数据系统,领导问笔者给些建议。结合近两年的主数据系统管理、维护经验,给大致写了一些。 里面少数问题属于目前在运行的主数据系统的系统痛点所致,不过大多数笔者认为是通病,一口气写来已两千字,…

【验证码识别】Python+卷积神经网络算法+人工智能+深度学习+Django网页界面+计算机课设项目+TensorFlow+算法模型

一、介绍 验证码识别,使用Python作为开发语言,通过TensorFlow搭建CNN卷积神经网络算法模型,并通过对收集的几千张验证码图片作为数据集,然后进行迭代训练,最终得到一个识别精度较高的模型文件,然后使用Dja…

Cesium 区域高程图

Cesium 区域高程图 const terrainAnalyse new HeightMapMaterial({viewer,style: {stops: [0, 0.05, 0.5, 1],//颜色梯度设置colors: [green, yellow, blue , red],}});

JS 分支语句

目录 1. 表达式与语句 1.1 表达式 1.2 语句 1.3 区别 2. 程序三大流控制语句 3. 分支语句 3.1 if 分支语句 3.2 双分支 if 语句 3.3 双分支语句案例 3.3.1 案例一 3.3.2 案例二 3.4 多分支语句 1. 表达式与语句 1.1 表达式 1.2 语句 1.3 区别 2. 程序三大流控制语…

66 消息队列

66 消息队列 基础概念 参考资料:消息队列MQ快速入门(概念、RPC、MQ实质思路、队列介绍、队列对比、应用场景) 消息队列就是一个使用队列来通信的组件;为什么需要消息队列? 在实际的商业项目中,它这么做肯…

shell原理

shell 是个进程 , exe在user/bin/bash [用户名主机名 pwd] snprintf fflush(stdout),在没有\n情况下立马输出 strtok 第一个参数null表示传入上个有效参数 命令行中,有些命令必须由子进程执行, 如ls 有些…

【进阶OpenCV】 (11)--DNN板块--实现风格迁移

文章目录 DNN板块一、DNN特点二、DNN函数流程三、实现风格迁移1. 图像预处理2. 加载星空模型3. 输出处理 总结 DNN板块 DNN模块是 OpenCV 中专门用来实现 DNN(Deep Neural Networks,深度神经网络) 模块的相关功能,其作用是载入别的深度学习框架(如 TensorFlow、Caf…

考虑促销因素的医药电商平台需求预测研究

一、考虑促销因素的医药电商平台需求预测研究 一、引言 1. 互联网医疗健康的发展 内容:介绍了在互联网的大背景下,医疗健康行业如何迅速发展,举例了1药网和叮当快药等平台提供的服务。重点:互联网医疗用户规模和市场…

《人工智能(AI)和深度学习简史》

人工智能(AI)和深度学习在过去几十年里有了飞跃式的进步,彻底改变了像计算机视觉、自然语言处理、机器人这些领域。本文会带你快速浏览AI和深度学习发展的关键历史时刻,从最早的神经网络模型,一直到现在的大型语言模型…

【技术支持】家里智能电视不能联网重置小米路由器之路

问题现象 最近家里的路由器出现一点问题,现象是手机和电脑连接wifi后,都可以正常打开网页看视频。 但是小爱同学和小米盒子,都出现网络问题,不能正常播放音乐或者视频。 这是小米盒子的网络问题截图 这是和小米盒子连接的智能电…

骨架提取(持续更新)

一 什么是骨架提取 1.1 简介 骨架提取是图像处理或计算机视觉中的一种技术,用于从二值化图像中提取物体的中心线或轮廓,通常称为“骨架”或“细化图像”。这一技术主要用于简化形状表示,同时保留物体的拓扑结构。 这里我们强调了&#xff…