YOLO11改进|注意力机制篇|引入线性注意力机制FLAttention

news2024/10/9 11:45:28

在这里插入图片描述

目录

    • 一、【FLA】注意力机制
      • 1.1【FLA】注意力介绍
      • 1.2【FLA】核心代码
    • 二、添加【FLA】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【FLA】注意力机制

1.1【FLA】注意力介绍

在这里插入图片描述

下图是【FLA】的结构图,让我们简单分析一下运行过程和优势,以及和Softmax Attention的对比

  1. Softmax Attention(左侧)
  • 处理流程:
  • 输入矩阵:查询矩阵 𝑄的大小为 𝑁×𝑑,键矩阵 𝐾𝑇的大小为 𝑑×𝑁,值矩阵 𝑉 的大小为 𝑁×𝑑,其中 𝑁是序列长度,𝑑是特征维度。
  • 计算注意力分数:通过矩阵乘法 𝑄𝐾𝑇,得到一个大小为 𝑁×𝑁的注意力权重矩阵。
  • Softmax 归一化:通过 Softmax 函数对注意力权重进行归一化。
    应用到值矩阵 𝑉:然后将归一化的注意力权重乘以值矩阵 𝑉,最终输出为 𝑁×𝑑。
  • 复杂度:该计算的复杂度是 𝑂(𝑁2𝑑)。其中主要的计算代价在于矩阵 𝑄𝐾𝑇 的乘法,这个操作产生一个 𝑁×𝑁的注意力矩阵。因此,当序列长度 𝑁较大时,计算开销显著增加,尤其在长序列处理时表现较差。
  1. Linear Attention(右侧)
  • 处理流程:

  • 输入矩阵:查询矩阵 𝑄的大小为 𝑁×𝑑,键矩阵 𝐾𝑇为 𝑑×𝑁,值矩阵 𝑉为 𝑁×𝑑。先计算 𝐾𝑇𝑉:与 Softmax Attention 不同,Linear Attention 先将键矩阵 𝐾𝑇和值矩阵 𝑉相乘,得到一个大小为 𝑑×𝑑的矩阵。再计算 𝑄(𝐾𝑇𝑉):然后将查询矩阵 𝑄与该 𝑑×𝑑矩阵相乘,最终得到输出为 𝑁×𝑑。

  • 复杂度:

  • Linear Attention 的计算复杂度是 𝑂(𝑁𝑑2),相较于 Softmax Attention 的 𝑂(𝑁2𝑑),降低了一个 𝑁维度。这种结构的计算复杂度不再依赖于序列长度 𝑁,因此适合处理长序列任务。在这里插入图片描述

1.2【FLA】核心代码

import torch.nn as nn
import torch
from einops import rearrange
 
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    # Pad to 'same' shape outputs
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p
 
 
class Conv(nn.Module):
    # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
    default_act = nn.SiLU()  # default activation
 
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
 
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
 
    def forward_fuse(self, x):
        return self.act(self.conv(x))
 
 
class FocusedLinearAttention(nn.Module):
    def __init__(self, dim, num_patches=64, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1,
                 focusing_factor=3.0, kernel_size=5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
 
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
 
        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)
 
        self.focusing_factor = focusing_factor
        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
        # self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches // (sr_ratio * sr_ratio), dim)))
 
 
    def forward(self, x):
        B, C, H, W = x.shape  # 输入为四维:[批次大小, 通道数, 高度, 宽度]
        dtype, device = x.dtype, x.device
        # 调整输入以匹配原始模块的预期格式
        x = rearrange(x, 'b c h w -> b (h w) c')
        q = self.q(x)
        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        else:
            kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        k, v = kv[0], kv[1]
        N = H * W  # 序列长度
        # 重新生成位置编码
        positional_encoding = nn.Parameter(torch.zeros(size=(1, N, self.dim), device=device))
        k = k + positional_encoding
        focusing_factor = self.focusing_factor
        kernel_function = nn.ReLU()
        scale = nn.Softplus()(self.scale)
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        q = q ** focusing_factor
        k = k ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
        bool = False
        if dtype == torch.float16:
            q = q.float()
            k = k.float()
            v = v.float()
            bool = True
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
        if self.sr_ratio > 1:
            v = nn.functional.interpolate(v.permute(0, 2, 1), size=x.shape[1], mode='linear').permute(0, 2, 1)
        if bool:
            v = v.to(torch.float16)
            x = x.to(torch.float16)
 
        num = int(v.shape[1] ** 0.5)
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        x = x + feature_map
        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
 
        x = self.proj(x)
        x = self.proj_drop(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
        return x
 

二、添加【FLA】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个FLA.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【FLA】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1,1,FocusedLinearAttention,[]]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【FLA】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2198925.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java并发:同步工具类(信号量,等待完成,阶段同步,Exchanger,Phaser)

1,信号量(Stemaphore) Semaphore也就是信号量,提供了资源数量的并发访问控制,其使用代码很简单,如下所示: // 一开始有5份共享资源。第二个参数表示是否是公平 // 公平锁排队,非公…

人脸识别face-api.js应用简介

前阵子学习了一下face-api.js ,偶有心得,跟大家分享一下。 face-api.js的原始项目是https://github.com/justadudewhohacks/face-api.js ,最后一个release是2020年3月22日的0.22.2版,组件较老,API文档很全,…

AI产品经理面试,背烂这100个问题就稳了

