即插即用hilo注意力机制,捕获低频高频特征

news2024/12/23 20:11:34

题目:Fast Vision Transformers with HiLo Attention

论文地址:  https://arxiv.org/abs/2205.13213

创新点

  • HiLo自注意力机制:作者提出了一种新的自注意力机制,称为HiLo注意力,旨在同时捕捉图像中的高频和低频信息。该方法通过将自注意力分为两个分支,高频分支(Hi-Fi)处理局部的高分辨率细节,低频分支(Lo-Fi)处理全局的低分辨率结构。这样可以提高计算效率,特别是在高分辨率图像上,同时保持准确性。

  • LITv2模型:基于HiLo注意力机制,文献引入了LITv2模型,该模型在多个主流计算机视觉任务(如图像分类、物体检测和语义分割)上表现优越。LITv2通过在早期阶段删除多头自注意力(MSA)层,并在后期阶段使用高效的HiLo注意力机制,提升了模型的速度和内存效率。

  • 速度优化:作者通过实际平台上的速度评估(而非通常的FLOPs计算)设计了该模型,以确保其在GPU和CPU上的实际速度更快。例如,HiLo机制在CPU上比局部窗口注意力机制快1.6倍,比空间缩减注意力机制快1.4倍。

  • 相对位置编码优化:文献还对相对位置编码进行了优化,采用了3×3的深度卷积层代替传统的固定相对位置编码,这大大加快了密集预测任务(如分割)的训练和推理速度。

方法

整体结构

       LITv2模型基于HiLo注意力机制,分离处理高频和低频信息,通过局部窗口自注意力捕捉细节、高效全局注意力处理全局结构。此外,模型采用3×3深度卷积层替代位置编码,减少计算复杂度并扩大感受野。整体架构分为多阶段,生成金字塔特征图,适用于密集预测任务,结合残差连接和全局自注意力确保性能与效率的平衡。

  • Patch Embedding层:模型首先将输入图像切分为固定大小的图像块(patch),然后通过线性变换将每个patch映射到一个高维特征空间,这与大多数Vision Transformer类似。

  • HiLo注意力机制:这是模型的核心创新点。HiLo注意力机制将多头自注意力(MSA)分成两个部分:

  • 高频(Hi-Fi)注意力:处理局部的高频细节信息,使用的是局部窗口自注意力(例如2×2窗口),能够高效捕获图像中的细节信息。

  • 低频(Lo-Fi)注意力:处理全局的低频信息,先通过平均池化获得低频特征,再进行全局自注意力计算,从而减少计算复杂度。

  • 深度卷积层(Depthwise Convolution Layer):为了进一步提高效率,LITv2引入了3×3的深度卷积层用于代替传统的多层感知机(MLP)中的位置编码。这种设计不仅减少了位置编码的计算负担,还扩大了早期阶段特征的感受野。

  • 多阶段结构:模型通常分为多个阶段(例如4个阶段),在每个阶段生成金字塔结构的特征图(pyramid feature maps),用于处理不同分辨率的特征。这使得模型在图像分类之外的密集预测任务(如物体检测和语义分割)中更具优势。

  • 残差连接和归一化:在每个Transformer模块中,模型使用标准的残差连接和LayerNorm层。这些是标准的ViT组件,用于稳定训练并保持特征的传递。

  • 后期的全局自注意力:在模型的后期阶段,虽然早期阶段使用了高效的局部自注意力和低频注意力机制,但后期阶段会使用标准的多头自注意力机制来处理下采样后的低分辨率特征图,以进一步提升性能。

即插即用模块

将HiLo注意力机制提取为即插即用模块,主要适用于以下场景:

  • 高分辨率图像处理:在需要处理高分辨率图像的任务中,例如图像分类、目标检测、语义分割等,HiLo通过高效分离高频和低频信息,显著减少计算复杂度和内存占用,提升推理速度和处理能力。

  • 低延迟应用场景:HiLo能够在实际硬件平台(如GPU和CPU)上加快推理速度,特别适用于需要低延迟的场景,例如无人机图像处理、自动驾驶中的实时感知系统等。

  • 视觉任务中的密集预测:在需要对每个像素进行精细预测的任务中,如语义分割和实例分割,HiLo能够高效处理局部细节和全局结构,提升预测的准确性和速度。

