VanillaNet 原理与代码解读

news2024/10/6 8:35:20

paper:VanillaNet: the Power of Minimalism in Deep Learning

official implementation: GitHub - huawei-noah/VanillaNet

存在的问题 

虽然复杂网络的性能很好,但它们日益增加的复杂性给部署带来了挑战。例如,ResNets中的shortcut操作在合并不同层的特征时耗费了大量的off-chip memory traffic。再比如AS-MLP中的axial shift操作以及Swin Transformer中的shift window self-attention操作都需要复杂的工程实现,包括重写CUDA代码。

本文的创新点

本文提出了VanillaNet,一种新的神经网络架构,有着简单而优雅的设计,同时在视觉任务中保持了显著的性能。VanillaNet通过舍弃过多的深度、shortcut以及self-attention等复杂的操作,解决了复杂度的问题,非常适合资源有限的环境。

方法介绍

A Vanilla Neural Architecture

大多数SOTA分类网络的架构都包含三个部分:一个stem block将输入图片由3个通道转换为多通道并进行下采样,一个main body提取特征,一个全连接层用来输出分类结果。其中main body通常包含4个stages,每个stage堆叠多个相同的blocks,每个stage后特征图的分辨率降低通道数增加。不同网络的区别主要在于blocks的设计不同。

本文提出的VanillaNet也遵循这种流行的设计架构,不同的是,每个stage只包含一层网络层从而构建一个极度简洁的网络。

VanillaNet-6的结构如图(1)所示,具体包括:stem部分是一个stride=4的4x4x3xC的卷积层将3通道的输入图片映射为C通道的feature map。stage1,2,3中,用一个stride=2的maxpooling进行下采样同时通道数翻倍。在stage4中保持通道数不变,因为它后面是一个average pooling层。最后一个全连接层输出分类结果。为了使用最小的计算量所有的卷积层都是1x1大小,每个卷积层后跟一个BN层和一个激活函数。

尽管VanillaNet的结构简单且层数很少,但其较弱的非线性限制了其性能,接下来作者又提出了一系列方法来解决该问题。

Training of Vanilla Networks

Deep Training Strategy

深度训练策略的核心思想是在训练初期训练两个卷积层和一个激活函数而不是只训练一个卷积层。随着训练的进行,激活函数逐渐演变成一个恒等映射。在训练结束时,通过结构重参数化,可以将两个卷积层合并为一个,从而减少推理时间。

对于一个激活函数 \(A(x)\)(例如常见的ReLU和Tanh),我们将它和一个恒等映射结合起来,如下

其中 \(\lambda\) 是一个超参用来平衡修改后的激活函数 \(A'(x)\) 的非线性。假设当前的epoch和训练完整的epoch数分别为 \(e\) 和 \(E\),我们设置 \(\lambda =\frac{e}{E} \)。这样,在训练刚开始时 \(e=0,A'(x)=A(x)\),这意味网络具有很强的非线性。当训练收敛完成后,\(A'(x)=x\),这意味着两个卷积层中间没有激活函数了,我们就可以通过结构重参数化方法将其合并成一个卷积。

Series Informed Activation Function

简单网络和浅层网络的性能较差主要是因为较差的非线性。有两种提高网络非线性的方法:叠加非线性激活层和提高每个激活层的非线性能力。大多数网络选择前者,而本文选择后者,不过也是通过堆叠的方式。(这里文中说前者是serially stacking,而后者是concurrently stacking,个人理解应该都是连续堆叠,只不过前者通常是卷积层和激活函数一起堆叠才导致网络越来越深,而本文是只堆叠激活函数)。

具体是通过加权堆叠的方式,其中 \(n\) 表示堆叠的个数,\(a_{i},b_{i}\) 分别是每个激活函数的scale和bias。通过这种堆叠,激活函数的非线性能力可以大大提高。

式(5)可以看作是数学中的级数series,为了进一步提高series的近似能力,作者使series-based的函数能够通过改变来自neighbors的输入来学习全局信息,类似于BNET。具体对于一个输入特征 \(x\in\mathbb{R}^{H\times W\times C}\),激活函数可以表示为

其中 \(h\in\left \{ 1,2,...,H \right \} ,w\in\left \{ 1,2,...,W \right \} ,c\in\left \{ 1,2,...,C \right \} \)。

实验结果

VanillaNet的具体结构如表6

在ImageNet数据集上,和其它一些SOTA模型的对比如表4

 可以看出,VanillaNet仅用10层就取得了80.57的top-1精度,不同的层数下,和相同精度的其它模型相比,具有显著的速度优势。

代码解读

首先是深度训练策略,在models/vanillanet.py中,class VanillaNet()包含了网络的具体实现。其中self.deploy用来表示是否为推理阶段,当self.deploy=False时表示是训练阶段,可以看到stem阶段包含self.stem1self.stem2,main body阶段每个block都包含self.conv1self.conv2,最后的全连接层也包含self.cls1self.cls2。当训练完成后即推理阶段,所有的阶段的1和2之间的激活函数都变成了恒等映射或者说操作1和2之间没有激活函数了,然后通过结构重参数化将operation 1,2合并成1个。

其中self.act_learn即为式(1)中的 \(\lambda\),在main.py中,act_learn随着训练的进行而变化。

act_learn = epoch / args.decay_epochs * 1.0
model.module.change_act(act_learn)

接着是激活函数的堆叠,这里作者将激活函数的简单加权堆叠即式(5)演变为可以学习临近输入的式(6)后,可以通过深度卷积来实现,其中堆叠个数超参 \(n=3\),即代码中的self.act_num

# Series informed activation function. Implemented by conv.
class activation(nn.ReLU):
    def __init__(self, dim, act_num=3, deploy=False):
        super(activation, self).__init__()
        self.act_num = act_num
        self.deploy = deploy
        self.dim = dim
        self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
        if deploy:
            self.bias = torch.nn.Parameter(torch.zeros(dim))
        else:
            self.bias = None
            self.bn = nn.BatchNorm2d(dim, eps=1e-6)
        weight_init.trunc_normal_(self.weight, std=.02)

    def forward(self, x):
        if self.deploy:
            return torch.nn.functional.conv2d(
                super(activation, self).forward(x), 
                self.weight, self.bias, padding=self.act_num, groups=self.dim)
        else:
            return self.bn(torch.nn.functional.conv2d(
                super(activation, self).forward(x),
                self.weight, padding=self.act_num, groups=self.dim))

    def _fuse_bn_tensor(self, weight, bn):
        kernel = weight
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (0 - running_mean) * gamma / std

    def switch_to_deploy(self):
        kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)
        self.weight.data = kernel
        self.bias = torch.nn.Parameter(torch.zeros(self.dim))
        self.bias.data = bias
        self.__delattr__('bn')
        self.deploy = True

