剪枝与重参第九课:DBB重参

news2024/10/7 15:29:07

目录

  • DBB重参
    • 前言
    • 1. DBB
    • 2. DBB的六种变换
      • 2.1 Transform I: a conv for conv-BN
      • 2.2 Transform II:a conv for branch addition
      • 2.3 Transform III:a conv for sequential convolutions
      • 2.4 Transform IV:a conv for depth concatenation
      • 2.5 Transform V:a conv for average pooling
      • 2.6 Transform VI:a conv for multi-scale convolutions
    • 3. DBB特殊结构
      • 3.1 具有Identity性质的1x1Conv2d
      • 3.2 BN+Pad
    • 4. DBB网络搭建
      • 4.1 conv_bn
      • 4.2 branch
      • 4.3 forward
      • 4.4 重参的实现
      • 4.5 模型导出
    • 5. 完整示例代码
    • 总结

DBB重参

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解DBB的重参。

课程大纲可看下面的思维导图

在这里插入图片描述

1. DBB

Diverse Branch Block 是继 ACNet 的又一次对网络结构参数化的探索,即ACNet v2,DBB 设计了一个类似 Inception 的模块,以多分支的结构丰富卷积块的特征空间,各分支结构包括平均池化,多尺度卷积等。最后在推理阶段前,把多分支结构中进行重参数化,融合成一个主分支。加快推理速度的同时,顺带提升一下精度。

在这里插入图片描述

上图给出了设计的 DBB 结构示意图。类似 Inception,它采用 1x1,1x1-KxK,1x1-AVG 等组合方式对原始 KxK 卷积进行增强。对于 1x1-KxK 分支,设置中间通道数等于输入通道数并将 1x1 卷积初始化为 Identity 矩阵;其他分支则采用常规方式初始化。

此外,在每个卷积后都添加 BN 层用于提供训练时的非线性,这对于性能提升很有必要。

2. DBB的六种变换

对于一个常规的卷积网络,在推理阶段DBB存在6种变换,如下图所示:

在这里插入图片描述

2.1 Transform I: a conv for conv-BN

变换I:卷积替换卷积+BN

在这里插入图片描述

def transI_fusebn(kernel, bn):
    gamma = bn.weight
    std   = (bn.running_var + bn.eps).sqrt()
    k     = kernel * ((gamma / std).view(-1, 1, 1, 1))
    b     = bn.bias - bn.running_var * gamma / std
    return k, b

2.2 Transform II:a conv for branch addition

变换II:卷积分支融合

def transII_addbranch(kernels, biases):
    k = sum(kernels)
    b = sum(biases)
    return k, b

2.3 Transform III:a conv for sequential convolutions

变换III:sequential卷积融合

在这里插入图片描述

def transIII_1x1_kxk(k1, b1, k2, b2, groups):
    if groups == 1:
        k     = F.conv2d(k2, k1.permute(1, 0, 2, 3))
        b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
    else:
        
        k_slices = []
        b_slices = []
        k1_T = k1.permute(1, 0, 2, 3)

        k1_group_width = k1.size(0) // groups
        k2_group_width = k2.size(0) // groups
        for g in range(groups):
            k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
            k2_slice   = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
            k_slices.append(F.conv2d(k2_slice, k1_T_slice))
            b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
        
        k, b_hat = transIV_depthconcat(k_slices, b_slices)
    return k, b_hat + b2

2.4 Transform IV:a conv for depth concatenation

变换IV:卷积拼接

在这里插入图片描述

def transIV_depthconcat(kernels, biases):
    return torch.cat(kernels, dim=0), torch.cat(biases)

2.5 Transform V:a conv for average pooling

变换V:平均池化

在这里插入图片描述

def transV_avg(channels, kernel_size, groups):
    input_dim = channels // groups
    k = torch.zeros((channels, input_dim, kernel_size, kernel_size))  
    k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1. / kernel_size**2
    return k

2.6 Transform VI:a conv for multi-scale convolutions

变换VI:多尺度卷积

def transVI_multiscale(kernel, target_kerne_size):
    H_pixels_to_pad = (target_kerne_size - kernel.size(2)) // 2
    W_pixels_to_pad = (target_kerne_size - kernel.size(3)) // 2
    return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])

3. DBB特殊结构

3.1 具有Identity性质的1x1Conv2d

DBB网络中还有一种具有Identity性质的卷积模块,其实现如下:

