CVPR 2023 | 主干网络FasterNet 核心解读 代码分析

news2024/11/19 7:43:48

本文分享来自CVPR 2023的论文,提出了一种快速的主干网络,名为FasterNet

论文提出了一种新的卷积算子,partial convolution,部分卷积(PConv),通过减少冗余计算内存访问来更有效地提取空间特征。

创新在于部分卷积(PConv),它选择一部分通道的特性进行常规卷积剩余部分通道的特性保持不变,降低了计算复杂度,从而实现了快速高效的神经网络。

区别于常规卷积:PConv只对输入通道的一部分应用卷积,而保留其余部分不变。

论文地址:Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks

代码地址:https://github.com/JierunChen/FasterNet/tree/master

目录

一、PConv算子设计原理

二、PConv算子的代码解析 

三、FasterNet模型原理

四、FasterNet模型测试

五、实验分析


背景:

  • MobileNet、ShuffleNet和GhostNet等利用深度卷积(DWConv)或 组卷积(GConv)来提取空间特征。
  • 然而,在减少FLOPs的过程中,算子经常会受到内存访问增加的副作用的影响
  • MicroNet进一步分解和稀疏网络,将其FLOPs推至极低水平。尽管这种方法在FLOPs方面有所改进,但其碎片计算效率很低。
  • 上述网络通常伴随着额外的数据操作,如级联、Shuffle和池化这些操作的运行时间对于小型模型来说往往很重要

一、PConv算子设计原理

 1、这种部分卷积的核心思想对输入特征图的部分通道应用卷积操作而保留其他通道不变。这种操作可以有效地减少计算冗余,提高计算效率。

对于连续或规则的内存访问,将第一个或最后一个连续的通道视为整个特征图的代表进行计算。

在不丧失一般性的情况下认为输入和输出特征图具有相同数量的通道

设计原因

通过利用特征图的冗余度可以进一步优化成本。

如下图所示,特征图在不同通道之间具有高度相似性。许多其他著作也涵盖了这种冗余,但很少有人以简单而有效的方式充分利用它。

于是出了PConv,对输入特征图的部分通道应用卷积操作而保留其他通道不变,同时减少计算冗余和内存访问。

2、为了充分有效地利用来自所有通道的信息,进一步将逐点卷积(PWConv)附加到PConv

它们在输入特征图上的有效感受野看起来像一个T形Conv,与均匀处理补丁的常规Conv相比,它更专注于中心位置。

通过实验表明:中心位置是卷积操作中最常见的突出位置,即中心位置的权重比周围的更重。这与集中于中心位置的T形计算一致。

虽然T形卷积可以直接用于高效计算,但作者表明,将T形卷积分解为PConv和PWConv更好,因为该分解利用了卷积操作间冗余并进一步节省了FLOPs。

二、PConv算子的代码解析 

PConv算子的代码:

'''
输入三个参数:dim(输入特征图的通道数),n_div(分割的组数)和forward(前向传播的方法)
输出:卷积后的特征图
'''
class Partial_conv3(nn.Module):
    def __init__(self, dim, n_div, forward):
        super().__init__()
        self.dim_conv3 = dim // n_div # 计算出卷积部分的通道数
        self.dim_untouched = dim - self.dim_conv3 # 计算出不需要卷积部分的通道数

        # 定义一个3*3卷积,输入通道数为self.dim_conv3,输出通道数也为self.dim_conv3,步长为1,填充为1,且不使用bias。
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)

        if forward == 'slicing':
            self.forward = self.forward_slicing
        elif forward == 'split_cat':
            self.forward = self.forward_split_cat
        else:
            raise NotImplementedError

    # 只适合推理
    def forward_slicing(self, x: Tensor) -> Tensor:
        # 对输入x进行深拷贝,以保持原始输入的完整性。后面的操作不会改变原始输入x。
        x = x.clone()   
        # 对输入x中前self.dim_conv3个通道应用卷积操作,并将结果保存回x中对应的位置。
        x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
        return x

    # 适合训练/推理
    def forward_split_cat(self, x: Tensor) -> Tensor:
        # 使用torch.split函数将输入x沿着通道维度(即第1维,索引从0开始)分割成两个部分,
        # 分别为x1和x2。分割的长度为[self.dim_conv3, self.dim_untouched],
        # 即分割后的x1的通道数为self.dim_conv3,x2的通道数为self.dim_untouched。
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)
        return x

