ICLR 2022)ODConv:即插即用的动态卷积 (附代码)

news2024/11/28 23:33:31

论文地址:Omni-Dimensional Dynamic Convolution | OpenReview

代码地址:https://github.com/OSVAI/ODConv/blob/main/modules/odconv.py

1.是什么?

ODConv是一种动态卷积算法,它的原理是在卷积过程中,根据输入数据的特征动态地调整卷积核的形状和大小,以适应不同的输入数据。具体来说,ODConv通过引入一个可学习的形变模块,根据输入数据的特征动态地调整卷积核的形状和大小,从而提高了卷积神经网络的性能。与CondConv和DyConv不同,ODConv不仅考虑了空间维度、输入通道维度和输出通道维度,还考虑了卷积核的形状和大小,因此可以更好地适应不同的输入数据。

2.为什么?

常规卷积只有一个静态卷积核且与输入样本无关。对于动态卷积来说,它对多个卷积核进行线性加权,而加权值则与输入有关,这就使得动态卷积具有输入依赖性。它可以描述如下:

尽管动态卷积的定义很简单,但CondConv与DyConv的实现是不相同的,主要体现在计算a_{wi}的结构\pi_{wi}(x)训练策略以及实施动态卷积的层,这些实现上的差异导致了不同的模型精度、模型大小以及推理效率。

  • 两者均为\pi_{wi}(x)采用了类SE架构,但CondConv采用的是Sigmoid,而DyConv采用的是Softmax;
  • DyConv采用的退化策略进行训练以抑制Softmax的one-hot输出;
  • 对于他们嵌入的CNN架构,CondConv替换了最后几个模块的卷积与全连接层,而DyConv则对除第一个卷积外的其他卷积均进行了替换。

根据动态卷积的公式来看,动态卷积有两个基本元素:

  • 卷积核{W_{1},..,W_{n}}
  • 用于计算注意力{a_{w1,...,a_{wn}}}的注意力函数\phi _{wi}(x)

给定n个卷积核,其对应的核空间有以下四个维度:

  • 空间核尺寸k×k;
  • 输入通道数c_{in}
  • 输出通道数c_{out}
  • 卷积核数量n

然而,对于CondConv与DyConv来说,\phi _{wi}(x)均采用单个注意力标量a_{wi},这就意味着它的的输出滤波器W_{i}^{m}R^{k*k*c_{in}}对于输入具有相同的注意力值。换句话说,卷积核 W_{i}的空间维度、输入通道维度以及输出通道维度均被CondConv与DyConv所忽视了。这就导致了关于核空间的粗糙探索。这可能就是为什么CondConv与DyConv对于大网络的性能增益较低的原因。

此外,相比常规卷积,动态卷积的卷积核参数往往是其n倍。比如CondConv中的n=8,DyConv中的n=4。当动态卷积使用过多时无疑会极大程度提升模型大小。我们发现:当 移除掉CondConv/DyConv中的注意力机制(即a_{wi}=1)后,其性能提升接近于零。比如,对于ResNet18,其性能增益从1.78%/2.51%下降到了0.08%/0.14。

上述发现意味着:动态卷积中的注意力机制起关键性作用,更有效的设计也许可以在模型精度与大小之间得到更好的平衡。

一定程度上讲,ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。作为一种“即插即用”的操作,它可以轻易的嵌入到现有CNN网络中。ImageNet分类与COCO检测任务上的实验验证了所提ODConv的优异性:即可提升大模型的性能,又可提升轻量型模型的性能,实乃万金油是也!值得一提的是,受益于其改进的特征提取能力,ODConv搭配一个卷积核时仍可取得与现有多核动态卷积相当甚至更优的性能

3 怎么样?

3.1 网络结构

基于前述讨论,ODConv通过并行策略引入一种多维注意力机制以对卷积核空间的四个维度学习更灵活的注意力。上图给出CondConv、DyConv以及ODConv的差异图。

延续动态卷积的定义,ODConv可以描述成如下形式:

其中,a_{wi}表示卷积核W_{i}的注意力标量,a_{si}\epsilon R^{k*k},a_{ci}\epsilon R^{c_{in}},a_{fi}\epsilon R^{c_{out}}表示新引入的三个注意力,分别沿空域维度、输入通道维度以及输出通道维度。这四个注意力采用多头注意力模块\pi_{i}(x)计算得到。

在ODConv中,对于卷积核W_{i}a_{si}对k*k空域位置上的卷积参数赋予不用的注意力值,见上图a;a_{ci}对不同输入通道的卷积滤波器赋予不同的注意力值,见上图b;a_{fi}对不同输出通道的卷积滤波器赋予不同的注意力值,见上图c;而a_{wi}则对n个整体卷积核赋予不同的值,见上图d。

原则上来讲,这四种类型的注意力是互补的,通过渐进式对卷积W_{i}沿位置、通道、滤波器以及核等维度乘以不同的注意力将使得卷积操作对于输入存在各个维度的差异性,提供更好的性能以捕获丰富上下文信息。因此,ODCOnv可以大幅提升卷积的特征提取能力;更重要的是,采用更少卷积核的ODConv可以取得与CondConv、DyConv相当甚至更优的性能。

