RepVGG论文理解与代码分析

news2024/10/7 4:27:27

最近,看到很多轻量化工作是基于RepVGG改进而来,决定重新回顾一下RepVGG,并在此记录一些理解与心得。

论文地址:https://arxiv.org/abs/2101.03697

Introduction

RepVGG通过结构重参数化思想,让训练网络的多路结构(多分支模型训练时的优势——性能高)转换为推理网络的单路结构(模型推理时的好处——速度快、省内存),从而达到推理速度快与模型性能高兼备的效果。

Method

根据Resnet与Inception等论文论证,与单路直筒结构相比,增加shortcut与muti-branch可以提高模型的性能,RepVGG在VGG的基础上做了相关实验,从下表中可以发现,Identity branch与1x1 branch能够提高准确率,但同时推理时间也延长了许多。
在这里插入图片描述

为了在保证模型精度的基础上,降低模型推理时间,RepVGG采用了结构重构的方法,即在训练时使用复杂结构,而测试时将复杂模型重构成单路直筒型模型,如下图所示。图B中,RepVGG在training时加入了shortcut和1x1 branch,而在test时,RepVGG重构成直筒型VGG模型,如图c。

在这里插入图片描述

Re-param for Plain Inference-time Model
如下图所示,我们的目的是将带有BN层的3x3卷积与1x1卷积以及Identity mapping融合成一个3x3卷积,如何融合呢?操作分为两部:1.融合conv与bn;2.合并卷积。
在这里插入图片描述
1.Merge BN and Conv together

在这里插入图片描述

2. Fuse 3x3conv 1x1conv and identity

假设输入特征尺寸为(1x2x3x3),输出特征尺寸为输入相同,stride=1,那么conv3x3的卷积过程如下图所示,conv的特征尺寸为(2x2x3x3),首先对输入扩边,padding=kernel_size // 2,然后从左上角开始滑动窗口,最后获得右边的输出特征。
在这里插入图片描述

conv_1x1的卷积相较于3x3卷积更为简单,只需要将卷积参数与对应的单个输入参数相乘再相加。

在这里插入图片描述
为了使conv_1x1可以与conv_3x3线性相加,我们可以将conv_1x1的1x1卷积核扩边成3x3的卷积核,这样形式上就可以变成conv_3x3,同时,结果与conv_1x1一致。
在这里插入图片描述
同样,Identity mapping可以看作特殊的conv_1x1,参数是固定的,即输入特征的某一层对应的卷积参数为1,其余均为0。我们可以按照上面介绍的conv_1x1与conv_3x3转化方式,将Identity mapping转化成conv_3x3的形式,如下图所示。
在这里插入图片描述
如此,我们就可以将带有BN层的3x3卷积与1x1卷积以及Identity mapping融合成一个3x3卷积。
在这里插入图片描述

code

我们看一下代码如何实现RepVGG,首先,我们看到deploy的控制Flag,当depoly为False时,模型处于training状态,结构没有重参化,当deploy为True时,模型处于test状态,需要对结构重参。

结构重参时,需要调用switch_to_deploy(),该函数的作用是调用self.get_equivalent_kernel_bias()获得重构后的kernel, bias, 对重构卷积赋值–self.rbr_reparam = nn.Conv2d,其中Conv2d的参数是kernel,偏移是bias,将参数detach脱离计算图,并删除其余卷积操作内存(否则无法正常运行)。

get_equivalent_kernel_bias(self):逻辑很简单,分为两部:1.调用self._fuse_bn_tensor()将conv与bn融合。需要注意的是,identity没有conv,只有BN操作,所以在self._fuse_bn_tensor()函数中需要构造卷积kernel,kernel_value[i, i % input_dim, 1, 1] = 1,需要注意分组卷积的形式。函数的返回值 return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