这段代码定义了一个名为 Partial_conv3 的 PyTorch 模块,它是nn.Module的子类。这个模块主要实现了一种部分卷积(Partial Convolution); 

这种部分卷积的核心思想对输入特征图的部分通道应用卷积操作而保留其他通道不变。这种操作可以有效地减少计算冗余,提高计算效率。

方式1:slicing

 # 只适合推理
    def forward_slicing(self, x: Tensor) -> Tensor:
        # 对输入x进行深拷贝,以保持原始输入的完整性。后面的操作不会改变原始输入x。
        x = x.clone()   
        # 对输入x中前self.dim_conv3个通道应用卷积操作,并将结果保存回x中对应的位置。
        x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
        return x

方式2:split_cat

    # 适合训练/推理
    def forward_split_cat(self, x: Tensor) -> Tensor:
        # 使用torch.split函数将输入x沿着通道维度(即第1维,索引从0开始)分割成两个部分,
        # 分别为x1和x2。分割的长度为[self.dim_conv3, self.dim_untouched],
        # 即分割后的x1的通道数为self.dim_conv3,x2的通道数为self.dim_untouched。
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)
        return x

三、FasterNet模型原理

基于部分卷积算子PConv逐点卷积PWConv,作为主要的算子,进一步提出FasterNet。

这是一个新的神经网络家族,运行速度非常快,对许多视觉任务有效。模型架构如下:

它有4个层次级,每个层次级前面都有一个嵌入层(步长为4的常规4×4卷积)或一个合并层(步长为2的常规2×2卷积),用于空间下采样和通道数量扩展。每个阶段都有一堆FasterNet块。

每个FasterNet块有一个PConv层,后跟2个PWConv(或Conv 1×1)层。它们一起显示为倒置残差块,其中中间层具有扩展的通道数量,并且放置了Shorcut以重用输入特征。

最后两个阶段中的块消耗更少的内存访问,并且倾向于具有更高的FLOPS,因此,放置了更多FasterNet块,并相应地将更多计算分配给最后两个阶段。

补充一下标准化和激活层

标准化和激活层对于高性能神经网络也是不可或缺的。

然而,许多先前的工作在整个网络中过度使用这些层,这可能会限制特征多样性,从而损害性能。它还可以降低整体计算速度。

相比之下,只将它们放在每个中间PWConv之后,以保持特征多样性并实现较低的延迟。

四、FasterNet模型测试

使用默认的参数构建FasterNet

        mlp_ratio=2.0,

        embed_dim=96,

        depths=(1, 2, 8, 2),

        drop_path_rate=0.10,

看一下的模型参数 :

感觉模型也不小的。。。。。。。

测试代码分享给大家(代码存放路径:models/model_summary.py)

import torch.nn as nn
from fasternet import FasterNet
from torchsummary import summary

# 默认参数
def fasternet(**kwargs):
    model = FasterNet(**kwargs)
    return model

# S
def fasternet_s(**kwargs):
    model = FasterNet(
        mlp_ratio=2.0,
        embed_dim=128,
        depths=(1, 2, 13, 2),
        drop_path_rate=0.15,
        act_layer='RELU',
        fork_feat=True,
        **kwargs
        )

    return model

# M
def fasternet_m(**kwargs):
    model = FasterNet(
        mlp_ratio=2.0,
        embed_dim=144,
        depths=(3, 4, 18, 3),
        drop_path_rate=0.2,
        act_layer='RELU',
        fork_feat=True,
        **kwargs
        )

    return model

# L
def fasternet_l(**kwargs):
    model = FasterNet(
        mlp_ratio=2.0,
        embed_dim=192,
        depths=(3, 4, 18, 3),
        drop_path_rate=0.3,
        act_layer='RELU',
        fork_feat=True,
        **kwargs
        )

    return model

print("fasternet:", fasternet)
model = fasternet()
summary(model, input_size=(3, 224, 224))


print("fasternet_s:", fasternet_s)
model = fasternet_s()
summary(model, input_size=(3, 224, 224))


