【深度学习】RepVGG解析和学习体会,结构重参数化的后的速度比较,代码实现

news2024/11/13 10:10:58

文章目录

  • 前言
  • 0. Vgg
  • 1.RepVGG Block 详解


前言

论文名称:RepVGG: Making VGG-style ConvNets Great Again
论文下载地址:https://arxiv.org/abs/2101.03697
官方源码(Pytorch实现):https://github.com/DingXiaoH/RepVGG
大神的讲解:
bilibili视频讲解:https://www.bilibili.com/video/BV15f4y1o7QR
https://blog.csdn.net/qq_37541097/article/details/125692507


0. Vgg

VGG网络是2014年由牛津大学著名研究组VGG (Visual Geometry Group) 提出的。在2014到2016年(ResNet提出之前),VGG网络可以说是当时最火并被广泛应用的Backbone。后面由于各种新的网络提出,论精度VGG比不上ResNet,论速度和参数数量VGG比不过MobileNet等轻量级网络,慢慢的VGG开始淡出人们的视线-但因为其堆叠结构简单,是很多网络结构的backbone。当VGG已经被大家遗忘时,2021年清华大学、旷视科技以及香港科技大学等机构共同提出了RepVGG网络,希望能够让VGG-style网络Great Again。

在这里插入图片描述
通过论文的图一可以看出,RepVGG无论是在精度还是速度上都已经超过了ResNet、EffcientNet以及ReNeXt等网络。那RepVGG究竟用了什么方法使得VGG网络能够获得如此大的提升呢,在论文的摘要中,作者提到了structural re-parameterization technique方法,即结构重参数化。实际上就是在训练时,使用一个类似ResNet-style的多分支模型,而推理时转化成VGG-style的单路模型。如下图所示,图(B)表示RepVGG训练时所采用的网络结构,而在推理时采用图(C)的网络结构。关于如何将图(B)转换到图(C)以及为什么要这么做后面再细说,如果对模型优化部署有了解就会发现这和做网络图优化或者说算子融合非常类似。
在这里插入图片描述

1.RepVGG Block 详解

其实关于RepVGG整个模型没太多好说的,就是在不断堆叠RepVGG Block,只要之前看过VGG以及ResNet的代码,那么RepVGG也不在话下。这里主要还是聊下RepVGG Block中的一些细节。由于论文中的图都是简化过的,于是我自己根据源码绘制了下图的RepVGG Block(注意是针对训练时采用的结构)。其中图(a)是进行下采样(stride=2)时使用的RepVGG Block结构,图(b)是正常的(stride=1)RepVGG Block结构。通过图(b)可以看到训练时RepVGG Block并行了三个分支:一个卷积核大小为3x3的主分支,一个卷积核大小为1x1的shortcut分支以及一个只连了BN的shortcut分支。
在这里插入图片描述
这里首先抛出一个问题,为什么训练时要采用多分支结构。如果之前看过像Inception系列、ResNet以及DenseNet等模型,我们能够发现这些模型都并行了多个分支。至少根据现有的一些经验来看,并行多个分支一般能够增加模型的表征能力。所以你会发现一些论文喜欢各种魔改网络并行分支。在论文的表6中,作者也做了个简单的消融实验,在使用单路结构时(不使用其他任何分支)Acc大概为72.39,在加上Identity branch以及1x1 branch后Acc达到了75.14。
在这里插入图片描述
接着再问另外一个问题,为什么推理时作者要将多分支模型转换成单路模型。根据论文3.1章节的内容可知,采用单路模型会更快、更省内存并且更加的灵活。

更快:主要是考虑到模型在推理时硬件计算的并行程度以及MAC(memory access cost),对于多分支模型,硬件需要分别计算每个分支的结果,有的分支计算的快,有的分支计算的慢,而计算快的分支计算完后只能干等着,等其他分支都计算完后才能做进一步融合,这样会导致硬件算力不能充分利用,或者说并行度不够高。而且每个分支都需要去访问一次内存,计算完后还需要将计算结果存入内存(不断地访问和写入内存会在IO上浪费很多时间)。
更省内存:在论文的图3当中,作者举了个例子,如图(A)所示的Residual模块,假设卷积层不改变channel的数量,那么在主分支和shortcut分支上都要保存各自的特征图或者称Activation,那么在add操作前占用的内存大概是输入Activation的两倍,而图(B)的Plain结构占用内存始终不变。
在这里插入图片描述
更加灵活:作者在论文中提到了模型优化的剪枝问题,对于多分支的模型,结构限制较多剪枝很麻烦,而对于Plain结构的模型就相对灵活很多,剪枝也更加方便。
其实除此之外,在多分支转化成单路模型后很多算子进行了融合(比如Conv2d和BN融合),使得计算量变小了,而且算子减少后启动kernel的次数也减少了(比如在GPU中,每次执行一个算子就要启动一次kernel,启动kernel也需要消耗时间)。而且现在的硬件一般对3x3的卷积操作做了大量的优化,转成单路模型后采用的都是3x3卷积,这样也能进一步加速推理。如下图多分支模型(B)转换成单路模型图(C)。
在这里插入图片描述
2 结构重参数化
在简单了解RepVGG Block的训练结构后,接下来再来聊聊怎么将训练好的RepVGG Block转成推理时的模型结构,即structural re-parameterization technique过程。 根据论文中的图4(左侧)可以看到,结构重参数化主要分为两步,第一步主要是将Conv2d算子和BN算子融合以及将只有BN的分支转换成一个Conv2d算子,第二步将每个分支上的3x3卷积层融合成一个卷积层。关于参数具体融合的过程可以看图中右侧的部分,如果你能看懂图中要表达的含义,那么ok你可以跳过本文后续所有内容干其他事去了,如果没看懂可以接着往后看。
在这里插入图片描述
2.1 融合Conv2d和BN
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

