深度学习网络模型——RepVGG网络详解

news2024/11/24 13:05:42

深度学习网络模型——RepVGG网络详解

  • 0 前言
  • 1 RepVGG Block详解
  • 2 结构重参数化
    • 2.1 融合Conv2d和BN
    • 2.2 Conv2d+BN融合实验(Pytorch)
    • 2.3 将1x1卷积转换成3x3卷积
    • 2.4 将BN转换成3x3卷积
    • 2.5 多分支融合
    • 2.6 结构重参数化实验(Pytorch)
  • 3 模型配置

论文名称: RepVGG: Making VGG-style ConvNets Great Again
论文下载地址: https://arxiv.org/abs/2101.03697
官方源码(Pytorch实现): https://github.com/DingXiaoH/RepVGG

0 前言

在这里插入图片描述

1 RepVGG Block详解

在这里插入图片描述

2 结构重参数化

在这里插入图片描述

2.1 融合Conv2d和BN

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

2.2 Conv2d+BN融合实验(Pytorch)

在这里插入图片描述

from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn


def main():
    torch.random.manual_seed(0)

    f1 = torch.randn(1, 2, 3, 3)

    module = nn.Sequential(OrderedDict(
        conv=nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False),
        bn=nn.BatchNorm2d(num_features=2)
    ))

    module.eval()

    with torch.no_grad():
        output1 = module(f1)
        print(output1)

    # fuse conv + bn
    kernel = module.conv.weight 
    running_mean = module.bn.running_mean
    running_var = module.bn.running_var
    gamma = module.bn.weight
    beta = module.bn.bias
    eps = module.bn.eps
    std = (running_var + eps).sqrt()
    t = (gamma / std).reshape(-1, 1, 1, 1)  # [ch] -> [ch, 1, 1, 1]
    kernel = kernel * t
    bias = beta - running_mean * gamma / std
    fused_conv = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)
    fused_conv.load_state_dict(OrderedDict(weight=kernel, bias=bias))

    with torch.no_grad():
        output2 = fused_conv(f1)
        print(output2)

    np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
    print("convert module has been tested, and the result looks good!")


if __name__ == '__main__':
    main()

终端输出结果:
在这里插入图片描述

2.3 将1x1卷积转换成3x3卷积

在这里插入图片描述

2.4 将BN转换成3x3卷积

在这里插入图片描述
代码截图如下所示:
在这里插入图片描述
在这里插入图片描述

2.5 多分支融合

在这里插入图片描述
代码截图:
在这里插入图片描述
图像演示:
在这里插入图片描述

2.6 结构重参数化实验(Pytorch)

import time
import torch.nn as nn
import numpy as np
import torch


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
    result = nn.Sequential()
    result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                        kernel_size=kernel_size, stride=stride, padding=padding,
                                        groups=groups, bias=False))
    result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    return result


class RepVGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False):
        super(RepVGGBlock, self).__init__()
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels
        self.nonlinearity = nn.ReLU()

        if deploy:
            self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=kernel_size, stride=stride,
                                         padding=padding, dilation=dilation, groups=groups,
                                         bias=True, padding_mode=padding_mode)

        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) \
                if out_channels == in_channels and stride == 1 else None
            self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                     stride=stride, padding=padding, groups=groups)
            self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                                   stride=stride, padding=0, groups=groups)

    def forward(self, inputs):
        if hasattr(self, 'rbr_reparam'):
            return self.nonlinearity(self.rbr_reparam(inputs))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,
                                     out_channels=self.rbr_dense.conv.out_channels,
                                     kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,
                                     padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation,
                                     groups=self.rbr_dense.conv.groups, bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__('rbr_dense')
        self.__delattr__('rbr_1x1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True

def main():
    f1 = torch.randn(1, 64, 64, 64)
    block = RepVGGBlock(in_channels=64, out_channels=64)
    block.eval()
    with torch.no_grad():
        output1 = block(f1)
        start_time = time.time()
        for _ in range(100):
            block(f1)
        print(f"consume time: {time.time() - start_time}")

        # re-parameterization
        block.switch_to_deploy()
        output2 = block(f1)
        start_time = time.time()
        for _ in range(100):
            block(f1)
        print(f"consume time: {time.time() - start_time}")

        np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)
        print("convert module has been tested, and the result looks good!")


