即插即用模块之DO-Conv(深度过度参数化卷积层)详解

news2024/11/24 8:49:59

目录

一、摘要

二、核心创新点

三、代码详解

四、实验结果

4.1Image Classification

4.2Semantic Segmentation

4.3Object Detection 

五、总结


论文:DOConv论文

代码:DOConv代码

一、摘要

卷积层是卷积神经网络(cnn)的核心组成部分。在本文中,我们建议用额外的深度卷积来增强卷积层,其中每个输入通道与不同的二维核进行卷积。这两个卷积的组合构成了一个过度参数化,因为它增加了可学习的参数,而结果的线性操作可以用单个卷积层来表示。我们把这个深度过度参数化的卷积层称为DO-Conv。我们通过大量的实验表明,仅仅用DO-Conv层替换传统的卷积层就可以提高cnn在许多经典视觉任务上的性能,例如图像分类、检测和分割。此外,在推理阶段,深度卷积被折叠成常规卷积,将计算量减少到完全等同于卷积层的计算量,而没有过度参数化。由于DO-Conv在不增加推理计算复杂度的情况下引入了性能提升,我们主张将其作为传统卷积层的替代方案。

二、核心创新点

深度过参数化卷积层(DO-Conv)是一个具有可训练kernel深度卷积和一个具有可训练常规卷积的组合。给定一个输入, DO-Conv算子的输出与卷积层相同,是一个同维特征。DO-Conv算子是深度卷积算子和卷积算子的复合,如图所示,有两种数学上等价的方法来实现复合:特征复合(a)和核复合(b)。

三、代码详解

# 使用 utf-8 编码
# 导入必要的库
import math  # 导入数学库
import torch  # 导入 PyTorch 库
import numpy as np  # 导入 NumPy 库
from torch.nn import init  # 导入 PyTorch 中的初始化函数
from itertools import repeat  # 导入 itertools 库中的 repeat 函数
from torch.nn import functional as F  # 导入 PyTorch 中的函数式接口
from torch._jit_internal import Optional  # 导入 PyTorch 中的可选模块
from torch.nn.parameter import Parameter  # 导入 PyTorch 中的参数类
from torch.nn.modules.module import Module  # 导入 PyTorch 中的模块类
import collections  # 导入 collections 库

# 定义自定义模块 DOConv2d
class DOConv2d(Module):
    """
       DOConv2d 可以作为 torch.nn.Conv2d 的替代。
       接口与 Conv2d 类似,但有一个例外:
            1. D_mul:超参数的深度乘法器。
       请注意,groups 参数在 DO-Conv(groups=1)、DO-DConv(groups=in_channels)、DO-GConv(其他情况)之间切换。
    """
    # 常量声明
    __constants__ = ['stride', 'padding', 'dilation', 'groups',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size', 'D_mul']
    # 注解声明
    __annotations__ = {'bias': Optional[torch.Tensor]}

    # 初始化函数
    def __init__(self, in_channels, out_channels, kernel_size, D_mul=None, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(DOConv2d, self).__init__()

        # 将 kernel_size、stride、padding、dilation 转化为二元元组
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        # 检查 groups 是否合法
        if in_channels % groups != 0:
            raise ValueError('in_channels 必须能被 groups 整除')
        if out_channels % groups != 0:
            raise ValueError('out_channels 必须能被 groups 整除')
        # 检查 padding_mode 是否合法
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError("padding_mode 必须为 {} 中的一种,但得到 padding_mode='{}'".format(
                valid_padding_modes, padding_mode))
        
        # 初始化模块参数
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))

        #################################### 初始化 D & W ###################################
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
        self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
        init.kaiming_uniform_(self.W, a=math.sqrt(5))

        if M * N > 1:
            self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
            init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
            self.D.data = torch.from_numpy(init_zero)

            eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
            d_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
            if self.D_mul % (M * N) != 0:  # 当 D_mul > M * N 时
                zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
                self.d_diag = Parameter(torch.cat([d_diag, zeros], dim=2), requires_grad=False)
            else:  # 当 D_mul = M * N 时
                self.d_diag = Parameter(d_diag, requires_grad=False)
        ##################################################################################################

        # 初始化偏置参数
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter('bias', None)

    # 返回模块配置的字符串表示形式
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        return s.format(**self.__dict__)

    # 重新设置状态
    def __setstate__(self, state):
        super(DOConv2d, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

    # 辅助函数,执行卷积操作
    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    # 前向传播函数
    def forward(self, input):
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
        if M * N > 1:
            ######################### 计算 DoW #################
            # (input_channels, D_mul, M * N)
            D = self.D + self.d_diag
            W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))

            # einsum 输出 (out_channels // groups, in_channels, M * N),
            # 重塑为
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
            #######################################################
        else:
            # 在这种情况下 D_mul == M * N
            # 从 (out_channels, in_channels // groups, D_mul) 重塑为 (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(self.W, DoW_shape)
        return self._conv_forward(input, DoW)