class RepVGGBlock(nn.Module):
    '''RepVGGBlock is a basic rep-style block, including training and deploy status
    This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
    '''
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
        super(RepVGGBlock, self).__init__()
        """ Initialization of the class.
        Args:
            in_channels (int): Number of channels in the input image
            out_channels (int): Number of channels produced by the convolution
            kernel_size (int or tuple): Size of the convolving kernel
            stride (int or tuple, optional): Stride of the convolution. Default: 1
            padding (int or tuple, optional): Zero-padding added to both sides of
                the input. Default: 1
            dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
            groups (int, optional): Number of blocked connections from input
                channels to output channels. Default: 1
            padding_mode (string, optional): Default: 'zeros'
            deploy: Whether to be deploy status or training status. Default: False
            use_se: Whether to use se. Default: False
        """
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels
        self.out_channels = out_channels

        assert kernel_size == 3
        assert padding == 1

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.ReLU()

        if use_se:
            raise NotImplementedError("se block not supported yet")
        else:
            self.se = nn.Identity()

        if deploy:
            self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)

        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else None
            self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)
            self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)

    def forward(self, inputs):
        '''Forward process'''
        if hasattr(self, 'rbr_reparam'):
            return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,
                                     kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
                                     padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__('rbr_dense')
        self.__delattr__('rbr_1x1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True

Drawback

在RepVGG中存在一个致命的缺点,我们知道RepVGG是为了轻量化而生,但是结构重参导致模型参数方差较大,从而引起量化误差,实验证明,RepVGG通过训练后量化会将准确率降低20%,与此同时,由于特殊的training与test结构差异,RepVGG很难进行感知量化训练。为了对量化友好,Repopt提出了很好的思路解决这个问题,我们下篇继续介绍Repopt。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/52106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[附源码]JAVA毕业设计-高中辅助教学系统-(系统+LW)

[附源码]JAVA毕业设计-高中辅助教学系统-(系统LW) 目运行 环境项配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技…

[附源码]Python计算机毕业设计Django电商小程序

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

「湖仓一体」释放全量数据价值!巨杉数据库亮相2022沙丘大会

近日,由数字化研究与知识服务平台沙丘社区主办的2022沙丘大会成功举办,巨杉数据库受邀出席大会,并在湖仓一体专场进行《湖仓一体释放全量数据价值》的主题演讲。 近日,由数字化研究与知识服务平台沙丘社区主办的2022沙丘大会以线上…

低代码开发平台助力生产管理:采购成本管理的优化

采购是企业经营活动中的重要环节,它处于企业生产经营活动的最前端,同时也是成本费用中占较大比重的环节。采购成本主要是指企业在生产过程中用于采购产品及服务等交易活动所产生的一系列支出,包括物资的购买价格、税费、运输费等,…

X电容和Y电容

X电容和Y电容 1安规电容 安规电容之所以称之为安规,它是指用于这样的场合:即电容器失效后,不会导致电击,也不危及人身安全。安规电容包含X电容和Y电容两种,它普通电容不一样的是,普通电容即使在外部电源断…

从0到1 Webpack搭建Vue3开发、生产环境

起步 创建项目目录 mkdir webpack-vue3-demo初始化 package.json npm init -y参考文档 安装 webpack webpack-cli webpack-dev-server webpack-merge npm install webpack webpack-cli webpack-dev-server webpack-merge --save-dev创建配置文件 mkdir build cd build …

vscode配置git和c++

vscode配置git和cvscode配置c1.必要配置2.可选配置配置git1.命令行使用git2.IDE使用git3.一点补充过滤文件设置别名之前一直在用vscodepython做实验,现在想利用vscode复习下c和git顺便做做力扣。vscode配置c 1.必要配置 由于vscode只是个编辑器,所以首…

JVM之运行时数据区 面试相关

JVM创建对象的方式创建对象的步骤内存布局对象访问定位![请添加图片描述](https://img-blog.csdnimg.cn/fa106bd4936440b28e1c359d57ba4d25.png)直接内存创建对象的方式 new 常见方式 Xxx静态方法 XxxBuilder/XxxFactory的静态方法Class的newInstance() 反射,只能空…

魔兽世界开服架设服务器搭建教程

魔兽世界开服架设服务器搭建教程 准备工具: 1、装有windows98/2000/xp/2003系统、内存至少256M的电脑一台 2、魔兽服务器端一个 3、服务器一台(魔兽世界对服务器的配置要求并不是很高,CPU 16核 、16线程 带宽最好是选择50M的,游戏…

美食杰项目 -- 发布菜谱(七)

目录前言:具体实现思路:步骤:1. 展示美食杰发布菜谱页效果2. 引入element-ui3. 代码总结:前言: 本文给大家讲解,美食杰项目中 实现发布菜谱页的效果,和具体代码。 具体实现思路: 按…

骑行运动耳机哪个好,列举五款适合在骑行过程中佩戴的耳机

谈起耳机,人们第一印象应该是传统的入耳式耳机,这种耳机在音质以及体积上确实占据了一定的优势,但还是存在着不少的缺点,特别是佩戴的过程中会让我们的耳道保持堵塞状态,导致中耳炎等疾病的频频发生,而这两…

ASEMI-KBL410是什么元器件,kbl410整流桥参数

编辑-Z 俗话说,时势造英雄,整流桥大军中有一款整流桥KBL410有哪些你所不知道的?KBL410是什么元器件?kbl410整流桥参数是多少? KBL410参数描述 型号:KBL410 封装:KBL-4 电性参数:…

ARC113D题解

ARC113D - Sky Reflector 题目大意 有一个nnn行mmm列的表格,你可以在每个表格中填入一个111到kkk之间的整数,定义序列A,BA,BA,B如下: 对于每一个i1,2,…,ni1,2,\dots,ni1,2,…,n,AiA_iAi​是第iii行的最小值对于每一个j1,2,…,…

强化学习:Actor-Critic、SPG、DDPG、MADDPG

马尔可夫决策过程(MDP) MDP 由元组 (S,A,P,R,γ)(S, A, P, R, \gamma)(S,A,P,R,γ) 描述,分别表示有限状态集、有限动作集、状态转移概率、回报函数、折扣因子 。与马尔可夫过程不同,MDP的状态转移概率是包含动作的,即…

Express 7 指南 - 开发中间件

Express Express 中文网 本文仅用于学习记录,不存在任何商业用途,如侵删 文章目录Express7 指南 - 开发中间件7.1 概述7.2 例子7.2.1 中间件函数 myLogger7.2.2 中间件函数 requestTime7.2.3 中间件函数 validateCookies7.3 可配置的中间件7 指南 - 开发…

中断系统中的设备树__Linux对中断处理的框架及代码流程简述

1 异常向量入口: arch\arm\kernel\entry-armv.S .section .vectors, "ax", %progbits .L__vectors_start: W(b) vector_rst W(b) vector_und W(ldr) pc, .L__vectors_start 0x1000 W(b) vector_pabt W(b) vector_dabt W(b) …

14 【接口规范和业务分层】

14 【接口规范和业务分层】 1.接口规范-RESTful架构 1.1 什么是REST REST全称是Representational State Transfer,中文意思是表述(编者注:通常译为表征)性状态转移。 它首次出现在2000年Roy Fielding的博士论文中,R…

教程九 在Go中使用Energy创建跨平台GUI应用 - Go绑定变量JS调用

介绍 Energy Go中定义的变量、结构和函数绑定,在JS中使用。 在Energy中不只可以调用 JS 和 事件机制,也可以通过Go绑定在Go中定义的一些变量函数在JS中调用,在使用的时候就如同在JS调用本身定义的函数一样方便。 运行此示例,需…

Flutter FlutterActivity找不到

Flutter FlutterActivity找不到1.大多数报错应该都是这个样子2.接下来找到我们自己安装的 flutterSDK 路径我放在下面 flutterSdk\flutter_windows_3.3.4-stable\flutter\bin\cache\artifacts\engine\android-arm 3.这个界面大家应该都很熟悉吧(这是快捷键 ctrlshiftalts) …