Selective Kernel Networks论文总结和代码实现

news2024/11/22 14:59:58

论文:https://arxiv.org/abs/1903.06586?context=cs

中文版:(CVPR-2019)选择性的内核网络_sk卷积

源码:GitHub - implus/SKNet: Code for our CVPR 2019 paper: Selective Kernel Networks

目录

一、论文出发点

二、论文主要工作

三、SK模块的具体实现

四、实验

五、总结

六、代码实现


SKNet是SENet的加强版,是attention机制中的与SE同等地位的一个模块,可以方便地添加到现有的网络模型中,对分类问题,分割问题有一定的提升。

一、论文出发点

在神经科学界,众所周知,视觉皮层神经元的感受野大小受刺激的调节,这在构建CNN时很少被考虑。作者设计一个叫做选择性核(SK)单元的构件,其中具有不同核大小的多个分支在这些分支的信息指导下,使用softmax注意力进行融合。对这些分支的不同关注产生了融合层中神经元的有效感受野的不同大小。

引用博文(SKNet)Selective Kernel Network 解析 - 知乎中的一句话:虽然在此之前,自适应感受野大小的机制还没有人提出,或者说很少被考虑到。但是有一点是存在共识的,那就是结合不同感受野大小能够提升神经元的适应能力。

二、论文主要工作

1.提出了一种非线性方法的聚合方法,从多个内核中聚合信息,实现神经元的自适应RF大小。

2.引入了 “选择性内核”(SK)卷积,它由三组运算符组成。分裂、融合和选择。

三、SK模块的具体实现

SK模块的整体结构图:

在该例中,只有两个分支,但是实际上是可以扩展到多个分支的情况。

​​主要通过三个运算符实现 SK 卷积——Split、Fuse 和 Select。

1. Split

​​对输入的特征图\mathbf{X} \in \mathbb{R}^{H^{\prime} \times W^{\prime} \times C^{\prime}}​​,分别进行两次卷积变换\tilde{\mathcal{F}}​​、\widehat{\mathcal{F}}​​。

(1)\tilde{\mathcal{F}}: \mathbf{X} \rightarrow \tilde{\mathbf{U}} \in\mathbb{R}^{H \times W \times C}​​:\tilde{\mathcal{F}}​​过程中卷积核大小为3x3,特征图X经过卷积变换为\tilde{\mathbf{U}}​​。

(2)\widehat{\mathcal{F}}: \mathbf{X} \rightarrow \widehat{\mathbf{U}} \in \mathbb{R}^{H \times W \times C}​​:\widehat{\mathcal{F}}​​过程中卷积核大小为5x5,特征图X经过卷积变换为\widehat{\mathbf{U}}​​。

特征图X经过Split操作,输出两个新的特征图\tilde{\mathbf{U}}​​、\widehat{\mathbf{U}}​​。

2. Fuse

目的:使神经元能够根据刺激内容自适应地调整其RF大小,因此需要使用门来控制来自多个分支的信息流,这些分支携带不同规模的信息进入下一层的神经元

方法:使用门整合来自所有分支的信息,也就是将来自多个分支的特征进行融合

步骤

(1)\mathbf{U}=\tilde{\mathbf{U}}+\widehat{\mathbf{U}}​​:特征图\tilde{\mathbf{U}}​​、\widehat{\mathbf{U}}​​相加,得到新的特征图U,U中融合了多个感受野的信息

(2)\mathcal{F}_{g p}​​:U通过全局平均池化生成\mathbf{s} \in \mathbb{R}^{C}​​来嵌入全局信息,\mathbf{s} \in \mathbb{R}^{C\times 1}​​是一个有C个元素的列向量。对应的算子公式如下:

​​通过将U的第C个feature map缩小空间尺寸H×W来计算得到s的第C个元素。

(3)\mathcal{F}_{f c}​​:通过一个简单的全连接 (fc) 层,将向量s压缩为特征向量\mathbf{z} \in \mathbb{R}^{d\times 1}​​。对应的算子公式如下:

​​δ是ReLU函数,\mathcal{B}​​表示批量归一化,W\in \mathbb{R}^{d \times C}​​为权重矩阵。

