笔记02----重新思考轻量化视觉Transformer中的局部感知CloFormer(即插即用)

news2025/1/13 5:50:20

1. 基本信息

  • 论文标题: 《Rethinking Local Perception in Lightweight Vision Transformer》
  • 中文标题: 《重新思考轻量化视觉Transformer中的局部感知》
  • 作者单位: 清华大学
  • 发表时间: 2023
  • 论文地址: https://arxiv.org/abs/2303.17803
  • 代码地址: https://github.com/qhfan/CloFormer

2. 应用场景

  • 图像分类、目标检测、语义分割等领域。

3. 研究背景

  • 现阶段,Transformer在图像分类、目标检测、语义分割等领域表现出优异的性能。然而Transformer参数量和计算量太大,不适合部署到移动设备。
  • 在现有的轻量级Transformer模型中,大多数方法只注重设计稀疏注意力以有效处理低频全局信息,而处理高频局部信息的方法相对简单。

4. 方法概述

为了同时利用共享权重和上下文感知权重的优势,提出了CloFormer,这是一种具有上下文感知局部增强功能的轻量级视觉转换器,具体贡献如下:

  1. CloFormer 中,引入了一种名为 AttnConv 的卷积算子,它采用注意力机制,充分利用共享权重和上下文感知权重的优势来实现局部感知。 此外,它使用了一种新方法,该方法结合了比普通局部自注意力更强的非线性来生成上下文感知权重。
  2. CloFormer 中,采用双分支架构,其中一个分支使用AttnConv 捕获高频信息,而另一个分支使用带有下采样的普通注意力捕获低频信息。 双分支结构使 CloFormer 能够同时捕获高频和低频信息。
  3. 该方法在图像分类、目标检测和语义分割方面的广泛实验证明了 CloFormer 的优越性。 CloFormerImageNet1k 上仅用 4.2M 参数和 0.6G FLOP 就实现了 77.0% 的准确率,明显优于其他模型。

4.1 整体网络结构

在这里插入图片描述

如上图所示,CloFormer包含一个卷积主干和四个阶段。每个阶段由Clo block和ConvFFN组成, 先通过卷积主干传递输入图像以获得tokens。 该系统由四个卷积组成,每个卷积的步幅分别为2、2、1和1。 随后,标记经过四个阶段的Clo block和ConvFFN来提取层次特征。 最后,利用全局平均池化和全连接层来生成预测。

  • ConvFFN

为了将局部信息整合到FFN过程中,用ConvFFN取代了传统的FFN。 ConvFFN和常用的FFN之间的主要区别是,ConvFFN在GELU激活后使用深度卷积(DWconv),这使得ConvFFN能够聚合局部信息。 由于DWconv,下行采样可以直接在ConvFFN中执行,而无需引入PatchMerge模块。 CloFormer使用了两种类型的ConvFFN。 第一种是级内ConvFFN,它直接利用跳过连接。 另一个是连接两个阶段的ConvFFN。 在这种类型的ConvFFN的跳过连接中,使用DWconv和全连接层分别对输入进行下采样和上维。

在这里插入图片描述

  • Clo block

每个块由一个本地分支和一个全局分支组成。 在全局分支中,首先对K和V进行下采样,然后对Q、K和V进行标准attention处理,提取低频全局信息。

在这里插入图片描述

4.2 AttnConv模块

全局分支的模式有效地减少了需要注意的flop的数量,也产生了一个全局接受野。 然而,它在有效捕获低频全局信息的同时,对高频局部信息的处理能力不足。
在AttnConv中,首先应用线性变换得到Q,K, V,这与标准注意力相同,在进行线性变换后,首先对V进行共享权值的局部特征聚合处理,然后基于处理后的V和Q, K进行上下文感知的局部增强。具体分为为如下三个步骤:

  • Local Feature Aggregation

使用一个简单的深度卷积(DWconv)来对 V 进行局部信息聚合。

  • Context-aware Local Enhancement

使用两个DWconv分别聚合Q和K的本地信息。 然后,计算Q和K的Hadamard积,并对结果进行一系列变换,以获得−1到1之间的上下文感知权重。 最后,利用生成的权值对局部特征进行增强。

  • Fusion with Global Branch

使用简单的方法将局部分支的输出与全局分支的输出融合。

4.3 代码

可以将Clo block当作注意力机制使用,具体代码如下:

