自定义卷积实现卷积的重参数【手撕代码】

news2025/1/17 6:01:23

在我的上篇文章中主要对RepVGG进行了解析【RepVGG网络中重参化网络结构解读】,里面详细的对论文中的代码进行了解析,展示了RepVGG在重参数时是如何将训练分支进行合并的,总的一句话就是在推理阶段,会将1x1以及identity分支以padding的方法变成3x3的卷积后再与主干中的3x3卷积进行合并。

而这篇文章的目的仿照RepVGG如何自定义一个含有分支的卷积,并采用重参数技术进行合并推理


目录

卷积块的定义

重参数化

BN、identity以及卷积的融合 

1x1变成3x3 

 完整代码:


卷积块的定义

下面的代码是先定义一个基础的卷积块CB(一个conv和一个BN层的结合)

def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0):
    resluts = nn.Sequential()
    resluts.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
    resluts.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    return resluts

 下面的代码就是我们的定义的含有分支的卷积块。

参数说明:

in_channels:输入通道数

out_channels:输出通道数

stride:卷积步长

groups:分组卷积组数

padding_mode:padding模式,以0补pad

deploy:在重参数的时候将会设置为True将网络分支合并

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = groups
        self.deploy = deploy
        self.identity = nn.Identity()
        self.relu = nn.ReLU()

        if deploy:
            self.rbr_reparam = nn.Conv2d(self.in_channels, self.out_channels, 3, stride=1, padding=1, padding_mode=padding_mode)
        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.in_channels == self.out_channels and stride == 1 else None
            self.conv3_3 = conv_bn(self.in_channels, self.out_channels, 3, stride, padding=1)
            self.conv1_1 = conv_bn(self.in_channels, self.out_channels, 1, 1)
            print('RepConv Block, identity = ', self.rbr_identity)

    def forward(self, x):
        if hasattr(self, 'rbr_reparam'):
            return self.relu(self.identity(self.rbr_reparam(x)))
        out1 = self.conv3_3(x)
        out2 = self.conv1_1(x)
        out3 = self.identity(x)
        return self.relu(out1 + out2 + out3)