对比前面两种动态卷积的公式可以发现:ODConv是一种更广义的动态卷积。此外,当设置n=1,a_{s1}=a_{c1}=a_{w1}=1时,ODConv则退化为仅具有滤波器层面的注意力,基于输入对卷积滤波器进行调制后再进行卷积,类似于SE。故SE是ODConv的一个特例。

那么如何实现ODConv的四种类型的注意力值呢?延续CondConv与DyConv,我们同样采用SE风格的注意力模块,但使其具有多个头以计算多种类型注意力,整体结构见上图。具体来说,对于输入先通过GAP收缩为长度为c_{in}的特征向量,然后采用FC与四个头生成不同类型的注意力值。对于四个头,其维度分别为k*k,c_{in}×1,c_{out}×1,n×1。

在训练方面,我们采用了DyConv中的退化策略以加速训练。在具体架构嵌入方面,我们参考DyConv对除第一个卷积外的其他所有卷积进行替换。

3.2 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd


class Attention(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
        super(Attention, self).__init__()
        attention_channel = max(int(in_planes * reduction), min_channel)
        self.kernel_size = kernel_size
        self.kernel_num = kernel_num
        self.temperature = 1.0

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)
        self.bn = nn.BatchNorm2d(attention_channel)
        self.relu = nn.ReLU(inplace=True)

        self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
        self.func_channel = self.get_channel_attention

        if in_planes == groups and in_planes == out_planes:  # depth-wise convolution
            self.func_filter = self.skip
        else:
            self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
            self.func_filter = self.get_filter_attention

        if kernel_size == 1:  # point-wise convolution
            self.func_spatial = self.skip
        else:
            self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
            self.func_spatial = self.get_spatial_attention

        if kernel_num == 1:
            self.func_kernel = self.skip
        else:
            self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
            self.func_kernel = self.get_kernel_attention

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def update_temperature(self, temperature):
        self.temperature = temperature

    @staticmethod
    def skip(_):
        return 1.0

    def get_channel_attention(self, x):
        channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return channel_attention

    def get_filter_attention(self, x):
        filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return filter_attention

    def get_spatial_attention(self, x):
        spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
        spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
        return spatial_attention

    def get_kernel_attention(self, x):
        kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
        kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
        return kernel_attention

    def forward(self, x):
        x = self.avgpool(x)
        x = self.fc(x)
        x = self.bn(x)
        x = self.relu(x)
        return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)


class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 reduction=0.0625, kernel_num=4):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
                                   requires_grad=True)
        self._initialize_weights()

        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common

    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')

    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)

    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output

    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output

    def forward(self, x):
        return self._forward_impl(x)

 参考:

ODConv详解

ICLR 2022 | 涨点神器!Intel提出ODConv:即插即用的动态卷积

致敬CondConv!Intel提出即插即用的“万金油”动态卷积ODConv

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1160042.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第五章 I/O管理 十一、减少磁盘延迟时间的方法

目录 一、交替编号 1、定义: 二、磁盘地址结构的设计 三、错位命名 四、总结 一、交替编号 1、定义: 若采用交替编号的策略,即让逻辑上相邻的扇区在物理上有一定的间隔,可以使读取连续的逻辑扇区所需要的延迟时间更小。 二、…

AI智能语音识别模块(二)——基于Arduino的语音控制MP3播放器

文章目录 简介离线语音控制模块Mini MP3模块0.96寸 OLED模块实验准备安装库接线定义主要程序实验效果注意事项总结 简介 在前面一篇文章里我们对AI智能语音识别模块进行了介绍,并对离线语音模组下载固件的过程进行了一个简单描述,不知道大家还记不记得&…

Nginx+keepalived实现七层的负载均衡

1.keepalived VRRP 介绍 keepalived是什么? keepalived是集群管理中保证集群高可用的一个服务软件,用来防止单点故障。 keepalived工作原理 keepalived是以VRRP协议为实现基础的,VRRP全称Virtual Router Redundancy Protocol&…

视频AI智剪,批量剪辑助力高效创作

视频AI智剪是一种基于人工智能技术的自动化剪辑工具,它可以自动对视频素材进行分析、筛选、剪辑和优化,从而生成一部高质量的视频作品。而批量剪辑则是指利用AI智剪技术,同时对多个视频素材进行自动化剪辑,大大提高了剪辑效率。本…

NI‑9237国产化50 kS/s/ch,桥模拟输入,4通道C系列应变/桥输入模块

50 kS/s/ch,桥模拟输入,4通道C系列应变/桥输入模块 NI‑9237提供了所有的信号调理功能来实现多达四个基于桥的传感器的供电和测量。该模块提供通道间零相位延迟的应变或负载测量。它还具有60 VDC隔离和1,000 Vrms瞬态隔离,提供高…

操作系统 day03(运行机制)

