华为诺亚 VanillaNet

news2024/7/4 4:55:44

文章标题:《VanillaNet: the Power of Minimalism in Deep Learning》
文章地址:https://arxiv.org/abs/2305.12972
github地址:https://github.com/huawei-noah/VanillaNet

华为诺亚方舟实验室和悉尼大学,2023年5月代码刚开源的文章


作者说,在卷积网络中加入人为设计的模块,达到了更好的效果,复杂度也增加了。尽管这些很深很复杂的神经网络被优化得很好,达到了令人满意的性能,但是这给部署带来了挑战。

比方说 ResNets 里的 shortcut 操作大量的芯片内存。另外,像 AS-MLPaxial shiftSwin Transformershift window self attention 这些复杂的操作需要复杂的工程实现,包括重写 CUDA 的代码。

ResNet 的发展看起来让大家放弃了用纯的卷积层来构造网络。就像 ResNet 它自己说的:没有 shortcut 的普通网络将出现梯度消失,导致 34 34 34 层的普通网络性能比 18 18 18 层的差。 另外,像 AlexNetVGGNet 这种简单网络的性能被 ResNetsViT 等深度复杂网络所超越,于是更少人花心思去设计和优化简单的网络。

于是提出了 VanillaNet,这是一种新颖的神经网络架构,强调设计的优雅和简单,同时在计算机视觉任务中保持卓越的性能。VanillaNet 通过避免过多的 depthshortcuts 和复杂的操作(如self-attention)来实现这一点,从而产生了一系列精简的网络,这些网络解决了固有的复杂性问题,非常适合资源有限的环境。

(1)

为了训练这个 VanillaNet,作者对面临的挑战进行了全面分析,并且制定了叫做 deep training 的策略。简单来说就是准备好网络之后,在训练的时候逐渐消除卷积层之间的非线性层(激活函数),最后把卷积层也合并成一个。

假设激活函数(通常可以是ReLUTanh)表示为 A ( x ) A(x) A(x),再结合一个恒等映射(identity mapping),写成如下形式:

A ′ ( x ) = ( 1 − λ ) A ( x ) + λ x (1) A'(x) = (1 - \lambda) A(x) + \lambda x \tag{1} A(x)=(1λ)A(x)+λx(1)
其中 λ \lambda λ 是个超参数,用于调整这个函数 A ′ ( x ) A'(x) A(x) 的非线性能力。

设总的 epochs 数是 E E E,当前是第 e e e 个 epoch,则 λ = e E \lambda = \dfrac{e}{E} λ=Ee

所以开始的时候 λ = 0 \lambda = 0 λ=0,这表现为一个完整的激活函数,没有恒等映射。
随着训练的进行,最后 λ = 1 \lambda = 1 λ=1,两个卷积之间没有激活函数了。

画个图给你看:

在这里插入图片描述

最后把这两层卷积也合并起来,bn层也融合进来,BN融合的公式如下:

在这里插入图片描述

代码如下:

    def _fuse_bn_tensor(self, conv, bn):
        kernel = conv.weight
        bias = conv.bias
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (bias - running_mean) * gamma / std

(2)

这么弄了之后,为了增强网络的非线性的能力,又提出了一种有效的,基于级数的激活函数。它包含多个可学习的仿射变换。
原文如下所示:

在这里插入图片描述

公式写得很复杂,根据代码的理解,简单来说就是设计了一组卷积核,参数是可学习的,对激活后的数据做一次卷积,再加上BN。
把这些操作包装起来叫做自己的 Activation。画个图给你看:
在这里插入图片描述
他说这个实际上比真正的卷积的计算量要小,并且给了一堆证明:

在这里插入图片描述

这个 Activation 是作者封装的激活函数,代码如下:

