Mamba-yolo|结合Mamba注意力机制的视觉检测

news2025/1/17 23:10:47

一、本文介绍

 PDF地址:https://arxiv.org/pdf/2405.16605v1

代码地址:GitHub - LeapLabTHU/MLLA: Official repository of MLLA

Demystify Mamba in Vision: A Linear AttentionPerspective一文中引入Baseline Mamba,指明Mamba在处理各种高分辨率图像的视觉任务有着很好的效率。发现了强大的Mamba和线性注意力Transformer( linear attention Transformer)非常相似,然后就分析了两者之间的异同。将Mamba模型重述为linear attention Transformer的变体,并且主要有六大差异,分别是:input gate, forget gate,shortcut, no attention normalization, single-head, and modified block design。作者对每个设计都细致的分析了优缺点,评估了性能,最终发现forget gate和block design是Mamba这么给力的主要贡献点。基于以上发现,作者提出了一个类似mamba的线性注意力模型,Mamba-Like Linear Attention (MLLA) ,相当于取其精华,去其糟粕,把mamba两个最为关键的优点设计结合到线性注意力模型当中,具有可并行计算和快速推理的特点。本文将结合YOlOV8检测模型通过添加MLLA模块提升检测精度。

二、宏观架构设计

线性注意 Transformer 模型通常采用图 (a) 中的设计,它由线性注意力模块和 MLP 模块组成。相比之下,Mamba 通过结合 H3和 Gated Attention这两个设计来改进,得到如图 (b) 所示的架构。改进的 Mamba Block 集成了多种操作,例如选择性 SSM、深度卷积、线性映射、激活函数、门控机制等,并且往往比传统的 Transformer 设计更有效。

MLLA (Mamba-Like Linear Attention)的则是通过将Mamba模型的一些核心设计融入线性注意力机制,从而提升模型的性能。具体来说,MLLA主要整合了Mamba中的"忘记门”(forget gate9)和模块设计(block design)这两个关键因素,这些因素被认为是Mamba成功的主要原因。
以下是对MLLA原理的详细分析:
1.忘记门(Forget Gate)
1.忘记门提供了局部偏差和位置信息。所有的忘记门元素严格限制在0到1之间,这意味着模型在接收到当前输入后会持续衰减失前的隐藏状态。这种特性确保了模型对输入序列的顺序敏感。
2.忘记门的局部偏差和位置信息对于图像处理任务来说非常重要,尽管引入忘记门会导致计算需要采用递归的形式,从而降低并行计算的效率。
2.模块设计(Block Design)
1.Mamba的模块设计在保持相似的浮点运算次数(FLOPS)的同时,通过替换注意力子模块为线性注意力来提升性能。结果表明,采用这种模块设计能够显著提高模型的表现。
3.线性注意力的改进:
1.线性注意力被重新设计以整合忘记门和模块设计,这种改进后的模型被称为MLLA。实验结果显示,MLLA在图像分类和高分辨率密集预测任务中均优于各种视觉Mamba模型
4.并行计算和快速推理速度:
1.MLLA通过使用位置编码(ROPE)来替代忘记门,从而在保持并行计算和快速推理速度的同时,提供必要的位置信息。这使得MLLA在处理非自回归的视觉任务时更加有效

结合yolov8改进

核心代码
 

import torch
import torch.nn as nn
 