class IdentityBasedConv1x1(nn.Conv2d):
    def __init__(self, channels, groups=1):
        super().__init__(in_channels=channels,
                         out_channels=channels,
                         kernel_size=1,
                         stride=1,
                         padding=0,
                         groups=groups,
                         bias=False)
        assert channels % groups == 0
        input_dim = channels // groups
        id_value = np.zeros((channels, input_dim, 1, 1))
        for i in range(channels):
            id_value[i, i % input_dim, 0, 0] = 1
        self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
        nn.init.zeros_(self.weight)
    
    def forward(self, input):
        kernel = self.weight + self.id_tensor.to(self.weight.device)
        result = F.conv2d(input, kernel, None, stride=1, padding=0, dilation=self.dilation, groups=self.groups)
        return result

    def get_actual_kernel(self):
        return self.weight + self.id_tensor.to(self.weight.device)

3.2 BN+Pad

BN层加Pad,其实现如下:

class BNAndPadLayer(nn.Module):
    def __init__(self,
                 pad_pixels,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features=num_features,
                                 eps = eps,
                                 momentum=momentum,
                                 affine=affine,
                                 track_running_stats=track_running_stats)
        self.pad_pixels = pad_pixels
    
    def forward(self, input):
        output = self.bn(input)
        if self.pad_pixels > 0:
            if self.bn.affine:
                pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
                    self.bn.running_var + self.bn.eps
                )
            else:
                pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)

            output = F.pad(output, [self.pad_pixels]*4)
            pad_values = pad_values.view(1, -1, 1, 1)
            output[:, :, 0:self.pad_pixels, :] = pad_values
            output[:, :, -self.pad_pixels:, :] = pad_values
            output[:, :, :, 0:self.pad_pixels] = pad_values
            output[:, :, :, -self.pad_pixels:] = pad_values
        return output
    
    @property
    def weight(self):
        return self.bn.weight
    
    @property
    def bias(self):
        return self.bn.bias

    @property
    def running_mean(self):
        return self.bn.running_mean
    
    @property
    def running_var(self):
        return self.bn.running_var
    
    @property
    def eps(self):
        return self.bn.eps

4. DBB网络搭建

4.1 conv_bn

先写一个函数用来实现conv+bn

def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, padding_mode='zeros', dilation=1, groups=1):
    conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                           stride=stride, padding=padding, dilation=dilation,
                           groups=groups, bias=False, padding_mode=padding_mode)
    bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
    se = nn.Sequential()
    se.add_module('conv', conv_layer)
    se.add_module('bn', bn_layer)
    return se

4.2 branch

分支的实现

class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
        super().__init__()
        self.deploy = deploy

        if nonlinear is None:
            self.nonlinear = nn.Identity()
        else:
            self.nonlinear = nonlinear
    
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.groups = groups
        assert padding == kernel_size // 2

        if deploy:
            self.dbb_reparam = nn.Conv2d(
                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation, groups=groups, bias=True
            )
        else:
            self.bdd_origin = conv_bn(
                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                stride=stride, padding=padding, dilation=dilation, groups=groups
            )
        
            self.dbb_avg = nn.Sequential()
            if groups < out_channels:
                self.dbb_avg.add_module(
                    'conv', nn.Conv2d(
                        in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                        stride=1, padding=0, groups=groups, bias=True
                    )
                )

                self.dbb_avg.add_module(
                    'bn', BNAndPadLayer(pad_pixels=padding,
                                        num_features=out_channels)
                )

                self.dbb_avg.add_module(
                    'avg', nn.AvgPool2d(kernel_size=kernel_size,
                                        stride=stride,
                                        padding=0)
                )

                self.dbb1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
                                      stride=stride, padding=0, groups=groups)
            else:
                self.dbb_avg.add_module(
                    'avg', nn.AvgPool2d(kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding)
                )
            self.dbb_avg.add_module(
                'avgbn', nn.BatchNorm2d(out_channels)
            )

        if internal_channels_1x1_3x3 is None:
            internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels
        
        self.dbb_1x1_kxk = nn.Sequential()
        if internal_channels_1x1_3x3 == in_channels:
            self.dbb_1x1_kxk.add_module('idconv1',
                                        IdentityBasedConv1x1(channels=in_channels, groups=groups))
        else:
            self.dbb_1x1_kxk.add_module('conv1',
                                        nn.Conv2d(in_channels=in_channels,
                                                  out_channels=internal_channels_1x1_3x3,
                                                  kernel_size=1,
                                                  stride=1,
                                                  padding=0,
                                                  groups=groups,
                                                  bias=False))
        self.dbb_1x1_kxk.add_module('bn1',
                                    BNAndPadLayer(pad_pixels=padding,
                                                 num_features=internal_channels_1x1_3x3,
                                                 affine=True))
        self.dbb_1x1_kxk.add_module('conv2',
                                    nn.Conv2d(in_channels=internal_channels_1x1_3x3,
                                              out_channels=out_channels,
                                              kernel_size=kernel_size,
                                              stride=stride,
                                              padding=0,
                                              groups=groups,
                                              bias=True))
        self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))