# 定义辅助函数
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse

# 定义辅助函数,将输入转化为二元元组
_pair = _ntuple(2)

四、实验结果

4.1Image Classification

4.2Semantic Segmentation

4.3Object Detection 

五、总结

DO-Conv是一种深度过参数化卷积层,是一种新颖、简单、通用的提高cnn性能的方法。除了提高现有cnn的训练和最终精度的实际意义之外,在推理阶段不引入额外的计算,我们设想其优势的揭示也可以鼓励进一步探索过度参数化作为网络架构设计的一个新维度。

在未来,对这一相当简单的方法进行理论理解,以在一系列应用中实现令人惊讶的非凡性能改进,将是有趣的。此外,我们希望扩大这些过度参数化卷积层可能有效的应用范围,并了解哪些超参数可以从中受益更多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1588455.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

外贸开发信必知技巧:高回复率不再是梦

外贸行业在Zoho的客户群体中占比较高。因为我们的国际化背景、丰富的产品组合、多语言多币种跨时区、高性价比等特点&#xff0c;成为外贸企业开展业务的选择。在和外贸客户沟通中&#xff0c;发现无论是外贸大拿还是新手小白&#xff0c;大家遇到一个共同的问题——发出去的开…

动态规划在矩阵链乘法中的应用:寻找最优括号化方案

动态规划在矩阵链乘法中的应用&#xff1a;寻找最优括号化方案 一、问题描述二、动态规划的基本概念三、矩阵链乘法问题的动态规划解法四、伪代码五、C语言代码示例六、计算括号化方案的数量七、结论 计算括号化方案的数量问题是计算机科学中的一个经典问题&#xff0c;它涉及到…

K8S node节点执行kubectl get pods报错

第一个问题是由第二个问题产生的&#xff0c;第二个问题也是最常见的 网上找的都是从master节点把文件复制过来&#xff0c;这样确实可以解决&#xff0c;但是麻烦&#xff0c;有一个node节点还好&#xff0c;如果有多个呢&#xff1f;每个都复制吗&#xff1f;下面是我从外网…

为什么要“挺”鸿蒙?

鸿蒙到底是什么&#xff1f; 随着5G、物联网等技术的快速发展&#xff0c;智能终端设备的应用场景也越来越广泛。为了满足不同设备间的互联互通需求&#xff0c;华为在2019年推出了自主研发的操作系统——鸿蒙OS。值得关注的是&#xff0c;这也是首款国产操作系统。 要了解鸿…

密码学 | 椭圆曲线 ECC 密码学入门(三)

目录 7 这一切意味着什么&#xff1f; 8 椭圆曲线密码学的应用 9 椭圆曲线密码学的缺点 10 展望未来 ⚠️ 原文地址&#xff1a;A (Relatively Easy To Understand) Primer on Elliptic Curve Cryptography ⚠️ 写在前面&#xff1a;本文属搬运博客&#xff0c;自己留…

YOLOv8草莓生长状态(灰叶病缺钙需要肥料)检测系统(python开发,带有训练模型,可以重新训练,并有Pyqt5界面可视化)

本次检测系统&#xff0c;不仅可以检测图片、视频或摄像头当中出现的草莓叶子是否有灰叶病&#xff0c;还可以检测出草莓叶是否缺钙、是否需要施肥等状态。基于最新的YOLO-v8训练的草莓生长状态检测模型和完整的python代码以及草莓的训练数据&#xff0c;下载后即可运行&#x…

CentOS7安装MySQL8.0教程

环境介绍 操作系统&#xff1a;Centos7.6 MySQL版本&#xff1a; 8.0.27 只要是8.0.*版本&#xff0c;那就可以按照本文说明安装 一、安装前准备 1、卸载MariaDB 安装MySQL的话会和MariaDB的文件冲突&#xff0c;所以需要先卸载掉MariaDB。 1.1、查看是否安装mariadb rpm -…

如何在Windows通过固定tcp公网地址ssh远程访问本地Kali Linux

文章目录 1. 启动kali ssh 服务2. kali 安装cpolar 内网穿透3. 配置kali ssh公网地址4. 远程连接5. 固定连接SSH公网地址6. SSH固定地址连接测试 本文主要介绍如何在Kali系统编辑SSH配置文件并结合cpolar内网穿透软件&#xff0c;实现公网环境ssh远程连接本地kali系统。 1. 启…

