YOLO11改进|注意力机制篇|引入局部注意力HaloAttention

news2024/12/24 8:16:22

在这里插入图片描述

目录

    • 一、【HaloAttention】注意力机制
      • 1.1【HaloAttention】注意力介绍
      • 1.2【HaloAttention】核心代码
    • 二、添加【HaloAttention】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【HaloAttention】注意力机制

1.1【HaloAttention】注意力介绍

在这里插入图片描述

下图是【HaloAttention】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 图像分块:

  • 输入图像大小为 4×4×𝑐,其中 𝑐
    是通道数。该图像首先被分割为多个小块(如图所示被分为 4 个 2×2×𝑐的小块),每个块称为一个“block”。

  • Haloing 操作:

  • 在图像分块后,使用 haloing 操作扩展每个小块的边界。图中显示的是一个 halo 值为 1 的情况,即每个小块在其原有区域上扩展了 1 个像素的边界,形成了带有额外边界信息的邻域窗口。这一操作目的是为了在计算注意力时捕获块与块之间的上下文信息。

  • 邻域窗口计算:

  • Haloing 之后,每个小块拥有邻近区域的信息,即在扩展后的邻域窗口中包含了来自周围小块的部分信息。图中显示了每个小块及其周围邻域的窗口(如红色小块与其邻域的相关部分)。

  • 查询与注意力机制:

  • 在邻域窗口中应用 注意力机制。以每个小块作为查询(Query),与其扩展后的邻域窗口进行注意力计算,从中提取重要的上下文特征。注意力机制的引入使得每个小块不仅能够学习到自身的特征,还能从周围的块中获取相关的上下文信息,从而增强特征表达。

  • 输出:

  • 通过注意力机制的加权输出每个小块的结果,形成新的特征图。输出的特征图大小仍然是分块前的大小,但每个块内的特征已经经过上下文增强和融合。
    优势

  • 降低计算复杂度:

  • 通过将图像分割成小块并只在局部区域内应用注意力机制,减少了全局自注意力带来的高计算开销。这种方法可以大幅度降低计算复杂度,特别适合处理高分辨率图像或大规模数据集。

  • 局部上下文捕获:

  • Haloing 操作的引入使得每个块在计算注意力时能够感知到其邻域的上下文信息,克服了仅依赖自身区域的局限性。因此,它能够更好地捕捉局部细节和相关性,特别是在需要高精度定位的任务中(如图像分割或检测任务)。

  • 有效的特征增强:

  • 通过分块后的注意力机制,模型可以集中计算各个小块的注意力权重,并在局部范围内提升特征表达能力。这样可以避免全局注意力在大图像上计算时引入的冗余信息,同时仍能保证特征的有效整合。

  • 灵活性强:

  • 该方法可广泛应用于图像分类、目标检测、语义分割等任务中,并且可以根据实际需求调整分块大小和 halo 值,灵活适应不同的计算资源和任务要求。在这里插入图片描述

1.2【HaloAttention】核心代码

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat


def to(x):
    return {"device": x.device, "dtype": x.dtype}


def pair(x):
    return (x, x) if not isinstance(x, tuple) else x


def expand_dim(t, dim, k):
    t = t.unsqueeze(dim=dim)
    expand_shape = [-1] * len(t.shape)
    expand_shape[dim] = k
    return t.expand(*expand_shape)


def rel_to_abs(x):
    b, l, m = x.shape
    r = (m + 1) // 2

    col_pad = torch.zeros((b, l, 1), **to(x))
    x = torch.cat((x, col_pad), dim=2)
    flat_x = rearrange(x, "b l c -> b (l c)")
    flat_pad = torch.zeros((b, m - l), **to(x))
    flat_x_padded = torch.cat((flat_x, flat_pad), dim=1)
    final_x = flat_x_padded.reshape(b, l + 1, m)
    final_x = final_x[:, :l, -r:]
    return final_x


def relative_logits_1d(q, rel_k):
    b, h, w, _ = q.shape
    r = (rel_k.shape[0] + 1) // 2

    logits = einsum("b x y d, r d -> b x y r", q, rel_k)
    logits = rearrange(logits, "b x y r -> (b x) y r")
    logits = rel_to_abs(logits)

    logits = logits.reshape(b, h, w, r)
    logits = expand_dim(logits, dim=2, k=r)
    return logits


class RelPosEmb(nn.Module):
    def __init__(self, block_size, rel_size, dim_head):
        super().__init__()
        height = width = rel_size
        scale = dim_head**-0.5

        self.block_size = block_size
        self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale)
        self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale)

    def forward(self, q):
        block = self.block_size

        q = rearrange(q, "b (x y) c -> b x y c", x=block)
        rel_logits_w = relative_logits_1d(q, self.rel_width)
        rel_logits_w = rearrange(rel_logits_w, "b x i y j-> b (x y) (i j)")

        q = rearrange(q, "b x y d -> b y x d")
        rel_logits_h = relative_logits_1d(q, self.rel_height)
        rel_logits_h = rearrange(rel_logits_h, "b x i y j -> b (y x) (j i)")
        return rel_logits_w + rel_logits_h


