蓝图分离卷积BSConv 学习笔记 (附代码)

news2024/11/27 0:50:08

论文地址:https://arxiv.org/abs/2003.13549

代码地址:https://github.com/zeiss-microscopy/BSConv

1.是什么?

BSConv是深度可分离卷积DSConv的升级版本,它更好地利用内核内部相关性来实现高效分离。具体而言,BSConvU是将一个标准的卷积分解为1x1卷积(PW)和一个逐通道卷积,是深度可分离卷积(DSConv—逐通道、逐点)的逆向版本。此外,BSConv还有一个变体操作—BSConvS。

2.为什么?

受启发于预训练模型的核属性的量化分析:深度方向的强相关性。作者提出一种“蓝图分离卷积”(blueprint separable convolutions, BSConv)作为高效CNN的构建模块。

基于该发现,作者构建了一套理论基础并用于推导如何采用标准OP进行高效实现。更进一步,所提方法为深度分离卷积的应用(深度分离卷积已成为当前主流网络架构的核心模块)提供了系统的理论推导、可解释性以及原因分析。最后,作者揭示了基于深度分离卷积的网络架构(如MobileNet)隐式的依赖于跨核相关性;而所提BSConv则基于核内相关性,故可以为常规卷积提供一种更有效的拆分。

作者通过充分的实验(大尺度分类与细粒度分类)验证了所提BSConv可以明显的提升MobileNet以及其他基于深度分离卷积的架构的性能,而不会引入额外的复杂度。对于细粒度问题,所提方法取得13.7%的性能提升;在ImageNet分类任务,BSConv在ResNet的“即插即用”取得了9.5%的性能提升。

3.怎么样?

3.1网络结构

在标准卷积中,每个卷积层对输入张量U\epsilon R^{M*Y*X}进行变化得到输出张量V\epsilon R^{N*Y*X},相应的卷积核F^{(1)},...,F^{(N)},每个卷积核的尺寸为M*K*K。相应的公式可以描述为(图示见下图):

这些卷积核将通过反向传播方式进行优化训练。

预训练CNN中的卷积核可以通过一个模板以及M个因子进行近似。该发现也是本文提的(blueprint separable convolutions,BSConv)的驱动源泉,它滤波器卷积提供另一种定义方式。

尽管上述定义为滤波器添加了硬约束,但作者通过实验表明:相比标准卷积,所提方法可以达到相同甚至更优的性能。另外,需要注意的是:标准卷积的可训练参数为M\cdot N\cdot K^{2},而所提方法仅具有N\cdot K^{2}+ M\cdot N个可训练参数。

3.2 Variants and Implementations

前面已经介绍了BSConv的卷积核信息,它的权值M\cdot N可以组合为矩阵W=(w_{n,m})。此时根据W的学习方式不同,又有两种不同的变种。

  • BSConv-U:在大多场景下,权值W可以不进行任何约束进行训练学习。此时,公式(1)可以转换为如下公式。此时,常规卷积1*1可以解耦为卷积K*K深度卷积,见下图。

 对于这种形式的CNN架构,作者发现:权值W在行方向存在高度相关性。这为进一步的正则化与参数降低提供了可能。也就引出了下面将要介绍的BSConv-S变种。

  • BSConv-S:基于前述发现,作者对权值W进行低秩分解:W = W^{A}*W^{B}。其中W^{A}=N*M',W^{B}=M'*M,M'=[p\cdot M],p\epsilon (0.0,1.0).而后,经过一些列的变换处理,最终BSConv的公式转换为下面的公式。此时,常规卷积可以解耦为1*1卷积+1*1卷积+K*K深度卷积,见上图。

3.3  Discussion

前面已经介绍了BSConv的两种变种,这里将对比分析一下上述两种变种与已有模块的区别和联系。

  • BSConv-U是一种逆深度分类卷积。两者的出发点有一些区别:DSConv实施了跨核相关性,而BSConv-U则实施了核内相关性。已有研究表明:尽管跨核相关性与核内相关性都是有效假设,但核内相关性更有优势,对于高效分离更具潜力。需要注意的是:卷积后不跟激活函数或者规范化函数。

  • BSConv-S是一种具有正交正则化功能的转移线性瓶颈模块。线性瓶颈层是当前高效网络MobileNet的核心模块,它由pointwise、depthwise、pointwise级联构成,而BSConv-S则是由pointwise, pointwise, depthwise级联构成。从中可以看到两者之间的紧密联系。此外,需要注意的是:与前者相同,激活函数与规范化函数不在模块内添加

3.4代码实现

class BSConvU(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", with_bn=False, bn_kwargs=None):
        super().__init__()

        # check arguments
        if bn_kwargs is None:
            bn_kwargs = {}

        # pointwise
        self.add_module("pw", torch.nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=(1, 1),
                stride=1,
                padding=0,
                dilation=1,
                groups=1,
                bias=False,
        ))

        # batchnorm
        if with_bn:
            self.add_module("bn", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))

        # depthwise
        self.add_module("dw", torch.nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                groups=out_channels,
                bias=bias,
                padding_mode=padding_mode,
        ))