bilibili PC客户端架构设计——基于Electron

众所周知&#xff0c;bilibili是个学习的网站&#xff0c;网页端和粉版移动端都非常的好用&#xff0c;不过&#xff0c;相对其它平台来说bilibili的PC客户端也算是大器晚成了。在有些场景PC客户端的优势也是显而易见的&#xff0c;比如&#xff0c;跓留电脑桌面的快捷、独立的…

实战纪实 | 编辑器漏洞之Ueditor-任意文件上传漏洞 (老洞新谈)

UEditor 任意文件上传漏洞 前言 前段时间在做某政府单位的项目的时候发现存在该漏洞&#xff0c;虽然是一个老洞&#xff0c;但这也是容易被忽视&#xff0c;且能快速拿到shell的漏洞&#xff0c;在利用方式上有一些不一样的心得&#xff0c;希望能帮助到一些还不太了解的小伙…

JAVA中如何确保N个线程可以访问N个资源,但同时又不导致死锁?

使用多线程的时候&#xff0c;一种非常简单的避免死锁的方式&#xff1a;指定获取锁的顺序&#xff0c;并强制现场按照指定的顺序获取锁。因此&#xff0c;所有线程按照同样的顺序加锁和释放就不会出现死锁。 请问什么是死锁(deadlock)? 竞争不可抢占资源形成死锁 如果有两…

工业采集网关有何功能?可以带来哪些价值?-天拓四方

一、行业背景 随着工业领域的快速发展&#xff0c;尤其是智能制造的兴起&#xff0c;工业自动化、智能化和数字化已成为工业转型升级的必然趋势。在这一进程中&#xff0c;工业数据采集和处理扮演着至关重要的角色。作为连接工业现场设备、传感器与上层管理系统的桥梁&#xf…

2024年环境预防与新材料国际会议 (EPNM 2024)

2024年环境预防与新材料国际会议 (EPNM 2024) 2024 International Conference on Environmental Prevention and New Materials 【会议简介】 2024年环境预防与新材料国际会议即将在张家界召开。本次会议旨在汇聚全球环境预防与新材料领域的专家学者&#xff0c;共同探讨环境…

【MATLAB源码-第37期】matlab基于STBC(空时分组码)的MIMO系统误码率仿真。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 空时分组码&#xff08;Space-Time Block Code&#xff0c;简称STBC&#xff09;是一种在多输入多输出&#xff08;MIMO&#xff09;无线通信系统中用于提高数据传输可靠性的编码技术。MIMO技术利用多个发射和接收天线来同时…

RA4000CE为汽车动力传动系统提供解决方案

目前汽车电气化的水平越来越高&#xff0c;其中比较显著的一个发展方向就是将发动机管理系统和自动变速器控制系统&#xff0c;集成为动力传动系统的综合控制(PCM)。作为汽车动力的核心部件&#xff0c;通过电子系统的运用&#xff0c;将外部多个传感器和执行环节的数据进行统一…

私有化即时通讯软件,WorkPlus提供的私有化、安全通讯解决方案

在当今信息化快速发展的时代&#xff0c;安全问题已经成为各行各业关注的焦点。特别是在金融、政府单位和芯片等关键行业&#xff0c;信息安全的重要性不言而喻。这些行业涉及到大量的敏感数据和关键信息&#xff0c;一旦发生泄露&#xff0c;可能会对国家安全、企业利益甚至个…

上位机图像处理和嵌入式模块部署(改进的qmacvisual动态插件卸载)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 前面我们讨论过&#xff0c;qmacvisual虽然提供了很多的功能&#xff0c;包括的种类很多&#xff0c;但是总有一些功能是客户希望定制的。这些都是…

局域网内部使用的视频会议系统推荐

随着远程办公的普及和全球化的发展趋势&#xff0c;企业需要一个高效、灵活、安全的音视频会议解决方案&#xff0c;以支持远程办公的协同工作、跨地域沟通等需要。私有化音视频会议就是一个适合企业自身部署的解决方案。它不仅能够满足企业信息管理和保密的需求&#xff0c;而…

关于DNS解析那些事儿,了解DNS解析的基础知识

DNS&#xff0c;全称Domain Name System域名系统&#xff0c;是一个将域名和IP地址相互映射的一个分布于世界各地的分布式数据库&#xff0c;而DNS解析就是将域名转换为IP地址的过程&#xff0c;使人们可以轻松实现通过域名访问网站。DNS解析是网站建设非常关键的一步&#xff…