Nas-FPN(CVPR 2019)原理与代码解析

news2024/9/28 7:20:36

paper:NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection

third-party implementation:https://github.com/open-mmlab/mmdetection/tree/main/configs/nas_fpn

本文的创新点

本文采用神经网络结构搜索(Neural Architecture Search, NAS),在一个覆盖所有跨尺度连接的新型可扩展搜索空间中发现了一个新的特征金字塔结构,NAS-FPN。与原始FPN相比,NAS-FPN显著提高了目标检测的性能,并取得了更好了速度-精度的平衡。

方法介绍

考虑到其简单而高效的结构,目标检测模型采用RetinaNet,如图2所示

作者提出了merging cell作为FPN的basic building block,将任何两层的输入特征融合为一层的输出特征。如图3所示

其中Binary Op包括两种候选方案,如图4所示 

最终搜索到的NAS-FPN的完整结构如图6所示 

图7展示的搜索到的所有结构,其中(a)是原始FPN结构,(b)-(f)的精度逐渐变高,(f)是最终的NAS-FPN结构。

因为是搜索到的结构,并且图示非常清晰,这里就不过多介绍具体结构了。接下来结合代码和图(6)(7)的结构介绍一下实现细节 

代码解析

这里以mmdetection中的实现为例,实现代码在mmdet/models/necks/nas_fpn.py中,下面是完整的forward函数。其中self.fpn_stages=7是nas-fpn重复的次数,每个nas-fpn的输出是下一个nas-fpn的输入。forward最开始的输入是backbone的输出C2~C5,这里只取C3~C5通过lateral_conv得到P3~P5,然后进行下采样得到P6和P7,完整的P3~P7作为第一个nas-fpn的输入。

def forward(self, inputs: Tuple[Tensor]) -> tuple:
    # [(8,256,160,160),
    #  (8,512,80,80),
    #  (8,1024,40,40),
    #  (8,2048,20,20)]
    """Forward function.

     Args:
        inputs (tuple[Tensor]): Features from the upstream network, each
            is a 4D-tensor.

    Returns:
        tuple: Feature maps, each is a 4D-tensor.
    """
    # build P3-P5
    feats = [
        lateral_conv(inputs[i + self.start_level])
        for i, lateral_conv in enumerate(self.lateral_convs)
    ]  # [(8,256,80,80),(8,256,40,40),(8,256,20,20)]
    # build P6-P7 on top of P5
    for downsample in self.extra_downsamples:
        feats.append(downsample(feats[-1]))  # [..., (8,256,10,10),(8,256,5,5)]

    p3, p4, p5, p6, p7 = feats

    for stage in self.fpn_stages:
        # gp(p6, p4) -> p4_1
        # print(stage['gp_64_4'])
        p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])  # (8,256,40,40)
        # sum(p4_1, p4) -> p4_2
        p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])  # (8,256,40,40)
        # sum(p4_2, p3) -> p3_out
        p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])  # (8,256,80,80)
        # sum(p3_out, p4_2) -> p4_out
        p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])  # (8,256,40,40)
        # sum(p5, gp(p4_out, p3_out)) -> p5_out
        p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])  # (8,256,20,20)
        p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])  # (8,256,20,20)
        # sum(p7, gp(p5_out, p4_2)) -> p7_out
        p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])  # (8,256,5,5)
        p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])  # (8,256,5,5)
        # gp(p7_out, p5_out) -> p6_out
        p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])  # (8,256,10,10)

    return p3, p4, p5, p6, p7

在for循环中,从上到到下分别对应图6中从左到右按顺序所有的GP和Sum。其中GP对应GlobalPoolingCell,Sum对应SumCell,具体实现都在MMCV中。

GlobalPoolingCell的实现如下,其中self.input1_conv和self.input2_conv是空的,self._resize通过双线性插值进行上采样,通过max pooling进行下采样。

def forward(self,
            x1: torch.Tensor,
            x2: torch.Tensor,
            out_size: Optional[tuple] = None) -> torch.Tensor:
    assert x1.shape[:2] == x2.shape[:2]
    assert out_size is None or len(out_size) == 2
    if out_size is None:  # resize to larger one
        out_size = max(x1.size()[2:], x2.size()[2:])

    x1 = self.input1_conv(x1)
    x2 = self.input2_conv(x2)

    x1 = self._resize(x1, out_size)
    x2 = self._resize(x2, out_size)

    x = self._binary_op(x1, x2)
    if self.with_out_conv:
        x = self.out_conv(x)
    return x

