爆改YOLOv8|利用全新的聚焦式线性注意力模块Focused Linear Attention 改进yolov8(v1)

news2024/12/23 12:37:36

1,本文介绍

全新的聚焦线性注意力模块(Focused Linear Attention)是一种旨在提高计算效率和准确性的注意力机制。传统的自注意力机制在处理长序列数据时通常计算复杂度较高,限制了其在大规模数据上的应用。聚焦线性注意力模块则通过优化注意力计算的方式,显著降低了计算复杂度。

核心特点:

  1. 线性时间复杂度:与传统的自注意力机制不同,聚焦线性注意力模块采用了线性时间复杂度的计算方法,这使得处理长序列数据时更加高效。

  2. 聚焦机制:该模块专注于关键的上下文信息,通过聚焦策略来提高注意力计算的准确性,同时减少不必要的计算开销。

  3. 改进的性能:通过优化注意力计算,聚焦线性注意力模块能够在保持高性能的同时,显著提升模型的计算效率,尤其适用于大规模数据处理。

应用场景:

  • 自然语言处理:在长文本或大规模语料库的处理上,聚焦线性注意力模块能够提供更高的效率和更低的延迟。
  • 计算机视觉:在处理高分辨率图像或视频数据时,能够加速计算过程,提升模型的实时性。

总体而言,聚焦线性注意力模块为处理大规模和长序列数据提供了一种高效且精确的解决方案,适用于各种需要高效注意力机制的应用场景。

关于Focused Linear Attention的详细介绍可以看论文:https://arxiv.org/pdf/2308.00442

本文将讲解如何将Focused Linear Attention融合进yolov8

话不多说,上代码!

2, 将Focused Linear Attention融合进yolov8


2.1 步骤一

找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个FLA.py文件,文件名字可以根据你自己的习惯起,然后将Focused Linear Attention的核心代码复制进去