然后我们可以直接看forward函数,【这里先暂时不看if hasattr(self,'rbr_reparam')这一段】,可以看到产生三个out,进行相加后再经过relu激活函数,打印的网络以及网络结构如下,能很清楚的看到卷积块的分支,分别是identity、3x3和1x1卷积。

 

 ConvBlock(
  (identity): Identity()
  (relu): ReLU()
  (rbr_identity): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3_3): Sequential(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1_1): Sequential(
    (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

按照RepVGG的思想,在训练过程中网络基本单元如上,推理阶段将会合并分支,下面就看看怎么实现的。

重参数化

下面这些代码是定义在上面的类ConvBlock中的。

get_equivalent_kernel_bias这个函数获取3x3,1x1以及identity中的权值和bias,这个内部实现的核心是fuse_bn_tensor这个函数。

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv3_3)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv1_1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

BN、identity以及卷积的融合 

fuse_bn_tensor是获取权值、方差、均值等,将卷积和BN层进行融合。

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight  # 卷积的权值
            running_mean = branch.bn.running_mean  # bn的均值
            running_var = branch.bn.running_var  # bn的方差
            gamma = branch.bn.weight  # bn的权值
            beta = branch.bn.bias  # bn的bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)  # 创建一个全零32*32*3*3的矩阵用来记录权值
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1  # 获得一个中间值为1,周边值为0的新卷积核
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std =(running_var + eps).sqrt()
        t = (gamma/std).reshape(-1, 1, 1, 1)
        return kernel * t, beta -running_mean * gamma / std

当第一次遍历的时候branch为Sequential时,获取Conv层,以及BN层的参数(如果你这里没有BN层可以考虑把这部分代码删去,只获取Conv)

Sequential(
  (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

其中在fuse_bn_tensor这里面有个核心的代码是下面这一行,这个是在identity这个分支中,会创建一个全0的矩阵,矩阵大小为in_channels*input_dim*3*3。

kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32) 

 然后在上面的kernel_value中每个通道的中间那个元素位置为1。

                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1  # 获得一个中间值为1,周边值为0的新卷积核

 

1x1变成3x3 

下面的代码是将1x1的卷积补成3x3卷积

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return F.pad(kernel1x1, [1, 1, 1, 1])  # 在1x1四周补padding

最后是看switch_to_deploy函数,我们前面通过get_equivalent_kernel_bias获得kernel和bias,创建重参数化卷积rbr_reparam,大小为3x3.

然后把之前融合的卷积权值和bias传入给新键的rbr_reparam中,在将deploy设置为True就得到了我们想要的卷积了。 

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(self.conv3_3.conv.in_channels, self.conv3_3.conv.out_channels,
                                     kernel_size=self.conv3_3.conv.kernel_size,
                                     stride=self.conv3_3.conv.stride,
                                     padding=self.conv3_3.conv.padding,
                                     dilation=self.conv3_3.conv.dilation,
                                     groups=self.conv3_3.conv.groups,
                                     bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        for param in self.parameters():
            param.detach_()
        self.__delattr__('conv3_3')
        self.__delattr__('conv1_1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True

 

 

 完整代码:

 

import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0):
    resluts = nn.Sequential()
    resluts.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
    resluts.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    return resluts


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, groups=1, padding_mode='zeros', deploy=False):
        super(ConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = groups
        self.deploy = deploy
        self.identity = nn.Identity()
        self.relu = nn.ReLU()

        if deploy:
            self.rbr_reparam = nn.Conv2d(self.in_channels, self.out_channels, 3, stride=1, padding=1, padding_mode=padding_mode)
        else:
            self.rbr_identity = nn.BatchNorm2d(num_features=self.in_channels) if self.in_channels == self.out_channels and stride == 1 else None
            self.conv3_3 = conv_bn(self.in_channels, self.out_channels, 3, stride, padding=1)
            self.conv1_1 = conv_bn(self.in_channels, self.out_channels, 1, 1)
            print('RepConv Block, identity = ', self.rbr_identity)

    def forward(self, x):
        if hasattr(self, 'rbr_reparam'):
            return self.relu(self.identity(self.rbr_reparam(x)))
        out1 = self.conv3_3(x)
        out2 = self.conv1_1(x)
        out3 = self.identity(x)
        return self.relu(out1 + out2 + out3)

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv3_3)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv1_1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return F.pad(kernel1x1, [1, 1, 1, 1])  # 在1x1四周补padding

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight  # 卷积的权值
            running_mean = branch.bn.running_mean  # bn的均值  没有BN层可以考虑删除改部分
            running_var = branch.bn.running_var  # bn的方差
            gamma = branch.bn.weight  # bn的权值
            beta = branch.bn.bias  # bn的bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)  # identity
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)  # 创建一个全零32*32*3*3的矩阵用来记录权值
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1  # 获得一个中间值为1,周边值为0的新卷积核
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std =(running_var + eps).sqrt()
        t = (gamma/std).reshape(-1, 1, 1, 1)
        return kernel * t, beta -running_mean * gamma / std

    def switch_to_deploy(self):
        if hasattr(self, 'rbr_reparam'):
            return
        kernel, bias = self.get_equivalent_kernel_bias()
        self.rbr_reparam = nn.Conv2d(self.conv3_3.conv.in_channels, self.conv3_3.conv.out_channels,
                                     kernel_size=self.conv3_3.conv.kernel_size,
                                     stride=self.conv3_3.conv.stride,
                                     padding=self.conv3_3.conv.padding,
                                     dilation=self.conv3_3.conv.dilation,
                                     groups=self.conv3_3.conv.groups,
                                     bias=True)
        self.rbr_reparam.weight.data = kernel
        self.rbr_reparam.bias.data = bias
        for param in self.parameters():
            param.detach_()
        self.__delattr__('conv3_3')
        self.__delattr__('conv1_1')
        if hasattr(self, 'rbr_identity'):
            self.__delattr__('rbr_identity')
        if hasattr(self, 'id_tensor'):
            self.__delattr__('id_tensor')
        self.deploy = True

def repconv(model:nn.Module, save_path=None, do_copy=True):
    if do_copy:
        model =copy.deepcopy(model)
    for module in model.modules():
        if hasattr(module, 'switch_to_deploy'):
            module.switch_to_deploy()
    if save_path is not None:
        torch.save(model.state_dict(), save_path)
    return model
model = ConvBlock(32, 32)
print(model)
x = torch.randn(1, 32, 24, 24)
torch.onnx.export(model, x, "Conv.onnx", verbose=True, input_names=['images'], output_names=['output'],
                  opset_version=12)
rep_model = repconv(model)
torch.onnx.export(rep_model,x,'rep_model.onnx', verbose=True,input_names=['images'], output_names=['output'],
                  opset_version=12)

 有些细节后面再慢慢补充,最近阳了有些不太舒服~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/117135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vivo 游戏中心低代码平台的提效秘诀

作者:vivo 互联网服务器团队- Chen Wenyang 本文根据陈文洋老师在“2022 vivo开发者大会"现场演讲内容整理而成。公众号回复【2022 VDC】获取互联网技术分会场议题相关资料。 在互联网流量见顶和用户需求分层的背景下,如何快速迭代产品功能&#xf…

函数模板-C11/17/14

函数模板 文章目录函数模板定义函数模板使用函数模板样例两阶段翻译 Two-Phase Translation模板的编译和链接问题多模板参数引入额外模板参数作为返回值类型让编译器自己找出返回值类型将返回值声明为两个模板参数的公共类型样例默认模板参数样例重载函数模板模板函数特化非类型…

cocoapods的使用

swift开发之cocoapods的使用 之前介绍了cocoapods的使用,我们可以知道通过pod search XXX(三方依赖库名称)可以就搜索到想要的第三方是否存在。 这次主要简单介绍cocoapods如何引入第三方库的,以BluetoothKit为例。 首先,我们终端中通过cd命令定位到要…

二十二、shiro安全框架基础

一、简介 1. shiro简介 Apache Shiro 是 Java 的一个安全(权限)框架。Shiro 可以非常容易的开发出足够好的应用,其不仅可以用在JavaSE 环境,也可以用在 JavaEE 环境。Shiro 可以完成:认证、授权、加密、会话管理、与…

“智慧”控漏 削减产销差-城镇供水管网分区计量管理系统

平升电子城镇供水管网分区计量管理系统根据国际国内分区计量的要求和标准研发,专门针对水司漏损控制和产销差管理而设计。系统涵盖分区管理、管网流量和压力监控、水量统计分析、产销差分析、漏损评估、夜间最小流量分析、用水异常报警等功能。核心目标是找到整个管…

ReactJS入门

目录 一:前端开发的演变 二:ReactJS简介 三:搭建环境 四:React快速入门 一:前端开发的演变 到目前为止,前端的开发经历了四个阶段,目前处于第四个阶段。这四个阶段分别是: 阶段一…

equals()与hashcode()之间的关系

1、equals简介 被用来检测两个对象是否相等,即两个对象的内容是否相等; equals 方法(是String类从它的超类Object中继承的)用于比较引用和比较基本数据类型时具有不同的功能: 比较基本数据类型,如果两个值…

马哥SRE第11周课程作业

ansible role zabbix相关话题1. ansible 常用指令总结,并附有相关示例。1.1 Ansible相关工具1.1.1 ansible-doc1.1.2 ansible 命令用法1.1.3 ansible-console1.1.4 ansible-playbook1.1.5 ansible-vault1.1.5 ansible-galaxy2. 总结ansible playbook目录结构及文件用…

javaee之Spring4

之前说到AccountDao需要继承JdbcDaoSupport这个类,那么现在来看一下这个类的内容 JdbcDaoSupport.java package com.itheima.dao.impl;/*** 此类用于抽取dao中的重复代码 */public class JdbcDaoSupport {private JdbcTemplate jdbcTemplate;public void setJdbcT…

人大金仓数据库备份应用sys_dump的使用

人大金仓数据库软件给数据库管理员用户提供了管理维护数据库的多个客户端应用,更多参考:《KingbaseES客户端应用参考手册》。 我们可以看到备份的应用有两个: 1、sys_dump:将KingbaseES数据库备份为一个脚本文件或者其他归档文件 2、sys_d…

表单校验重要性和多规则校验

表单校验分类 校验位置: 客户端校验 服务端校验 表单校验框架 JSR:java规范提案 303:提供bean属性相关校验规则 JCP:java社区 Hibernate框架中包含一套独立的校验框架hibernate-validator 实际的校验规则 同一个字段有多个约束条件 引用…

股权转让项目:沈阳派尔化学有限公司55%股权转让

股权转让项目:沈阳派尔化学有限公司55%股权转让;该项目由 广州产权交易所 发布,于2022年12月25日被塔米狗平台收录。 该公司在 2021 年最新一期财务报告中, 披露的资产总额(万元):7148.98 &…

装修半包包括哪些内容呢?极家精工装修好不好

​装修半包包括哪些内容呢?极家精工装修好不好。在装修房子的时候,很多人都会选择半包装修,因为可以自己挑选材料,自己跟工程比较放心。另外一边比较重要的原因就是能省钱,对于预算有限的小伙伴真的再适合不过啦&#…

唐玄奘把 「JWT 令牌」玩到了极致

唐玄奘把 「JWT 令牌」玩到了极致 你好,我是悟空。 西游记的故事想必大家在暑假看过很多遍了,为了取得真经,唐玄奘历经苦难,终于达成。 在途经各国的时候,唐玄奘都会拿出一个通关文牒交给当地的国王进行盖章&#x…

基于线性表的图书管理系统(java)

目录 1、简介 2、代码 (1)ManageSystem类 (2)book类 3、测试程序运行结果截图 (1)登录和创建 (2)输出 (3)查找 (4)插入 &a…

如何用乐高积木式操作让 ChatGPT 变得更强大?

需求这些日子,很多小伙伴儿玩儿 ChatGPT 不亦乐乎,甚至陷入了沉迷。他们尝试了各种 ChatGPT 的功能。不少功能强悍到不可思议;当然,也有些功能尝试因遇到障碍无法完成。于是很多用户非常失望,觉得 ChatGPT 好像啥都干不…

20221227:Rockchip-RK模型转换

Tips: 不同芯片对应的NPU和toolkit是不同的,注意区分! 平台 RK1808/RK1806 RV1109/RV1126 RKNPU:本工程主要为Rockchip NPU提供驱动、示例等。 GitHub - rockchip-linux/rknpuContribute to rockchip-linux/rknpu development by creating an account on GitHub.https://gi…

小程序项目开发

目录 一,flex弹性布局 1.什么是flex布局? 2.flex属性 3.视图层 View WXML 1数据绑定 2.列表渲染 3.条件渲染 4.模板 5. 数据处理 二,轮播图--组件的使用 1.WXSS 样式导入 内联样式 选择器 全局样式与局部样式 WXS 页面渲染 三&…

zabbix常用监控项解读

CPU来源模板:Template Module Linux CPU by Zabbix agent 内存(memory)来源模板:Template Module Linux memory by Zabbix agent 磁盘空间(disk) 数据来源:Get /proc/diskstats 监控项原型&am…

【小5聊】ElementUI-Vue3-TS项目简单创建

vue2升级到vue3,不管任何框架,升级总有它改进的地方和原因,否则升级就毫无意义,技术变化日新月异,必须保持与时俱进,否则就很容易在技术的浪潮中被淘汰! vue3相比以前版本,最大一个变…