疑问

在作者的官方解读卷积的尽头不是Transformer,极简架构潜力无限 - 知乎下也有评论指出,本来式(5)连续堆叠激活函数增加非线性的想法很好,但演变为式(6)后,就又还原成卷积了,激活函数加权相加的权重 \(a_{i,j,c}\) 就是卷积核的权重,官方实现中也是通过depth convolution实现的series informed activation,这样把外面的卷积层移到激活函数里了,能叫层数减少了吗?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/680167.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浏览器工作原理

浏览器(也称为网络浏览器或互联网浏览器)是安装在我们设备上的软件应用程序,使我们能够访问万维网。在阅读这篇文字时,你实际上正在使用一个浏览器。 有许多浏览器正在被使用,截至2022年,使用最多的是&…

为了找到好工作,花2个月时间整理了3.5W字的自动化测试面试题(答案+学习路线)!

从5月初开始找工作到现在,先后面试了阿里巴巴、字节跳动、网易、快手的测试开发岗。 大公司对于测试开发的要求相比来说高很多,要求掌握的知识点的广度和深度层次也比较高,遂整理了这两个月的面试题目文档供大家参考,同时也是为了…

基于java+swing+mysql商城购物系统

基于javaswingmysql商城购物系统 一、系统介绍二、功能展示1.项目骨架2.主界面3.用户登陆4.添加商品类别5、添加商品6、商品管理 四、其它1.其他系统实现五.获取源码 一、系统介绍 项目类型:Java SE项目 项目名称:商城购物系统 用户类型:双…

【C++学习】线程库 | IO流 | 空间配置器

🐱作者:一只大喵咪1201 🐱专栏:《C学习》 🔥格言:你只管努力,剩下的交给时间! 一、线程库 在C11之前,涉及到多线程问题,都是和平台相关的,比如w…

【Python爬虫开发基础⑥】计算机网络基础(Web和HTTP)