self._binary_op的实现如下,其中self.global_pool是全局平均池化。

def _binary_op(self, x1, x2):
    x2_att = self.global_pool(x2).sigmoid()
    return x2 + x2_att * x1

最后的self.out_conv是Conv+BN+ReLU的组合,对应图6中的R-C-B。注意图6中只有第一个GP和最后一个GP后有R-C-B,中间两个GP后没有,即上面代码中self.with_out_conv=False。

SumCell和GlobalPoolingCell继承自同一个基类,forward函数是一样的。区别在于SumCell中的self._binary_op就是sum操作,如下

class SumCell(BaseMergeCell):
    def __init__(self, in_channels: int, out_channels: int, **kwargs):
        super().__init__(in_channels, out_channels, **kwargs)

    def _binary_op(self, x1, x2):
        return x1 + x2

此外,5个SumCell后都有R-C-B。

实验结果

和其他SOTA模型的对比如表1所示

其中7@256表示NAS-FPN堆叠7次,通道数为256。

文中特别提到由于NAS-FPN的结构堆叠多层引入了更多的参数,需要一个合适的正则化方法来防止过拟合。本文采用DropBlock,具体介绍见DropBlock(NeurIPS 2018)论文与代码解析-CSDN博客。图10展示了DropBlock显著提升了NAS-FPN的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1407997.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

bash 5.2中文修订4

Compound Commands 复合命令 复合命令是 shell 编程语言的结构。每个构造都以保留字或控制运算符开始,并以相应的保留字或运算符终止。与复合命令关联的任何重定向(请参阅 Redirections )都适用于该复合命令中的所有命令,除非显式…

高质量简历模板网站,免费、免费、免费

你们在制作简历时,是不是基本只关注两件事:简历模板,还有基本信息的填写。 当你再次坐下来更新你的简历时,可能会发现自己不自觉地选择了那个“看起来最好看的模板”,填写基本信息,却没有深入思考如何使简历…

搜维尔科技:【简报】元宇宙数字人赛道,《莉思菱娜》

个性有些古灵精怪时儿安静时而吵闹,虽然以人类寿命来算已经200多岁但在 吸血鬼中还只是个小毛头,从中学开始喜欢打扮偏爱黑白灰色系的服装喜欢时 尚圈,立志想成为美妆或时尚网红不过目前还是学生,脸上的浅色血迹是纹身 贴纸&#…

Spark读取kafka(流式和批数据)

spark读取kafka(批数据处理) # 按照偏移量读取kafka数据 from pyspark.sql import SparkSessionss SparkSession.builder.getOrCreate()# spark读取kafka options {# 写kafka配置信息# 指定kafka的连接的broker服务节点信息kafka.bootstrap.servers: n…

《幻兽帕鲁》被指AI缝合,开发过程疑点重重,最后附游戏安装教程

由日本游戏工作室Pocketpair开发的《Palworld / 幻兽帕鲁》毫无疑问成为了2024年的首个巨热游戏!上周五(2024年1月19日)游戏上线抢先体验,仅在3天内销量就已突破400万!并于2024年1月21日创下了1291967名同时在线玩家的…

[ACM学习] 树形dp之换根

算法概述 总的来说: 题目描述:一棵树求哪一个节点为根时,XXX最大或最小 分为两步:1. 树形dp 2. 第二次dfs 问题引入 如果暴力就是 O(n^2) , 当从1到2的时候,2及其子树所有的深度都减一,其它…

手把手教你快速掌握连接远程git仓库or赋值远程仓库到本地并上传代码到gitee

1. 先去官网安装Git ,这里不多赘述网上教程很多 2.1去gitee注册一个账号,然后去我的新建一个仓库,这里是演示一下新手第一次操作的流程 2.2设置仓库名称完成创建(这里的库名随便输入看自己): 2.3 打开git bash 配置用户名&#x…

Kubernetes-Taint (污点)和 Toleration(容忍)

目录 一、Taint(污点) 1.污点的组成 2.污点的设置、查看和去除 3.污点实验: 二、Toleration(容忍) 1.容忍设置的方案 2.容忍实验: Taint 和 toleration 相互配合,可以用来避免 pod 被分配…