机器指令 二进制机器指令就是处理器(CPU)能识别、执行的最基本命令 程序运行的过程就是CPU执行一条一条的机器指令的过程 应用程序和内核程序 操作系统的最重要角色是:系统资源的管理者,而操作系统的对系统资源的管理工作就是…

Mysql系列 -索引模型数据结构

索引就是排好序的数据结构,可以帮助我们快速的查找到数据,那么底层的数据到底是如何存储的呢? 为什么InnoDB 用的是Btree 存储结构? 大家可以看看这个可视化的网站 数据结构和算法的可视化工具 可以看到数据结构里面有链表&…

JVM虚拟机:垃圾回收算法和垃圾回收器之间的关系

GC垃圾回收算法 在前面的课程中我们学习了GC垃圾回收算法,分别为: 引用回收算法 复制算法 标记清除算法 标记整理算法 这些垃圾回收算法是理论,有多种垃圾回收器可以实现这些理论。目前为止没有最完美的垃圾回收器,只能针对具体的情况选择最合适的垃圾回收器,进行分代收集…

高校教务系统登录页面JS分析——天津大学

高校教务系统密码加密逻辑及JS逆向 本文将介绍天津大学教务系统的密码加密逻辑以及使用JavaScript进行逆向分析的过程。通过本文,你将了解到密码加密的基本概念、常用加密算法以及如何通过逆向分析来破解密码。 本文仅供交流学习,勿用于非法用途。 一、密…

MFC网络通信-Udp服务端

目录 1、UI的布局 2、代码的实现: (1)、自定义的子类CServerSocket (2)、重写OnReceive事件 (3)、在CUdpServerDlg类中处理 (4)、在OnInitDialog函数中 &#xff0…

Leetcode刷题详解——Pow(x, n)

1. 题目链接:50. Pow(x, n) 2. 题目描述: 实现 pow(x, n) ,即计算 x 的整数 n 次幂函数(即,xn )。 示例 1: 输入:x 2.00000, n 10 输出:1024.00000示例 2:…

Mozilla Firefox 119 现已可供下载

Mozilla Firefox 119 开源网络浏览器现在可以下载了,是时候先看看它的新功能和改进了。 Firefox 119 改进了 Firefox View 功能,现在可以提供更多内容,如最近关闭的标签页和浏览历史,你可以按日期或网站排序,还支持查…

项目知识点总结-住房图片信息添加-Excel导出

(1)住房信息添加 Controller: RequestMapping("/add")public String add(Home home, Model model) throws IOException{String sqlPath null;//定义文件保存的本地路径String localPath"D:\\AnZhuang\\Java项目\\选题\\Xin-…

YOLOv5 分类模型的加载

YOLOv5 分类模型的加载 flyfish 版本 6.2 yolov5s分类模型 python classify/train.py --model resnet18.pt --data cifar100 --epochs 5 --img 224resnet18模型 python classify/train.py --model resnet18.pt --data cifar100 --epochs 5 --img 128导出模型看一下结构 p…

基于SSM的理发店管理系统

基于SSM的理发店管理系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringSpringMVCMyBatis工具:IDEA/Ecilpse、Navicat、Maven 系统展示 主页 公告信息 管理员界面 用户界面 摘要 基于SSM(Spring、Spring MVC、…

水库大坝可视化智能远程监管方案,助力安全监测智能巡检

一、背景需求 水库大坝作为防洪度汛的重要设施,其安全问题直接关系到人民群众的生命财产安全。因此,必须加强对大坝水库的安全管理,对水库除险加固和运行管护要消除存量隐患,实现常态化管理,同时要配套完善重点小型水…

windows使用FindWindow函数查找窗口句柄

理解什么是句柄? 对于“句柄”,之前一直停留在一知半解的认识层面,也说不清具体概念,只知道它是一个标识符,用来标记对象或者说某个东西的。只知其名不知其意。目前学习windows编程,对“句柄”做一个完整的…

GPT与人类共生:解析AI助手的兴起

随着GPT模型的崭新应用,如百度的​1​和CSDN的​2​,以及AI助手的普及,人们开始讨论AI对就业市场和互联网公司的潜在影响。本文将探讨GPT和AI助手的共生关系,以及我们如何使用它们,以及使用的平台和动机。 GPT和AI助手…

【AI视野·今日Robot 机器人论文速览 第六十一期】Tue, 24 Oct 2023

AI视野今日CS.Robotics 机器人学论文速览 Tue, 24 Oct 2023 Totally 50 papers 👉上期速览✈更多精彩请移步主页 Daily Robotics Papers Robot Fine-Tuning Made Easy: Pre-Training Rewards and Policies for Autonomous Real-World Reinforcement Learning Autho…

常用编程语言排行与应用场景汇总(2023.10)

文章目录 编程语言排行一、Python二、C三、C四、Java五、C#六、JavaScript七、VB(Visual Basic)八、PHP九、SQL十、ASM(Assembly Language)十一、Go十二、Scratch十三、Delphi/Object Pascal十四、MATLAB十五、Swift十六、Fortran…