class BSConvS(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, padding_mode="zeros", p=0.25, min_mid_channels=4, with_bn=False, bn_kwargs=None):
        super().__init__()

        # check arguments
        assert 0.0 <= p <= 1.0
        mid_channels = min(in_channels, max(min_mid_channels, math.ceil(p * in_channels)))
        if bn_kwargs is None:
            bn_kwargs = {}

        # pointwise 1
        self.add_module("pw1", torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=mid_channels,
            kernel_size=(1, 1),
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=False,
        ))

        # batchnorm
        if with_bn:
            self.add_module("bn1", torch.nn.BatchNorm2d(num_features=mid_channels, **bn_kwargs))

        # pointwise 2
        self.add_module("pw2", torch.nn.Conv2d(
            in_channels=mid_channels,
            out_channels=out_channels,
            kernel_size=(1, 1),
            stride=1,
            padding=0,
            dilation=1,
            groups=1,
            bias=False,
        ))

        # batchnorm
        if with_bn:
            self.add_module("bn2", torch.nn.BatchNorm2d(num_features=out_channels, **bn_kwargs))

        # depthwise
        self.add_module("dw", torch.nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=out_channels,
            bias=bias,
            padding_mode=padding_mode,
        ))

    def _reg_loss(self):
        W = self[0].weight[:, :, 0, 0]
        WWt = torch.mm(W, torch.transpose(W, 0, 1))
        I = torch.eye(WWt.shape[0], device=WWt.device)
        return torch.norm(WWt - I, p="fro")


class BSConvS_ModelRegLossMixin():
    def reg_loss(self, alpha=0.1):
        loss = 0.0
        for sub_module in self.modules():
            if hasattr(sub_module, "_reg_loss"):
                loss += sub_module._reg_loss()
        return alpha * loss

参考:

深度分离卷积重思考:BSConv

轻量化神经网络卷积设计研究进展

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1167048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

lerna + vite + typescript 多库,多应用共存项目脚手架模板

最近想把多个代码仓进行合并&#xff0c;形成一个大的代码仓&#xff0c;需要将各个库以及应用放在一个项目下&#xff0c;统一打包管理。会形成如下文件结构&#xff1a; 在网上找了一圈&#xff0c;没有找到合适的脚手架模板。索性自己弄一个吧&#xff0c;开源一下&#x…

云产品OSS免费试用获取奖励步骤

文章目录 1、获取活动链接2、报名参加3、试用产品领取产品试用权限上传文件开启加速传输提交作品 4、提交任务获取奖励 1、获取活动链接 活动时间2023.11.1&#xff5e;2023.11.30 名额有限&#xff0c;先到先得 进群群主获取活动链接 2、报名参加 直接点击链接进入小程序进…

Centralized Feature Pyramid for Object Detection解读

Centralized Feature Pyramid for Object Detection 问题 主流的特征金字塔集中于层间特征交互&#xff0c;而忽略了层内特征规则。尽管一些方法试图在注意力机制或视觉变换器的帮助下学习紧凑的层内特征表示&#xff0c;但它们忽略了对密集预测任务非常重要的被忽略的角点区…

自动驾驶算法(四):RRT*算法讲解与代码实现(基于采样的路径规划)

目录 1 RRT算法和RRT*算法 2 RRT*代码相比于RRT的改进 3 RRT*完整代码 1 RRT算法和RRT*算法 从上篇博客我们可以看出&#xff0c;RRT算法找到最短路径特别快。因为它是一段一段的过去的&#xff0c;但同时它产生的路径也是非常糟糕、随机的只要找到了终点就会结束。 因此我们…

INFINI Labs 产品更新 | Agent 全新重构,优化指标采集,支持集中配置管理,支持动态下发等功能

INFINI Labs 产品又更新啦~ 本次更新主要有 Agent、Console、Loadgen 等产品&#xff0c;其中 Agent 进行全新重构升级&#xff0c;新版限制了 CPU 资源消耗&#xff0c;优化了内存&#xff0c;相比旧版内存使用率降低 10 倍&#xff0c;极大的降低了对宿主服务器造成资源占用…

Postgresql在linux环境下以源码方式安装

linux环境下源码方式的安装 1.下载安装包&#xff08;源码安装方式&#xff09; 安装包下载 https://www.postgresql.org/ftp/source/ 2.安装postgresql ① 创建安装目录 mkdir /opt/pgsql12② 解压下载的安装包 cd /opt/pgsql12 tar -zxvf postgresql-12.16.tar.gz ③编…

【UE5 Cesium】actor随着视角远近来变化其本身大小