class activation(nn.ReLU):
    def __init__(self, dim, act_num=3, deploy=False):
        super(activation, self).__init__()
        self.act_num = act_num
        self.deploy = deploy
        self.dim = dim
        self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num*2 + 1, act_num*2 + 1))
        if deploy:
            self.bias = torch.nn.Parameter(torch.zeros(dim))
        else:
            self.bias = None
            self.bn = nn.BatchNorm2d(dim, eps=1e-6)
        weight_init.trunc_normal_(self.weight, std=.02)

    def forward(self, x):
        if self.deploy:
            return torch.nn.functional.conv2d(
                super(activation, self).forward(x), 
                self.weight, self.bias, padding=self.act_num, groups=self.dim)
        else:
            return self.bn(torch.nn.functional.conv2d(
                super(activation, self).forward(x),
                self.weight, padding=self.act_num, groups=self.dim))

    def _fuse_bn_tensor(self, weight, bn):
        kernel = weight
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (0 - running_mean) * gamma / std
    
    def switch_to_deploy(self):
        kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)
        self.weight.data = kernel
        self.bias = torch.nn.Parameter(torch.zeros(self.dim))
        self.bias.data = bias
        self.__delattr__('bn')
        self.deploy = True

VanillaNet 的主要卖点就是以上两个东西。



整体网络很简单,看起来像 VGG 或是 AlexNet

在这里插入图片描述

实验部分自己看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/617473.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于HAL库的STM32的单定时器的多路输入捕获测量脉冲频率(外部时钟实现)

目录 写在前面 一般的做法(定时器单通道输入捕获) 以外部时钟的方式(高低频都适用) 测试效果 写在前面 STM32的定时器本身有输入捕获的功能。可选择双端捕获,上升沿捕获或者是下降沿捕获。对应捕获频率来说,连续捕获上升沿或下降沿的时间间隔就是其脉…

手把手教你F103工程文件的创建并且通过protesu仿真验证创建工程文件的正确性(低成本)

目录 一、新建工程文件夹 二、新建一个工程框架 三、添加文件 四、仿真验证 五、仿真调试中遇到的问题并解决 一、新建工程文件夹 新建工程文件夹分为 2 个步骤:1,新建工程文件夹;2,拷贝工程相关文件。 1.新建工程文件 首先…

【04】STM32·HAL库开发-MDK5使用技巧 |文本美化 | 代码编辑技巧 | 查找与替换技巧 | 编译问题定位 | 窗口视图初始化

目录 1.文本美化(熟悉)1.1编辑器设置1.2字体和颜色设置1.3用户关键字设置1.4代码提示&语法检测1.5global.prop文件妙用 2.代码编辑技巧(熟悉)2.1Tab键的妙用2.2快速定位函数或变量被定义的地方2.3快速注释&快速取消注释 3…

python面向对象操作2(速通版)

目录 一、私有和公有属性的定义和使用 1.公有属性定义和使用 2.私有属性 二、继承 1.应用 2.子类不能用父类的私有方法 3.子类初始化父类 4.子类重写和调用父类方法 5.多层继承 6.多层继承-初始化过程 7.多继承基本格式 8.多层多继承时的初始化问题 9.多继承初始化…

云原生Docker Cgroups资源控制操作