很明显这里,向量s先是通过了一个全连接层将c个通道变成d个​,减少参数量,再经过批量归一化函数,最后通过ReLU函数得到特征向量z

(4)论文中还研究了 d 对模型效率的影响,其算子公式如下

​​这里C/r,可以看出SENet论文的痕迹,目的与SENet论文中的一致,因为一个全连接层无法同时应用relu和sigmoid两个非线性函数,但是两者又缺一不可。为了减少参数,所以设置了r比率。

3. Select

​​目的:在紧凑的特征描述符z的引导下,利用跨通道的软注意来自适应地选择不同的信息空间尺度。

方法:softmax 运算符应用于通道数,得到各分支上特征图的软注意力向量。这里的特征图示例为\tilde{\mathbf{U}}​​、\widehat{\mathbf{U}}​​,因此得到软注意力向量为\tilde{\mathbf{U}}​​、\widehat{\mathbf{U}}​​的软注意力。

步骤:

(1)分别两次对特征向量z使用softmax函数得到软注意力向量\mathbf{a}, \mathbf{b}​​ ,这时向量中的每一个数值对应一个channel的分数,代表其channel的重要程度,同时将\mathbf{a}, \mathbf{b}​​再次变回了c个维度,这里又可以看出SENet论文的痕迹。算子公式如下:

​​其中\mathbf{A}, \mathbf{B} \in \mathbb{R}^{C \times d}​​,\mathbf{a}, \mathbf{b}​​ 分别表示\widetilde{\mathbf{U}}​​和\widehat{\mathbf{U}}​​的软注意力向量。请注意,\mathbf{A}_{c} \in \mathbb{R}^{1 \times d}​​是\mathbf{A}​​的第 c行,a_{c}​​ ​是\mathbf{a}​​的第c个元素,\mathbf{B}_{c}​​​和b_{c}​​也是如此。

(2)各特征图与对应的注意力权重相乘,再相加,最终得到特征图\mathbf{V}​​。

​其中\mathbf{V}=\left[\mathbf{V}_{1}, \mathbf{V}_{2}, \ldots, \mathbf{V}_{C}\right]​​,\mathbf{V}_{c} \in \mathbb{R}^{H \times W}​​。

这个特征图V是通过各种内核上的注意力权重获得的,融合了多个感受野的信息,具有了多尺度的信息。

关于特征图V最终达到的效果,这里博文【CV中的Attention机制】SKNet-SENet的提升版 - 知乎总结比较好,需要进一步理解,可参考。

四、实验

五、总结

在SKNet中可以看到许多SE模块的痕迹,总的来说SK模块先是通过不同的卷积核将输入特征图进行划分为几个不同的子特征图,再将子特征图相加融合,再经过压缩和softmax处理,得到各种内核上的注意力权重,再与对应子特征图相乘,再相加得到最终的输出特征图V,这个特征图既融合了多个感受野的信息,具有了多尺度的信息。

六、代码实现

1.SK卷积