VUE3好看的我的家乡网站模板源码

文章目录 1.设计来源1.1 首页界面1.2 旅游导航界面1.3 上海景点界面1.4 上海美食界面1.5 上海故事界面1.6 联系我们界面1.7 在线留言界面 2.效果和结构2.1 动态效果2.2 代码结构 源码下载 作者:xcLeigh 文章地址:https://blog.csdn.net/weixin_43151418/…

虹科方案丨湿热灭菌工艺验证解决方案,确保所有产品和容器达到无菌要求

来源:虹科环境监测技术 虹科方案丨湿热灭菌工艺验证解决方案,确保所有产品和容器达到无菌要求 原文链接:https://mp.weixin.qq.com/s/O-pKQdehB9mHSETpU8egbA 欢迎关注虹科,为您提供最新资讯! #蒸汽灭菌 #高压灭菌 …

小程序直播系统源码_报价与开发_OctShop

近几年,随着直播的火热,人们对于直播带货是相当的熟悉了,逐渐渗透到各行各业中,小程序直播可以实时的更全面的传递商品信息,同时还可以与主播进行互动,可以通过直播聚集的人气打造团购气氛,通过…

LSTM时间序列预测

本文借鉴了数学建模清风老师的课件与思路,可以点击查看链接查看清风老师视频讲解:【1】演示:基于LSTM深度学习网络预测时间序列(MATLAB工具箱)_哔哩哔哩_bilibili % Forecast of time series based on LSTM deep learn…

win 下使用 cmd 运行 jar 包

1、使用 Win R 输入 cmd 命令打开命令提示符 2、在 cmd 窗口中输入以下命令 java -jar xxxxxx.jar 运行 jar 包,控制台出现中文乱码 原因是 windows 默认使用 GBK 编码格式,程序使用 UTF-8 编码格式 将编码格式改为 UTF-8 编码,在 cmd 窗…

C#中IsNullOrEmpty和IsNullOrWhiteSpace的区别?

前言 今天我们一起来探讨C#中两个常用的字符串处理方法:IsNullOrEmpty和IsNullOrWhiteSpace。这两个方法在处理字符串时非常常见,但是它们之间存在一些细微的区别。在本文中,我们将详细解释这两个方法的功能和使用场景,并帮助您更…

Qt Quick程序的发布|Qt5中QML和Qt Quick 的更改

# Quick程序的发布旧版做法 # Qt5中QML和Qt Quick 的更改 1.QML语言的更改(Qt4->Qt5) 在QML语言中,只有少量更改会影响QML代码的迁移:无法直接导入单独的文件(例如:import"MyType.qml”),需要导人该文件所在的目录; JavaScript文件中的相对路径被解析…

webassembly003 whisper.cpp的python绑定实现+Cython+Setuptools

python绑定项目 官方未提供python的封装绑定,直接调用执行文件 https://github.com/stlukey/whispercpp.py提供了源码和Cpython结合的绑定 https://github.com/zhujun1980/whispercpp_py提供了ctype方式的绑定,需要先make libwhisper.so Pybind11 bi…

你真的会数据结构吗:顺序表

❀❀❀ 文章由不准备秃的大伟原创 ❀❀❀ ♪♪♪ 若有转载,请联系博主哦~ ♪♪♪ ❤❤❤ 致力学好编程的宝藏博主,代码兴国!❤❤❤ 又和大家见面啦!在大家看到这个标题的时候其实就已经发现了:我们的C语言的基础知识大…

Shell脚本的if条件语句

目录 1.单分支结构 2.双分支结构 3.多分支结构 4.例题 1.单分支结构 实际上使用“&&”和“||”逻辑测试已经可以完成简单的判断并执行相应的操作,但是当需要选择执行的命令语句较多时,这种方式将使执行代码显得很复杂,不好理解。…

gdzwfw某省公共资源交易平台逆向学习

声明:本文中网站仅为学习技术使用,请勿暴力爬取数据。 学习地址:aHR0cHM6Ly95Z3AuZ2R6d2Z3Lmdvdi5jbi8jLzQ0L2p5Z2c 此网站采用请求头反爬,难点是请求头中几个参数是如何生成的(别问为什么知道是请求头,一…

学单片机前先学什么?

学单片机前先学什么? 在开始前我有一些资料,是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!&#xff…