import torch  
import torch.nn as nn  
from efficientnet_pytorch.model import MemoryEfficientSwish  # 从 EfficientNet 的库中引入高效激活函数 Swish  
class AttnMap(nn.Module):  
    def __init__(self, dim):  
        super().__init__()  
        # 定义一个包含两层卷积和激活函数的块,用于生成注意力映射  
        self.act_block = nn.Sequential(  
            nn.Conv2d(dim, dim, 1, 1, 0),  # 1x1 卷积,保持通道数不变  
            MemoryEfficientSwish(),       # Swish 激活函数  
            nn.Conv2d(dim, dim, 1, 1, 0)  # 再次使用 1x1 卷积  
        )  
  
    def forward(self, x):  
        return self.act_block(x)  # 前向传播,返回处理后的张量  
  
class CloAttention(nn.Module):  
    def __init__(self, dim, num_heads=8, group_split=[4, 4], kernel_sizes=[5], window_size=4,  
                 attn_drop=0., proj_drop=0., qkv_bias=True):  
        super().__init__()  
        # 参数初始化和断言检查  
        assert sum(group_split) == num_heads  # 确保分组的头总数等于注意力头总数  
        assert len(kernel_sizes) + 1 == len(group_split)  # 核大小和分组数一致  
  
        self.dim = dim  # 输入通道数  
        self.num_heads = num_heads  # 总的多头注意力头数  
        self.dim_head = dim // num_heads  # 每个头的通道数  
        self.scalor = self.dim_head ** -0.5  # 注意力缩放因子  
        self.kernel_sizes = kernel_sizes  # 高频分支的卷积核大小  
        self.window_size = window_size  # 低频分支窗口大小  
        self.group_split = group_split  # 每个分支分配的头数  
  
        # 创建高频和低频分支的模块  
        convs = []  # 高频卷积  
        act_blocks = []  # 高频注意力模块  
        qkvs = []  # 高频分支的 QKV 卷积  
  
        for i in range(len(kernel_sizes)):  
            kernel_size = kernel_sizes[i]  
            group_head = group_split[i]  
            if group_head == 0:  
                continue  # 如果分组头数为 0,跳过此分支  
            convs.append(nn.Conv2d(3 * self.dim_head * group_head, 3 * self.dim_head * group_head, kernel_size,  
                                   1, kernel_size // 2, groups=3 * self.dim_head * group_head))  # 高频卷积  
            act_blocks.append(AttnMap(self.dim_head * group_head))  # 注意力映射模块  
            qkvs.append(nn.Conv2d(dim, 3 * group_head * self.dim_head, 1, 1, 0, bias=qkv_bias))  # QKV 卷积  
  
        # 定义低频全局注意力分支  
        if group_split[-1] != 0:  
            self.global_q = nn.Conv2d(dim, group_split[-1] * self.dim_head, 1, 1, 0, bias=qkv_bias)  # Q 卷积  
            self.global_kv = nn.Conv2d(dim, group_split[-1] * self.dim_head * 2, 1, 1, 0, bias=qkv_bias)  # KV 卷积  
            self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size != 1 else nn.Identity()  # 平均池化  
  
        # 将模块添加到 ModuleList 中  
        self.convs = nn.ModuleList(convs)  
        self.act_blocks = nn.ModuleList(act_blocks)  
        self.qkvs = nn.ModuleList(qkvs)  
        self.proj = nn.Conv2d(dim, dim, 1, 1, 0, bias=qkv_bias)  # 投影层  
        self.attn_drop = nn.Dropout(attn_drop)  # 注意力权重的 dropout        self.proj_drop = nn.Dropout(proj_drop)  # 输出的 dropout  
    def high_fre_attntion(self, x: torch.Tensor, to_qkv: nn.Module, mixer: nn.Module, attn_block: nn.Module):  
        '''  
        高频分支的注意力计算  
        x: (b c h w) 输入特征  
        '''        b, c, h, w = x.size()  
        qkv = to_qkv(x)  # 计算 QKV,得到 (b, 3*m*d, h, w)        qkv = mixer(qkv).reshape(b, 3, -1, h, w).transpose(0, 1).contiguous()  # 混合后得到 (3, b, m*d, h, w)        q, k, v = qkv  # 分解为 Q、K、V  
        attn = attn_block(q.mul(k)).mul(self.scalor)  # 计算缩放后的注意力  
        attn = self.attn_drop(torch.tanh(attn))  # 使用 tanh 激活并应用 dropout        res = attn.mul(v)  # 应用注意力权重到 V        return res  
  
  
    def low_fre_attention(self, x: torch.Tensor, to_q: nn.Module, to_kv: nn.Module, avgpool: nn.Module):  
        '''  
        低频分支的注意力计算  
        x: (b c h w) 输入特征  
        '''        b, c, h, w = x.size()  
        q = to_q(x).reshape(b, -1, self.dim_head, h * w).transpose(-1, -2).contiguous()  # 计算 Q 并调整形状为 (b, m, h*w, d)        kv = avgpool(x)  # 对输入特征进行平均池化  
        kv = to_kv(kv).view(b, 2, -1, self.dim_head, (h * w) // (self.window_size ** 2)).permute(1, 0, 2, 4, 3).contiguous()  # 计算 KV        k, v = kv  # 分解为 K、V  
        attn = self.scalor * q @ k.transpose(-1, -2)  # 计算缩放后的注意力  
        attn = self.attn_drop(attn.softmax(dim=-1))  # 对注意力进行 softmax 和 dropout        res = attn @ v  # 应用注意力权重到 V        res = res.transpose(2, 3).reshape(b, -1, h, w).contiguous()  # 调整形状为原始形状  
        return res  
  
  
    def forward(self, x: torch.Tensor):  
        '''  
        x: (b c h w) 输入特征  
        '''        res = []  # 保存各分支的输出  
        for i in range(len(self.kernel_sizes)):  
            if self.group_split[i] == 0:  
                continue  
            res.append(self.high_fre_attntion(x, self.qkvs[i], self.convs[i], self.act_blocks[i]))  # 高频分支输出  
        if self.group_split[-1] != 0:  
            res.append(self.low_fre_attention(x, self.global_q, self.global_kv, self.avgpool))  # 低频分支输出  
        return self.proj_drop(self.proj(torch.cat(res, dim=1)))  # 合并分支输出并应用投影  
  
  
# 输入 N C HW,  输出 N C H W
if __name__ == '__main__':  
    block = CloAttention(64).cuda()  # 初始化 CloAttention 模块  
    input = torch.rand(1, 64, 64, 64).cuda()  # 创建一个随机输入  
    output = block(input)  # 前向传播  
    print(f"Input_Size:{input.size()}\nOutput_Size:{output.size()}")  # 打印输入和输出的张量形状

5. 结果

表中报告了ImageNet1K分类结果。 结果表明,当模型大小和FLOPs相似时,模型比以前的模型性能更好。 其中,CloFormer-XXS仅使用4.2万个参数和0.6G FLOPs, Top-1准确率达到77.0%,分别超过ShuffleNetV22x、MobileViT-XS和EdgeViT-XXS 1.6%、2.2%和2.6%

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2244559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

光猫、路由器、交换机之连接使用(Connection and Usage of Optical Cats, Routers, and Switches)

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 本人主要分享计算机核心技…

群核科技首次公开“双核技术引擎”,发布多模态CAD大模型

11月20日,群核科技在杭州举办了第九届酷科技峰会。现场,群核科技首次正式介绍其技术底层核心:基于GPU高性能计算的物理世界模拟器。并对外公开了两大技术引擎:群核启真(渲染)引擎和群核矩阵(CAD…

目录背景缺少vscode右键打开选项

目录背景缺少vscode右键打开选项 1.打开右键管理 下载地址:https://wwyz.lanzoul.com/iZy9G2fl28uj 2.开始搜索框搜索vscode, 找到其源目录 3.目录背景里面, 加入vscode.exe 3.然后在目录背景下, 右键, code就可以打…

【已解决】“EndNote could not connect to the online sync service”问题的解决

本人不止一次在使用EndNote软件时遇到过“EndNote could not connect to the online sync service”这个问题。 过去遇到这个问题都是用这个方法来解决: 这个方法虽然能解决,但工程量太大,每次做完得歇半天身体才能缓过来。 后来再遇到该问…

Java小白成长记(创作笔记一)

目录 序言 思维导图 开发流程 新建SpringBoot并整合MybatisPlus 新建SpringBoot 整合MybatisPlus 统一结果封装 全局异常处理 引入数据库 序言 在一个充满阳光的早晨,一位对编程世界充满好奇的年轻人小小白,怀揣着梦想与激情,踏上了学习…

vue--响应式数据

1、建一个vue应用程序&#xff08;简约&#xff09; 引入外链式 <title>第一个Vue程序</title><script src"../vue.global.js"></script> </head> {{}}插值表达式 <body><!-- {{ }} 插值表达式, 可以将 Vue 实例中定义的…

网络安全,文明上网(2)加强网络安全意识

前言 在当今这个数据驱动的时代&#xff0c;对网络安全保持高度警觉已经成为每个人的基本要求。 网络安全意识&#xff1a;信息时代的必备防御 网络已经成为我们生活中不可或缺的一部分&#xff0c;信息技术的快速进步使得我们对网络的依赖性日益增强。然而&#xff0c;网络安全…

怎么做好白盒测试?

白盒测试 一、什么是白盒测试&#xff1f;二、白盒测试特点三、白盒测试的设计方法1、逻辑覆盖法1、测试设计方法—语句覆盖a、用例设计如下&#xff1a;b、语句覆盖的局限性 2、测试设计方法—判定覆盖a、测试用例如下&#xff1a;b、判定覆盖的局限性 3、测试设计方法—条件覆…

阻尼Newton方法-数值最优化方法-课程学习笔记-5

这篇文章我们继续来学习数值最优化方法第三章的后续内容 阻尼Newton方法 这一章我们以及在之前了解过了&#xff1a;最速下降法&#xff0c;基本Newton方法&#xff0c;这一节我们来了解阻尼newton方法 之前我们提到的基本Newton方法是以一固定步长和Newton方向进行迭代的&a…

力扣 只出现一次的数字-136

只出现一次的数字-136 class Solution { public:int singleNumber(vector<int>& nums) {//按位异或的规则是&#xff1a;两个二进制位相同时结果为0&#xff0c;不同时结果为1//具有自反性&#xff0c;两个二进制位相同时结果为0&#xff0c;一个数(a)和0按位异或的…

Vue.js 自定义指令:从零开始创建自己的指令

vue使用directive 前言vue2使用vue3使用 前言 关于使用自定义指令在官网中是这样描述的 vue2:对普通 DOM 元素进行底层操作&#xff0c;这时候就会用到自定义指令。 vue3:自定义指令主要是为了重用涉及普通元素的底层 DOM 访问的逻辑。 在 Vue.js 中使用自定义指令&#xf…

MySQL-关键字执行顺序

&#x1f496;简介 在MySQL中&#xff0c;SQL查询语句的执行遵循一定的逻辑顺序&#xff0c;即使这些关键字在SQL语句中的物理排列可能有所不同。 &#x1f31f;语句顺序 (8) SELECT (9) DISTINCT<select_list> (1) FROM <left_table> (3) <join_type> JO…

Odoo :免费且开源的农牧行业ERP管理系统

文 / 开源智造Odoo亚太金牌服务 引言 提供农牧企业数字化、智能化、无人化产品服务及全产业链高度协同的一体化解决方案&#xff0c;提升企业智慧种养、成本领先、产业互联的核心竞争力。 行业典型痛点 一、成本管理粗放&#xff0c;效率低、管控弱 产品研发过程缺少体系化…

labview记录系统所用月数和天数

在做项目时会遇到采集系统的记录&#xff0c;比如一个项目测试要跑很久这个时候就需要在软件系统上显示项目运行了多少天&#xff0c;从开始测试开始一共用了多少年多少月。 年的话还好计算只需要把年份减掉就可以了&#xff0c;相比之下月份和天数就比较难确定&#xff0c;一…

【C++笔记】list使用详解及模拟实现

前言 各位读者朋友们大家好&#xff01;上期我们讲了vector的使用以及底层的模拟实现&#xff0c;这期我们来讲list。 目录 前言一. list的介绍及使用1.1 list的介绍1.2 list的使用1.2.1 list的构造1.2.2 list iterator的使用1.2.3 list capacity1.2.4 list element access1.…

window 中安装 php 环境

window 中安装 php 环境 一、准备二、下载三、安装四、测试 一、准备 安装前需要安装 Apache &#xff0c;可以查看这篇博客。 二、下载 先到这里下载 这里选择版本为“VS16 x64 Thread Safe”&#xff0c;这个版本不要选择线程安全的&#xff0c;我试过&#xff0c;会缺少文…

【大模型】LLaMA: Open and Efficient Foundation Language Models

链接&#xff1a;https://arxiv.org/pdf/2302.13971 论文&#xff1a;LLaMA: Open and Efficient Foundation Language Models Introduction 规模和效果 7B to 65B&#xff0c;LLaMA-13B 超过 GPT-3 (175B)Motivation 如何最好地缩放特定训练计算预算的数据集和模型大小&…

流程图图解@RequestBody @RequestPart @RequestParam @ModelAttribute

RequestBody 只能用一次&#xff0c;因为只有一个请求体 #mermaid-svg-8WZfkzl0GPvOiNj3 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-8WZfkzl0GPvOiNj3 .error-icon{fill:#552222;}#mermaid-svg-8WZfkzl0GPvOiNj…

论文阅读--supervised learning with quantum enhanced feature spaces

简略摘要 量子算法实现计算加速的核心要素是通过可控纠缠和干涉利用指数级大的量子态空间。本文在超导处理器上提出并实验实现了两种量子算法。这两种方法的一个关键组成部分是使用量子态空间作为特征空间。只有在量子计算机上才能有效访问的量子增强特征空间的使用为量子优势提…

django+boostrap实现注册

一、django介绍 Django 是一个高级的 Python 网络框架&#xff0c;可以快速开发安全和可维护的网站。由经验丰富的开发者构建&#xff0c;Django 负责处理网站开发中麻烦的部分&#xff0c;因此你可以专注于编写应用程序&#xff0c;而无需重新开发。 它是免费和开源的&#x…