4.3 forward

前向传播的实现

class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
        super().__init__()
        ...

    def forward(self, inputs):
        if hasattr(self, 'dbb_reparam'):
            return self.nonlinear(self.dbb_reparam(inputs))
        
        out = self.dbb_origin(inputs)
        if hasattr(self, 'dbb_1x1'):
            out += self.dbb_1x1(inputs)
        out += self.dbb_avg(inputs)
        out += self.dbb_1x1(inputs)
        return self.nonlinear(out)

4.4 重参的实现

重参实现过程

class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, internal_channels_1x1_3x3=None, deploy=False, nonlinear=None):
        super().__init__()
        ...

    def forward(self, inputs):
        ...

    def switch_to_deploy(self):
        if hasattr(self, 'dbb_reparam'):
            return

        kernel, bias = self.get_equivalent_kernel_bias()
        self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels,
                                     out_channels=self.dbb_origin.conv.out_channels,
                                     kernel_size=self.dbb_origin.conv.kernel_size,
                                     stride=self.dbb_origin.conv.stride,
                                     padding=self.dbb_origin.conv.padding,
                                     dilation=self.dbb_origin.conv.dilation,
                                     groups=self.dbb_origin.conv.groups,
                                     bias=True)
        
        self.dbb_reparam.weight.data = kernel
        self.dbb_reparam.bias.data = bias
        for para in self.parameters():
            para.detach()
        
        self.__delattr__('dbb_origin')
        self.__delattr__('dbb_avg')
        if hasattr(self, 'dbb_1x1'):
            self.__delattr__('dbb_1x1')
        self.__delattr__('dbb_1x1_kxk')
    
    def get_equivalent_kernel_bias(self):
        k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
                                           self.dbb_1x1.bn)
        
        if hasattr(self, 'dbb_1x1'):
            k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
                                         self.dbb_1x1.bn)
            
            k_1x1 = transVI_multiscale(k_1x1,
                                       self.kernel_size)
        else:
            k_1x1, b_1x1 = 0
        
        if hasattr(self.dbb_1x1_kxk, 'idconv1'):
            k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
        else:
            k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
        
        k_1x1_kxk_first, b_1x1_kxk_first   = transI_fusebn(k_1x1_kxk_first,
                                                           self.dbb_1x1_kxk.bn1)
        k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
                                                           self.dbb_1x1_kxk.bn2)

        k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
                                                              b_1x1_kxk_first,
                                                              k_1x1_kxk_second,
                                                              b_1x1_kxk_second,
                                                              groups=self.groups)
        
        k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
        k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
                                                           self.dbb_avg.avgbn)
        if hasattr(self.dbb_avg, 'conv'):
            k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight,
                                                             self.dbb_avg.bn)
            k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
                                                                  b_1x1_avg_first,
                                                                  k_1x1_avg_second,
                                                                  b_1x1_avg_second,
                                                                  groups=self.groups)
        else:
            k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
        
        return transII_addbranch((k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged),
                                 (b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged))

4.5 模型导出

DBB网络模型的导出和对比

if __name__ == '__main__':
    
    x = torch.randn(1, 4, 224, 224)

    model = DiverseBranchBlock(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=1, groups=2, deploy=False)

    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            nn.init.uniform_(module.running_mean, 0, 0.1)
            nn.init.uniform_(module.running_var, 0, 0.2)
            nn.init.uniform_(module.weight, 0, 0.3)
            nn.init.uniform_(module.bias, 0, 0.4)
        
    model.eval()
    out = model(x)
    torch.onnx.export(model=model, args=x, f='./DBB.onnx', verbose=False)

    model.switch_to_deploy()
    deployout = model(x)

    torch.onnx.export(model=model, args=x, f='./DBB-deploy.onnx', verbose=False)

    print('\nDifference between the outputs of the origin-DBB and rep-DBB is: {}\n'.format(
        ((deployout - out) ** 2).sum()
    ))

5. 完整示例代码

DBB网络重参的完整示例代码如下:

import torch
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch.nn as nn
import torch.nn.functional as F


def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
            padding_mode='zeros'):
    conv_layer = nn.Conv2d(in_channels  = in_channels, 
                           out_channels = out_channels, 
                           kernel_size  = kernel_size,
                           stride       = stride, 
                           padding      = padding, 
                           dilation     = dilation, 
                           groups       = groups,
                           bias         = False, 
                           padding_mode = padding_mode)
    bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
    se = nn.Sequential()
    se.add_module('conv', conv_layer)
    se.add_module('bn', bn_layer)
    return se


