论文阅读【2】-SepViT: Separable Vision Transformer论文结构漫谈与Python实现测试

news2024/11/16 6:34:23

可分离卷积+ViT实现轻量级transformer结构

  • 1. 论文主要工作
    • 1.1 摘要内容
    • 1.2 写作动机(Motivations)
      • 1.2.1 Transformer Patch结构的巨大计算量问题
      • 1.2.2 Swin:针对计算量的优化
      • 1.2.3 Twins:针对边缘端部署优化
      • 1.2.4 Cswin:存在吞吐量问题
      • 1.2.5 提出 SepViT结构
  • 2. 轻量级模型:从AlexNet的分组卷积 到 MobileNet系列
  • 3. Pytorch的demo测试
    • 3.1 模型关键结构
    • 3.2 模型的变体结构
    • 3.2 程序-mnist数据集测试
    • 3.3 测试结果
  • 参考文献资料

前言: 这篇论文第一作者是Bytedance实习生Li Wei,电子科技大学一名学生。论文主体是轻量级ViT。本博客简单解读了该论文的主要工作,并基于Python+Pytorch,做了一个结构的简单测试试验分享。


1. 论文主要工作

1.1 摘要内容

Vision Transformers在一系列的视觉任务中取得了巨大的成功。然而,它们通常都需要大量的计算来实现高性能,这在部署在资源有限的设备上这是一个负担。

为了解决这些问题,作者受深度可分离卷积启发设计了深度可分离Vision Transformers,缩写为SepViT。SepViT通过一个深度可分离Self-Attention促进Window内部和Window之间的信息交互。并群欣设计了新的Window Token Embedding和分组Self-Attention方法,分别对计算成本可忽略的Window之间的注意力关系进行建模,并捕获多个Window的长期视觉依赖关系。

在各种基准测试任务上进行的大量实验表明,SepViT可以在准确性和延迟之间的权衡方面达到最先进的结果。其中,SepViT在ImageNet-1K分类上的准确率达到了84.0%,而延迟率降低了40%。在下游视觉任务中,SepViT在ADE20K语义分割任务达到50.4%的mIoU,基于RetinaNet的COCO目标检测任务达到47.5AP,基于Mask R-CNCN的48.7 box AP检测和分割任务实现43.9 mask AP。

1.2 写作动机(Motivations)

近年来,许多计算机视觉(CV)研究人员致力于设计面向CV的Vision Transformers,以超过卷积神经网络(CNNs)的性能。Vision Transformers具有较高的远距离依赖建模能力,在图像分类、语义分割、目标检测等多种视觉任务中取得了显著的效果。然而,强大的性能通常是以计算复杂度为代价的。

1.2.1 Transformer Patch结构的巨大计算量问题

最初,ViT首先将Transformer引入图像识别任务中。它将整个图像分割为几个Patches,并将每个Patch作为一个Token提供给Transformer。然而,由于计算效率低下的Self-Attention,基于Patch的Transformer很难部署。

1.2.2 Swin:针对计算量的优化

为了解决这个问题,Swin提出了基于Window的Self-Attention,限制了非重叠子Window中Self-Attention的计算。显然,基于Window的Self-Attention在很大程度上降低了复杂性,但构建Window间连接的Shift操作符给ONNX或TensorRT的部署带来了困难。

1.2.3 Twins:针对边缘端部署优化

Twins利用基于Window的Self-Attention和PVT的Spatial Reduction Attention,提出了空间可分离Self-Attention。尽管Twins是部署友好型的,并且具有出色的性能,但它的计算复杂度几乎没有降低。

1.2.4 Cswin:存在吞吐量问题

CSWin通过Cross-Shaped Window Self-Attention得到了最先进的性能,但吞吐量较低。

尽管在这些著名的Transformer中取得了不同程度的进展,但它最近的大部分成功都伴随着巨大的资源需求。

1.2.5 提出 SepViT结构

为了克服上述问题本文提出了一种高效的Transformer Backbone,称为可分离Vision Transformers (SepViT),它可以按顺序捕获局部和全局依赖。SepViT的一个关键设计元素是深度可分离的Self-Attention模块,如图2所示。

受MobileNet中深度可分卷积的启发重新设计了Self-Attention模块,并提出了深度可分离Self-Attention,它由Depthwise Self-Attention和Pointwise Self-Attention组成,分别对应于MobileNet中的Depthwise和PointWise卷积。Depthwise Self-Attention用于捕获每个Window内的局部特征,Pointwise Self-Attention用于构建Window间的连接,提高表达能力。

2. 轻量级模型:从AlexNet的分组卷积 到 MobileNet系列

