YOLO11改进|注意力机制篇|引入并行分块注意力机制PPA

news2024/10/7 9:34:14

在这里插入图片描述

目录

    • 一、【PPA】注意力机制
      • 1.1【PPA】注意力介绍
      • 1.2【PPA】核心代码
    • 二、添加【PPA】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【PPA】注意力机制

1.1【PPA】注意力介绍

在这里插入图片描述

下图是PPA并行分块注意力机制结构图,大致分析一下工作流程和优势

  • 工作流程

  • 1.1 Patch-Aware 分支
    输入的特征图大小为 𝐻′×𝑊′×𝐶。特征图通过 Pointwise Conv (PW Conv),它是一个 1×1 卷积,用来处理通道维度,改变每个通道上的信息。接下来,进入两个并行的 Patch-Aware 模块,分别使用不同的感受野大小 𝑝=2p=2 和 𝑝=4。这两个并行分支是为了捕获不同尺度下的局部信息。Patch-Aware 模块 通过以下步骤处理输入:
    首先,输入的特征图被通过一个 Unfold 操作,这相当于将输入特征图划分成多个 𝑝×𝑝的局部块。然后对这些局部块进行求平均操作(Mean),获得局部感知特征。接着通过一个全连接层 FFN 和 Softmax 操作,为每个局部块生成注意力权重,最后进行 Feature Selection,即选择最重要的局部信息用于后续处理。

  • 1.2 Attention 模块
    Patch-Aware 模块 的输出与卷积层的输出进行相加,之后进入 Attention 模块 进行进一步处理。Attention 模块 包含 Channel Attention 和 Spatial Attention 两个部分:Channel Attention 计算每个通道的重要性,通过一个 1×1的卷积来生成权重,并与原特征进行逐元素相乘。Spatial Attention 则是关注空间维度的特征,通过计算每个空间位置的注意力权重来进一步加强特征。

  • 1.3 输出
    最终的特征图经过 Attention 模块后,生成的输出大小为 𝐻′×𝑊′×𝐶′,可以用于下游任务如分类或目标检测。

  • 优势分析

  • 2.1 多尺度感知
    PPA 模块通过并行的 Patch-Aware 机制引入了不同的感受野,分别捕捉到小块和大块的局部特征。这种多尺度的感知能力使得模型能够同时关注局部细节和全局上下文信息,有助于处理多样化的输入数据。

  • 2.2 强化特征选择
    通过 Feature Selection,模块能够有效地选择重要的局部特征,并通过 Attention 机制在通道和空间维度进一步增强特征。这样能够显著提高模型的特征表达能力,尤其是在处理复杂场景时。

  • 2.3 并行计算加速
    该模块采用并行的 Patch-Aware 分支,同时处理不同尺度的信息,能够提高计算效率。与串行处理相比,并行结构使得模型在不增加太多计算负担的前提下,获得更丰富的特征表示。

  • 2.4 通道和空间的自适应权重
    Attention 模块提供了自适应的通道权重和空间权重,能够根据输入数据的不同动态调整特征的重要性。这种自适应性能够增强模型在处理不同任务时的灵活性,使其能够对特定任务或数据分布更具鲁棒性。

在这里插入图片描述