def transI_fusebn(kernel, bn):
    '''
    Returns:
    k: the scaled kernel, computed by element-wise multiplying the kernel 
       with the ratio of the scaling factor 
       and the standard deviation, reshaped to have a new first dimension of size -1
    b: the bias term, computed by subtracting the product of the scaling factor 
       and the running mean of the batch normalization layer, 
       normalized by the standard deviation, from the bias of the batch normalization layer
    '''
    gamma = bn.weight
    std = (bn.running_var + bn.eps).sqrt()
    k = kernel * ((gamma / std).view(-1, 1, 1, 1))
    b = bn.bias - bn.running_mean * gamma / std
    return k, b


def transII_addbranch(kernels, biases):
    '''
    Input:
        kernels: tuple
        biases : tuple
    '''
    k = sum(kernels)
    b = sum(biases)
    return k, b


def transIII_1x1_kxk(k1, b1, k2, b2, groups):
    if groups == 1:
        k = F.conv2d(k2, k1.permute(1, 0, 2, 3)) #
        b_hat = (k2 * b1.view(1, -1, 1, 1)).sum((1, 2, 3))
    else:
        # initializes an empty list for storing the results of the 1x1 convolutions.
        k_slices = []
        # initializes an empty list for storing the bias terms for the kxk convolutions
        b_slices = []
        # switch the in_channels and out_channels
        k1_T = k1.permute(1, 0, 2, 3)
        # Compute the numbers of k1-group out channels
        k1_group_width = k1.size(0) // groups
        # Compute the numbers of k2-group out channels
        k2_group_width = k2.size(0) // groups
        # loops over the number of groups
        for g in range(groups):
            # extracts a slice of k1_T that corresponds to the channels in the current group
            k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
            # extracts a slice of k2 that corresponds to the channels in the current group
            k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
            k_slices.append(F.conv2d(k2_slice, k1_T_slice))
            b_slices.append((k2_slice * b1[g*k1_group_width:(g+1)*k1_group_width].view(1, -1, 1, 1)).sum((1, 2, 3)))
        # concatenates the results of the 1x1 convolutions and 
        # the bias terms across the group dimension by calling the transIV_depthconcat function
        k, b_hat = transIV_depthconcat(k_slices, b_slices)
    # returns the concatenated results of the 1x1 convolutions and 
    # the bias terms, with the bias term for the kxk convolution added to b2
    return k, b_hat + b2


def transIV_depthconcat(kernels, biases):
    '''
    Parameters:
        kernels: list
        biases : list
    '''
    return torch.cat(kernels, dim=0), torch.cat(biases)


def transV_avg(channels, kernel_size, groups):
    # Calculate the number of input dimensions for each group
    input_dim = channels // groups
    # Create a tensor of zeros with dimensions (channels, input_dim, kernel_size, kernel_size)
    k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
    # Fill the diagonal blocks of the tensor with the average transform
    k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
    return k


#   This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def transVI_multiscale(kernel, target_kernel_size):
    # Calculate the number of pixels to pad on the height dimension
    H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
    # Calculate the number of pixels to pad on the width dimension
    W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
    return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])


class IdentityBasedConv1x1(nn.Conv2d):
    '''
    This module implements a convolution operation that adds an identity matrix to the weight kernel, 
    allowing it to act as an identity operation in addition to the normal convolutional operation.
    '''
    def __init__(self, channels, groups=1):
        super().__init__(in_channels  = channels,
                         out_channels = channels,
                         kernel_size  = 1,
                         stride       = 1,
                         padding      = 0,
                         groups       = groups,
                         bias         = False)
        # Raises an assertion error if the number of input channels is not divisible by the number of groups
        assert channels % groups == 0
        # Calculates the size of input channel per group
        input_dim = channels // groups
        # Creates an identity matrix with the same size as the weight tensor with the value of 1 
        # for the diagonal elements and 0 for other elements.
        id_value  = np.zeros((channels, input_dim, 1, 1))
        for i in range(channels):
            id_value[i, i % input_dim, 0, 0] = 1
        # Initializes the id_tensor attribute with the identity matrix 
        # and initializes the weight attribute with zeros.
        self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
        nn.init.zeros_(self.weight)

    def forward(self, input):
        # By adding the identity matrix to the weight tensor, 
        # the IdentityBasedConv1x1 module can perform two operations simultaneously: 
        # normal convolution operation and identity operation. 
        # This makes the module more flexible and powerful, 
        # and it can be useful in many applications, such as in residual networks and in neural architecture search
        kernel = self.weight + self.id_tensor.to(self.weight.device)
        result = F.conv2d(input,
                          kernel,
                          None,
                          stride=1,
                          padding=0,
                          dilation=self.dilation,
                          groups=self.groups)
        return result

    def get_actual_kernel(self):
        return self.weight + self.id_tensor.to(self.weight.device)