资源控制 Docker 通过 Cgroup 来控制容器使用的资源配额,包括 CPU、内存、磁盘三大方面, 基本覆盖了常见的资源配额和使用量控制。 Cgroup 是 ControlGroups 的缩写,是 Linux 内核提供的一种可以限制、记录、隔离进程组所使用的物理资源(如…

Node服务器-express框架

1 Express认识初体验 2 Express中间件使用 3 Express请求和响应 4 Express路由的使用 5 Express的错误处理 6 Express的源码解析 一、手动创建express的过程: 1、在项目文件的根目录创建package.json文件 npm init 2、下载express npm install express 3、基本…

kafka3

分区副本机制 kafka 从 0.8.0 版本开始引入了分区副本;引入了数据冗余 用CAP理论来说,就是通过副本及副本leader动态选举机制提高了kafka的 分区容错性和可用性 但从而也带来了数据一致性的巨大困难! 6.6.2分区副本的数据一致性困难 kaf…

多模态学习

什么是多模态学习? 模态 模态是指一些表达或感知事物的方式,每一种信息的来源或者形式,都可以称为一种模态 视频图像文本音频 多模态 多模态即是从多个模态表达或感知事物 多模态学习 从多种模态的数据中学习并且提升自身的算法 多…

【k8s 系列】k8s 学习三,docker回顾,k8s 起航

k8s 逐渐已经作为一个程序员不得不学的技术,尤其是做云原生的兄弟们,若你会,那么还是挺难的 学习 k8s ,实践尤为重要,如果身边有自己公司就是做云的,那么云服务器倒是不用担心,若不是&#xff…

IMX6ULL PHY 芯片驱动

前言 之前使用 IMX6ULL 开发板时遇到过 nfs 挂载不上的问题,当时是通过更换官方新版 kernel 解决的,参考《挂载 nfs 文件系统》。 今天,使用自己编译的 kernel 又遇到了该问题,第二次遇到了,该正面解决了。 NFS 挂载…

18JS09——作用域

作用域 一、作用域1、作用域 二、变量的作用域1、变量作用域的分类2、全局变量3、局部变量4、全局变量和局部变量区别 三、作用域链 目标: 1、作用域 2、变量的作用域 3、作用域链 一、作用域 1、作用域 通常来说,一段程序代码中所用到的名字并不总是有…

python基础----06-----文件读写追加操作

一 文件编码概念 思考:计算机只能识别: 0和1,那么我们丰富的文本文件是如何被计算机识别,并存储在硬盘中呢? 答案:使用编码技术(密码本)将内容翻译成0和1存入。 常见编码有UTF8,gbk等等。不同的编码,将内…

vulnhub靶场之DC-3渗透教程(Joomla CMS)

目录 0x01靶机概述 0x02靶场环境搭建 0x03主机发现 0x04靶场渗透过程 ​ 0x05靶机提权 0x06渗透实验总结 0x01靶机概述 靶机基本信息: 靶机下载链接https://download.vulnhub.com/dc/DC-3-2.zip作者DCAU发布日期2020年4月25日难度中等 0x02靶场环…

【Flink】DataStream API使用之输出算子(Sink)

输出算子(Sink) Flink作为数据处理框架,最终还是需要把计算处理的结果写入到外部存储,为外部应用提供支持。Flink提供了很多方式输出到外部系统。 1. 连接外部系统 在Flink中我们可以在各种Fuction中处理输出到外部系统&#xf…

C#读写参数到APP.Config

C#读写参数到APP.Config 介绍程序Demo常见错误 介绍 系统在开发时,可能需要设置默认参数,比如数据库的链接参数,某个参数的默认数据等等。对于这些数据,可直接在app.config中读取。 在读写时,需要先了解configuratio…

echo命令在Unix中的作用以及其常见用法

在Unix系统中,"echo"是一个常用的命令,用于在终端或脚本中输出文本。它可以将指定的字符串或变量的值打印到标准输出,从而向用户提供信息或进行调试。 本文将详细介绍"echo"命令在Unix中的作用以及其常见用法。 基本语法…

Keras-3-实例1-二分类问题

1. 二分类问题 1.1 IMDB 数据集加载 IMDB 包含5w条严重两极分化的评论,数据集被分为 2.5w 训练数据 和 2.5w 测试数据,训练集和测试集中的正面和负面评论占比都是50% from keras.datasets import imdb(train_data, train_labels), (test_data, test_l…

UE5 Chaos破碎系统学习1

在UE5中,Chaos破碎系统被直接进行了整合,本篇文章就来讲讲chaos的基础使用。 1.基础破碎 1.首先选中需要进行破碎的模型,例如这里选择一个Box,然后切换至Fracture Mode(破碎模式): 2.点击右侧…

JAVA实现打字练习软件

转眼已经学了一学期的java了,老师让我们根据所学知识点写一个打字练习软件的综合练习。一开始我也不是很有思路,我找了一下发现csdn上关于这个小项目的代码也不算很多,所以我最后自己在csdn查了一些资料,写了这么一个简略版本的打…

【C++】——list的介绍及模拟实现

文章目录 1. 前言2. list的介绍3. list的常用接口3.1 list的构造函数3.2 iterator的使用3.3 list的空间管理3.4 list的结点访问3.5 list的增删查改 4. list迭代器失效的问题5. list模拟实现6. list与vector的对比7. 结尾 1. 前言 我们之前已经学习了string和vector&#xff0c…