1.2【PPA】核心代码

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(
            in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out * x


class LocalGlobalAttention(nn.Module):
    def __init__(self, output_dim, patch_size):
        super().__init__()
        self.output_dim = output_dim
        self.patch_size = patch_size
        self.mlp1 = nn.Linear(patch_size * patch_size, output_dim // 2)
        self.norm = nn.LayerNorm(output_dim // 2)
        self.mlp2 = nn.Linear(output_dim // 2, output_dim)
        self.conv = nn.Conv2d(output_dim, output_dim, kernel_size=1)
        self.prompt = torch.nn.parameter.Parameter(
            torch.randn(output_dim, requires_grad=True)
        )
        self.top_down_transform = torch.nn.parameter.Parameter(
            torch.eye(output_dim), requires_grad=True
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        B, H, W, C = x.shape
        P = self.patch_size

        # Local branch
        local_patches = x.unfold(1, P, P).unfold(2, P, P)  # (B, H/P, W/P, P, P, C)
        local_patches = local_patches.reshape(B, -1, P * P, C)  # (B, H/P*W/P, P*P, C)
        local_patches = local_patches.mean(dim=-1)  # (B, H/P*W/P, P*P)

        local_patches = self.mlp1(local_patches)  # (B, H/P*W/P, input_dim // 2)
        local_patches = self.norm(local_patches)  # (B, H/P*W/P, input_dim // 2)
        local_patches = self.mlp2(local_patches)  # (B, H/P*W/P, output_dim)

        local_attention = F.softmax(local_patches, dim=-1)  # (B, H/P*W/P, output_dim)
        local_out = local_patches * local_attention  # (B, H/P*W/P, output_dim)

        cos_sim = F.normalize(local_out, dim=-1) @ F.normalize(
            self.prompt[None, ..., None], dim=1
        )  # B, N, 1
        mask = cos_sim.clamp(0, 1)
        local_out = local_out * mask
        local_out = local_out @ self.top_down_transform

        # Restore shapes
        local_out = local_out.reshape(
            B, H // P, W // P, self.output_dim
        )  # (B, H/P, W/P, output_dim)
        local_out = local_out.permute(0, 3, 1, 2)
        local_out = F.interpolate(
            local_out, size=(H, W), mode="bilinear", align_corners=False
        )
        output = self.conv(local_out)

        return output


class ECA(nn.Module):
    def __init__(self, in_channel, gamma=2, b=1):
        super(ECA, self).__init__()
        k = int(abs((math.log(in_channel, 2) + b) / gamma))
        kernel_size = k if k % 2 else k + 1
        padding = kernel_size // 2
        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv = nn.Sequential(
            nn.Conv1d(
                in_channels=1,
                out_channels=1,
                kernel_size=kernel_size,
                padding=padding,
                bias=False,
            ),
            nn.Sigmoid(),
        )

    def forward(self, x):
        out = self.pool(x)
        out = out.view(x.size(0), 1, x.size(1))
        out = self.conv(out)
        out = out.view(x.size(0), x.size(1), 1, 1)
        return out * x


class conv_block(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=(1, 1),
        dilation=(1, 1),
        norm_type="bn",
        activation=True,
        use_bias=True,
        groups=1,
    ):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=use_bias,
            groups=groups,
        )

        self.norm_type = norm_type
        self.act = activation

        if self.norm_type == "gn":
            self.norm = nn.GroupNorm(
                32 if out_features >= 32 else out_features, out_features
            )
        if self.norm_type == "bn":
            self.norm = nn.BatchNorm2d(out_features)
        if self.act:
            # self.relu = nn.GELU()
            self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        if self.norm_type is not None:
            x = self.norm(x)
        if self.act:
            x = self.relu(x)
        return x


class PPA(nn.Module):
    def __init__(self, in_features, filters) -> None:
        super().__init__()

        self.skip = conv_block(
            in_features=in_features,
            out_features=filters,
            kernel_size=(1, 1),
            padding=(0, 0),
            norm_type="bn",
            activation=False,
        )
        self.c1 = conv_block(
            in_features=in_features,
            out_features=filters,
            kernel_size=(3, 3),
            padding=(1, 1),
            norm_type="bn",
            activation=True,
        )
        self.c2 = conv_block(
            in_features=filters,
            out_features=filters,
            kernel_size=(3, 3),
            padding=(1, 1),
            norm_type="bn",
            activation=True,
        )
        self.c3 = conv_block(
            in_features=filters,
            out_features=filters,
            kernel_size=(3, 3),
            padding=(1, 1),
            norm_type="bn",
            activation=True,
        )
        self.sa = SpatialAttentionModule()
        self.cn = ECA(filters)
        self.lga2 = LocalGlobalAttention(filters, 2)
        self.lga4 = LocalGlobalAttention(filters, 4)

        self.bn1 = nn.BatchNorm2d(filters)
        self.drop = nn.Dropout2d(0.1)
        self.relu = nn.ReLU()

        self.gelu = nn.GELU()

    def forward(self, x):
        x_skip = self.skip(x)
        x_lga2 = self.lga2(x_skip)
        x_lga4 = self.lga4(x_skip)
        x1 = self.c1(x)
        x2 = self.c2(x1)
        x3 = self.c3(x2)
        x = x1 + x2 + x3 + x_skip + x_lga2 + x_lga4
        x = self.cn(x)
        x = self.sa(x)
        x = self.drop(x)
        x = self.bn1(x)
        x = self.relu(x)
        return x


if __name__ == "__main__":
    block = PPA(64, 64)
    input = torch.rand(3, 64, 128, 128)
    output = block(input)
    print(input.size())
    print(output.size())



二、添加【PPA】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个PPA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【PPA】注意力机制在最小层检测头中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1,1,PPA,[256]]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)



  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准 ,同时中目标检测层中参数为[512],大目标检测层中参数为[1024]

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【PPA】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2193895.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Taipy:AI应用全栈开发神器

Taipy 是一个免费使用的 Python 库,任何具备基本 Python 技能的人都可以使用。它是数据科学家、机器学习工程师和 Python 程序员的得力工具。使用 Taipy,你可以轻松地将数据和机器学习模型转换为功能齐全的 Web 应用程序。在我们生活的瞬息万变的世界中&…

消费者Rebalance机制

优质博文:IT-BLOG-CN 一、消费者Rebalance机制 在Apache Kafka中,消费者组 Consumer Group会在以下几种情况下发生重新平衡Rebalance: 【1】消费者加入或离开消费者组: 当一个新的消费者加入消费者组或一个现有的消费者离开消费…

springboot工程中使用tcp协议

文章目录 一、概述二、实现思路三、代码结构四、代码放送五、运行界面六. 主要技术点 一、概述 在上文JAVA TCP协议初体验 中,我们使用java实现了tcp协议的一个雏形,实际中大部分项目都已采用springboot,那么,怎么在springboot中…

【机器学习】知识总结1(人工智能、机器学习、深度学习、贝叶斯、回归分析)

目录 一、机器学习、深度学习 1.人工智能 1.1人工智能概念 1.2人工智能的主要研究内容与应用领域 1.2.1主要研究内容: 1.2.2应用领域 2.机器学习 2.1机器学习的概念 2.2机器学习的基本思路 2.3机器学习的分类 3.深度学习 3.1深度学习的概念 3.2人工智能…

Cocos_鼠标滚轮放缩地图

文章目录 前言一、环境二、版本一_code2.分析类属性方法详细分析详细分析onLoad()onMouseWheel(event)详细分析 总结 前言 学习笔记,请多多斧正。 一、环境 通过精灵rect放置脚本实现鼠标滚轮放缩地图。 二、版本一_code import { _decorator, Component, Node }…

task【XTuner微调个人小助手认知】

1 微调前置基础 本节主要重点是带领大家实现个人小助手微调,如果想了解微调相关的基本概念,可以访问XTuner微调前置基础。 2 准备工作 环境安装:我们想要用简单易上手的微调工具包 XTuner 来对模型进行微调的话,第一步是安装 XTu…

Trie树之最大异或对问题

这是C算法基础-数据结构专栏的第二十八篇文章,专栏详情请见此处。 从这篇博客开始,文章将会于每周一更新,望周知! 引入 上次,我们学习了Trie树之字符串统计问题,字符串统计问题中的Trie树节点存储的是字符…

面试官:如何实现分布式系统的限流?

限流的概念以及作用我前一篇文章已经做了介绍:并发限流算法的实践 目录 限流的几种算法 : 1、令牌桶算法 2、漏桶算法 3. 滑动时间窗口计数器算法 5. 全局限流 6. 客户端限流 7. API网关限流 8. 熔断与降级 本篇重点: 具体实现: 限流的几种算法 : 这里主要讲在分…

快速熟悉Nginx

一、Nginx是什么? ‌Nginx是一款高性能、轻量级的Web服务器和反向代理服务器。‌ ‌特点‌:Nginx采用事件驱动的异步非阻塞处理框架,内存占用少,并发能力强,资源消耗低。‌功能‌:Nginx主要用作静态文件服…

Arduino UNO R3自学笔记22 之 Arduino基础篇学习总结

注意:学习和写作过程中,部分资料搜集于互联网,如有侵权请联系删除。 前言:目前将Arduino的大多数基础内容学习了,做个总结。 1.编程语言 学习单片机,在面向单片机编程时,语言是最基础的&#…

给Linux操作系统命令取个别名

一个Linux终端命令的别名通常是其命令的缩写,用来减少键盘输入。命令格式为: alias [alias-name‘original-command’] 其中,alias-name是用户给命令取的别名(新名),original-comm…

whisper 实现语音识别 ASR - python 实现

语音识别(Speech Recognition),同时称为自动语音识别(英语:Automatic Speech Recognition, ASR),将语音音频转换为文字的技术。 whisper是一个通用的语音识别模型,由OpenAI公司开发。…

浸没边界 直接强迫法 圆球绕流验证 阅读笔记

Combined multi-direct forcing and immersed boundary method for simulating flows with moving particles https://doi.org/10.1016/j.ijmultiphaseflow.2007.10.004 他的意思是,不止需要一次的直接强迫 直接强迫的次数与误差成低于二阶的关系 不知道是不是一阶…

输电线路悬垂线夹检测无人机航拍图像数据集,总共1600左右图片,悬垂线夹识别,标注为voc格式

输电线路悬垂线夹检测无人机航拍图像数据集,总共1600左右图片,悬垂线夹识别,标注为voc格式 输电线路悬垂线夹检测无人机航拍图像数据集介绍 数据集名称 输电线路悬垂线夹检测数据集 (Transmission Line Fittings Detection Dataset) 数据集…

sv标准研读第十二章-过程性编程语句

书接上回: sv标准研读第一章-综述 sv标准研读第二章-标准引用 sv标准研读第三章-设计和验证的building block sv标准研读第四章-时间调度机制 sv标准研读第五章-词法 sv标准研读第六章-数据类型 sv标准研读第七章-聚合数据类型 sv标准研读第八章-class sv标…

使用链地址法实现哈希表(哈希函数为除留余数法)

该代码实现了一个哈希表,使用拉链法(链地址法)来解决哈希冲突,核心思想是通过链表存储哈希冲突的数据。哈希表的大小被设置为 MAX_SIZE,其中哈希函数采用除留余数法。以下是代码的详细解释和总结: #includ…

C++关于链表基础知识

单链表 // 结点的定义 template <class T> struct Node { T data ; Node <T> *next; //指向下一个node 的类型与本node相同 } // 最后一个node指针指向Null 生成结点&#xff1a; Node <T> * p new Node < T>; 为结点赋值: p-> data …

LLM+知识图谱新工具! iText2KG:使用大型语言模型构建增量知识图谱

iText2KG是一个基于大型语言模型的增量知识图谱构建工具&#xff0c;通过从文本文档中提取实体和关系来逐步构建知识图谱。该工具具有零样本学习能力&#xff0c;能够在无需特定训练的情况下&#xff0c;在多个领域中进行知识提取。它包括文档提炼、实体提取和关系提取模块&…

BM1 反转链表

要求 代码 /*** struct ListNode {* int val;* struct ListNode *next;* };*/ /*** 代码中的类名、方法名、参数名已经指定&#xff0c;请勿修改&#xff0c;直接返回方法规定的值即可*** param head ListNode类* return ListNode类*/ struct ListNode* ReverseList(struct …

【LeetCode-热题100-128题】官方题解好像有误

最长连续序列 题目链接&#xff1a;https://leetcode.cn/problems/longest-consecutive-sequence/?envTypestudy-plan-v2&envIdtop-100-liked 给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的…