class BNAndPadLayer(nn.Module):
    def __init__(self,
                 pad_pixels,
                 num_features,
                 eps=1e-5,
                 momentum=0.1,
                 affine=True,
                 track_running_stats=True):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features,
                                 eps,
                                 momentum,
                                 affine,
                                 track_running_stats)
        self.pad_pixels = pad_pixels

    def forward(self, input):
        output = self.bn(input)
        if self.pad_pixels > 0:
            # If the BatchNorm2d layer is affine (i.e. has learnable weights)
            if self.bn.affine:
                # Calculate the padding values using the batch normalization statistics
                pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
                    self.bn.running_var + self.bn.eps)
            # If the BatchNorm2d layer is not affine (i.e. has no learnable weights)
            else:
                # Calculate the padding values based on the batch normalization mean and variance
                pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
            # Pad the output tensor with zeros on all sides
            output = F.pad(output, [self.pad_pixels] * 4)
            # Reshape the padding values to have a size of (1, num_features, 1, 1)
            pad_values = pad_values.view(1, -1, 1, 1)
            # Replace the top padding values with the calculated values
            output[:, :, 0:self.pad_pixels, :] = pad_values
            # Replace the bottom padding values with the calculated values
            output[:, :, -self.pad_pixels:, :] = pad_values
            # Replace the left padding values with the calculated values
            output[:, :, :, 0:self.pad_pixels] = pad_values
            # Replace the right padding values with the calculated values
            output[:, :, :, -self.pad_pixels:] = pad_values
        return output

    @property
    def weight(self):
        return self.bn.weight

    @property
    def bias(self):
        return self.bn.bias

    @property
    def running_mean(self):
        return self.bn.running_mean

    @property
    def running_var(self):
        return self.bn.running_var

    @property
    def eps(self):
        return self.bn.eps


class DiverseBranchBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride = 1, padding   = 0, dilation  = 1, groups = 1,
                 internal_channels_1x1_3x3 = None,
                 deploy = False, nonlinear = None
        ):
        super().__init__()
        self.deploy = deploy

        if nonlinear is None:
            self.nonlinear = nn.Identity()
        else:
            self.nonlinear = nonlinear

        self.kernel_size   = kernel_size
        self.out_channels  = out_channels
        self.groups        = groups
        assert padding == kernel_size // 2

        if deploy:
            self.dbb_reparam = nn.Conv2d(
                in_channels  = in_channels, out_channels = out_channels, kernel_size  = kernel_size,
                stride       = stride,      padding      = padding,      dilation     = dilation,
                groups       = groups,      bias         = True)
        else:
            self.dbb_origin = conv_bn(
                in_channels  = in_channels, out_channels = out_channels, kernel_size  = kernel_size,
                stride       = stride,      padding      = padding,
                dilation     = dilation,    groups       = groups)

            self.dbb_avg = nn.Sequential()
            if groups < out_channels:
                self.dbb_avg.add_module(
                    'conv', nn.Conv2d(in_channels  = in_channels,
                                      out_channels = out_channels,
                                      kernel_size  = 1,
                                      stride       = 1,
                                      padding      = 0,
                                      groups       = groups,
                                      bias         = False))

                self.dbb_avg.add_module(
                    'bn', BNAndPadLayer(pad_pixels   = padding,
                                        num_features = out_channels))

                self.dbb_avg.add_module(
                    'avg', nn.AvgPool2d(kernel_size = kernel_size,
                                        stride      = stride,
                                        padding     = 0))

                self.dbb_1x1 = conv_bn(in_channels  = in_channels,
                                       out_channels = out_channels,
                                       kernel_size  = 1,
                                       stride       = stride,
                                       padding      = 0,
                                       groups       = groups)
            else:
                self.dbb_avg.add_module('avg',
                                        nn.AvgPool2d(kernel_size = kernel_size,
                                                     stride      = stride,
                                                     padding     = padding))

            self.dbb_avg.add_module('avgbn',
                                    nn.BatchNorm2d(out_channels))

            if internal_channels_1x1_3x3 is None:
                # For mobilenet, it is better to have 2X internal channels
                # internal_channels = in_channels or 2*in_channels
                internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels

            self.dbb_1x1_kxk = nn.Sequential()
            if internal_channels_1x1_3x3 == in_channels:
                self.dbb_1x1_kxk.add_module('idconv1',
                                            IdentityBasedConv1x1(channels=in_channels, groups=groups))
            else:
                self.dbb_1x1_kxk.add_module('conv1',
                                            nn.Conv2d(in_channels=in_channels,
                                                      out_channels=internal_channels_1x1_3x3,
                                                      kernel_size=1,
                                                      stride=1,
                                                      padding=0,
                                                      groups=groups,
                                                      bias=False))
            self.dbb_1x1_kxk.add_module('bn1',
                                        BNAndPadLayer(pad_pixels=padding,
                                                      num_features=internal_channels_1x1_3x3,
                                                      affine=True))
            self.dbb_1x1_kxk.add_module('conv2',
                                        nn.Conv2d(in_channels=internal_channels_1x1_3x3, 
                                                  out_channels=out_channels,
                                                  kernel_size=kernel_size,
                                                  stride=stride,
                                                  padding=0,
                                                  groups=groups,
                                                  bias=False))
            self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))


    def forward(self, inputs):
    
        if hasattr(self, 'dbb_reparam'):
            return self.nonlinear(self.dbb_reparam(inputs))

        out = self.dbb_origin(inputs)
        if hasattr(self, 'dbb_1x1'):
            out += self.dbb_1x1(inputs)
        out += self.dbb_avg(inputs)
        out += self.dbb_1x1_kxk(inputs)
        return self.nonlinear(out)
    
    
    def switch_to_deploy(self):
        if hasattr(self, 'dbb_reparam'):
            return
        kernel, bias     = self.get_equivalent_kernel_bias()
        self.dbb_reparam = nn.Conv2d(in_channels  = self.dbb_origin.conv.in_channels,
                                     out_channels = self.dbb_origin.conv.out_channels,
                                     kernel_size  = self.dbb_origin.conv.kernel_size,
                                     stride       = self.dbb_origin.conv.stride,
                                     padding      = self.dbb_origin.conv.padding,
                                     dilation     = self.dbb_origin.conv.dilation,
                                     groups       = self.dbb_origin.conv.groups, 
                                     bias         = True)
        self.dbb_reparam.weight.data = kernel
        self.dbb_reparam.bias.data   = bias
        for para in self.parameters():
            para.detach_()
        self.__delattr__('dbb_origin')
        self.__delattr__('dbb_avg')
        if hasattr(self, 'dbb_1x1'):
            self.__delattr__('dbb_1x1')
        self.__delattr__('dbb_1x1_kxk')
    
    
    def get_equivalent_kernel_bias(self):
    # ================== 1. k_origin, b_origin 
        k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
                                           self.dbb_origin.bn)
        
    # ================== 2. k_1x1_origin, b_1x1_origin 
        if hasattr(self, 'dbb_1x1'):
            # 按照方式1进行conv+bn的融合
            k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
                                         self.dbb_1x1.bn)
            # 按照方式方式6进行多尺度卷积的合并
            k_1x1 = transVI_multiscale(k_1x1,
                                       self.kernel_size)
        else:
            k_1x1, b_1x1 = 0, 0

    # ================== 3. k_1x1_kxk_merged, b_1x1_kxk_merged 
        if hasattr(self.dbb_1x1_kxk, 'idconv1'):
            k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
        else:
            k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
        # 按照方式1进行conv+bn的融合
        k_1x1_kxk_first, b_1x1_kxk_first   = transI_fusebn(k_1x1_kxk_first,
                                                           self.dbb_1x1_kxk.bn1)
        # 按照方式1进行conv+bn的融合
        k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
                                                           self.dbb_1x1_kxk.bn2)
        # 按照方式3进行1x1卷积与kxk卷积的合并
        k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
                                                              b_1x1_kxk_first,
                                                              k_1x1_kxk_second,
                                                              b_1x1_kxk_second,
                                                              groups=self.groups)
        
    # ================== 4. k_1x1_avg_merged, b_1x1_avg_merged 
        k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
        # 按照方式1进行conv+bn的融合
        k_1x1_avg_second, b_1x1_avg_second     = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
                                                               self.dbb_avg.avgbn)
        if hasattr(self.dbb_avg, 'conv'):
            # 按照方式1进行conv+bn的融合
            k_1x1_avg_first, b_1x1_avg_first   = transI_fusebn(self.dbb_avg.conv.weight,
                                                               self.dbb_avg.bn)
            # 按照方式3进行1x1卷积与kxk卷积的合并
            k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
                                                                  b_1x1_avg_first,
                                                                  k_1x1_avg_second,
                                                                  b_1x1_avg_second,
                                                                  groups=self.groups)
        else:
            k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
            
    # ================== 5. Final merge
        return transII_addbranch((k_origin,
                                  k_1x1,
                                  k_1x1_kxk_merged,
                                  k_1x1_avg_merged),
                                 (b_origin,
                                  b_1x1,
                                  b_1x1_kxk_merged,
                                  b_1x1_avg_merged))

    