❎传统的产品经理,侧重于用户体验与业务流程的优化,强调“以人为本” ✅而AI产品经理更加注重的,视如何将技术应用在业务问题上 ➡虽然不需要会写代码,但也要深入理解AI模型的运作原理,包括大模型技术(如…

【Linux系统编程】第二十九弹---深入探索Linux文件系统:从磁盘存储到inode结构与文件操作

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】 目录 1、文件系统相关知识 2、磁盘 2.1、理论补充 2.2、看看物理磁盘 2.3、磁盘的存储结构 2.3、对磁盘的存储进行逻辑抽象 3、…

【北京迅为】《STM32MP157开发板嵌入式开发指南》-第二十章 makefile 基本语法(上)

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

头疼来袭?别急,这份自救指南让你秒变“不痛达人”!

在这个快节奏的时代,头疼仿佛成了我们生活中的“不速之客”,时不时就来敲敲门,让人措手不及。无论是工作压力山大、熬夜追剧后的疲惫,还是突如其来的偏头痛,都让人苦不堪言。但别怕,今天就来给大家送上一份…

基于SpringBoot 助农农产品销售平台小程序 【附源码】

基于SpringBoot 助农农产品销售平台小程序 效果如下: 管理员主界面 用户管理界面 农户管理界面 农户主界面 小程序首页界面 农产品详情界面 详情界面 研究背景 随着互联网技术的快速发展和智能手机的普及,传统的农产品销售模式面临着诸多挑战。信息不…

《RabbitMQ篇》交换机基本概览

生产者都是把消息给交换机,由交换机分发给消息队列。 routingKey:路由键,也可称为绑定,是交换机和队列之间的桥梁,交换机会根据routingKey来把消息转发到对应的队列。 Fanout 不处理路由键。你只需要简单的将队列绑定…

【业务场景】最全的购物车设计与实现

前言 博主最近在做一个购物商城,正好设计到购物车模块,于是乎全面的来聊一聊购物车模块实现的一些核心要点吧,很值得反复品味的设计,当需要实现购物车的时候,本文应该拿来就能用。 目录 1.需要解决的核心问题清单 2…

下一代电源管理:Modern Standby与S3睡眠的对比

Modern Standby与S3睡眠的对比 一、引言二、Modern Standby概述三、S3睡眠模式概述四、Modern Standby与S3睡眠的差异五、实际应用和适用场景六、测试Modern Standby的性能6.1、PowerCfg命令行工具6.2、Windows Performance Toolkit 七、总结 一、引言 电源管理在现代计算设备…

Midjourney中文版:解锁你的创意之旅

在创意与技术的交汇点,Midjourney中文版正等待着每一位热爱艺术、渴望表达的灵魂。这不仅仅是一款AI绘画工具,更是一个激发无限灵感、让创意自由翱翔的奇妙平台。 Midjourney AI超强绘画 (原生态系统)用户端:Ai Loadinghttps://w…

Linux操作系统——软件包的管理(实验报告)

实验——软件安装的基本操作 一、实验目的 熟悉软件安装流程,掌握java的安装流程,熟悉相关命令的操作。 二、实验环境 硬件:PC电脑一台,网络正常; 配置:win10系统,内存大于8G 硬盘500G及以上…

机器学习实战27-基于双向长短期记忆网络 BiLSTM 的黄金价格模型研究

大家好,我是微学AI,今天给大家介绍一下机器学习实战27-基于双向长短期记忆网络 BiLSTM 的黄金价格模型研究。本文针对黄金价格预测问题,展开基于改造后的长短期记忆网络BiLSTM的黄金价格模型研究。文章首先介绍了项目背景,随后详细…

LSTM的变体

一、GRU 1、什么是GRU 门控循环单元(GRU)是一种循环神经网络(RNN)的变体,它通过引入门控机制来控制信息的流动,从而有效地解决了传统RNN中的梯度消失问题。GRU由Cho等人在2014年提出,它简化了…

判断回文 python

题目&#xff1a; 输入一个四位数&#xff0c;判断该数是否为回文数&#xff0c;回文数是指正序&#xff08;从左向右&#xff09;和倒序&#xff08;从右向左&#xff09;读都是一样的整数&#xff0c;比如1221。 代码法1&#xff1a; ninput() nint(n) if n<1000 or n&g…

微积分复习笔记 Calculus Volume 1 - 2.2 The Limit of a Function

2.2 The Limit of a Function - Calculus Volume 1 | OpenStax

中控自动化测试实战和实车智能驾驶业务解析

一.中控自动化测试流程及环境搭建 1.中控自动化测试流程 2.中控自动化测试环境的搭建 1.JDK环境配置 安装 Java安装包.生成java\bin jre\bin JAVA_HOME: java目录 c:\java path:%JAVA_HOME%\bin jre\bin 为了后面appium server GUI客户端中的环境配置 2.SDK 配置 pal…

怎么编辑图片?这5款工具教你快速编辑

怎么编辑图片&#xff1f;编辑图片是一项既具创意又实用的技能&#xff0c;它不仅能够提升图片的视觉效果&#xff0c;增强信息的传达力&#xff0c;还能激发无限的创作灵感。通过编辑图片&#xff0c;我们可以轻松调整色彩、添加文字、裁剪构图&#xff0c;甚至创造出令人惊叹…

Oxygen Forensic Detective 17.0 发布,新增功能概览

Oxygen Forensic Detective 17.0 发布&#xff0c;新增功能概览 Oxygen Forensic Detective Windows 17 Multilingual - 领先的一体化数字取证软件 digital forensic software 请访问原文链接&#xff1a;https://sysin.org/blog/oxygen-forensic-detective/&#xff0c;查看…

【学习笔记】SquareLine Studio安装教程(LVGL官方工具)

一.简介与导航&#xff1a; SquareLine Studio是由LVGL官方开发的一款UI设计工具&#xff0c;采用图形化进行界面UI设计&#xff0c;轻易上手。 SquareLine Studio官方网址&#xff1a;https://squareline.io/SquareLine Studio官方文档&#xff1a;https://docs.squareline.io…