print("fasternet_m:", fasternet_m)
model = fasternet_m()
summary(model, input_size=(3, 224, 224))


print("fasternet_l:", fasternet_l)
model = fasternet_l()
summary(model, input_size=(3, 224, 224))

github有各个版本的预训练模型,大家可以测试一下。

nameresolutionacc#paramsFLOPsmodel
FasterNet-T0224x22471.93.9M0.34Gmodel
FasterNet-T1224x22476.27.6M0.85Gmodel
FasterNet-T2224x22478.915.0M1.90Gmodel
FasterNet-S224x22481.331.1M4.55Gmodel
FasterNet-M224x22483.053.5M8.72Gmodel
FasterNet-L224x22483.593.4M15.49Gmodel

官方给的数据:

五、实验分析

FasterNet在不同设备(CPU、GPU、ARM),精度-吞吐量和精度-延迟权衡方面具有最高的效率。

图像分类中,比较ImageNet-1k基准。具有类似TOP-1精度的模型被组合在一起。除MobileViT和EdgeNeXt的分辨率为256×256外,所有型号的分辨率均为224×224。OOM是内存不足的缩写。

关于COCO目标检测实例分割基准的结果,Flop是根据图像大小(1280,800)计算的。

分享完成~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1155300.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

xhadmin多应用SaaS框架怎么更新?

xhadmin是什么? xhadmin 是一套基于最新技术的研发的多应用 Saas 框架,支持在线升级和安装模块及模板,拥有良好的开发框架、成熟稳定的技术解决方案、提供丰富的扩展功能。为开发者赋能,助力企业发展、国家富强,致力于…

【设计模式】第13节:结构型模式之“享元模式”

一、简介 所谓“享元”,顾名思义就是被共享的单元。享元模式的意图是复用对象,节省内存,前提是享元对象是不可变对象。 实现:通过工厂模式,在工厂类中,通过一个Map或者List来缓存已经创建好的享元对象&am…

这样的软件测试报告模板你绝对没见过!!!

测试报告如此重要,那么我们应该如何撰写呢?为了让大家彻底掌握测试模板的撰写,所以本文结构如下: 1、测试报告写给谁看? 2、测试报告的基本骨架(通过|不通过)? 3、测试报告如何才能达…

超级搜索技术,普通人变强的唯一外挂

搜索效率:Google >微信公众号 >短视频 >百度 1、信息咨询搜索 在Google搜索栏前面加上 “” 限定关键词 intitle 限定标题 allintitle 限定标题多个关键词 intext 限定内容关键词 inurl 限定网址关键词 site 限定网址来源 imagesize 限定图片尺寸 filet…

[LeetCode]-27. 移除元素-26.删除有序数组中的重复项-88.合并两个有序数组

目录 27.移除元素 题目 思路 代码 26. 删除有序数组中的重复项 题目 思路 代码 88.合并两个有序数组 题目 思路 代码 总结 27.移除元素 27. 移除元素 - 力扣(LeetCode)https://leetcode.cn/problems/remove-element/description/ 题目 给你一…

【快报】正在把教学视频搬运到B站和油管

hello 大家好,我是老戴。 熟悉我的同学知道,我从14年开始录制GIS相关的教学视频,之前是放到优酷上给大家下载,后期发现很多人把视频弄下来淘宝上卖,然后我就把视频整体放到了我自己的网站上。 随着视频录制的数量越来…

C++归并排序算法的应用:计算右侧小于当前元素的个数

题目 给你一个整数数组 nums ,按要求返回一个新数组 counts 。数组 counts 有该性质: counts[i] 的值是 nums[i] 右侧小于 nums[i] 的元素的数量。 示例 1: 输入:nums [5,2,6,1] 输出:[2,1,1,0] 解释: 5 …

2024年湖北武汉建筑企业三类人员安全员ABC怎么报考

2024年湖北武汉建筑企业三类人员安全员ABC怎么报考 武汉建筑企业报考三类人员,建筑单位归属地在武汉,且有建筑相关的一些资Z,才可以申报一定数量的三类人员、安全员ABC、建筑安全员ABC、专职安全员C证、建设厅安全员ABC证。 建筑企业-报考建…