2.2 轻量化模型
针对移动端视觉任务,提出了许多轻量级和移动端友好的卷积方案。其中,分组卷积是由AlexNet首次提出的分组卷积,它对特征映射进行分组并进行分布式训练。

那么,移动端友好卷积的代表性工作必须是具有深度可分离卷积的MobileNet。深度可分离卷积包括用于空间信息通信的Depthwise卷积和用于跨通道信息交换的Pointwise卷积。随着时间的推移,许多基于上述工作的变体被开发出来。从v1堆叠可分离卷积结构到v2引入linear bottleneck与inverted residual,再到v3的结构搜索,轻量级模型似乎已经发展到了尽头。


而这篇论文的工作中,作者将深度可分离卷积的思想应用到Transformer中【1】,旨在在不牺牲性能的情况下降低Transformer的计算复杂度。

3. Pytorch的demo测试

3.1 模型关键结构

如下图所示,SepViT遵循了广泛使用的层次体系结构和基于Window的Self-Attention。

此外,SepViT还采用了条件位置编码(CPE)。对于每个阶段,都有一个重叠的Patch合并层用于特征图降采样,然后是一系列的SepViT Block。空间分辨率将以stride=4步或stride=2步逐步进行下采样,最终达到32倍下采样,通道尺寸也逐步增加一倍。

值得注意的是,局部上下文和全局信息都可以在单个SepViT Block中捕获,而其他工作应该使用2个连续的Block来完成这种局部-全局建模。

在SepViT块中,每个Window内的局部信息通信是通过DepthWise Self-Attention(DWA)实现的,Window间的全局信息交换是通过PointWise Self-Attention(PWA)进行。

3.2 模型的变体结构

在这里插入图片描述

此外,还设计了SepViT-Lite变体与一个非常轻的模型尺寸。SepViT变体的具体配置如表1所示,由于SepViT的效率更高,因此在某些阶段,SepViT的Embedding块深度比竞争对手要小。

DSSA和GSA分别表示具有Depthwise Separable Self-Attention和Grouped Self-Attention Block。此外,每个MLP层的扩展比设置为4,在所有SepViT变体中,DSSA和GSA的Window Sizes分别为7×7和14×14。

在这里插入图片描述


PS: 但是,Github代码这个库的所有者提到了一个问题:它只将SepViT的版本与这个特定的自我注意层结合起来,因为他觉得分组的注意层既不显著也不新颖,作者也不清楚他们是如何处理组自我注意层的窗口标记的。此外,似乎仅凭DSSA层,他们就能够击败Swin。

3.2 程序-mnist数据集测试

包安装

pip install vit-pytorch -i http://pypi.douban.com/simple --trusted-host pypi.douban.com

demo程序测试

from torch.optim.lr_scheduler import StepLR
from vit_pytorch.sep_vit import SepViT
import torch
from torch import nn
from torch import optim
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torchvision
'''
1. sep_vit.py模型中定义的是3通道数据张量数据输入,现在修改为了单通道的数据输入
2. window_size参数修改为1
'''
##########################################################################################################
# Training Hyper-parameters settings
batch_size = 49
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 42
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# inputs, labels = inputs.to(device), labels.to(device)
###########################################################################################################
# 模型定义
model = SepViT(
    num_classes=10,
    dim=32,               # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
    dim_head=32,          # attention head dimension
    heads=(1, 2, 4, 8),   # number of heads per stage
    depth=(1, 2, 6, 2),   # number of transformer blocks per stage
    window_size=1,        # window size of DSS Attention block (报错:height 64 and width 64 must be divisible by window size 7)
    dropout=0.1           # dropout
)

# loss function
loss_func = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
######################################################################################################################
# 其余超参数设置
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
# model.to(device)
model.eval()
############################################################################################################
# Mnist数据加载
train_data = datasets.MNIST(
    root = 'data',
    train = True,
    transform = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]),
    download = True,
)
test_data = datasets.MNIST(
    root = 'data',
    train = False,
    transform = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))


loaders = {
    'train': torch.utils.data.DataLoader(train_data,
                                         batch_size=100,
                                         shuffle=True,
                                         num_workers=1),

    'test': torch.utils.data.DataLoader(test_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=1),
}
#############################################################################################################
# 模型训练
from torch.autograd import Variable
num_epochs = 10

def train(num_epochs, cnn, loaders):
    cnn.train()

    # Train the model
    total_step = len(loaders['train'])

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(loaders['train']):
            # gives batch data, normalize x when iterate train_loader
            b_x = Variable(images)  # batch x
            b_y = Variable(labels)  # batch y

            output = model(b_x)
            loss = loss_func(output, b_y)

            # clear gradients for this training step
            optimizer.zero_grad()

            # backpropagation, compute gradients
            loss.backward()
            # apply gradients
            optimizer.step()

            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
            pass

        pass

    pass
