GCNet论文总结和代码实现

news2024/11/22 22:45:00

GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond(当非局部网络遇到挤压激励网络)

论文:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

源码:https://gitcode.net/mirrors/xvjiarui/GCNet/-/blob/master/mmdet/ops/gcb/context_block.py

目录

一、背景、出发点和主要工作

二、Non-local Networks分析

2.1 重温Non-local Networks

2.2 分析

三、Methods(方法) 

3.1 简化Non-local模块(SNL模块)

3.2 Global Context Modeling Framework(GC框架)

3.3 Global Context Block(GC模块)

四、实验

4.1 COCO上的对象检测/分割

(1)消融实验

(2)backbones增强实验

4.2 ImageNet图像分类

五、结论

六、代码实现


一、背景、出发点和主要工作

背景: GCNet是NLNet的衍生网络。

出发点:通过严格的实证分析,作者发现对于图像中不同的查询位置,non-local network所捕获的全局上下文几乎是相同的

主要工作:在本文中,作者受上述发现启发,创建了一个基于query-independent(查询无关)公式的简化网络,它保持了NLNet的精度,但显著减少了计算量。进一步观察发现,这种简化的设计与SENet具有相似的结构,因此作者将SE模块与NL模块一同整合入了一个捕获全局上下文的三步式通用框架中。在通用框架中,作者设计了一个轻量级的模块,称为全局上下文(GC)块,可以有效地捕获全局上下文,并且轻量级的属性允许将其应用于骨干网络中的多个层(通用性)。

  • 1. 作者基于query-independent公式精简了NLNet,减少了计算量。
  • 2. 作者将SE模块与NL模块共同整合入多头注意力机制中,命名为GC模块。
  • 3. GC模块是一个轻量级的模块,因此它允许被应用于骨干网络中的多个层中。(通用性,即插即用,估计是因为残差连接的作用)

二、Non-local Networks分析

2.1 重温Non-local Networks

(1)Non-local Networks的定义:

其中,i 是输入特征中要被计算的位置,jx_i 所有可能关联到位置的索引。f(\ ,\ ) 用于计算两位置间的相关度。N_p 是特征图中所有的像素数。\textup{W}_\textup{z}​ 和 \textup{W}_\textup{v}表示线性变换矩阵(例如,1x1卷积),C(x)是归一化因子。

(2)计算像素邻域间的相似度的四种方法:

 分别是高斯函数嵌入式高斯函数点积Concatenation(维度拼接操作)。

2.2 分析

作者分别从可视化数学统计两个方面上对NLNet进行了分析。

1. 可视化

目的: 观察不同像素点的注意力图特征。

过程:从COCO数据集中随机选择六张图像,对每张图像分别可视化三个不同的query位置(红点)通过NL模块生成地注意力图(热图)。作者惊奇地发现,对于不同的query位置,它们的注意力图几乎是相同的。如下图所示:

生成不同的query位置注意力图的公式推测可为:

由这个公式得到的矩阵是query位置处的像素与图像上其他像素之前相似度矩阵。

2. 统计分析

目的:比较不同的query位置生成的注意力图的差异大小。

方法:采用余弦距离Jensen-Shannon散度(JSD)两种方法进行比较。

由上表可知,NL模块产生的attention map的余弦差与JSD差都非常小,这再次验证了可视化的观察结果。换句话来说,虽然NL模块打算计算特定于每个query位置的全局上下文,但训练后的全局上下文实际上与query位置无关。

三、Methods(方法) 

NL模块最初的定义,采用嵌入式高斯函数计算相似度:

3.1 简化Non-local模块(SNL模块)

1. 去除 W_qx_iW_\textup{z}

基于上述观察,作者认为全局上下文的捕获实际上与query位置无关,作者设计直接生成一个全局attention map,所有的位置共享这一个attention map,去除生成查询的卷积操作( W_qx_i )来简化Non-local模块。此外,作者根据[12]得出的结论,有和没有W_\textup{z} ​的变体可以达到相当的性能,因此在剔除了残差连接中卷积(W_\textup{z})。定义如下:

Q:百思不得其解,公式中嵌入式高斯函数为什么只有一个输入?

A:实际上,据推测这里不是嵌入式高斯函数,而是softmax函数,结合GCNet代码观察,生成全局attention map的是经过一个1x1卷积 + view操作 + softmax函数实现的。softmax函数定义如下:

2. 变换 W_\textup{v} 位置