在线开发平台是什么?有哪些优势?

目录 一、什么是在线开发平台? 二、企业为什么选择在线开发平台? (1)风险低,回报高 (2)可视化操作更形象 (3)易维护 三、在线开发平台功能展示 技术介绍 随着互联网和信息…

Jetpack:024-Jetpack中的滚动事件

文章目录 1. 概念介绍2. 使用方法2.1 高级事件2.2 低级事件 3. 示例代码4. 内容总结 我们在上一章回中介绍了Jetpack中事件相关的内容,本章回中主要 介绍事件中的滚动事件。闲话休提,让我们一起Talk Android Jetpack吧! 1. 概念介绍 我们在…

三相马达的电机故障维护

目录 电机故障维护​编辑 更换电机操作 三相电路 热继电器 今天继续小编的工作经验的分享,今天就说说遇到的问题吧,今天组立熔接机出现故障,后面部分出现了“咕噜噜”的杂声,走到后面一听是电机发出的声音。没有办法了就开始拆…

Py之transformers_stream_generator:transformers_stream_generator的简介、安装、使用方法之详细攻略

Py之transformers_stream_generator:transformers_stream_generator的简介、安装、使用方法之详细攻略 目录 transformers_stream_generator的简介 1、Web Demo T1、original T2、stream transformers_stream_generator的安装 transformers_stream_generator的…

【Linux虚拟机】 JDK、Tomcat、MySQL安装配置讲解

目录 一、上传安装包到服务器 二、JDK与Tomcat安装 2.1 解压安装包 2.2 配置JDK环境变量 2.3 配置Tomcat环境 三、MySQL安装配置 3.1 删除默认数据库 3.2 安装mysql安装包 3.3 mysql初始化操作 四、后端接口部署 4.1 导入项目.war 4.2 新建数据库 4.3 运行服务器项目…

白票某度自媒体混剪剪辑视频素材/爬虫软件说明文档

大家好,我是淘小白~ 软件:某度自媒体混剪素材爬虫软件 语言:Python 说明文档: 1、自定义关键词采集 2、采集百度aigc视频素材,经过测试,使用剪映的文字成片某度视频素材,可过头条的原创检测…

SPSS单样本t检验

前言: 本专栏参考教材为《SPSS22.0从入门到精通》,由于软件版本原因,部分内容有所改变,为适应软件版本的变化,特此创作此专栏便于大家学习。本专栏使用软件为:SPSS25.0 本专栏所有的数据文件请点击此链接下…

OSFP基础实验

目录 题目:拓扑如下 实验步骤: 第一步:设计思路 第二步:搭建拓扑 第三步:配置命令 1)IP地址配置 2)OSPF配置 3)R3部分接口做静默接口 4)缺省路由 5&#xff09…

数据结构之“初窥门径”

目录 前言: 一,数据结构起源 二,基本概念和术语 2.1数据 2.2数据元素 2.3数据项 2.4数据对象 2.5数据结构 三,逻辑结构与物理结构 3.1逻辑结构 3.1.1集合结构 3.1.2线性结构 3.1.3树形结构 3.1.4图形结构 3.2物理结…

unity中meta文件GUID异常问题

错误信息: The .meta file Assets/Scripts/Editor/ConvertConfigToBinary/TxtConverter.cs.meta does not have a valid GUID and its corresponding Asset file will be ignored. If this file is not malformed, please add a GUID, or delete the .meta file and…

深度了解msvcr110.dll丢失的5个解决方法以及原因

在计算机使用过程中,我们经常会遇到一些错误提示,其中之一就是“msvcr110.dll丢失”。这个错误提示通常出现在运行某些程序时,它意味着计算机无法找到所需的动态链接库文件。本文将详细介绍msvcr110.dll丢失的原因以及5个解决方法。 一、msvc…

[Unity+智谱AI开放平台]调用ChatGLM Tuobo模型驱动AI小姐姐数字人

1.简述 本篇文章主要介绍一下,在Unity端,集成智谱AI开放平台提供的chatglm模型api,实现AI聊天互动相关的功能。从智谱AI官方站点上看到,提供有chatglm turbo的公共模型服务,能够实现32K超长上下文,应用到我…