if __name__ == '__main__':
    train(num_epochs, model, loaders)

3.3 测试结果

Epoch [1/10], Step [100/1875], Loss: 1.1673
Epoch [1/10], Step [200/1875], Loss: 0.4501
Epoch [1/10], Step [300/1875], Loss: 0.6198
Epoch [1/10], Step [400/1875], Loss: 0.3469
Epoch [1/10], Step [500/1875], Loss: 0.3964
Epoch [1/10], Step [600/1875], Loss: 0.4239
Epoch [1/10], Step [700/1875], Loss: 0.4740
Epoch [1/10], Step [800/1875], Loss: 0.3085
Epoch [1/10], Step [900/1875], Loss: 0.6178
Epoch [1/10], Step [1000/1875], Loss: 0.6015
Epoch [1/10], Step [1100/1875], Loss: 0.1161
Epoch [1/10], Step [1200/1875], Loss: 0.2946
Epoch [1/10], Step [1300/1875], Loss: 0.2689
Epoch [1/10], Step [1400/1875], Loss: 0.3746
Epoch [1/10], Step [1500/1875], Loss: 0.2356
Epoch [1/10], Step [1600/1875], Loss: 0.0904
Epoch [1/10], Step [1700/1875], Loss: 0.2226
...

参考文献资料

【1】SepViT: Separable Vision Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/431391.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【c#串口通信(2)】串口相关参数介绍

1、端口号(Port) 我们使用一个串口的时候,首先是要打开这个串口,那么我们怎么知道电脑上现在支持几个串口呢?对应的端口号又是什么呢? 由于我的电脑系统是window11,下面就以window11为例介绍如…

部分国产水文水动力模型介绍

一、HydroMPM模型 1、模型介绍 2016年度自立项目HydroMPM系统开发与集成完成的洪水分析模拟软件等成果经权威专家鉴定整体达到国际领先水平,HydroMPM_FloodRisk入选国家防总《全国重点地区洪水风险图编制项目可选软件名录》。成果应用项目100余项,累计…

spring自定义命名空间

命名空间 如果你曾经在配置datasource是用过properties文件储存我们的数据库连接信息&#xff0c;那么一定在xml文件中配置过这样的语句。 <context:property-placeholder location"classpath:jdbc.properties"/>而我们的spring当中很明显是没有这个context的…

【git】git的一些基础操作

文章目录一.git下载二.git初次操作1.生成公钥2.修改全局用户名和邮箱地址&#xff1a;3.本地仓库关联远端仓库4.本地初始化5.将项目上所有的文件添加到本地仓库6.提交到本地仓库7.创建main分支8.推送到main分支三.git其他操作1.develop分支2.查看分支3.切换分支4.查看分支历史一…

python wannier90 基于wannier90的*_hr.dat文件选取截断hopping绘制能带图

我们知道wannier90可以根据选取TMDs的轨道信息生成详细的hopping energy *_hr.dat文件&#xff0c;选取所有的hopping绘制起来的时候比较简单&#xff0c;但是我们发现取几圈的近似hopping也可以将band表示出来&#xff0c;类似的思想有Pybinding的三带近似&#xff08;DOI: 10…

区块链技术在软件开发中的应用

如果你是一名软件开发者或者IT从业者&#xff0c;你一定已经听说过区块链技术。区块链是一种基于密码学的分布式账本技术&#xff0c;被广泛应用于数字货币、金融、物联网等领域。但是&#xff0c;除了这些领域之外&#xff0c;区块链技术还可以在软件开发中发挥重要作用。本文…

CLIP 论文解读

文章目录模型训练推理实验与Visual N-Grams 相比较分布Shift的鲁棒性不足参考现有的计算机视觉系统用来预测一组固定的预订对象类别&#xff0c;比如ImageNet数据集有1000类&#xff0c;CoCo数据集有80类。这种受限的监督形式限制了模型的通用性和可用性。使用这种方法训练好的…

《花雕学AI》02:人工智能挺麻利,十分钟就为我写了一篇长长的故事

ChatGPT最近火爆全网&#xff0c;上线短短两个多月&#xff0c;活跃用户就过亿了&#xff0c;刷新了历史最火应用记录&#xff0c;网上几乎每天也都是ChatGPT各种消息。国内用户由于无法直接访问ChatGPT&#xff0c;所以大部分用户都无缘体验。不过呢&#xff0c;前段时间微软正…

Nginx实现会话保持,集群模式下session域共享

前言 生产环境下&#xff0c;多数系统为了应对线上多种复杂情况而进行了集群架构的部署&#xff0c;保证系统的高性能、价格有效性、可伸缩性、高可用性等。通常将生产环境下的域名指向Nginx服务&#xff0c;通过它做HTTP协议的Web负载均衡。 session是什么 在计算机中&…