class HaloAttention(nn.Module):
    def __init__(self, dim, block_size, halo_size, dim_head=64, heads=8):
        super().__init__()
        assert halo_size > 0, "halo size must be greater than 0"

        self.dim = dim
        self.heads = heads
        self.scale = dim_head**-0.5

        self.block_size = block_size
        self.halo_size = halo_size

        inner_dim = dim_head * heads

        self.rel_pos_emb = RelPosEmb(
            block_size=block_size,
            rel_size=block_size + (halo_size * 2),
            dim_head=dim_head,
        )

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

    def forward(self, x):
        b, c, h, w, block, halo, heads, device = (
            *x.shape,
            self.block_size,
            self.halo_size,
            self.heads,
            x.device,
        )
        assert (
            h % block == 0 and w % block == 0
        ), "fmap dimensions must be divisible by the block size"
        assert (
            c == self.dim
        ), f"channels for input ({c}) does not equal to the correct dimension ({self.dim})"

        # get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values

        q_inp = rearrange(
            x, "b c (h p1) (w p2) -> (b h w) (p1 p2) c", p1=block, p2=block
        )

        kv_inp = F.unfold(x, kernel_size=block + halo * 2, stride=block, padding=halo)
        kv_inp = rearrange(kv_inp, "b (c j) i -> (b i) j c", c=c)

        # derive queries, keys, values

        q = self.to_q(q_inp)
        k, v = self.to_kv(kv_inp).chunk(2, dim=-1)

        # split heads

        q, k, v = map(
            lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=heads), (q, k, v)
        )

        # scale

        q *= self.scale

        # attention

        sim = einsum("b i d, b j d -> b i j", q, k)

        # add relative positional bias

        sim += self.rel_pos_emb(q)

        # mask out padding (in the paper, they claim to not need masks, but what about padding?)

        mask = torch.ones(1, 1, h, w, device=device)
        mask = F.unfold(
            mask, kernel_size=block + (halo * 2), stride=block, padding=halo
        )
        mask = repeat(mask, "() j i -> (b i h) () j", b=b, h=heads)
        mask = mask.bool()

        max_neg_value = -torch.finfo(sim.dtype).max
        sim.masked_fill_(mask, max_neg_value)

        # attention

        attn = sim.softmax(dim=-1)

        # aggregate

        out = einsum("b i j, b j d -> b i d", attn, v)

        # merge and combine heads

        out = rearrange(out, "(b h) n d -> b n (h d)", h=heads)
        out = self.to_out(out)

        # merge blocks back to original feature map

        out = rearrange(
            out,
            "(b h w) (p1 p2) c -> b c (h p1) (w p2)",
            b=b,
            h=(h // block),
            w=(w // block),
            p1=block,
            p2=block,
        )
        return out


if __name__ == "__main__":
    input = torch.rand(3, 32, 64, 64).cuda()
    model = HaloAttention(
        dim=32,
        block_size=2,
        halo_size=1,
    ).cuda()
    output = model(input)
    print(input.size(), output.size())


二、添加【HaloAttention】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个HaloAttention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【HaloAttention】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128,3,2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256,3,2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512,3,2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024,3,2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, HaloAttention, [2, 1]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【HaloAttention】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2203896.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于组合模型的公交交通客流预测研究

摘 要 本研究致力于解决公交客流预测问题,旨在通过融合多种机器学习模型的强大能力,提升预测准确性,为城市公交系统的优化运营和交通管理提供科学依据。研究首先回顾了公交客流预测领域的相关文献,分析了传统统计方法在处理大规…

企业大文件传输之:镭速如何提升上传文件浏览器压缩效率?

互联网技术的迅猛发展带来了文件传输需求的激增,尤其是在需要在浏览器中上传大文件的场景下。为了提升传输效率并减少服务器的带宽和资源消耗,文件压缩变得尤为重要。许多开发者选择使用JSZip等开源工具来实现浏览器端的文件压缩。 不过,这些…

运动耳机选哪个品牌比较好?盘点五大高品质运动耳机推荐!

在骨传导耳机日益普及的同时,一个不容忽视的问题也逐渐暴露在大众视野之中。根据可靠消息,有超过九成的运动爱好者反馈在使用骨传导耳机时感到佩戴不适!作为一名有着5年经验的运动达人,我秉持着对消费者负责的态度,同时…

LLM 何时需要检索增强? 减轻 LLM 的过度自信有助于检索增强

洞见 检索增强(RA)技术作为减轻大语言模型(LLMs)幻觉问题的一种手段,已经受到了广泛的关注。然而,由于其带来的额外计算成本以及检索结果质量的不确定性,持续不断地应用RA并非总是最优的解决方…

【Python】Conda离线执行命令

以下链接证明了想要离线使用conda命令的方法 启用离线模式 — Anaconda documentation 基本上大部分的命令都会提供网络选项 例如creat命令 conda create — conda 24.7.1 文档 - Conda 文档

PCL 将点云投影到拟合平面

PCL点云算法汇总及实战案例汇总的目录地址链接: PCL点云算法与项目实战案例汇总(长期更新) 一、概述 点云投影到拟合平面是指将三维点云数据中的点投影到与其最接近的二维平面上。通过投影到平面,可以消除数据的高度变化或Z轴信息…

小程序会取代APP吗?——零工市场小程序和APP的区别

小程序在某些场景下有着取代了APP的潜力,特别是零工市场这样的领域中,单其中能不能完全取代还有待分析。 1.小程序无需下载,想用的时候随时打开,在零工市场领域,小程序可以快速连接求职者和雇主,满足临时工…

秋天来临,猫咪又到换毛季,掉毛严重怎么办?宠物空气净化器有用吗?

秋天到了,新一轮的宠物换毛季又来了。谁能想到这只胖猫和之前刚接回来时的皮包骨小猫是同一只!除了养了一年长了些肉外,更多的都是换毛季掉毛”膨胀“的。每天下班回家都要搞卫生,家里衣服上、地板上,目光所及之处都有…

GNU链接器(LD):PROVIDE、PROVIDE_HIDDEN关键字介绍

0 参考资料 GNU-LD-v2.30-中文手册.pdf GNU linker.pdf1 前言 一个完整的编译工具链应该包含以下4个部分: (1)编译器 (2)汇编器 (3)链接器 (4)lib库 在GNU工具链中&…

用布尔表达式巧解数字电路图

1.前置知识 明确AND,OR,XOR,NOR,NOT运算的规则 参见:E25.【C语言】练习:修改二进制序列的指定位 这里再补充一个布尔运算符:NOR,即先进行OR运算,再进行NOT运算 如下图为其数字电路的符号 注意到在OR符号的基础上,在尾部加了一个(其实由简化而来) 附:NOR的真值表 2.R-S触发…

第二十章 番外 混淆矩阵

混淆矩阵(Confusion Matrix)是一种用于描述监督学习中分类模型性能的特定表格布局。它提供了直观的方式来理解分类器的性能,特别是对于多类别分类任务。混淆矩阵通过比较实际类别标签与分类器预测的类别标签来展示分类结果。 混淆矩阵的基本…

2-118 基于matlab的六面体建模和掉落仿真

基于matlab的六面体建模和掉落仿真,将对象建模为刚体来模拟将立方体扔到地面上。同时考虑地面摩擦力、刚度和阻尼所施加的力,在三个维度上跟踪平移运动和旋转运动。程序已调通,可直接运行。 下载源程序请点链接:2-118 基于matla…

Microsoft Edge 离线安装包制作或获取方法和下载地址分享

方法一:自制压缩包 进入目录 "C:\Program Files (x86)\Microsoft\Edge\Application" 或 "C:\Program Files (x86)\Microsoft\EdgeCore\Edge版本号",将所有文件打包,再放到没有安装到 Edge 的电脑里解压,运行…

打破常规,BD仓储物流的效能提升!

当前,随着国家战略的推进,JS与民用领域的融合不断加深,物流业也步入了军民融合的新时代。在智能仓储物流方面,JS物流的智能化进展受到了BD系统的高度关注和重视。 一、建设JS仓储物流RFID基础设施 JS物流领域引入RFID技术的基础工…

入门端到端第一步!最新综述回顾基于深度学习的规划方法发展历程

这篇新的综述,系统的回顾了基于深度学习的预测和规划方法, 端到端方法的发展历程, 非常适合初学者了解领域背景. The Integration of Prediction and Planning in Deep Learning Automated Driving Systems: A Review 0. 摘要 自动化驾驶系统有潜力彻底改变个人、公共和货物…

Cesium 获取当前视角信息

通过 浏览器控制台,直接获取到当前地球视角的信息,然后通过 flyTo 跳转视角。 方法: 控制台内输入下列代码,控制台就会输出视角信息: const camera viewer.camera; const position camera.positionCartographic; c…

Python:条件分支 if 语句全讲解

Python:条件分支 if 语句全讲解 如果我拿出下面的代码,阁下该做何应对? if not reset_excuted and (terminated or truncated):... else:...---- 前言: 消化论文代码的时候看到这个东西直接大脑冻结,没想过会在这么…

高含金量WebGIS学习教程?

智慧校园——适合0基础入门 智慧交通——适合0基础入门 VUE-适合前端进阶 Mapbox项目开发实例 Openlayers零基础入门 智慧机场——适合有前端基础 threejs三维开发入门 三维进阶:cesium零基础入门教程 面试讲解:剖析地信大厂技术面试真题&#x…

【算法】DP系列之 斐波那契数列模型

【ps】本篇有 4 道 leetcode OJ。 目录 一、算法简介 二、相关例题 1)第 N 个泰波那契数 .1- 题目解析 .2- 代码编写 2)三步问题 .1- 题目解析 .2- 代码编写 3)使用最小花费爬楼梯 .1- 题目解析 .2- 代码编写 4)解码…

vue3实现 长列表虚拟滚动

1、直接看代码 <template><!--定义一个大容器&#xff0c;此容器可以滚动--><div class"view" ref"viewRef" scroll"handleScroll"><!--定义一个可以撑满整个data的容器&#xff0c;主要是让父元素滚动起来--><div …