__all__ = ['MLLAttention']
 
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
 
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
 
 
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
                 bias=True, dropout=0, norm=nn.BatchNorm2d, act_func=nn.ReLU):
        super(ConvLayer, self).__init__()
        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=(padding, padding),
            dilation=(dilation, dilation),
            groups=groups,
            bias=bias,
        )
        self.norm = norm(num_features=out_channels) if norm else None
        self.act = act_func() if act_func else None
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x
 
 
class RoPE(torch.nn.Module):
    r"""Rotary Positional Embedding.
    """
 
    def __init__(self, base=10000):
        super(RoPE, self).__init__()
        self.base = base
 
    def generate_rotations(self, x):
        # 获取输入张量的形状
        *channel_dims, feature_dim = x.shape[1:-1][0], x.shape[-1]
        k_max = feature_dim // (2 * len(channel_dims))
 
        assert feature_dim % k_max == 0, "Feature dimension must be divisible by 2 * k_max"
 
        # 生成角度
        theta_ks = 1 / (self.base ** (torch.arange(k_max, dtype=x.dtype, device=x.device) / k_max))
        angles = torch.cat([t.unsqueeze(-1) * theta_ks for t in
                            torch.meshgrid([torch.arange(d, dtype=x.dtype, device=x.device) for d in channel_dims],
                                           indexing='ij')], dim=-1)
 
        # 计算旋转矩阵的实部和虚部
        rotations_re = torch.cos(angles).unsqueeze(dim=-1)
        rotations_im = torch.sin(angles).unsqueeze(dim=-1)
        rotations = torch.cat([rotations_re, rotations_im], dim=-1)
 
        return rotations
 
    def forward(self, x):
        # 生成旋转矩阵
        rotations = self.generate_rotations(x)
 
        # 将 x 转换为复数形式
        x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
 
        # 应用旋转矩阵
        pe_x = torch.view_as_complex(rotations) * x_complex
 
        # 将结果转换回实数形式并展平最后两个维度
        return torch.view_as_real(pe_x).flatten(-2)
 
 