效果 步骤 1. 首先我将“DynamicPawn”设置为默认的pawn类 2. 新建一个父类为actor的蓝图&#xff0c;添加一个静态网格体组件 当事件开始运行后添加一个定时器&#xff0c;委托给一个自定义事件&#xff0c;每2s执行一次&#xff0c;该事件每2s获取一下“DynamicPawn”和acto…

【优秀毕设】基于vue+ssm+springboot的校园交友网站系统设计(附源码论文)

摘要 随着信息技术和网络技术的飞速发展&#xff0c;人类已进入全新信息化时代&#xff0c;传统管理技术已无法高效&#xff0c;便捷地管理信息。为了迎合时代需求&#xff0c;优化管理效率&#xff0c;各种各样的管理系统应运而生&#xff0c;各行各业相继进入信息管理时代&a…

智能客服系统应用什么技术?

随着科技的飞速发展&#xff0c;智能客服系统逐渐出现在我们的生活中。这些系统不仅能够提供即时的客户服务&#xff0c;还可以通过人工智能等技术实现更加高效和准确的服务。那么&#xff0c;智能客服系统究竟应用了哪些技术呢&#xff1f;本文将详细解析。 1、机器学习技术 …

专访虚拟人科技:如何利用 3DCAT 实时云渲染打造元宇宙空间

自古以来&#xff0c;人们对理想世界的探索从未停止&#xff0c;而最近元宇宙的热潮加速了这一步伐&#xff0c;带来了许多新的应用。作为元宇宙的关键入口&#xff0c;虚拟现实&#xff08;VR&#xff09;将成为连接虚拟和现实的桥梁。苹果发布的VISION PRO头戴设备将人们对VR…

YOLOv8改进之更换BiFPN并融合P2小目标检测层

目录 1. BiFPN 1.1 FPN的演进 2. YOLOv8改进之更换BiFPN并融合P2小目标检测层 1. BiFPN BiFPN&#xff08;Bi-directional Feature Pyramid Network&#xff09;是一种用于目标检测和语义分割任务的神经网络架构&#xff0c;旨在改善特征金字塔网络&#xff08;Feature Pyram…

京东平台3个热门API接口的分享【附代码实例】

应用程序接口API&#xff08;Application Programming Interface&#xff09;&#xff0c;是提供特定业务输出能力、连接不同系统的一种约定。这里包括外部系统与提供服务的系统&#xff08;中后台系统&#xff09;或后台不同系统之间的交互点。包括外部接口、内部接口&#xf…

微服务接口测试中的参数传递

这是一个微服务蓬勃发展的时代。在微服务测试中&#xff0c;最典型的一种场景就是接口测试&#xff0c;其目标是验证微服务对客户端或其他微服务暴露的接口是否能够正常工作。对于最常见的基于Restful风格的微服务来说&#xff0c;其对外暴露的接口就是HTTP端点(Endpoint)。 这…

基于SSM的餐饮掌上设备点餐系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

纸白银可以双向交易吗?

纸白银是指以合约形式进行交易的白银&#xff0c;投资者可以通过期货市场进行买入和卖出操作。因此&#xff0c;纸白银是可以双向交易的。投资者既可以选择做多&#xff08;买入&#xff09;纸白银合约&#xff0c;也可以选择做空&#xff08;卖出&#xff09;纸白银合约&#…

Qt中正确的设置窗体的背景图片的几种方式

Qt中正确的设置窗体的背景图片的几种方式 QLabel加载图片方式之一Chapter1 Qt中正确的设置窗体的背景图片的几种方式一、利用styleSheet设置窗体的背景图片 Chapter2 Qt的主窗口背景设置方法一&#xff1a;最简单的方式是通过ui界面来设置&#xff0c;例如设置背景图片方法二 &…

默认路由配置

默认路由&#xff1a; 在末节路由器上使用。&#xff08;末节路由器是前往其他网络只有一条路可以走的路由器&#xff09; 默认路由被称为最后的关卡&#xff0c;也就是静态路由不可用并且动态路由也不可用&#xff0c;最后就会选择默认路由。有时在末节路由器上写静态路由时…

重磅发布|美创科技新一代 数据安全管理平台(DSM Cloud)全新升级

重磅发布 新一代 数据安全管理平台&#xff08;DSM Cloud&#xff09; 美创科技新一代 数据安全管理平台&#xff08;简称&#xff1a;DSM Cloud&#xff09;全新升级&#xff0c;正式发布。 在业务上云飞速发展过程中&#xff0c;快速应对数据激增&#xff0c;同时有效保障数…

基于LDA主题+协同过滤+矩阵分解算法的智能电影推荐系统——机器学习算法应用(含python、JavaScript工程源码)+MovieLens数据集(一)

目录 前言总体设计系统整体结构图系统流程图 运行环境Python环境Pycharm 环境数据库 相关其它博客工程源代码下载其它资料下载 前言 前段时间&#xff0c;博主分享过关于一篇使用协同过滤算法进行智能电影推荐系统的博文《基于TensorFlowCNN协同过滤算法的智能电影推荐系统——…