if __name__ == '__main__':
    main()

终端输出结果如下:
在这里插入图片描述
通过对比能够发现,结构重参数化后推理速度翻倍了,并且转换前后的输出保持一致。

3 模型配置

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/346562.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java实现定时发送邮件

特别说明&#xff1a;邮件所采用的均为QQ邮件 一、邮箱准备 作为发送方&#xff0c;需要开启相关服务。 首先打开邮箱&#xff0c;然后选择设置&#xff0c;再选择账户 开启以下服务 我们可以在这里获取邮箱的授权码。 二、项目准备 2.1、依赖引入 <dependencies>…

二分法-蓝桥杯

一、二分法引入-猜数游戏二分法:折半搜索。二分的效率:很高&#xff0c;O(logn)例如猜数游戏&#xff0c;若n1000万&#xff0c;只需要猜log10 7 24次猜数游戏的代码&#xff1a;bin_search------>二分搜索把一个长度为n的有序序列上O(n)的查找时间&#xff0c;优化到了O(lo…

【java】Spring Boot --Spring Boot 集成 MyBatis

文章目录1. 前言2. 实例场景3. 数据库模块实现4. Spring Boot 后端实现4.1 使用 Spring Initializr 创建项目4.2 引入项目依赖4.3 数据源配置4.4 开发数据对象类4.5 开发数据访问层4.6 添加 MyBatis 映射文件5. 测试6. 小结1. 前言 企业级应用数据持久层框架&#xff0c;最常见…

【项目】Vue3+TS CMS 基本搭建相关配置

&#x1f4ad;&#x1f4ad; ✨&#xff1a;Vue3 TS   &#x1f49f;&#xff1a;东非不开森的主页   &#x1f49c;: today beginning&#x1f49c;&#x1f49c;   &#x1f338;: 如有错误或不足之处&#xff0c;希望可以指正&#xff0c;非常感谢&#x1f609;   基本…

2023爱分析 · 数据科学与机器学习平台厂商全景报告 | 爱分析报告

报告编委 黄勇 爱分析合伙人&首席分析师 孟晨静 爱分析分析师 目录 1. 研究范围定义 2. 厂商全景地图 3. 市场分析与厂商评估 4. 入选厂商列表 1. 研究范围定义 研究范围 经济新常态下&#xff0c;如何对海量数据进行分析挖掘以支撑敏捷决策、适应市场的快…

Milvus 新版本来啦!首席工程师带你划重点:安全、稳定、升级友好

Milvus 又又又又出新版本了&#xff01;Milvus 2.2.3 版本是 2.2 系列的小版本升级&#xff0c;尽管是小版本的更新&#xff0c;但是依然干货满满&#xff1a;首先是带来了社区中呼声很高的 coordinator 节点的高可用能力&#xff1b;其次还新增了不停机滚动升级的功能&#xf…

第八章:DNS解析服务器搭建

今天先讲一下DNS的简单配置。 Windows server DNS&#xff1a; 点击工具选择DNS 右击正向查找区域&#xff0c;然后选择新建&#xff0c;如果是根域就可以勾AD储存&#xff0c;不是的话就别勾 名字可以随便 最后点击完成&#xff0c;这是正向解析 右击反向查找区域&#xff0c…

面试浅谈之十大排序算法

面试浅谈之十大排序算法 HELLO&#xff0c;各位博友好&#xff0c;我是阿呆 &#x1f648;&#x1f648;&#x1f648; 这里是面试浅谈系列&#xff0c;收录在专栏面试中 &#x1f61c;&#x1f61c;&#x1f61c; 本系列将记录一些阿呆个人整理的面试题 &#x1f3c3;&…

【QA】[vue/element-ui] 日期输入框的表单验证问题

引入&#xff1a;element-ui的表单验证是使用rules来定义规则&#xff0c;其中日期类型的表单输入框如图所示&#xff0c;一般会使用 format 来设置自己需要的日期格式&#xff1a; <el-form-item label"生日" prop"birthday"><el-col :span&quo…

Java高手速成 | 图说重定向与转发

我们先回顾一下Servlet的工作原理&#xff0c;Servlet的工作原理跟小猪同学食堂就餐的过程很类似。小猪同学点了烤鸡腿&#xff08;要奥尔良风味的&#xff09;&#xff0c;食堂窗口的服务员记下了菜单&#xff0c;想了想后厨的所有厨师&#xff0c;然后将菜单和餐盘交给专门制…