class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1, L=32):
        """ Constructor
        Args:
            features: 输入通道维度
            WH: 输入特征图的空间维度
            M: 分支的数量
            G: 卷积组的数量
            r: 计算d,向量s的压缩倍数,C/r
            stride: 步长,默认为1
            L: 矢量z的最小维度,默认为32
        """
        super(SKConv, self).__init__()
        d = max(int(features / r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        # 使用不同kernel size的卷积,增加不同的感受野
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        # 全局平均池化
        self.gap = nn.AvgPool2d(int(WH / stride))
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        # 全连接层
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        ''' Split操作'''
        for i, conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)

        ''' Fuse操作'''
        fea_U = torch.sum(feas, dim=1)
        fea_s = self.gap(fea_U).squeeze_()
        fea_z = self.fc(fea_s)

        ''' Select操作'''
        for i, fc in enumerate(self.fcs):
            # fc-->d*c维
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
        # 计算attention权重
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        # 最后一步,各特征图与对应的注意力权重相乘,得到输出特征图V
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

2.SKNet完整源码

SKConv替代了ResNext中3*3卷积部分,用两个或多个不同卷积核大小的卷积操作加学习通道权重全连接层替代。

import torch
from torch import nn

# conv = SKConv(64, 32, 3, 8, 2)
class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1, L=32):
        """ Constructor
        Args:
            features: 输入通道维度
            WH: 输入特征图的空间维度
            M: 分支的数量
            G: 卷积组的数量
            r: 计算d,向量s的压缩倍数,C/r
            stride: 步长,默认为1
            L: 矢量z的最小维度,默认为32
        """
        super(SKConv, self).__init__()
        d = max(int(features / r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        # 使用不同kernel size的卷积,增加不同的感受野
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        # 全局平均池化
        self.gap = nn.AvgPool2d(int(WH / stride))
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        # 全连接层
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        ''' Split操作'''
        for i, conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)

        ''' Fuse操作'''
        fea_U = torch.sum(feas, dim=1)
        fea_s = self.gap(fea_U).squeeze_()
        fea_z = self.fc(fea_s)

        ''' Select操作'''
        for i, fc in enumerate(self.fcs):
            # fc-->d*c维
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
        # 计算attention权重
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        # 最后一步,各特征图与对应的注意力权重相乘,得到输出特征图V
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v


class SKUnit(nn.Module):
    def __init__(self, in_features, out_features, WH, M, G, r, mid_features=None, stride=1, L=32):
        """ Constructor
        Args:
            in_features: 输入通道维度
            out_features: 输出通道维度
            WH: 输入特征图的空间维度
            M: 分支的数量
            G: 卷积组的数量
            r: 计算d,论文中向量s的压缩倍数,C/r
            mid_features: 步长不为1的中间卷积的通道维度,默认为out_features/2
            stride: 步长,默认为1
            L: 论文中矢量z的最小维度,默认为32
        """
        super(SKUnit, self).__init__()
        if mid_features is None:
            mid_features = int(out_features / 2)
        self.feas = nn.Sequential(
            nn.Conv2d(in_features, mid_features, 1, stride=1),
            nn.BatchNorm2d(mid_features),
            SKConv替代了ResNext中3*3卷积部分
            SKConv(mid_features, WH, M, G, r, stride=stride, L=L),
            nn.BatchNorm2d(mid_features),
            nn.Conv2d(mid_features, out_features, 1, stride=1),
            nn.BatchNorm2d(out_features)
        )
        if in_features == out_features:  # when dim not change, in could be added diectly to out
            self.shortcut = nn.Sequential()
        else:  # when dim not change, in should also change dim to be added to out
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_features, out_features, 1, stride=stride),
                nn.BatchNorm2d(out_features)
            )

    def forward(self, x):
        fea = self.feas(x)
        return fea + self.shortcut(x)