import torch
import torch.nn as nn
from einops import rearrange
 
 
class FocusedLinearAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """
 
    def __init__(self, dim, window_size=[20, 20], num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 focusing_factor=3, kernel_size=5):
 
        super().__init__()
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.focusing_factor = focusing_factor
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.window_size = window_size
        self.positional_encoding = nn.Parameter(torch.zeros(size=(1, window_size[0] * window_size[1], dim)))
 
        self.softmax = nn.Softmax(dim=-1)
 
        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
 
    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
        q, k, v = qkv.unbind(0)
        k = k + self.positional_encoding[:, :k.shape[1], :]
        focusing_factor = self.focusing_factor
        kernel_function = nn.ReLU()
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        scale = nn.Softplus()(self.scale)
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        if float(focusing_factor) <= 6:
            q = q ** focusing_factor
            k = k ** focusing_factor
        else:
            q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor
            k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
 
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
 
        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
 
        num = int(v.shape[1] ** 0.5)
 
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        x = x + feature_map
        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
        x = self.proj(x)
        x = self.proj_drop(x)
        x = rearrange(x, "b (w h) c -> b c w h", b=B, c=self.dim, w=num, h=num)
        return x

2.2 步骤二

在task.py导入我们的模块

2.3 步骤三

在task.py的parse_model方法里面注册我们的模块

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件

# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
  - [-1, 1, FocusedLinearAttention, [256]]  # 10
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)
 
  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

可能会出现如下所示报错

 raise EinopsError(message + "\n {}".format(e))
einops.EinopsError:  Error while processing rearrange-reduction pattern "b (w h) c -> b c w h".
 Input tensor shape: torch.Size([128, 294, 32]). Additional info: {'w': 17, 'h': 17}.
 Shape mismatch, 294 != 289

则按以下方法进行修改

找到ultralytics/models/yolo/detect/train.py的如下所示代码

 def build_dataset(self, img_path, mode='train', batch=None):
        """
        Build YOLO Dataset.
        Args:
            img_path (str): Path to the folder containing images.
            mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
            batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
        """
        gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
     
        return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)

将该方法的最后一句代码修改为下面的代码

return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=(False if mode == 'val' else False), stride=gs)

不知不觉已经看完了哦,动动小手留个点赞收藏吧--_--

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2083777.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EmguCV学习笔记 C# 7.1 角点检测

版权声明&#xff1a;本文为博主原创文章&#xff0c;转载请在显著位置标明本文出处以及作者网名&#xff0c;未经作者允许不得用于商业目的。 EmguCV是一个基于OpenCV的开源免费的跨平台计算机视觉库,它向C#和VB.NET开发者提供了OpenCV库的大部分功能。 教程VB.net版本请访问…

Excel中让第一行始终显示

要在Excel中让第一行始终显示&#xff0c;你可以使用冻结窗格功能。具体步骤如下&#xff1a; 打开需要设置第一行一直显示的工作表。将光标定位在工作表内任意一个单元格内。选择“视图”菜单&#xff0c;单击工具栏中的“冻结窗格”命令。在弹出的下拉菜单中选择“冻结首行”…

字母的大小写转换(tolower、toupper、transform)

字母的大小写转换&#xff08;tolower、toupper、transform&#xff09; 1. tolower&#xff08;&#xff09;、toupper&#xff08;&#xff09;函数 &#xff08;这个在之前的一篇文章 “字符串中需要掌握的函数总结&#xff08;1&#xff09;”中有较为详细的介绍。&#…

时利和:如何提升工装夹具的加工质量?

在机械加工领域&#xff0c;工装夹具起着至关重要的作用。它不仅能够提高生产效率&#xff0c;还能保证加工精度&#xff0c;确保产品质量的稳定性。那么&#xff0c;如何提升工装夹具的加工质量呢?以下是时利和整理分享的几个关键因素。 一、精准的设计 工装夹具的设计是决定…

使用物联网卡访问萤石云的常见问题

使用物联网卡接入萤石开放平台时经常遇到各种问题&#xff0c;这边总结了常见的一些 用的是哪家运营商的卡&#xff1f; 电信 移动 联通&#xff08;申请的时候可以自主选择&#xff09; 卡有什么限制&#xff1f; 定向流量卡&#xff0c;只能访问萤石云平台&#xff0c;只能…

完美解决Jenkins重启后自动杀掉衍生进程(子进程)问题

完美解决Jenkins重启后自动杀掉衍生进程(子进程)问题 本文中使用的Jenkins版本为Version 2.452.3 先罗列一下前置问题&#xff1a;Jenkins任务构建完成自动杀掉衍生进程 用过Jenkins的都知道&#xff0c;Jenkins任务构建完成后&#xff0c;是会自动杀掉衍生进程&#xff0c;这…

安卓AppBarLayout与ViewPager2里的fragment里的webview滑动冲突

今天开发遇见一个头痛的问题&#xff0c;就是AppBarLayout和webview会存在一个冲突问题。如图下 问题出现在webview推到顶端的时候&#xff0c;AppBarLayout并不会跟着响应伸缩&#xff0c;解决办法是 在 webview 包 一个 父的 NestedScrollView 就能解决了。 运行效果入下 更改…

单向链表和双向链表的一些基本算法

单向链表头插尾插 单向链表的销毁与反转 反转原理&#xff1a;将头节点与后面的节点分开&#xff0c;然后从第一个节点开始对每个节点使用头插法 冒泡排 选排 链表环&#xff1a; 判断是否有环&#xff1a;弗洛伊德快慢指针&#xff08;快指针一般是慢指针的2倍&#xff0c;差为…

Selenium(HTML基础)

一、HTML基础 &#xff08;在学习自动化时&#xff0c;保证能看懂&#xff09; 1.1.HTML介绍 英文是HyperText Markup Language&#xff0c;译为:超文本标记语言是Internet上用于设计网页的主要语言&#xff0c;2008年发布了HTML5.0,是目前互联网的标准&#xff0c;并作为互联…

数据结构之 “单链表“

&#xff08;1&#xff09;在顺表表中&#xff0c;如果是头插/删的时间复杂度是O(1)&#xff1b;尾插/删的时间复杂度是O(N) &#xff08;2&#xff09;增容一般是呈2倍的增长&#xff0c;势必会有一定的空间浪费。比如&#xff1a;申请了50个空间&#xff0c;只用了两个&#…

【Matlab】SSA-BP麻雀搜索算法优化BP神经网络回归预测 可预测未来(附代码)

资源下载&#xff1a; 资源合集&#xff1a; 目录 一&#xff0c;概述 传统的BP神经网络存在一些问题&#xff0c;比如容易陷入局部最优解、训练速度慢等。为了解决这些问题&#xff0c;我们引入了麻雀算法作为优化方法&#xff0c;将其与BP神经网络相结合&#xff0c;提出了…

精准高效,省时省力——2024年翻译工具新趋势

如果你收到一份外文文档&#xff0c;但是你看不懂急需翻译成中文&#xff0c;将文档内容复制粘贴到翻译的的窗口获得中文内容是不是很麻烦。我了解到了deepl翻译文档之后还认识了不少的翻译神器&#xff0c;这次分享给你一起试试吧。 第一款福晰在线翻译 链接直达>>htt…

3D工艺大师:精准助力医疗设备远程维修

医疗设备是现代医疗体系中不可或缺的重要工具&#xff0c;它们帮助医生更准确地诊断病情&#xff0c;更有效地进行治疗。但随着技术的进步&#xff0c;这些设备变得越来越复杂&#xff0c;维修起来也面临诸多挑战。 首先&#xff0c;医疗设备结构复杂&#xff0c;零件众多&…

【Material-UI】Rating组件中的Rating precision属性

文章目录 一、Rating组件概述1. 组件简介2. precision属性的作用 二、Rating precision的基本用法三、Rating precision属性详解1. 精度选择的意义2. 如何在项目中选择合适的精度 四、Rating precision属性的实际应用场景1. 电商平台中的应用2. 电影评分应用3. 专业评测网站 五…

[Tomcat源码解析]——热部署和热加载原理

热部署 在Tomcat中可以通过Host标签设置热部署,当 autoDeploy为true时,在运行中的Tomcat中丢入一个war包,那么Tomcat不需要重启就可以自动加载该war包。 <Host name="localhost" appBase="webapps"unpackWARs="true" autoDeploy="…

Ubuntu18.04 下安装CUDA

安装步骤 1.查看是否安装了cuda # 法1 cat /usr/local/cuda/version.txt # 法2 nvcc --version 2.若没有安装&#xff0c;则查看是否有N卡驱动&#xff0c;若无N卡驱动&#xff0c;则到软件与更新 -> 附加驱动中安装驱动 3.查看N卡驱动支持的cuda版本 nvidia-smi 如下…

RabbitMQ 集群与高可用性

目录 单节点与集群部署 1.1. 单节点部署 1.2. 集群部署 镜像队列 1.定义与工作原理 2. 配置镜像队列 3.应用场景 4. 优缺点 5. Java 示例 分布式部署 1. 分布式部署的主要目标 2. 典型架构设计 3. RabbitMQ 分布式部署的关键技术 4. 部署策略和实践 5. 分布式部署…

图像变换——等距变换、相似变换、仿射变换、投影变换

%%图像变换 % I imread(cameraman.tif); I imread(F:\stitching\imagess\or\baiyun2.jpg); figure; imshow(I); title(原始图像); [w,h]size(I); thetapi/4;%旋转角 t[200,80];%平移tx,ty s0.3;%缩放尺度 %% 等距变换平移变换旋转变换 H_eprojective2d([cos(theta) sin(theta…

体育风尚杂志体育风尚杂志社体育风尚编辑部2024年第8期目录

体讯 体育产业“破圈” 3-7 成都大运会获2023年度最佳体育赛事媒体设施奖 8-10 斯巴达勇士赛 2024斯巴达勇士赛深圳站漫威主题赛在深圳光明欢乐田园揭幕 11-15 篮球 CBA季后赛 深圳马可波罗挺进季后赛八强 16-24 游泳 周六福2024年全国游泳冠军赛4月深圳激…

集成电路与电路基础之-二极管

二极管是什么 二极管&#xff0c;又称肖特基二极管或晶体二极管&#xff0c;是一种最基本的半导体器件之一。它由半导体材料&#xff08;如硅、硒、锗等&#xff09;制成&#xff0c;其内部结构是一个PN结&#xff0c;即由一个P型半导体区和一个N型半导体区组成。这种结构赋予…