RabbitMQ运行机制

消息的TTL&#xff08;Time To Live&#xff09; 消息的TTL就是消息的存活时间。 • RabbitMQ可以对队列和消息分别设置TTL。 • 对队列设置就是队列没有消费者连着的保留时间&#xff0c;也可以对每一个单独的消息做单独的 设置。超过了这个时间&#xff0c;我们认为这个消息…

什么是溶血症?什么是ABO溶血?溶血检查些什么?

什么是溶血症&#xff0c;什么是ABO溶血&#xff1f;女人是O型血&#xff0c;男人是其他血型的夫妻配对&#xff0c;最担心的是胎儿溶血症。从理论上讲&#xff0c;只要夫妻双方血型不同&#xff0c;母亲一定缺乏胎儿从父亲那里遗传的抗原。当任何人接触到他们缺乏的抗原时&…

Vue+node.js火车票订票系统vscode开发的

该系统的基本功能包括管理员、用户二个角色功能模块。 对于管理员可以使用的功能模块主要有&#xff0c;首页、个人中心&#xff0c;用户管理、系统公告管理、车次管理、车票信息管理、订票信息管理、系统管理等功能。 对于用户所使用的功能模块的操作主要是首页、个人中心、订…

【python百炼成魔】手把手带你学会python数据类型

文章目录前言一. python的基本数据类型1.1 如何查看数据类型1.2 数值数据类型1.2.1 整数类型1.2.2 浮点数类型1.2.3 bool 布尔数值类型1.2.4 字符串类型二. 数据类型强制转换2.1 强制转换为字符串类型2.2 强制转换为int类型2.3 强制转换函数之float() 函数三. 拓展几个运算函数…

2023年华为HCIA-Datacom认证视频课

一、下载地址&#xff1a;https://edu.csdn.net/learn/38282/607342?spm1003.2001.3001.4157 一、课程大纲 2023年华为考试大纲 考试分数章目录小节80第1章&#xff1a;网络参考模型1.1OSI网络参考模型介绍1.2OSI网络参考模型各层的作用1.3 OSI与TCP/IP模型的比较1.4 TCP与U…

【转载】通过HAL库实现MODBUS从机程序编写与调试-----STM32CubeMX操作篇

通过HAL库实现MODBUS从机程序编写与调试-----STM32CubeMX操作篇[【STM32】RS485 Modbus协议 采集传感器数据](https://blog.csdn.net/qq_33033059/article/details/106935583)基于STM32的ModbusRtu通信--ModbusRtu协议(一)基于STM32的ModbusRtu通信--终极Demo设计(二)STM32RS48…

TensorRT的C++接口解析

TensorRT的C接口解析 文章目录TensorRT的C接口解析3.1. The Build Phase3.1.1. Creating a Network Definition3.1.2. Importing a Model using the ONNX Parser3.1.3. Building an Engine注意&#xff1a;序列化引擎不能跨平台或 TensorRT 版本移植。引擎特定于它们构建的确切…

“黑铁时代”,地产人如何以客户视角加速房企数字化转型

本文从行业洞察、业务设计、数据建设以及实践探索四个部分详细阐述地产行业数字化的实践、思考和理解。点击文末“阅读原文”&#xff0c;观看完整版直播回放并下载演讲文档。一、洞察&#xff1a;房企经营思路的变化企业的转型都是围绕着业务经营变化进行的&#xff0c;房企数…

P1307 [NOIP2011 普及组] 数字反转

[NOIP2011 普及组] 数字反转 题目描述 给定一个整数 NNN&#xff0c;请将该数各个位上数字反转得到一个新数。新数也应满足整数的常见形式&#xff0c;即除非给定的原数为零&#xff0c;否则反转后得到的新数的最高位数字不应为零&#xff08;参见样例 2&#xff09;。 输入…

电源口防雷器电路设计方案

电源口防雷电路的设计需要注意的因素较多&#xff0c;有如下几方面&#xff1a;1、防雷电路的设计应满足规定的防护等级要求&#xff0c;且防雷电路的残压水平应能够保护后级电路免受损坏。2、在遇到雷电暂态过电压作用时&#xff0c;保护装置应具有足够快的动作响应速度&#…