class SKNet(nn.Module):
    def __init__(self, class_num):
        super(SKNet, self).__init__()
        self.basic_conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        )  # 32x32
        self.stage_1 = nn.Sequential(
            SKUnit(64, 256, 32, 2, 8, 2, stride=2),
            nn.ReLU(),
            SKUnit(256, 256, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(256, 256, 32, 2, 8, 2),
            nn.ReLU()
        )  # 32x32
        self.stage_2 = nn.Sequential(
            SKUnit(256, 512, 32, 2, 8, 2, stride=2),
            nn.ReLU(),
            SKUnit(512, 512, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(512, 512, 32, 2, 8, 2),
            nn.ReLU()
        )  # 16x16
        self.stage_3 = nn.Sequential(
            SKUnit(512, 1024, 32, 2, 8, 2, stride=2),
            nn.ReLU(),
            SKUnit(1024, 1024, 32, 2, 8, 2),
            nn.ReLU(),
            SKUnit(1024, 1024, 32, 2, 8, 2),
            nn.ReLU()
        )  # 8x8
        self.pool = nn.AvgPool2d(8)
        self.classifier = nn.Sequential(
            nn.Linear(1024, class_num),
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        fea = self.basic_conv(x)
        fea = self.stage_1(fea)
        fea = self.stage_2(fea)
        fea = self.stage_3(fea)
        fea = self.pool(fea)
        fea = torch.squeeze(fea)
        fea = self.classifier(fea)
        return fea


if __name__ == '__main__':
    # 随机生成8个(64,32,32)的特征图
    x = torch.rand(8, 64, 32, 32)
    conv = SKConv(64, 32, 3, 8, 2)
    out = conv(x)
    criterion = nn.L1Loss()
    loss = criterion(out, x)
    loss.backward()
    # 最终输出特征图V的size和损失值
    print('out shape : {}'.format(out.shape))
    print('loss value : {}'.format(loss))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/575580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

洛谷——树

洛谷——树 文章目录 洛谷——树树的重心会议题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示数据范围 思路 树的直径【XR-3】核心城市题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示思路 [NOI2003] 逃学的小孩题目描述输入格式输出格式样例 #1样例…

Cocos creator实现《滑雪趣挑战》滑雪小游戏资源及代码

Cocos creator实现《滑雪趣挑战》滑雪小游戏资源及代码 最近在学习Cocos Creator,作为新手,刚刚开始学习Cocos Creator,上线了两个微信小游戏,刚刚入门,这里记录一下《滑雪趣挑战》实现及上线过程的过程。 ](https://…

vue实现深拷贝的方法

在 vue中,深拷贝是一个很有用的功能,在不改变原来对象状态的情况下,进行对象的复制。 但要实现深拷贝,需要两个对象具有相同的属性。如果两个对象不同,深拷贝也不能实现。 1.我们将变量A的属性赋给变量B,但…

springboot+java医院门诊挂号系统设计与实现ssm008

本课题的目标是使医院门诊信息管理清晰化,透明化,便于操作,易于管理。通过功能模块的优化组合实现不同的管理细节,使管理过程实现最大程度的自动化与信息化,并能自动对人工操作环节进行复查,使医院门诊挂号系统出错率降至最低。 主…

3、mqtt客户端演示(MQTT通信协议(mosquitto)发布订阅 C语言实现)

可订阅可发布模式 具体代码 客户端1代码&#xff1a;pub.c #include <stdio.h> #include <stdlib.h> #include <mosquitto.h> #include <string.h>#define HOST "localhost" #define PORT 1883 #define KEEP_ALIVE 60 #define MSG_MAX_S…

ChatGPT提示词工程进阶教学

ChatGPT提示词工程 1 两种大型语言模型LLM1.1 基础大模型&#xff08;base LLM&#xff09;1.2 指令调优大模型(Instruction Tuned LLM) 2 如何更清晰、具体地书写提示词2.1 在提示词中使用“定界符”2.2 向模型请求结构化的输出2.3 要求模型检查任务条件是否满足2.4 输入多范例…

uCOSii中的互斥信号量

uCOSii中的互斥信号量 一、互斥型信号量项管理 (MUTUAL EXCLUSION SEMAPHORE MANAGEMENT) OSMutexAccept() 无条件等待地获取互斥型信号量 OSMutexCreate() 建立并初始化一个互斥型信号量 OSMutexDel() 删除互斥型信号量 OSMutexPend() 等待一个互斥型信号量 OSMutexPost…

扬帆起航——Qt自定义控件介绍

文章目录 前言自定义控件的定义自定义控件的好处如何实现自定义控件实现没有自带的控件 如何使用自定义控件测试和优化常见的自定义控件总结 前言 Qt 提供了丰富的控件、工具和库&#xff0c;可以帮助开发人员快速创建现代化的跨平台应用程序。但是对于某些特殊的需求&#xf…

【数据结构】冒泡,快速,直接插入,归并,选择排序

&#x1f38a;专栏【数据结构】 &#x1f354;喜欢的诗句&#xff1a;更喜岷山千里雪 三军过后尽开颜。 &#x1f386;音乐分享【Dream It Possible】 大一同学小吉&#xff0c;欢迎并且感谢大家指出我的问题&#x1f970; 目录 &#x1f381;冒泡排序 &#x1f3f3;️‍&…

CentOS7.4安装OpenVPN

系统环境 [rootvpn ~]# cat /etc/redhat-release CentOS Linux release 7.4.1708 (Core) 一. 准备工作 [rootvpn ~]# yum -y install openssl-devel openssl pam pam-devel lzo lzo-devel pkcs11-helper pkcs11-helper-devel 二. 安装OpenVPN服务 1. 下载openvpn源码包 [r…

【计算机网络 - 第六章】链路层

目录 一、概述 1、数据链路层提供的服务&#xff1f; 二、差错检测 1、奇偶校验 2、循环冗余校验CRC 三、多路访问链路和协议 1、概述 &#xff08;1&#xff09;多路访问协议 2、信道划分协议 ① 频分多路复用FDM ② 时分多路复用TDM ③ 波分多路复用WDM ④ 码分…

更好看的国产蓝牙耳机,音质也没问题,哈氪零度青春版体验

夏天躲在空调房里戴着耳机听音乐、玩游戏是很多人的日常&#xff0c;这两年国产耳机做得越来越好了&#xff0c;设计也很有新意&#xff0c;像是我现在用的这款哈氪零度青春版&#xff0c;就采用了一种冰封造型设计&#xff0c;视觉效果很新颖&#xff0c;看起来很有立体感&…

【一个简单的前后端交互页面】

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 客户端与服务器之间的通信流程 理解当前案例…

chatgpt赋能python:Python文件拆分技巧详解

Python 文件拆分技巧详解 随着数据量的不断增大&#xff0c;我们经常需要处理非常大的数据文件&#xff0c;这时候就需要用到文件拆分技巧。在Python中&#xff0c;文件拆分可以帮助我们提高数据处理的效率&#xff0c;这是一个非常实用的技巧。在本篇文章中&#xff0c;我们将…

奇巴布Feed流性能优化

01 项目背景 “爱奇艺奇巴布”是爱奇艺为0-8岁孩子和家长定制化设计的寓教于乐平台&#xff0c;为儿童量身打造精致的观看体验&#xff0c;精彩内容解锁寓教于乐新方式。为儿童提供优质动画内容的同时&#xff0c;我们更关注APP用户体验。在产品交互设计上我们立足儿童视角&…

抖音SEO矩阵系统开发分享及搭建流程

目录 产品功能亮点 产品介绍及开发背景 开发要求及实现流程 产品功能亮点 1. 支持多账号多平台一键 授权管理 2.支持矩阵视频批量剪辑&#xff0c;批量发布 3. 多平台关键词布局&#xff0c;提升企业及产品曝光 4. 评论区关键词自动回复&#xff0c;意向线索智能挖掘 5…

RTOS专栏(一) —— rt-thread简单介绍和qemu使用

本期主题&#xff1a; 简单介绍rt-thread介绍qemu和rt-thread怎么配合使用qemu的简单例子 rt-thread & qemu 1.rt-thread介绍2.qemu介绍3.搭建rt-thread和qemu开发环境4.简单例子 1.rt-thread介绍 RT-Thread 是一款完全由国内团队开发维护的嵌入式实时操作系统&#xff0…

《操作系统》期末主观题梳理

操作系统简答题 文章目录 操作系统简答题第一章第二章第三章第四章第五章第六章第七章第八章第九章 第一章 在计算机系统上配置OS(operating system, 操作系统)的目标是什么?作用主要表现在哪几个方面? 在计算机系统上配置OS, 主要目标是实现&#xff1a;方便性、有效性、可…

Error: Flash Download failed - Target DLL has been cancelled

文章目录 背景参考 背景 在使用keilv5进行STM32开发时&#xff0c;配置用JLink进行文件烧录&#xff0c;出现如下错误&#xff1a; 查阅资料&#xff0c;是因为Keil未识别烧录工具&#xff0c;需要进行下面的操作&#xff1a; 1.打开工程配置窗口&#xff0c;点开Debug选项卡…

并查集专题

⭐️前言⭐️ 本篇文章主要介绍与并查集相关的题目。 &#x1f349;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f4dd;私信必回哟&#x1f601; &#x1f349;博主将持续更新学习记录收获&#xff0c;友友们有任何问题可以在评论区留言 &#x1f349;博客中涉及源码及博主…