class MLLAttention(nn.Module):
    r""" Linear Attention with LePE and RoPE.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
    """
 
    def __init__(self, dim=3, input_resolution=[160, 160], num_heads=4, qkv_bias=True, **kwargs):
 
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.elu = nn.ELU()
        self.lepe = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.rope = RoPE()
 
    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, N, C)
        """
        x = x.reshape((x.size(0), x.size(2) * x.size(3), x.size(1)))
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        # self.rope = RoPE(shape=(h, w, self.dim))
        num_heads = self.num_heads
        head_dim = c // num_heads
 
        qk = self.qk(x).reshape(b, n, 2, c).permute(2, 0, 1, 3)
        q, k, v = qk[0], qk[1], x
        # q, k, v: b, n, c
 
        q = self.elu(q) + 1.0
        k = self.elu(k) + 1.0
        q_rope = self.rope(q.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k_rope = self.rope(k.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
 
        z = 1 / (q @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
        kv = (k_rope.transpose(-2, -1) * (n ** -0.5)) @ (v * (n ** -0.5))
        x = q_rope @ kv * z
 
        x = x.transpose(1, 2).reshape(b, n, c)
        v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = x + self.lepe(v).permute(0, 2, 3, 1).reshape(b, n, c)
        x = x.transpose(2, 1).reshape((b, c, h, w))
        return x
 
    def extra_repr(self) -> str:
        return f'dim={self.dim}, num_heads={self.num_heads}'
 
 
if __name__ == "__main__":
    # Generating Sample image
    image_size = (1, 64, 160, 160)
    image = torch.rand(*image_size)
 
    # Model
    model = MLLAttention(64)
 
    out = model(image)
    print(out.size())

修改一

第一还是建立文件,我们找到如下ultralvtics/n文件夹下建立一个目录名字呢就是'Addmodules文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。

修改二

第二步我们在该目录下创建一个新的py文件名字为'  __init__ .py,然后在其内部导入我们的检测头如
下图所示。

修改三 

第三步我门中到如下文件uitralytics/nn/tasks.py进行导入和注册我们的模块

修改四

按照我的添加在parse model里添加即可。

修改5

修改6 配置yolov8-MLLA.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
 
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
 
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [-1, 1, MLLAttention, []]  # 22 (P5/32-large) # 添加在大目标检测层后!
 
  - [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

7. 训练代码

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
 
if __name__ == '__main__':
    model = YOLO('yolov8-MLLA.yaml')
    # 如何切换模型版本, 上面的ymal文件可以改为 yolov8s.yaml就是使用的v8s,
    # 类似某个改进的yaml文件名称为yolov8-XXX.yaml那么如果想使用其它版本就把上面的名称改为yolov8l-XXX.yaml即可(改的是上面YOLO中间的名字不是配置文件的)!
    # model.load('yolov8n.pt') # 是否加载预训练权重,科研不建议大家加载否则很难提升精度
    model.train(data=r"C:\Users\Administrator\PycharmProjects\yolov5-master\yolov5-master\Construction Site Safety.v30-raw-images_latestversion.yolov8\data.yaml",
                # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
                cache=False,
                imgsz=640,
                epochs=150,
                single_cls=False,  # 是否是单类别检测
                batch=16,
                close_mosaic=0,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                # resume='runs/train/exp21/weights/last.pt', # 如过想续训就设置last.pt的地址
                amp=True,  # 如果出现训练损失为Nan可以关闭amp
                project='runs/train',
                name='exp',
                )

8.开启训练

专栏推荐

专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代码与数据集。可通过百度云盘进行获取,实现开箱即用

正在跟新中~

深度学习落地实战_机 _ 长的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1943435.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

零基础入门:创建一个简单的Python爬虫管理系统

摘要: 本文将手把手教你,从零开始构建一个简易的Python爬虫管理系统,无需编程基础,轻松掌握数据抓取技巧。通过实战演练,你将学会设置项目、编写基本爬虫代码、管理爬取任务与数据,为个人研究或企业需求奠…

回溯题目的套路总结

前言 昨天写完了LeeCode的7,8道回溯算法的题目,写一下总结,这类题目的共同特点就是暴力搜索问题,排列组合或者递归,枚举出所有可能的答案,思路很简单,实现起来的套路也很通用,一…

poi库简单使用(java如何实现动态替换模板Word内容)

目录 Blue留言: Blue的推荐: 什么是poi库? 实现动态替换 第一步:依赖 第二步:实现word模板中替换文字 模板word: 通过以下代码:(自己建一个类,随意取名&#xf…

SpringBoot框架学习笔记(五):静态资源访问、Rest风格请求处理、配置视图解析器、接收参数的相关注解详解

1 WEB开发-静态资源访问 1.1 基本介绍 (1)只要静态资源放在类路径的以下目录:/static、/public、/resources、/META-INF/resources 可以被直接访问。maven项目的类路径即为main/resources目录--对应SpringBoot源码为WebProperties.java类 …

nginx如何开启优先访问压缩文件

nginx输出gzip有很多条件: 开启了gzip:gzip on;gzip_types定义了content-type,需要注意的是text/html是强制性的,不需要也不能再添加这个响应输出的content-type在gzip_types里输出的content-length大于等于nginx配置的gzip_min_…

【TypeScript 一点点教程】

文章目录 一、开发环境搭建二、基本类型2.1 类型声明2.2 基本类型 三、编译3.1 tsc命令3.2 tsconfig.json3.2.1 基本配置项includeexcludeextendsfiles 3.2.2 compilerOptions编译器的配置项 四、面向对象4.1 类4.2 继承4.3 抽象类4.4 接口 一、开发环境搭建 下载Node.js《Nod…

操作系统——进程与线程(死锁)

1)为什么会产生死锁?产生死锁有什么条件? 2)有什么办法解决死锁? 一、死锁 死锁:多个程序因竞争资源而造成的一种僵局(互相等待对方手里的资源),使得各个进程都被阻塞,…

02.C++入门基础(下)

1.函数重载 C支持在同一作用域中出现同名函数,但是要求这些同名函数的形参不同,可以是参数个数不同或者类型不同。这样C函数调用就表现出了多态行为,使用更灵活。C语言是不支持同一作用域中出现同名函数的。 1、参数类型不同 2、参数个数不同…

volatile,最轻量的同步机制

目录 一、volatile 二、如何使用? 三、volatile关键字能代替synchronized关键字吗? 四、总结: 还是老样子,先来看一段代码: 我们先由我们自己的常规思路分析一下代码:子线程中,一直循环&…

DocRED数据集

DocRED数据集文件夹包含多个JSON文件,每个文件都有不同的用途。以下是这些文件的用途解释以及哪个文件是训练集: 文件解释 dev.json:包含开发集(验证集)的数据,通常用于模型调优和选择超参数。 label_map…

java面向对象进阶进阶篇--《包和final》

一、前言 今天还是面向对象相关知识点的分享,包是写小型项目时不可或缺的存在,final关键字用的地方不算太多。idea会提示我们导包,有时会自动导包,确实十分方便。但是我们也不能不会自己去导包。 面向对象篇不出意外的话本周就要…

【线性代数】矩阵变换

一些特殊的矩阵 一,对角矩阵 1,什么是对角矩阵 表示将矩阵进行伸缩(反射)变换,仅沿坐标轴方向伸缩(反射)变换。 2,对角矩阵可分解为多个F1矩阵,如下: 二&a…

python打包exe文件-实现记录

1、使用pyinstaller库 安装库: pip install pyinstaller打包命令标注主入库程序: pyinstaller -F.\程序入口文件.py 出现了一个问题就是我在打包运行之后会出现有一些插件没有被打包。 解决问题: 通过添加--hidden-importcomtypes.strea…

“微软蓝屏”事件引发的深度思考:网络安全与系统稳定性的挑战与应对

“微软蓝屏”事件暴露了网络安全哪些问题? 近日,一次由微软视窗系统软件更新引发的全球性“微软蓝屏”事件,不仅成为科技领域的热点新闻,更是一次对全球IT基础设施韧性与安全性的深刻检验。这次事件,源于美国电脑安全…

【Vue3】工程创建及目录说明

【Vue3】工程创建及目录说明 背景简介开发环境开发步骤及源码 背景 随着年龄的增长,很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来,技术出身的人总是很难放下一些执念,遂将这些知识整理成文,以纪念曾经努力学习奋斗的日…

全网最全最详细的C++23 标准详解:核心语言改进与新特性

1. 简介 C23 是由 C 标准委员会最新发布的标准,旨在进一步提升 C 语言的功能和开发效率。作为一项重要的编程语言标准更新,C23 引入了多个关键的新特性和改进,使开发者能够编写更高效、简洁和安全的代码。 与 C20 相比,C23 的变…

3112.力扣每日一题7/18 Java 迪杰斯特拉(Dijkstra)算法

博客主页:音符犹如代码系列专栏:算法练习关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 目录 迪杰斯特拉(Dijkstra)算法 解题思路 解题过…

C++学习指南(三)——模板

欢迎来到繁星的CSDN。本期内容主要包括模板template。 目录 一、什么是模板? 二、函数模板 模板的定义方式 模板的实例化(确定参数的类型) 隐式实例化 显式实例化 实例化顺序 三、类模板和模板类 类模板的实例化 一、什么是模板&#xff1…

智慧职校就业管理:开启校园招聘会新模式

在智慧职校的就业管理系统中,校园招聘会的出现,为学生们提供了一个展示自我、探寻职业道路的舞台,同时也为企业搭建了一座直面未来之星的桥梁。这一功能,凭借其独特的优势与前沿的技术,正在重新定义校园与职场之间的过…

2024中国大学生算法设计超级联赛(1)

🚀欢迎来到本文🚀 🍉个人简介:陈童学哦,彩笔ACMer一枚。 🏀所属专栏:杭电多校集训 本文用于记录回顾总结解题思路便于加深理解。 📢📢📢传送门 A - 循环位移解…