from collections import OrderedDict
import numpy as np
import torch 
import torch.nn as nn

def main():
    torch.random.manual_seed(0)
    f1 = torch.randn(1,2,3,3)
    module = nn.Sequential(OrderedDict(
        conv = nn.Conv2d(in_channels=2, 
                         out_channels=2,
                         kernel_size=3,
                         stride=1,
                         padding=1,
                         bias=False),
        bn = nn.BatchNorm2d(num_features=2)
    ))
    module.eval()
    with torch.no_grad():
        output1 = module(f1)
        print(output1)
    
    # fuse conv + bn
    # type: ignore
    kernel = module.conv.weight
    running_mean = module.bn.running_mean
    running_var = module.bn.running_var
    gamma = module.bn.weight
    beta = module.bn.bias
    eps = module.bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma/std).reshape(-1,1,1,1) #  [ch] -> [ch, 1, 1, 1]
    kernel = kernel * t
    bias = beta - running_mean * gamma/ std
    fused_conv = nn.Conv2d(
        in_channels=2, 
        out_channels=2, 
        kernel_size=3,
        stride=1,
        padding=1,
        bias=True
        )
    fused_conv.load_state_dict(OrderedDict(
            weight=kernel,
            bias=bias)),
    
    with torch.no_grad():
        output2 = fused_conv(f1)
        print(output2)
    
    np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
    print("convert module has been tested, and the result looks good!")

if __name__ == "__main__":
    main()
(base) D:\code\python_project\learn_torch>python fuse_conv_bn.py
tensor([[[[ 0.2554, -0.0267,  0.1502],
          [ 0.8394,  1.0100,  0.5443],
          [-0.7252, -0.6889,  0.4716]],

         [[ 0.6937,  0.1421,  0.4734],
          [ 0.0168,  0.5665, -0.2308],
          [-0.2812, -0.2572, -0.1287]]]])
tensor([[[[ 0.2554, -0.0267,  0.1502],
          [ 0.8394,  1.0100,  0.5443],
          [-0.7252, -0.6889,  0.4716]],

         [[ 0.6937,  0.1421,  0.4734],
          [ 0.0168,  0.5665, -0.2308],
          [-0.2812, -0.2572, -0.1287]]]])
convert module has been tested, and the result looks good!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/676033.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux动态监控进程懂了没?

这里写目录标题 top交互模式监控网络状态 top top与ps类似,他们都是用来显示正在执行的进程。 两者最大的不同之处就是top在执行一段时间可以更新正在运行的进程。 基本语法: top 选项 选项功能-d 秒数指定top指令每隔几秒更新,默认为3秒-i…

【新星计划·2023】Linux图形、字符界面介绍与区别

作者:Insist-- 个人主页:insist--个人主页 作者会持续更新网络知识和python基础知识,期待你的关注 前言 本文将介绍图形界面与命令行界面以及它们的区别,登录方法。 目录 一、图形界面与命令行界面介绍 1、图形界面 2、命令行…

Oracle单机版升级(11.2.0.3升级到11.2.0.4)

📢📢📢📣📣📣 哈喽!大家好,我是【IT邦德】,江湖人称jeames007,10余年DBA及大数据工作经验 一位上进心十足的【大数据领域博主】!😜&am…

人工神经网络ANN

文章目录 1. 人工神经网络简介1.1 生物神经网络1.2 人工神经网络 2. 人工神经网络原理2.1 ANN的基本构造2.1.1 神经元的结构模型2.1.2 网络拓扑结构 2.2 学习规则2.3 学习算法 3. 人工神经网络特点4. 人工神经网络的Python应用5. 源码仓库地址 1. 人工神经网络简介 1.1 生物神…

北通阿修罗2 Pro 多模板 连接Cemu 支持体感

需要使用体感的游戏基本上都是任天堂的游戏,如塞尔达。所以接下来针对CEMU模拟器介绍如何使用体感。 先看CEMU的手柄配置文档。 https://cemu.cfw.guide/controller-configuration.html 运动控制支持可能因手柄而异。任天堂Switch、Dualshock 4和DualSense手柄都支持…

Nik Color Efex 滤镜详解(5/5)

