【论文解读】ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoder

news2024/10/7 4:22:27

1. 本文贡献

提出了一个全卷积掩码的自动编码器框架和一个新的全局响应归一化(GRN)层

1.1 想法

本文的想法是希望能在 ConvNeXt  中使用MAE,但是MAE的设计架构是基于vision transformer的,与使用密集滑动窗口的标准ConvNets不兼容,因此作者的建议是在同一框架下共同设计网络架构和掩蔽自动编码器

1.1.1 操作梗要

将屏蔽输入视为一组稀疏补丁,并使用稀疏卷积仅处理可见部分。在实践中,我们可以用稀疏卷积实现ConvNeXt,在微调时,权重被转换回标准的密集层,而不需要特殊处理。

当直接在掩蔽输入上训练ConvNeXt时,我们发现了MLP层的特征崩溃的潜在问题。为了解决这个问题,我们建议添加一个全局响应规范化层来增强信道间的特征竞争

2. FCMA( Fully Convolutional Masked Autoencoder)

2.1 掩码

2.1.1 MAE的掩码

在正式进入ConvNeXt V2的掩码设计之前,我觉得有必要先看一下MAE的掩码是怎么实现的。

(1)掩码

    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))#保留量
        
        #拉一条同样长度的噪声,在大噪声处上掩码
        noise = torch.rand(N, L, device=x.device)  
        
        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1) #升序排列
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        #从原tensor中获取指定dim和指定index的数据

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore
           #x_masked:图像上掩码的仍然保留的数据
           #mask:在原始图像中的掩码,0:保留,1:掩码
           #ids_restore:noise从打乱ids_shuffle到恢复的序号

 参考下面这段代码就比较直观的看出两个argsort在实际上是在获得噪音和原始坐标从小到大的序号

import torch
len_keep=2#保留两个
x=torch.rand((1,4,1))#对应 N,L,D
print("x",x)
noise = torch.rand(1, 4)
print("noise",noise)
ids_shuffle = torch.argsort(noise, dim=1)#noise从小到大的序号
print("ids_shuffle",ids_shuffle)
ids_restore = torch.argsort(ids_shuffle, dim=1)#noise从打乱ids_shuffle到恢复的序号
print("ids_restore",ids_restore)
N, L, D = x.shape

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
print("ids_keep",ids_keep)
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
print("x_masked",x_masked)
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
print("mask",mask)
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
print("mask",mask)#在原始图像中的掩码位置

(2)掩码的使用

在encoder中patch_embed之后

2.1.2 ConvNeXt V2的掩码

由于卷积模型具有分层设计,其中在不同阶段对特征进行下采样,因此在最后阶段生成掩码,并递归地上采样,直到达到最佳分辨率。