为了进一步降低计算成本,作者应用分配定律W_\textup{v} (卷积)移到注意力池之外。 定义如下:

1x1的W_\textup{v}卷积的计算消耗从\mathcal O(HWC^2) 降低到到\mathcal O(C^2)

 

3.2 Global Context Modeling Framework(GC框架)

如下图所示,简化的non-local block可以抽象为三个过程:

(a) 全局注意力池化:采用 1x1 卷积 W_k ​和softmax函数获得注意力权重,然后进行注意力池化获取全局上下文特征。

(这个过程便是注意力池化,被池化为了一个通道向量。)

(b) 特征转换(transform):通过 1x1 卷积 W_\textup{v} 进行特征转换。

(c) 特征聚合:它采用加法将全局上下文特征聚合到每个位置的特征。

作者将上述过程抽象视为一个全局上下文建模框架,定义为:

其中, (a) \sum_j \alpha _j x_j 表示通过权重 \alpha _j 的加权平均将所有位置的特征组合在一起以获得全局上下文特征(SNL模块中的全局注意力池);

(b) \delta (\cdot ) 表示捕获通道依赖关系的特征转换(SNL模块中的 1x1 卷积);

(c) F (·,·) 表示将全局上下文特征聚合到每个位置的特征的融合函数(SNL模块中的广播元素加法)。

 

3.3 Global Context Block(GC模块)

为了进一步优化训练参数,将特征转换部分中简单的1x1卷积操作替换为bottleneck transform模块bottleneck transform模块由一个1x1卷积、一个ReLU层、一个1x1卷积和一个 sigmoid函数组成(与SENet中excitation操作基本一致)。这样做可以将参数数量从C⋅C减少到2⋅C⋅C/r。

Q:SNL为什么用1x1卷积代替(SENet)线性层?

A:作用是一样的,都是为了减少运算过程中参数数量。

替换完bottleneck transform模块之后,GC模块的定义如下:

其中,\alpha _j = \frac{e^{W_kx_j}}{\sum_{m}e^{W_kx_m}} 是全局注意力池的权重, \delta (\cdot ) = W_{v2}ReLU(LN(W_{v1}(·)) 表示bottleneck transform。具体来说,我们的 GC 块包括:

  • (a) 用于上下文建模的全局注意力池。
  • (b) bottleneck transform以捕获通道相关性。
  • (c) 用于特征融合的广播元素加法。

四、实验

4.1 COCO上的对象检测/分割

评价指标:average-precision分数(AP)。

backbones:Mask R-CNN,FPN或ResNet/ResNeXt。

实验细节:所有模型使用Synchronized SGD进行12个epoch的训练,学习速率初始化为0.02。

(1)消融实验

分别在ResNet/ResNeXt的c4位置的最后一个剩余块之前插入添加一个SE模块、SNL模块、GC模块。

由上表(a)显示,SNL和GC在参数更少、计算更少的情况下都可以达到与NL相当的性能,这表明原来的non-local设计在计算和参数方面存在冗余。

上表(f)列出了池化和融合的不同选择,表明在融合阶段加法比缩放更有效,集中注意力只比普通集中效果好一点点。

(2)backbones增强实验

在效果更佳的backbones上评估我们的GCNet,方法是用ResNet-101和ResNeXt-101替换ResNet-50,向多个层(c3+c4+c5)添加GC模块,并采用级联策略 。

值得注意的是,即使采用了更强的backbones,与基线相比,GCNet的收益仍然很大,这表明GC模块与GC框架是对当前模型能力的补充。

4.2 ImageNet图像分类

与在CoCo数据集上的实验设计,分别在ResNet/ResNeXt的c4位置添加一个SE模块、SNL模块、GC模块,此外在c3+c4+c5位置插入GC模块。

 表a报告了不同块的结果。GC块的性能略优于NL块和SNL块,参数少,计算量少,表明了模块设计的通用性和泛化能力。

五、结论

non-local networks作为研究远程依赖的先驱工作,打算建模特定于查询的全局上下文,但只建模与查询无关的上下文。在此基础上,作者对non-local networks进行了简化,并将简化后的模型抽象为全局上下文建模框架。然后作者提出了一个新的实例化框架,GC模块,它是轻量级的,可以有效地捕获远程依赖。CNet是通过将GC块应用到多个层来构建的,它通常在各种识别任务的主要基准上优于简化的NLNet和SENet。

六、代码实现

 GCNet/mmdet/ops/gcb/context_block.py

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
 
import torch
from torch import nn
 
from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
 
 
def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
    if isinstance(m, nn.Sequential):
        constant_init(m[-1], val=0)
    else:
        constant_init(m, val=0)
 
 
@PLUGIN_LAYERS.register_module()
class ContextBlock(nn.Module):
    """ContextBlock module in GCNet.
    See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
    (https://arxiv.org/abs/1904.11492) for details.
    Args:
        in_channels (int): Channels of the input feature map.
        ratio (float): Ratio of channels of transform bottleneck
        pooling_type (str): Pooling method for context modeling.
            Options are 'att' and 'avg', stand for attention pooling and
            average pooling respectively. Default: 'att'.
        fusion_types (Sequence[str]): Fusion method for feature fusion,
            Options are 'channels_add', 'channel_mul', stand for channelwise
            addition and multiplication respectively. Default: ('channel_add',)
    """
 
    _abbr_ = 'context_block'
 
    def __init__(self,
                 in_channels: int,
                 ratio: float,
                 pooling_type: str = 'att',
                 fusion_types: tuple = ('channel_add', )):
        super().__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.in_channels = in_channels
        self.ratio = ratio
        self.planes = int(in_channels * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()
 
    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True
 
        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)
 
    def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
 
        return context
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
 
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
 
        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/940190.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】227. 基本计算器 II

227. 基本计算器 II(中等) 方法:双栈解法 思路 我们可以使用两个栈 nums 和 ops 。 nums : 存放所有的数字ops :存放所有的数字以外的操作 然后从前往后做,对遍历到的字符做分情况讨论: 空格 …

爬虫逆向实战(二十三)--某准网数据

一、数据接口分析 主页地址:某准网 1、抓包 通过抓包可以发现数据接口是api_to/search/company_v2.json 2、判断是否有加密参数 请求参数是否加密? 通过查看“载荷”模块可以发现b参数和kiv参数是加密参数 请求头是否加密? 无响应是否加…

torch.mul()函数使用说明,含高维张量实例及运行结果

torch.mul函数使用说明,含实例及运行结果 torch.mul() 函数torch.mul() 函数定义参数及功能高维数据实例解释 参考博文及感谢 torch.mul() 函数 对输入的张量或数做点积运算,如果维度不统一会想进行维度统一(广播机制)&#xff0…

PConv : Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks

摘要 为了设计快速的神经网络,**许多研究都集中在减少浮点运算(FLOPs)**的数量。然而,我们观察到这种FLOPs的减少并不一定会导致相同程度的延迟减少。这主要是由于浮点运算每秒效率较低的问题所致。为了实现更快的网络,我们重新审视了流行的操作算子,并证明这种低FLOPS主…

docker高级(DockerFile解析)

1、构建三步骤 编写Dockerfile文件 docker build命令构建镜像 docker run依镜像运行容器实例 2、DockerFile构建过程解析 Dockerfile内容基础知识 1:每条保留字指令都必须为大写字母且后面要跟随至少一个参数 2:指令按照从上到下,顺序执行…

异地访问Oracle数据库的解决方案:利用内网穿透实现PL/SQL远程连接的建议与步骤

文章目录 前言1. 数据库搭建2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射 3. 公网远程访问4. 配置固定TCP端口地址4.1 保留一个固定的公网TCP端口地址4.2 配置固定公网TCP端口地址4.3 测试使用固定TCP端口地址远程Oracle ​ 小月糖糖主页 在强者的眼中,没有最…

SolidWorks软件安装包分享(附安装教程)

目录 一、软件简介 二、软件下载 一、软件简介 SolidWorks是一款由达索系统(Dassault Systmes)开发的三维计算机辅助设计(CAD)软件,被广泛应用于机械、电子、建筑和航空航天等领域。它以易学易用、强大的功能和良好的…

Michael.W基于Foundry精读Openzeppelin第32期——SignatureChecker.sol

Michael.W基于Foundry精读Openzeppelin第32期——SignatureChecker.sol 0. 版本0.1 SignatureChecker.sol 1. 目标合约2. 代码精读2.1 isValidSignatureNow(address signer, bytes32 hash, bytes memory signature) 0. 版本 [openzeppelin]:v4.8.3,[for…

华为OD机试 - 云短信平台优惠活动 - 回溯(Java 2023 B卷 200分)

目录 专栏导读一、题目描述二、输入描述三、输出描述四、解题思路五、Java算法源码六、效果展示1、输入2、输出3、说明 华为OD机试 2023B卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷&#…

教师如何有效地发放开学通知并收集签名回执?

老师在即将开学时,希望能够向家长发送开学通知,并确认家长已经收到通知。接下来教给各位老师如何完成这个需求的步骤: 好消息!博主给大家争取到的易查分福利,只需要在注册时输入邀请码:xmt66,就…

从AD迁移至AAD,看体外诊断领军企业如何用网络准入方案提升内网安全基线

摘要: 某医用电子跨国集团中国分支机构在由AD向AzureAD Global迁移时,创新使用宁盾网络准入,串联起上海、北京、无锡等国内多个职场与海外总部,实现平滑、稳定、全程无感知的无密码认证入网体验,并通过合规基线检查,确…

CAD的清除命令如何使用?CAD的清除命令使用方法

CAD广泛应用于土木建筑、装饰装潢、城市规划、园林设计、电子电路、机械设计、服装鞋帽、航空航天、轻工化工等诸多领域,因此CAD越来越成为一项基本技能,很多用人岗位都会要求会使用CAD,为帮助更多人快速学会CAD,而且CAD的使用本身…

taobao.trade.fullinfo.get(获取单笔交易的详细信息)天猫国际站店铺订单接口方法

淘宝交易API taobao.trade.fullinfo.get(获取单笔交易的详细信息) 获取单笔交易的详细信息 1. 只有单笔订单的情况下Trade数据结构中才包含商品相关的信息 2. 获取到的Order中的payment字段在单笔子订单时包含物流费用,多笔子订单时不包含物…

STM32--SPI通信与W25Q64(2)

STM32–SPI通信与W25Q64(1) 文章目录 SPI外设特征 SPI框图传输模式主模式全双工连续传输 非连续传输硬件SPI读写W25Q64 SPI外设 STM32内部集成了硬件SPI收发电路,可以由硬件自动执行时钟生成、数据收发等功能,减轻CPU的负担。 特…

leetcode 496. 下一个更大元素 I

2023.8.28 这题提供暴力解法和单调栈法两种方法。 暴力解&#xff1a; class Solution { public:vector<int> nextGreaterElement(vector<int>& nums1, vector<int>& nums2) {vector<int> ans(nums1.size(),-1);for(int i0; i<nums1.size…

『FastGithub』一款.Net开源的稳定可靠Github加速神器,轻松解决GitHub访问难题

&#x1f4e3;读完这篇文章里你能收获到 如何使用FastGithub解决Github无法访问问题了解FastGithub的工作原理 文章目录 一、前言二、项目介绍三、访问加速原理四、FastGithub安装1. 项目下载2. 解压双击运行3. 运行效果4. GitHub访问效果 一、前言 作为开发者&#xff0c;会…

springMVC之拦截器

文章目录 前言一、拦截器的配置二、拦截器的三个抽象方法三、多个拦截器的执行顺序总结 前言 拦截器 一、拦截器的配置 SpringMVC中的拦截器用于拦截控制器方法的执行 SpringMVC中的拦截器需要实现HandlerInterceptor SpringMVC的拦截器必须在SpringMVC的配置文件中进行配置&…

数字化智能工厂信息化系统集成整合规划建设方案[150页word]

导读&#xff1a;原文《150页6万字数字化智能工厂信息化系统集成整合规划建设方案》&#xff08;获取来源见文尾&#xff09;&#xff0c;本文精选其中精华及架构部分&#xff0c;逻辑清晰、内容完整&#xff0c;为快速形成售前方案提供参考。 数字化智能工厂建设方案 设备智…

大数据学习:Hive常用函数

Hive常用函数 1. Hive的参数传递 1.1 Hive命令行 查看hive命令的参数 [hadoopnode03 ~]$ hive -help语法结构: hive [-hiveconf xy]* [<-i filename>]* [<-f filename>|<-e query-string>][-S] 说明&#xff1a; -i 从文件初始化HQL。-e从命令行执行指定…

LVDS 2-port RGB 转 MIPI参数计算

有一些显示器是只给了屏幕的参数&#xff0c;屏幕输入的参数不一定&#xff0c;可能是输出的MIPI 给显示器&#xff0c;显示内部转换后是LVDS RGB&#xff0c;因此需要转换。 屏幕参数 转换为MIPI参数