if __name__ == '__main__':
    x = torch.randn(1, 4, 224, 224)

    model = DiverseBranchBlock(in_channels=4, out_channels=32, kernel_size=3, stride=1, padding=3//2,
                               groups=2, deploy=False)
    for module in model.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            nn.init.uniform_(module.running_mean, 0, 0.1)
            nn.init.uniform_(module.running_var, 0, 0.2)
            nn.init.uniform_(module.weight, 0, 0.3)
            nn.init.uniform_(module.bias, 0, 0.4)
      
            
    model.eval()
    out = model(x)
    # print(model)
    torch.onnx.export(model=model, args=x, f='../DBB.onnx', 
                      verbose=False)
    
    
    model.switch_to_deploy()
    deployout = model(x)
    # print(model)
    torch.onnx.export(
        model=model, args=x, f='../DBB-deploy.onnx', 
        verbose=False)

    print('\nDifference between the outputs of the origin-DBB and rep-DBB is: {}\n'.format(
        ((deployout - out) ** 2).sum()
    ))

总结

本次课程学习了 DBB 网络的重参,与 ACNet 的卷积替换相比,DBB 网络提出了更为复杂的类似 Inception 的多分支结构,并在推理阶段采用6种变换进行重参数化,融合成一个主分支,加快推理速度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/482881.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringBoot】 整合RabbitMQ 消息单独以及批量的TTL

生产者端 目录结构 导入依赖 修改yml 业务逻辑 队列消息过期 消息单独过期 TTL&#xff08;Time To Live&#xff09;存活时间。表示当消息由生产端存入MQ当中的存活时间&#xff0c;当时间到达的时候还未被消息就会被自动清除。RabbitMQ可以对消息单独设置过期时间也可以对…

爬虫 - QS世界大学排名数据

爬虫 - QS世界大学排名数据 网站简介爬虫方法概述使用工具爬虫概述 第一部分导入需要用到的python包设置selenium控制浏览器打开网页控制鼠标操作定位节点 提取数据滚轮翻页构建循环自动爬取数据数据储存 第二部分导入需要用到的python包获取网页设置请求头读取链接获取网页信息…

TIM-定时器——STM32

TIM-定时器——STM32 TIM(Timer)定时器 定时器可以对输入的时钟进行计数&#xff0c;并在计数值达到设定值时触发中断 16位计数器、预分频器、自动重装寄存器的时基单元&#xff0c;在72MHz计数时钟下可以实现最大59.65s的定时 不仅具备基本的定时中断功能&#xff0c;而且还包…

K8S第二讲 Kubernetes集群简易版搭建步骤

Kubernetes集群搭建步骤 1&#xff1a;准备物理或虚拟机器 为Kubernetes集群准备物理或虚拟机器。至少需要一个控制节点&#xff08;Master Node&#xff09;和一个工作节点&#xff08;Worker Node&#xff09;&#xff0c;建议使用Linux操作系统。 2&#xff1a; 安装Dock…

1987-2021年全国各省进出口总额数据含进口总额和出口总额

1987-2021年全国各省进出口总额数据含进口和出口 1、时间&#xff1a;1987-2021年 2、范围&#xff1a;包括全国30个省不含西藏 3、指标&#xff1a;进出口总额、进口总额、出口总额 4、单位&#xff1a;万美元 5、来源&#xff1a;各省NJ、JIN rong统计NJ 6、缺失情况说…

递归算法及经典例题详解

大部分人在学习编程时接触的第一个算法应该就是递归了&#xff0c;递归的思想其实很好理解&#xff0c;就是将一个问题拆分为若干个与本身相似的子问题&#xff0c;通过不断调用自身来求解。 但很多新手在实际操作中却很难正确使用到递归&#xff0c;有时面对问题还会有种无从…

win7下java环境搭建以及jdk环境变量配置

很多人在搭建页游、手游时候经常遇到JAVA闪退,基本都是环境变量或者路径错误导致的。本章节主要讲解在win7系统环境下,java环境变量配置方法,java环境配置正确,才可以对apk程序进行反编译运行页游手游。其他操作系统环境变量大同小异参考下就会了。 安装教程: 1、直接运…

让语言学习更简单的 WordFlow

作为一个英语并不是那么特别好的计算机专业学生&#xff0c;长期积累英语的学习对个人发展还是有意义的。简单来说&#xff0c;我在语言上最大的两个问题&#xff0c;一个自己「不理解」&#xff0c;另一个是自己「不会表达」。 上述两个问题主要体现在口语层面&#xff0c;而…

1997-2021年全国30省技术市场成交额(亿元)