消融实验

  • 该表展示了LITv1-S模型在引入不同结构修改后的性能变化,包括加入3×3深度卷积层(ConvFFN)、去除相对位置编码(RPE)、以及使用HiLo注意力机制后的影响。

  • 结果表明:引入深度卷积层后,模型在ImageNet分类和COCO检测任务中的性能提升显著,移除RPE后虽然有轻微的性能下降,但推理速度(FPS)显著提升,使用HiLo注意力机制后进一步提升了模型效率,特别是在FLOPs和推理速度上。

  • 该图展示了HiLo注意力机制中高频和低频头部分配比例(α)的影响。随着α值的增加(更多头部用于低频注意力),FLOPs逐渐减少,模型的Top-1准确率在α=0.9时达到最佳。

  • 该实验表明高频和低频信息在自注意力中的合理分配对模型效率和性能有重要影响。

 

  • 该图通过Fast Fourier Transform(FFT)可视化了Hi-Fi和Lo-Fi注意力机制输出特征中的频率成分。结果显示,Hi-Fi注意力捕捉更多的高频信息,而Lo-Fi主要关注低频信息。

  • 该实验验证了HiLo注意力机制在分离高频和低频特征时的有效性,符合文献提出的设计理念。

即插即用模块HiLo

import math
import torch
import torch.nn as nn
# 论文:Fast Vision Transformers with HiLo Attention
# 论文地址:https://arxiv.org/abs/2205.13213
class HiLo(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2, alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim/num_heads)
        self.dim = dim

        # self-attention heads in Lo-Fi
        self.l_heads = int(num_heads * alpha)
        # token dimension in Lo-Fi
        self.l_dim = self.l_heads * head_dim

        # self-attention heads in Hi-Fi
        self.h_heads = num_heads - self.l_heads
        # token dimension in Hi-Fi
        self.h_dim = self.h_heads * head_dim

        # local window size. The `s` in our paper.
        self.ws = window_size

        if self.ws == 1:
            # ws == 1 is equal to a standard multi-head self-attention
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        # Low frequence attention (Lo-Fi)
        if self.l_heads > 0:
            if self.ws != 1:
                self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
            self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
            self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
            self.l_proj = nn.Linear(self.l_dim, self.l_dim)

        # High frequence attention (Hi-Fi)
        if self.h_heads > 0:
            self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
            self.h_proj = nn.Linear(self.h_dim, self.h_dim)

    def hifi(self, x):
        B, H, W, C = x.shape
        h_group, w_group = H // self.ws, W // self.ws

        total_groups = h_group * w_group

        x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)

        qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2] # B, hw, n_head, ws*ws, head_dim

        attn = (q @ k.transpose(-2, -1)) * self.scale # B, hw, n_head, ws*ws, ws*ws
        attn = attn.softmax(dim=-1)
        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
        x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)

        x = self.h_proj(x)
        return x

    def lofi(self, x):
        B, H, W, C = x.shape

        q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)

        if self.ws > 1:
            x_ = x.permute(0, 3, 1, 2)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
        x = self.l_proj(x)
        return x

    def forward(self, x, H, W):
        B, N, C = x.shape

        x = x.reshape(B, H, W, C)

        if self.h_heads == 0:
            x = self.lofi(x)
            return x.reshape(B, N, C)

        if self.l_heads == 0:
            x = self.hifi(x)
            return x.reshape(B, N, C)

        hifi_out = self.hifi(x)
        lofi_out = self.lofi(x)

        x = torch.cat((hifi_out, lofi_out), dim=-1)
        x = x.reshape(B, N, C)

        return x

    def flops(self, H, W):
        # pad the feature map when the height and width cannot be divided by window size
        Hp = self.ws * math.ceil(H / self.ws)
        Wp = self.ws * math.ceil(W / self.ws)

        Np = Hp * Wp

        # For Hi-Fi
        # qkv
        hifi_flops = Np * self.dim * self.h_dim * 3
        nW = (Hp // self.ws) * (Wp // self.ws)
        window_len = self.ws * self.ws
        # q @ k and attn @ v
        window_flops = window_len * window_len * self.h_dim * 2
        hifi_flops += nW * window_flops
        # projection
        hifi_flops += Np * self.h_dim * self.h_dim

        # for Lo-Fi
        # q
        lofi_flops = Np * self.dim * self.l_dim
        kv_len = (Hp // self.ws) * (Wp // self.ws)
        # k, v
        lofi_flops += kv_len * self.dim * self.l_dim * 2
        # q @ k and attn @ v
        lofi_flops += Np * self.l_dim * kv_len * 2
        # projection
        lofi_flops += Np * self.l_dim * self.l_dim

        return hifi_flops + lofi_flops

if __name__ == '__main__':
    block = HiLo(dim=128)
    input = torch.rand(32, 128, 128) # input with shape (B, N, C)
    output = block(input, 16, 8) # H = 16, W = 8, since H * W should equal N
    print(input.size())
    print(output.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2212780.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构 ——— 顺序表oj题:有效的括号

目录 题目要求 代码实现 题目要求 给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效 有效字符串需满足: 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个…

深入解析网络流量回溯分析:如何有效进行网络故障排除

目录 什么是网络流量回溯分析? 网络流量回溯分析的核心优势 网络流量回溯分析如何助力网络故障排除? 1. 快速定位故障节点 真实案例:解决网络延迟问题 2. 精准分析流量异常 真实案例:识别恶意流量 3. 优化网络性能 为什么…

【Linux指令策】❤️基本必备指令❤️——打开Linux大门,带你快速上手Linux(超详细,收藏这一篇就足够啦~!!!)

【Linux入门】——基本指令 目录 一:认识操作系统 1.1:操作系统是什么? 1.2:操作系统 ——管理 1.3:操作系统——贯穿 二:Linux基本指令 2.1-指令学习(上篇) 2.1.1 > ls …

Chromium 前端form表单提交过程分析c++

一、本文以一个简单的 HTML 表单&#xff0c;包含两个文本输入框和一个提交按钮&#xff1a; <form action"demo_form.php">First name: <input type"text" name"fname"><br>Last name: <input type"text" name…

Unreal5从入门到精通之 如何使用事件分发器EventDispather

文章目录 前言1.创建事件分发器设置属性2.创建Bind、Unbind及Unbind All节点在蓝图类中创建在关卡蓝图中创建3.创建事件分发器事件节点4.调用事件分发器在蓝图类中进行调用在关卡蓝图中进行调用精彩推荐前言 事件分发器是 Unreal Engine(UE)中一个重要的概念,它负责在游戏运…

【C++】右值引用和移动语义(带你理解C++为何如此高效)

1.右值引用和移动语义 左值和右值的重点区分是能否取地址。 能取地址的是左值&#xff08;可以是值&#xff0c;也可以是表达式&#xff09;&#xff0c;不能取地址的是右值。 1.1 什么是左值 1.2 什么是右值 1.2.1 常见的右值 常见右值&#xff1a;常数&#xff08;10&…

【C/C++】速通某站上的经典“笔试”编程题

【C/C】速通某站上的经典“笔试”编程题 一. 题目描述&#xff1a;解题思路&#xff1a;代码实现&#xff1a; 二. 题目描述&#xff1a;解题思路&#xff1a;代码实现&#xff1a; 三. 题目描述&#xff1a;解题思路&#xff1a;代码实现&#xff1a; 一. 题目描述&#xff1a…

DS线性表之栈的讲解和实现(4)

文章目录 前言一、栈的概念及结构二、关于实现栈的分析关于栈顶指针top关于结构体栈的初始化入栈出栈获取栈顶元素获取栈元素个数判断栈是否为空栈的销毁 总结 前言 栈就是一个比较实用的数据结构了&#xff0c;且大致逻辑就是套用之前的两种线性表 具体选择哪种呢&#xff1f;…

综合布线研究实训室建设方案

1、 引言 随着信息技术的飞速发展&#xff0c;综合布线系统作为信息传输的基础设施&#xff0c;在各类建筑及信息化项目中发挥着越来越重要的作用。为了满足职业院校及企业对综合布线技术人才培养和研究的需求&#xff0c;本方案旨在建设一个集教学、实训、研究于一体的综合布…

ARM base instruction -- smull

有符号乘法运算 Signed Multiply Long multiplies two 32-bit register values, and writes the result to the 64-bit destination register. 将两个32位寄存器值相乘&#xff0c;并将结果写入64位目标寄存器。 64-bit variant SMULL <Xd>, <Wn>, <Wm>…

Linux破解root用户密码

在Linux启动菜单界面按【e】进入编辑启动菜单项 在LANGzh_CN.UTF-8&#xff08;或LANGen_US.UTF-8&#xff09;后面空出一格输入 rd.break consoletty0,再按【ctrlx】键启动Linux系统 以可读写的方式重新挂载文件系统 mount -o remount,rw /sysroot 改变根目录为/sysro…

Attention Is All You Need论文翻译

论文名称 注意力即是全部 论文地址 https://user.phil.hhu.de/~cwurm/wp-content/uploads/2020/01/7181-attention-is-all-you-need.pdf 摘要 主流的序列转导模型基于复杂的递归或卷积神经网络&#xff0c;这些网络包含编码器和解码器。性能最好的模型通过注意力机制将编码器和…

快速学习一个算法,Transformer模型架构

今天给大家分享一个超强的算法模型&#xff0c;Transformer Transformer 模型是目前自然语言处理&#xff08;NLP&#xff09;以及计算机视觉等领域中应用非常广泛的深度学习模型架构。 它由 Vaswani 等人在 2017 年的论文《Attention is All You Need》中提出&#xff0c;并…

【智能大数据分析 | 实验三】Storm实验:实时WordCountTopology

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈智能大数据分析 ⌋ ⌋ ⌋ 智能大数据分析是指利用先进的技术和算法对大规模数据进行深入分析和挖掘&#xff0c;以提取有价值的信息和洞察。它结合了大数据技术、人工智能&#xff08;AI&#xff09;、机器学习&#xff08;ML&a…

并查集的实现(朴素版)

这是C算法基础-数据结构专栏的第二十九篇文章&#xff0c;专栏详情请见此处。 由于作者即将参加CSP&#xff0c;所以到比赛结束前将不再发表文章&#xff01; 引入 并查集是一种可以快速合并查找集合的一种数据结构&#xff0c;这次我们将通过三道题来详细讲解并查集&#xff…

迈普pnsr2900x DOWNLOAD_FILE 任意文件读取漏洞

0x01 产品描述&#xff1a; ‌ 迈普NSR2900X系列是一款专为军队、政府、金融、中小型企业分支机构和中小型企业总部设计的信创接入路由器。‌ 该路由器采用国产核心元器件&#xff0c;基于国产操作系统运行迈普自主研发的网络操作系统及应用软件。它全面支持IPv4、IPv6、OS…

insert into values 语句优化

insert into values插入单行数据 SQL语句&#xff0c;insert into values插入单行数据&#xff0c;执行10万次&#xff0c;执行时间1279秒&#xff0c;优化总体执行耗时。 SQL文本&#xff0c;单行insert values&#xff0c;没有select部分。需要进一步分析执行过程消耗。 ins…

软考《信息系统运行管理员》- 5.1 信息系统数据资源维护体系

5.1 信息系统数据资源维护体系 文章目录 5.1 信息系统数据资源维护体系数据资源维护的管理对象数据资源维护的管理类型运行监控故障响应数据备份归档检索数据优化 数据资源维护的管理内容维护方案例行管理应急响应数据资源的开发与利用 数据是信息系统管理的对象与结果&#xf…

7-基于国产化FT-M6678+JFM7K325T的6U CPCI信号处理卡

一、板卡概述 本板卡系我公司自主研发&#xff0c;基于6U CPCI的通用高性能信号处理平台。板卡采用一片国产8核DSP FT-C6678和一片国产FPGA JFM7K325T-2FFG900作为主处理器。为您提供了丰富的运算资源。如下图所示&#xff1a; 二、设计参考标准 ● PCIMG 2.0 R3.0 CompactP…

Python酷库之旅-第三方库Pandas(147)

目录 一、用法精讲 666、pandas.Timestamp.astimezone方法 666-1、语法 666-2、参数 666-3、功能 666-4、返回值 666-5、说明 666-6、用法 666-6-1、数据准备 666-6-2、代码示例 666-6-3、结果输出 667、pandas.Timestamp.ceil方法 667-1、语法 667-2、参数 667…