13.广度优先搜索

一、算法内容 1.简介 广度优先搜索BFS&#xff08;Breadth First Search&#xff09;按照广度优先的方式进行搜索&#xff0c;可以理解为“尝试所有下一步可能”地穷举所有可行的方案&#xff0c;并不断尝试&#xff0c;直到找到一种情况满足问题问题的要求。 BFS从起点开始…

C语言——学生信息管理系统(数组)

文章目录一、前言二、目的三、框架1.菜单1.1主菜单1.2子菜单2.流程图2.1总流程图2.2开始流程图2.3增加学生信息流程图2.4.删除学生信息流程图2.5修改学生信息流程图2.6查询学生信息流程图2.7对学生信息排序流程图3.思路四、代码五、演示视频一、前言 因为最近是在赶进度总结&a…

无人驾驶--工控机安装autoware

时隔好久&#xff0c;又来写文章了&#xff0c;这次有高人指点&#xff0c;要系统的学习一下无人驾驶了。 使用的是易咖的底盘车&#xff0c;工控机是米文动力Apex Xavier II&#xff0c;基于autoware框架 首先是在工控机上安装autoware&#xff0c;工控是ubuntu18环境。 参…

Python入门教程+项目实战-9.2节: 字符串的操作符

目录 9.2.1 字符串常用操作符 9.2.2 操作符&#xff1a;拼接字符串 9.2.3 *操作符&#xff1a;字符串的乘法 9.2.4 []操作符&#xff1a;索引访问 9.2.5 [:]操作符&#xff1a;分片字符串 9.2.6 in操作符&#xff1a;查找子串 9.2.7 %操作符&#xff1a;格式化字符串 9…

为什么要做软件测试

随着信息技术的发展和普及&#xff0c;人们对软件的使用越来越普及。但是在软件的使用过程中&#xff0c;软件的效果却不尽如人意。为了确保软件的质量&#xff0c;整个软件业界已经逐渐意识到测试的重要性&#xff0c;软件测试已经成为IT 领域的黄金行业。本篇文章将会带领大家…

使用Tensorboard多超参数随机搜索训练

文章目录1超参数训练代码2远端电脑启动tensorboard完整代码位置https://gitee.com/chuge325/base_machinelearning.git 这里还参考了tensorflow的官方文档 但是由于是pytorch训练的差别还是比较大的&#xff0c;经过多次尝试完成了训练 硬件是两张v100 1超参数训练代码 这个…

Android Studio升级Gradle Plugin升级导致项目运行失败问题

背景&错误 升级Android Studio 旧项目无法运行&#xff0c;奇奇怪怪什么错误都有 例如&#xff1a; java.lang.IllegalAccessError: class org.gradle.api.internal.tasks.compile.processing.AggregatingProcessingStrategy (in unnamed module 0x390ea9fb) cannot acce…

传智健康-day2

一.需求分析(预约管理功能开发) 预约管理功能&#xff0c;包括检查项管理、检查组管理、体检套餐管理、预约设置等、预约管理属于系统的基础功能&#xff0c;主要就是管理一些体检的基础数据。 检查组是检查项的集合 二.基础环境搭建 1导入预约管理模块数据表 需要用到的…

Ubuntu安装MySQL及常用操作

一、安装MySQL 使用以下命令即可进行mysql安装&#xff0c;注意安装前先更新一下软件源以获得最新版本&#xff1a; sudo apt-get update #更新软件源 sudo apt-get install mysql-server #安装mysql 上述命令会安装以下包&#xff1a; apparmor mysql-client-5.7 mysql-c…

不定期更新:我对 ChatGPT 进行多方位了解后的报告,超级全面,建议想了解的朋友看看

优质介绍视频&#xff1a; GPT4前端【AI编程新纪元】 【渐构】万字科普GPT4为何会颠覆现有工作流&#xff1b;为何你要关注微软Copilot、文心一言等大模型 此文章不定期更新&#xff08;一周应该会更新一次&#xff09; 最近一次更新&#xff1a;2023.4.16 12:00 ChatGPT 是什…

零基础搭建私人影音媒体平台【远程访问Jellyfin播放器】

文章目录1. 前言2. Jellyfin服务网站搭建2.1. Jellyfin下载和安装2.2. Jellyfin网页测试3.本地网页发布3.1 cpolar的安装和注册3.2 Cpolar云端设置3.3 Cpolar本地设置4.公网访问测试5. 结语1. 前言 随着移动智能设备的普及&#xff0c;各种各样的使用需求也被开发出来&#xf…