专栏:python网络爬虫从基础到实战 欢迎订阅!近期还会不断更新~ 另外:如果想要看更多的计算机网络知识,可以关注我的专栏:计算机网络 往期推荐: 【Python爬虫开发基础①】Python基础(变量及其命名…

【数据结构】特殊矩阵的压缩存储

🎇【数据结构】特殊矩阵的压缩存储🎇 🌈 自在飞花轻似梦,无边丝雨细如愁 🌈 🌟 正式开始学习数据结构啦~此专栏作为学习过程中的记录🌟 文章目录 🎇【数据结构】特殊矩阵的压缩存储&#x1f38…

C语言学习(二十六)---指针练习题(二)

在上节的内容中,我们进一步学习了有关指针的内容,并做了一些关于指针的题目,今天我们将继续练习一些指针的题目,以便大家更好的理解和掌握指针的知识,好了,话不多说,开整!&#xff0…

【c++11】 左值引用和右值引用

c11特性 右值引用左值引用和右值引用左值引用右值引用比较 右值引用的应用左值引用的短处右值引用解决问题移动构造 STL的改动move()函数结语 右值引用 c从出现就有着引用的语法,但是在c11后又新增了右值引用的新特性,以往所学的引用成了左值引用。非左…

代码随想录算法训练营第四十二天| 背包问题

标准背包问题 有n件物品和一个最多能背重量为w 的背包。 第i件物品的重量是weight[i],得到的价值是value[i] 。每件物品只能用一次,求解将哪些物品装入背包里物品价值总和最大。 举一个例子: 背包最大重量为4。 物品为: 重量价…

5.2.3目录与文件之权限意义

现在我们知道了Linux系统内文件的三种身份(拥有者、群组与其他人),知道每种身份都有三种权限(rwx), 已知道能够使用chown, chgrp, chmod去修改这些权限与属性,当然,利用ls -l去观察文…

《C++高级编程》读书笔记(十:揭秘继承技术)

1、参考引用 C高级编程(第4版,C17标准)马克葛瑞格尔 2、建议先看《21天学通C》 这本书入门,笔记链接如下 21天学通C读书笔记(文章链接汇总) 1. 使用继承构建类 1.1 扩展类 当使用 C 编写类定义时&#xf…

WMS中Choreographer 配合 VSYNC 中断信号

WMS中Choreographer 配合 VSYNC 中断信号 1、了解SurfaceFlinger中VSYNC信号刷新2、Choreographer 舞蹈编导2.1 Choreographer初始化2.2 FrameHandler中处理任务2.3 FrameDisplayEventReceiver初始化3.4 简易流程图 3、ViewRootImpl中scheduleTraversals3.1 postCallback 通过n…

java——IO与NIO

文章目录 1. 传统IO模型字节流字符流 2. NIO模型 Java中的IO(输入输出)是用于在程序中读取和写入数据的一种机制。Java提供了两种不同的IO模型:传统的IO模型和NIO(New IO)模型。 1. 传统IO模型 在传统的IO模型中&…

WPF本地化/国际化,多语言切换

之前写过winformwinform使用本地化,中英文切换_winform 中英文切换_故里2130的博客-CSDN博客 基本的技术差不多,但是后来又发现了一个ResXManager工具,可以更好方便快捷的使用。 首先下载,网络不好的话,去官网下载&a…

01背包简介

01背包问题(0/1 Knapsack problem)是一个经典的动态规划问题,它描述了在给定容量限制的情况下,如何选择一组物品放入背包,以使得物品的总价值最大化。 问题描述: 假设有一个背包,其容量为C。现…

VulnHub项目:Fawkes

1、靶机地址 HarryPotter: Fawkes ~ VulnHub 该篇为哈利波特死亡圣器系列最终部,也是最难的一个靶机,难度真的是逐步提升!!! 2、渗透过程 确认靶机IP,kali IP,探测靶机开放端口 详细的扫描…

ICLR 23 | 工业视觉小样本异常检测最新网络:Graphcore

来源:投稿 作者:橡皮 编辑:学姐 论文链接:https://openreview.net/pdf?idxzmqxHdZAwO 论文代码:尚未开源 1.背景 随着人工智能中深度视觉检测技术的快速发展,检测工业产品表面的异常/缺陷受到了前所未有…

scratch lenet(11): C语言实现 squashing function

文章目录 1. 目的2. Sigmoidal Function2.1 S2 用到 Sigmoidal Function2.2 Sigmoidal Function 的定义 3. Squashing Function3.1 改用 Sigmoid Suahsing function 术语3.2 具体到 hyperlolic tangent 这一 squahsing function 4. Squahsing function 的实现References 1. 目的…

设计模式之观察者模式笔记

设计模式之观察者模式笔记 说明Observer(观察者)目录观察者模式示例类图抽象主题角色类抽象观察者类具体主题角色类具体的观察者角色类测试类 说明 记录下学习设计模式-观察者模式的写法。JDK使用版本为1.8版本。 Observer(观察者) 意图:定义对象间的一种一对多的依赖关系&a…

Gradle构建系统macOS安装与使用

1.打开gradle.org并点击安装 2.先决条件 ,确认安装JDK1.8或者更高版本已安装 在终端输入brew install gradle进行安装 安装成功如下: 查看安装版本号gradle -v 使用gradle 1.创建目录demo并进入该目录 mkdir demo cd demo 2.gradle init 使用Gradle开始构建 输入2开始构建应…