(1)掩码

    def gen_random_mask(self, x, mask_ratio):
        N = x.shape[0]
        L = (x.shape[2] // self.patch_size) ** 2
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.randn(N, L, device=x.device)

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # generate the binary mask: 0 is keep 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return mask

基本上也是大差不差

(2)掩码的使用:编码器encoder设计

两个挑战:

  1. 防止模型学习允许其从掩蔽区域复制和粘贴信息的shortcut。
  2. 保留2D图像结构

在MAE中图片被拉成长条(N, L, D),因此使用Mask其实是非常得心应手的,但是在 ConvNeXt V2 中始终是一个四维(N,C,H,W)的,即保留了2d的图像结构

解决办法:

稀疏卷积纳入框架中,以促进掩蔽自动编码器的预训练

具体的代码直接参考SparseConvNeXtV2

class Block(nn.Module):
    """ Sparse ConvNeXtV2 Block. 

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., D=3):
        super().__init__()
        self.dwconv = MinkowskiDepthwiseConvolution(dim, kernel_size=7, bias=True, dimension=D)
        self.norm = MinkowskiLayerNorm(dim, 1e-6)
        self.pwconv1 = MinkowskiLinear(dim, 4 * dim)   
        self.act = MinkowskiGELU()
        self.pwconv2 = MinkowskiLinear(4 * dim, dim)
        self.grn = MinkowskiGRN(4  * dim)
        self.drop_path = MinkowskiDropPath(drop_path)
    
    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.grn(x)
        x = self.pwconv2(x)
        x = input + self.drop_path(x)
        return x

2.2  解码器 和 损失函数 

解码器:使用一个轻量级的、普通的ConvNeXt块作为解码器

损失函数 :MSE重建损失

3.全局响应归一化

3.1 引入全局响应归一化的原因:“特征崩溃”现象

“特征崩溃”现象:有许多停滞或饱和的特征图,并且激活在通道之间变得多余。

3.2 全局响应归一化

在这项工作中,我们引入了一种新的响应归一化层,称为全局响应归一化(GRN),旨在提高通道的对比度和选择性。所提出的GRN单元由三个步骤组成:

1)全局特征聚合

全局函数G(·)将空间特征地图 x_i 聚合为向量gx,实验结果是使用L2范数效果最好

2)特征归一化

 

3)特征校准

使用计算的特征归一化分数来校准原始输入响应

 

3.3 实现

 没有使用稀疏卷积的实现:

class GRN(nn.Module):
    """ GRN (Global Response Normalization) layer
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/412787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

upload 通关pass16-pass20

1.pass16 白名单 二次渲染 需要先上传一个正常图片,然后下载下来,跟原图片进行比对,用010 16进制编辑器,把php代码放到没有改变的位置,即一样的地方 访问: 2.pass17 白名单 条件竞争 这题先是上传文件并…

音质蓝牙耳机哪款好用?2023公认音质好的四款蓝牙耳机推荐

现如今,蓝牙耳机越来越受欢迎,不少人在听歌、追剧、甚至是玩游戏的时候都会戴着它。最近看到很多人问,音质蓝牙耳机哪款好用?针对这个问题,我来给大家推荐四款公认音质好的蓝牙耳机,一起来看看吧。 一、南…

Nginx connect req access 模块

Nginx connect req access 模块演练 limig_conn模块:限制TCP连接数limit_req模块:限制请求频率access 模块(allow/deny):限制ip段访问auth_request: 基于HTTP响应状态码做权限控制压测可以使用 postman的 run collect…

AndroidStudio如何进行手机应用开发?

文章目录0、引言1、AndroidStudio开发环境配置2、创建第一个手机应用0、引言 Android手机应用因其搭载于手机,使用便捷,应用被大量开发使用。笔者使用手机多年,用过许多手机软件,在使用的过程中,虽然手机软件能解决大部…

C++开发必知的内存问题及常用的解决方法-经典文章

1. 内存管理功能问题 由于C语言对内存有主动控制权,内存使用灵活和效率高,但代价是不小心使用就会导致以下内存错误: • memory overrun:写内存越界 • double free:同一块内存释放两次 • use after free&#xff1…

【数据结构】二叉树顺序结构及实现

🚀write in front🚀 📜所属专栏:初阶数据结构 🛰️博客主页:睿睿的博客主页 🛰️代码仓库:🎉VS2022_C语言仓库 🎡您的点赞、关注、收藏、评论,是对…

Surfshark下载到使用完整教程|2023最新

2023年3月16日更新 在正式介绍surfshark的教程( 教程直达学习地址: qptool.net/shark.html )之前,我们可以来看看最近surfshark的服务与产品退化到什么程度了。我曾经是Surshark两年的忠实用户,但是,现在,作为一个负责人的测评&a…

PostMan动态参数及循环调用

最近需要在测试环境批量创建es索引,也就是某个接口需要循环调用且参数还是变化的,但是我又不想写代码和脚本,于是研究了一下postman一些好玩的功能,希望能节约大家的开发时间 一.设置请求参数 1.获取创建索引的请求以及参数&…

ELK+Filebeat日志分析系统

目录 一.ELK基本介绍 1.ELK是什么? 2.组件简介 2.1 ELK组件介绍 2.2 ELFK组件介绍 2.3 其它组件 4.使用ELK的原因 5.完整日志系统的基本特征 二.Elasticsearch的介绍 三.Logstash的介绍 四.Kibana的介绍 五.ELK的工作原理 六.部署ELK日志分析系统 1.环…

0基础学习软件测试有哪些建议

其实现在基础的资料和视频到处都是,就是看你有没有认真的去找学习资源了,去哪里学习都是要看你个人靠谱不靠谱,再好的教程和老师,你自己学习不进去也是白搭在正式选择之前,大可以在各种学习网站里面找找学习资源先自己…

springboot+vue动物园管理系统java

本系统使用的角色主要有系统管理员、注册用户,本系统分为系统前台和系统后台,首先在系统前台,游客用户可以经过账号注册,管理员审核通过后,用账号密码登录系统前台,查看论坛交流、动物展览、原生动物展览、…

HTML5 <head> 标签、HTML5 <i> 标签

HTML5 <head> 标签 实例 HTML5 <head> 标签表示文档的头部&#xff0c;其中包含了与该文档有关的信息&#xff01; 一份在头部带有 <title> 标签的 HTML 文档&#xff1a; <!DOCTYPE html> <html> <head> <meta charset"utf-8&…

Linux信号sigaction / signal

Linux信号sigaction / signal 文章目录Linux信号sigaction / signal目的函数原型struct sigaction信号枚举值ISO C99 signals.Historical signals specified by POSIX.New(er) POSIX signals (1003.1-2008, 1003.1-2013).Nonstandard signals found in all modern POSIX system…

虹科教您 | 基于Linux系统的RELY-TSN-KIT套件操作指南(1)——硬件设备与操作环境搭建

RELY-TSN-KIT是一款针对TSN的开箱即用的解决方案&#xff0c;它可以无缝实施确定性以太网网络&#xff0c;并从这些技术复杂性中抽象出用户设备和应用。该套件可评估基于IEEE 802.1AS同步的时间常识的重要性&#xff0c;并借助时间感知整形器来确定性地交付实时流量&#xff0c…

判断完全二叉树(层序遍历)| C

层序遍历 基本思路&#xff1a;利用队列&#xff0c;出上一层&#xff0c;带下一层&#xff08;NULL不入队列&#xff09; &#xff08;C语言需要自己构建队列→【队列】&#xff1c;用链表实现队列&#xff1e; | [数据结构] | C语言&#xff09; 代码 #include "Queu…

代码自动发布系统

之前是jenkins发现gitlab代码更新了就自动获取直接部署到服务器 现在是jenkins自动获取Code之后打包成镜像上传到仓库然后通知docker去拉取更新的镜像 分析 旧∶ 代码发布环境提前准备&#xff0c;以主机为颗粒度静态 新: 代码发布环境多套&#xff0c;以容器为颗粒度编译 …

Typora设置修改字体颜色快捷键

目录 1.typora如何设置修改字体颜色快捷键 2. AutoHotKey软件安装 3.typora关于AutoHotKey的具体操作 1.typora如何设置修改字体颜色快捷键 typora本身是不能直接修改字体颜色的&#xff0c;不过若是想修改还是可以用一些代码去改变的&#xff0c;但是每次都修改一次实在麻烦…

mysql常用的基础命令

通过学习mysql命令提高数据处理和工作效率 基础命令 1.登录MySQL mysql -u root -p 2.查看当前系统所有数据库 show databases; 3.切换数据库 use 数据库名称 4.查看数据库下的所有表 show tables; 5.查看表结构&#xff1b; desc 表名&#xff1b; 6.创建数据库 crea…

MAC OS(M1)安装配置miniconda

一、下载安装miniconda miniconde官网&#xff1a;Miniconda — Conda documentation M1最低只能适配到python3.8 打开终端,进入安装包所在文件夹&#xff0c;使用命令进行安装 bash Miniconda3-latest-MacOSX-arm64.sh一路回车 二、配置环境 安装完成后重启终端&#xf…

Unity ads广告插件的使用

介绍 Unity Ads SDK 由领先的移动游戏引擎创建,无论您在 Unity、Xcode 还是 Android Studio 中进行开发,都能为您的游戏提供全面的货币化框架。 使用 Unity Ads 将各种广告格式合并到游戏中的自然呈现点中。例如,您可以实施激励视频广告来构建更强大的游戏经济,同时为您的…