淡对比度 Pro Contrast 分析图像并为该图像创建特定的颜色,在保持画面细节的同时,实现更高的对比度。 校正色偏 Correct Color Cast 用于纠正色偏。 校正对比度 Correct Contrast 根据光影纠正对比度。 动态对比度 Dynamic Contrast 根据画面对象自动校…

【JY】浅析时程分析中的阻尼设置

(非线性)直接积分法、快速非线性分析(FNA)法等时程分析方法中的阻尼设置尤为重要,以SAP2000为例,进行抛砖引玉,各类软件做法也大同小异,可借鉴与学习。 模态阻尼 模态阻尼是用非耦合…

模拟电路系列分享-频率失真

目录 概要 整体架构流程 技术名词解释 技术细节 1.基本问题简介 2.线性失真 3.频率失真的危害 小结 概要 提示:这里可以添加技术概要 继续接着上一节的内容继续分享和学习, 整体架构流程 分三个部分,仔细的分享了失真方面的知识 技术名词…

团体程序设计天梯赛-练习集L1篇④

🚀欢迎来到本文🚀 🍉个人简介:Hello大家好呀,我是陈童学,一个与你一样正在慢慢前行的普通人。 🏀个人主页:陈童学哦CSDN 💡所属专栏:PTA 🎁希望各…

Spring Boot 日志的主要组件及其特点

Spring Boot 日志的主要组件及其特点 在开发应用程序时,日志是非常重要的一部分。它可以帮助我们了解应用程序的运行情况,发现并解决问题。在 Spring Boot 中,有许多不同的日志框架可供选择。本文将介绍 Spring Boot 日志的主要组件及其特点…

用OpenCV进行模板匹配

1. 引言 今天我们来研究一种传统图像处理领域中对象检测和跟踪不可或缺的方法——模板匹配,其主要目的是为了在图像上找到我们需要的图案,这听起来十分令人兴奋。 所以,事不宜迟,让我们直接开始吧! 2. 概念 模板匹…

哈夫曼树——数组实现

构造n个给定值节点构成的森林; 选择权值最小的两个构成叶子节点,根节点权值为两叶子节点之和, 删除原有的两棵树,将这棵树加入森林中; 重复这两部直到只有一棵树为止,此树就是哈夫曼树; #pr…

警惕这些“挂羊头卖狗肉”的高科技培训!

最近真的被误人子弟的教育骗子给气到! 事情是这样的,6月11号,我在2023 开放原子全球开源峰会上,遇到了一位从广东来北京参会的老师。 这位老师透露,他来自一所职业技术学院,学校师资挺不错的,可…

Spring Boot 如何配置日志级别和输出格式

Spring Boot 如何配置日志级别和输出格式 在开发一个应用程序时,日志记录是非常重要的一环。Spring Boot 提供了多种日志输出方式和配置选项,本文将介绍如何在 Spring Boot 应用程序中配置日志级别和输出格式。 配置日志级别 在 Spring Boot 应用程序中…

【知识点随笔分享 | 第一篇】避不开的浮点误差

引入: 各位在大一初入C语言的时候,老师肯定说过浮点数之间的比较要用做差法,当二者的差值特别小甚至于接近0的时候,这两个数就相等,不知道各位是否会有疑惑?为什么浮点数不可以直接进行比较呢? …

Nacos-手写配置中心基本原理

本文已收录于专栏 《中间件合集》 目录 概念说明Nacos配置中心Naocs配置项Naocs配置集Naocs配置快照 需求分析核心功能代码实现AService模块BService模块NacosService模块NacosSDK模块 注意事项总结提升 概念说明 Nacos注册中心:https://blog.csdn.net/weixin_4549…

vs中运行时库简要说明

vs中右键单击工程 -->属性–>c/c->代码生成,进入如下菜单中: 可以看出有如下几个选项: 多线程(/MT):链接目标库为libcmt.lib 多线程调试(/MTd):链接目标库为libcmtd.lib 多线程DLL(/MD):链接目标…

02.GLM-130B

文章目录 前言泛读相关知识GPTBERTT5小结 背景介绍主要贡献和创新点GLM 6B 精读自定义Mask模型量化1TB 的中英双语指令微调RLHFPEFT训练策略 实验分析与讨论模型参数六个指标其他测评结果 代码复现(6B)环境准备运行调用代码调用网页服务命令行调用 模型微…

在 Python 中生成随机 4 位数字

文章目录 在 Python 中生成随机数使用 random 模块在 Python 中生成随机数使用 random.randint() 方法使用 random.randrange() 方法 使用替代方法在 Python 中生成随机数总结 Python 是一种高级解释型编程语言,全球大多数程序员都在使用它。 它在面向对象编程 (OOP…

SpringCloud Alibaba入门5之Hystrix的使用

我们继续在前一章的基础上进行学习。 SpringCloud Alibaba入门5之使用OpenFegin调用服务_qinxun2008081的博客-CSDN博客 上一节我们已经使用OpenFeign完成了服务间的调用,如果现在存在大量的服务,每个服务有若干个节点,其中一个节点发生故障…