1997-2021年全国30省技术市场成交额 1、时间&#xff1a;1997-2021年 2、范围&#xff1a;30省不含西藏 3、来源&#xff1a;国家统计J 4、指标&#xff1a;技术市场成交额 5、缺失情况说明&#xff1a;无缺失 6、指标解释及用途&#xff1a; 技术市场成交额是一个客观、…

YOLOv5 训练自己的数据集

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客 &#x1f356; 原作者&#xff1a;K同学啊|接辅导、项目定制 ● 难度&#xff1a;夯实基础⭐⭐ ● 语言&#xff1a;Python3、Pytorch3 ● 时间&#xff1a;5月1日-5月6日 &#x1f37a;要求&#xff1…

基于C#开发 B/S架构的实验室管理系统 云LIS系统(MVC + SQLserver + Redis)

一、云LIS系统是将各种样本、免疫、临检、放免、及实验用的分析仪器&#xff0c;通过网络管理和传输实验分析过程中全部数据。对每一专业&#xff0c;实现检验申请、样本采集、样本核收、联机检验、质量控制、报告审核到报告发布的全环节的信息化管理平台。 二、基于B/S架构的云…

【SpringMVC】| SpringMVC注解式开发

目录 一&#xff1a;SpringMVC注解式开发 1. RequestMapping定义请求规则 2. 五种数据提交的方式 3. 请求参数中文乱码解决 4. action方法的返回值 5. SpringMVC的四种跳转方式 6. SpringMVC支持的默认参数类型 7. 日期处理 8. 标签的使用 9. 资源在WEB-INF目录下 一…

常见三种编码方式

常见三种编码方式 1. one-hot 编码2. 虚拟编码3. 效果编码 最近复习一些书&#xff0c;记录一下。在特征工程中&#xff0c;数据集经常会出现分类变量&#xff0c;这时候的分类变量可能是字符型&#xff0c;通常不能直接用于训练模型&#xff0c;这时需要对分类变量进行编码&am…

JavaWeb《CSS》

本笔记学习于Acwing平台 目录 1. 样式定义方式 2.1 行内样式表&#xff08;inline style sheet&#xff09; 2.2 内部样式表&#xff08;internal style sheet&#xff09; 2.3 外部样式表&#xff08;external style sheet&#xff09; 2. 选择器 2.1 标签选择器 2.2 …

ChatGPT服务器配置部署-chatGPT国内入口搭建

chatGPT国内入口 ChatGPT是由OpenAI公司开发的一种自然语言生成模型&#xff0c;国内入口一般是通过API接口或者SDK对接实现的。具体的对接方式可以参考以下步骤&#xff1a; 了解ChatGPT的API接口或者SDK: 首先需要了解ChatGPT提供的API接口或者SDK&#xff0c;包括使用方式、…

文件上传漏洞靶场

目录 第一关 源码 前端 后端 代码审计 前端 后端 绕过原理 抓包后未修改 抓包后修改且文件上传成功 第二关 源码 后端 代码审计 绕过原理 抓包后未修改 抓包后修改且文件上传成功 ​编辑 第三关 源码 后端 代码审计 绕过原理 第四关 源码 后端 代码审…

linux以太网(二)

内核版本&#xff1a;linux-3.14.16 基于imx6 一、文件fec_main.c分析 路径&#xff1a;drivers\net\ethernet\freescale\fec_main.c 1、platform总线 标准的平台总线使用方式 设备树匹配 设备树节点 2、平台总线probe 1&#xff09;分配net_device相关结构 分配 与平…

单源最短路问题

全部代码 全部代码在github acwing 上 正在更新 https://github.com/stolendance/acwing 图论 欢迎大家star与fork 单源最短路问题 先用spfa算法 不行再换其他的 spfa-超级万能 说不定比dijsktra还快 dis[] 代表第k到某一点的最短距离 queue 代表刚被更新的点 它有可能更…

【Java校招面试】基础知识(三)——多线程与并发

目录 前言一、基础概念二、互斥锁三、Java内存模型&#xff08;JMM&#xff09;四、线程池后记 前言 本篇主要介绍Java多线程与并发相关内容。 “基础知识”是本专栏的第一个部分&#xff0c;本篇博文是第三篇博文&#xff0c;如有需要&#xff0c;可&#xff1a; 点击这里&a…

每日一题——反转字符串—II

每日一题 反转字符串——II 题目链接 思路 我们先来举几个例子来理解题目意思 字符串“ abcdefgh ”&#xff0c;k 2&#xff0c;那么依据题目意思&#xff0c;反转后的字符串应该是“ bacdfegh ”&#xff08;即每2k个字符&#xff0c;就反转前k个字符&#xff0c;且无剩余…