论文地址:https://arxiv.org/abs/2003.13549
代码地址:https://github.com/zeiss-microscopy/BSConv
1.是什么?
BSConv是深度可分离卷积DSConv的升级版本,它更好地利用内核内部相关性来实现高效分离。具体而言,BSConvU是将一个标准的卷积分解为1x1卷积(PW)和一个逐通道卷积,是深度可分离卷积(DSConv—逐通道、逐点)的逆向版本。此外,BSConv还有一个变体操作—BSConvS。
2.为什么?
受启发于预训练模型的核属性的量化分析:深度方向的强相关性。作者提出一种“蓝图分离卷积”(blueprint separable convolutions, BSConv)作为高效CNN的构建模块。
基于该发现,作者构建了一套理论基础并用于推导如何采用标准OP进行高效实现。更进一步,所提方法为深度分离卷积的应用(深度分离卷积已成为当前主流网络架构的核心模块)提供了系统的理论推导、可解释性以及原因分析。最后,作者揭示了基于深度分离卷积的网络架构(如MobileNet)隐式的依赖于跨核相关性;而所提BSConv则基于核内相关性,故可以为常规卷积提供一种更有效的拆分。
作者通过充分的实验(大尺度分类与细粒度分类)验证了所提BSConv可以明显的提升MobileNet以及其他基于深度分离卷积的架构的性能,而不会引入额外的复杂度。对于细粒度问题,所提方法取得13.7%的性能提升;在ImageNet分类任务,BSConv在ResNet的“即插即用”取得了9.5%的性能提升。
3.怎么样?
3.1网络结构
在标准卷积中,每个卷积层对输入张量进行变化得到输出张量,相应的卷积核,每个卷积核的尺寸为M*K*K。相应的公式可以描述为(图示见下图):
这些卷积核将通过反向传播方式进行优化训练。
预训练CNN中的卷积核可以通过一个模板以及M个因子进行近似。该发现也是本文提的(blueprint separable convolutions
,BSConv)的驱动源泉,它滤波器卷积提供另一种定义方式。
尽管上述定义为滤波器添加了硬约束,但作者通过实验表明:相比标准卷积,所提方法可以达到相同甚至更优的性能。另外,需要注意的是:标准卷积的可训练参数为,而所提方法仅具有个可训练参数。
3.2 Variants and Implementations
前面已经介绍了BSConv的卷积核信息,它的权值可以组合为矩阵。此时根据W的学习方式不同,又有两种不同的变种。
-
BSConv-U:在大多场景下,权值W可以不进行任何约束进行训练学习。此时,公式(1)可以转换为如下公式。此时,常规卷积1*1可以解耦为卷积K*K深度卷积,见下图。
对于这种形式的CNN架构,作者发现:权值W在行方向存在高度相关性。这为进一步的正则化与参数降低提供了可能。也就引出了下面将要介绍的BSConv-S变种。
-
BSConv-S:基于前述发现,作者对权值W进行低秩分解:。其中.而后,经过一些列的变换处理,最终BSConv的公式转换为下面的公式。此时,常规卷积可以解耦为1*1卷积+1*1卷积+K*K深度卷积,见上图。
3.3 Discussion
前面已经介绍了BSConv的两种变种,这里将对比分析一下上述两种变种与已有模块的区别和联系。
-
BSConv-U是一种逆深度分类卷积。两者的出发点有一些区别:DSConv实施了跨核相关性,而BSConv-U则实施了核内相关性。已有研究表明:尽管跨核相关性与核内相关性都是有效假设,但核内相关性更有优势,对于高效分离更具潜力。需要注意的是:卷积后不跟激活函数或者规范化函数。
-
BSConv-S是一种具有正交正则化功能的转移线性瓶颈模块。线性瓶颈层是当前高效网络MobileNet的核心模块,它由
pointwise、depthwise、pointwise
级联构成,而BSConv-S则是由pointwise, pointwise, depthwise
级联构成。从中可以看到两者之间的紧密联系。此外,需要注意的是:与前者相同,激活函数与规范化函数不在模块内添加
3.4代码实现
class BSConvU(torch.nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", with_bn=False, bn_kwargs=None):
super().__init__()
# check arguments
if bn_kwargs is None:
bn_kwargs = {}
# pointwise
self.add_module("pw", torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
))
# batchnorm
if with_bn:
self.add_module("bn", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))
# depthwise
self.add_module("dw", torch.nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=out_channels,
bias=bias,
padding_mode=padding_mode,
))
class BSConvS(torch.nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", p=0.25, min_mid_channels=4, with_bn=False, bn_kwargs=None):
super().__init__()
# check arguments
assert 0.0 <= p <= 1.0
mid_channels = min(in_channels, max(min_mid_channels, math.ceil(p * in_channels)))
if bn_kwargs is None:
bn_kwargs = {}
# pointwise 1
self.add_module("pw1", torch.nn.Conv2d(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=(1, 1),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
))
# batchnorm
if with_bn:
self.add_module("bn1", torch.nn.BatchNorm2d(num_features=mid_channels, **bn_kwargs))
# pointwise 2
self.add_module("pw2", torch.nn.Conv2d(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=1,
padding=0,
dilation=1,
groups=1,
bias=False,
))
# batchnorm
if with_bn:
self.add_module("bn2", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))
# depthwise
self.add_module("dw", torch.nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=out_channels,
bias=bias,
padding_mode=padding_mode,
))
def _reg_loss(self):
W = self[0].weight[:, :, 0, 0]
WWt = torch.mm(W, torch.transpose(W, 0, 1))
I = torch.eye(WWt.shape[0], device=WWt.device)
return torch.norm(WWt - I, p="fro")
class BSConvS_ModelRegLossMixin():
def reg_loss(self, alpha=0.1):
loss = 0.0
for sub_module in self.modules():
if hasattr(sub_module, "_reg_loss"):
loss += sub_module._reg_loss()
return alpha * loss
参考:
深度分离卷积重思考:BSConv
轻量化神经网络卷积设计研究进展