RepConv是Yolov7,YoloV9中一个重要的结构,其主要用于在保持精度不退化的情况下提升推理速度。RepConv在学习阶段和推理阶段拥有不同的结构,这使得其推理阶段的复杂度大大降低。这项技术的核心在于重参数化。
一、重参数化
重参数化的主要思想是将卷积(Conv)和归一化(BN)融合在一起以减少运算量。Conv和BN的公式如下:
...①
...②,其中是可学习参数,BN可以拆分为下列公式:
均值:
方差:
归一化:
一般的卷积结构如上图所示,由Conv和BN叠加而成,其表达可由式①和式②融合:
,并将之化简后可得:
,其中可以视为w(),可以视为b(),即可将两个式子简化为一个卷积形式:
二、RepConv
RepConv的简化思路是:将不同尺寸的卷积统一为3x3->将3x3卷积和BN融合起来。就像下图展示的一样。
RepConv的重参数代码如下:
def fuse_repvgg_block(self):
if self.deploy:
return
print(f"RepConv.fuse_repvgg_block")
self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1]) # 3*3卷积
self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
rbr_1x1_bias = self.rbr_1x1.bias # ch_out
# 填充[1,1,1,1]表示在左右上下个填充1个单位,即第三四维(h,w)各增加2
weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1]) # co*ci*(ks+2)*(ks+2)
# Fuse self.rbr_identity
if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
# 0*0支路是BatchNorm2d或SyncBatchNorm的前提是out_channels=in_channels,在RepConv的__init__()中可以看到
identity_conv_1x1 = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=self.groups,
bias=False) # (co, ci, 1, 1)
identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze() # (co, ci)
identity_conv_1x1.weight.data.fill_(0.0)
identity_conv_1x1.weight.data.fill_diagonal_(1.0) # 变成一个单位阵,每个元素可看成一个1*1卷积片
identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3) # (co, ci, 1, 1), 现在我们得到了一个0*0卷积
identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity) # 与BN融合
bias_identity_expanded = identity_conv_1x1.bias
weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1]) # 将每个1*1卷积片pad一圈0,即在第3、4维各加2
else:
# channels_out不等于channels_in,零矩阵
bias_identity_expanded = torch.nn.Parameter(torch.zeros_like(rbr_1x1_bias))
weight_identity_expanded = torch.nn.Parameter(torch.zeros_like(weight_1x1_expanded))
self.rbr_dense.weight = torch.nn.Parameter(
self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
self.rbr_reparam = self.rbr_dense
self.deploy = True
if self.rbr_identity is not None:
del self.rbr_identity
self.rbr_identity = None
if self.rbr_1x1 is not None:
del self.rbr_1x1
self.rbr_1x1 = None
if self.rbr_dense is not None:
del self.rbr_dense
self.rbr_dense = None