MAE(Masked Autoencoders) 详解

news2024/12/27 13:11:14

MAE详解

  • 0. 引言
  • 1. 网络结构
  • 1.1 Mask 策略
  • 1.2 Encoder
  • 1.3 Decoder
  • 2. 关键问题解答
    • 2.1 进行分类任务怎么来做?
    • 2.2 非对称的编码器和解码器机制的介绍
    • 2.3 损失函数是怎么计算的?
    • 2.4 bert把mask放在编码端,为什么MAE加在解码端?
  • 3. 总结

0. 引言

masked autoencoders (MAE) 是用于CV的自监督学习方法,优点是扩展性强的(scalable),方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。MAE基于两个核心设计:(1)不对称的(asymmetric)编码解码结构,编码器仅仅对可见的patches进行编码,不对mask tokens进行任何处理,解码器将编码器的输出(latent representation)和mask tokens作为输入,重构image;(2)使用较高的mask比例(如75%)。MAE展现了很强的迁移性能,在ImageNet-1K上取得了best accuracy(87.8%),且因为方法简单,可扩展性极强(scalable)。
下图展示了MAEImageNet验证集上的重建结果。对于每个三元组,左边的图像是被遮挡的,中间的图像是MAE重建的,右边的图像是实际的。其中掩蔽率为80%,即在196个patch中只剩下39个对模型可见。可以看出,经过MAE还原后的图像可以大致将原始图像还原出来。
在这里插入图片描述

论文名称:Masked Autoencoders Are Scalable Vision Learners
论文地址:https://arxiv.org/abs/2111.06377
代码地址:https://github.com/facebookresearch/mae

1. 网络结构

MAE 模型整体网络结构如下所示。包含一个encoder模块和一个decoder模块。
首先,输入图像被按照patch_size分割成patch集合。然后,patch集合中的一个大的随机子集mask,没有被maskpatch会被输入encoder模型得到编码补丁。随后,编码补丁masked token(被mask的部分,其中每个masked token都是共享的可被学习的向量)被合并输入decoder。经decoder得到还原后的图案。
在这里插入图片描述

1.1 Mask 策略

首先,沿袭 ViT 的做法,将图像分成一块块(ViT 中是 16x16 大小)不重叠的 patch,然后使用服从均匀分布(uniform distribution)的采样策略对这些 patches 随机采样一部分,同时 mask 掉余下的另一部分。被 mask 掉的 patches 占所有 patches 的大部分(实验效果发现最好的比例是 75%),它们不会输入到 Encoder。

OK,策略很简单,那么这样做有什么好处呢?

首先,patch 在图像中是服从均匀分布来采样的,这样能够避免潜在的“中心归纳偏好”(也就是避免 patch 的位置大多都分布在靠近图像中心的区域);其次,采用高掩码比例(mask 掉图中大部分 patches)能够防止模型轻易地根据邻近的可见 patches 推断(原文是 extrapolation,外推,这词有点高级…)出这些掩码块;最后,这种策略还造就了稀疏的编码器输入,因为 Encoder 只处理可见的 patches,于是能够以更低的代价训练较大规模的 Encoder,因为计算量和内存占用都减少了。

虽然 mask 策略好像挺简单的,但却是至关重要的一个部分,因为其决定了预训练代理任务是否具有足够的挑战性,从而影响着 Encoder 学到的潜在特征表示 以及 Decoder 重建效果的质量。

1.2 Encoder

记住最重要的一点,Encoder 仅处理可见(un-masked)的 patches。Encoder 本身可以是 ViTResNet(其它 backbone 也 ok,就等你去实现了,大神给了你机会),至于如何将图像划分成 patch 嘛,使用 ViT 时的套路是这样的:

作者首先将图片数据 X ∈ R H × W × C X\in R^{H\times W \times C} XRH×W×C 按照 patch_size 进行切分并进行一维展平,得到数据 X ∈ R N × ( P 2 × C ) X\in R^{N\times (P^2\times C)} XRN×(P2×C) 。其中, P P P 表示 patch_size N N N 表示图片被切分为多少块,即 N = H × W P 2 N=\frac{H\times W}{P^2} N=P2H×W 。然后,这批数据经过线性变换后与原始图像的位置编码进行合并(并在首部添加类别编码 class embedding)。

由于 un-masked patches 占所有 patches 的少数,计算消耗和空间需求都减少了,因此可以训练很大的 Encoder

1.3 Decoder

Decoder 不仅需要处理经过 Encoder 编码的 un-masked 的 tokens,还需要处理 masked tokens。但请注意,masked token 并非由之前 mask 掉的 patch 经过 embedding 转换而来,而是可学习的。所有 masked patches 都共享的1个向量,对,仅仅就是1个!

那么你会问:这样如何区分各个 masked patch 所对应的 token 呢?

别忘了,我们还有 position embedding 嘛!如同在 Encoder 中的套路一样,这里对于 masked token 也需要加入位置信息。position emebdding 是每个 masked patch 对应1个,shape 是 ( N ′ , d i m ) (N',dim) (N,dim),其中 N ′ N' N 是 masked patch 的数量。但 masked token 只有1个怎么办是不是?简单粗暴——“复制”多份即可,使得每个 masked patch 都对应1个 masked token,这样就可以和 position embedding 进行相加了。

另外,Decoder 仅仅是在预训练任务为了重建图像而存在,而我们的下游任务形式多种多样,因此实际应用时很可能没 Decoder 什么事了。所以,Decoder 的设计和 Encoder 是解耦的,Decoder 可以设计得简单、轻量一些(比 Encoder 更窄、更浅。窄:对应通道数;浅:对应深度),毕竟主要学习潜在特征表示的是 Encoder

这样,尽管 Decoder 要处理的 token 很多(全量token,而 Encoder 仅处理 un-masked 的部分),但其本身轻量,所以还是能够高效计算。再结合 Encoder 虽然本身结构重载(相对 Decoder 来说),但其处理的 token 较少,这样,整体架构就十分 efficient 了!

2. 关键问题解答

2.1 进行分类任务怎么来做?

看起来 MAE 是一个图像还原的项目,那么如何使用它来做图像分类任务呢?
虽然 MAE 整体结构是图像还原项目,但是也可以用来做图像分类。MAE 采用先预训练然后再微调的方法得到分类模型。具体操作步骤如下:

  1. 首先,使用MAE模型进行训练来得到预训练好的模型。
  2. 然后,将Encoder部分提取出来。
  3. 最后,在后面加上全连接层进行分类。

整体而言:使用预训练模型得到一个可以提取“完整”特征的Encoder模型,然后在后面加上线性层进行分类。

2.2 非对称的编码器和解码器机制的介绍

  1. 非对称是说编码器看到的和解码器看到的东西是不一样的,这里编码器只看到那些可见的块,解码器拿到编码器的输出之后,就去重构那些被遮挡住的块
  2. 为什么使用这些非对称的架构,因为大量的块都被遮住了,这样的话编码器只用看可见的那些块,可以极大地减轻计算的开销,也可以使得内存更小一点

2.3 损失函数是怎么计算的?

MAE 预训练任务的目标是重建像素值,并且仅仅是 masked patches 的像素值,也就是仅对 mask 掉的部分计算 loss,而 loss 就是很大众的 MSE。为何仅计算 mask 部分的 loss?实验结果发现这样做模型的性能会更好,而如果对所有 patches 都计算 loss 的话会掉点。
那么模型是如何去预测 masked patches 的像素值并计算 loss 的呢?具体来说,就是:

在 Decoder 解码后的所有 tokens 中取出 masked tokens(在最开始 mask 掉 patches 的时候可以先记录下这些 masked 部分的索引),将这些 masked tokens 送入全连接层,将输出通道映射到1个 patch 的像素数量(PxPxC),也就是输出的 shape 是:(B,N’,PxPxC),其中的每个值就代表预测的像素值。最后,以之前 mask 掉的 patches 的像素值作为 target,与预测结果计算 MSE loss。

另外,作者提到使用归一化的像素值作为 target 效果更好,能够提升学到的表征的质量。这里的归一化做法是:计算每个 patch 像素值的均值与标准差,然后用均值与标准差去归一化对应的 patch 像素。

代码如下所示:

    def forward_loss(self, imgs, pred, mask):
        """
        imgs: [N, 3, H, W]
        pred: [N, L, p*p*3]
        mask: [N, L], 0 is keep, 1 is remove,  mask记录了哪些patch被mask
        """
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6)**.5

        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches
        return loss

2.4 bert把mask放在编码端,为什么MAE加在解码端?

bert在预训练中输入到encoder的里面有mask,但是在进行下游任务微调时没有mask,这样会使预训练和下游任务的微调存在一个gap,因为输入不一致会导致最终输出效果有影响,bert为了消除这个影响会对15%的词汇有8:11的比例,只有8份是真正mask,这样就缩小了两者的gap——bert是在缩小这个差距,MAE是在试图消除这个影响——让预训练和下游任务微调保持一致
MAE在decoder中加入了mask,是因为在下游任务只使用了encoder,所以在预训练和下游任务都不会出现mask——但是!在预训练时MAE看到的是25%patch,在下游任务看到的是100%patch,其实引入了另外一种gap。

3. 总结

MAE的算法还是非常简单的,就是利用vit来做和BERT一样的自监督学习,vit已经做了类似的事情了,但是本文在此基础之上提出了两点

  • 第一点是需要盖住更多的块,使得剩下的那些块,块与块之间的冗余度没有那么高,这样整个任务就变得复杂一点
  • 第二个是使用一个transformer架构的解码器,直接还原原始的像素信息,使得整个流程更加简单一点
  • 第三个是加上vit工作之后的各种技术,使得它的训练更加鲁棒一点

以上三点加起来,使得MAE能够在ImageNet-1k数据集上使用自监督训练的效果超过了之前的工作。
如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/599968.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

chatgpt赋能python:Python做词云:从入门到精通

Python做词云:从入门到精通 如果你对数据可视化有所追求,那么词云图一定会是你的首选之一。Python作为一种常用的编程语言,在这个领域也有着广泛的应用。本文将介绍Python做词云的方法,从入门到精通,让你轻松掌握这一…

伍尔特IT公司利用SNP软件实现SAP系统现代化

近日,SAP生态伙伴中的自动化数据迁移和数据管理软件提供商SNP公司,正在支持伍尔特IT公司(Wrth IT) 向SAP S/4HANA转型。伍尔特集团的全球IT服务提供商计划首先整合其现有的ECC系统,然后进行标准化,最后将清…

数据结构与算法课程设计---最小生成树的应用

文章目录 一.课题概述1.问题2.分析3.目标 二.图的实现1.图的存储结构2.图的基本操作2.1添加顶点2.2添加边弧2.3Kruskal算法2.4Prim算法 三.堆的实现1.堆的概念及结构2.堆的基本操作2.1入堆(向上调整算法)2.2出堆(向下调整算法) 四…

【PWN · ret2libc】[CISCN 2019东北]PWN2

虽然最近的ret2libc的做题基本一致(毕竟类型都是ret2libc嘛),但是对于本蒟蒻现阶段来说,还是有必要记录一下的 前言 持续巩固ret2libc的做题范式/基本套路能力,同时也发现,reverse与pwn密不可分的联系。 一…

chatgpt赋能python:Python做表格的优势及应用

Python做表格的优势及应用 在数据处理与可视化的领域,表格是最常见的形式之一,也是经常被用来展示数据的有效方式。Python作为一种流行的编程语言,在数据处理方面有着强大的功能,同时也提供了许多生成表格的库与工具。本文将会介…

Zotero的安装与数据同步

一、Zotero的下载与安装 对于需要通过大量阅读期刊论文的学生而言如何提高阅读的效率以及论文管理能力是及其重要的,这里我推荐科研萌新们从Zotero入手,因为Zotero相对于Endnote、NoteExpress这类付费文献管理工具(大多数的高校都购买了这类软…

python web开发(三)—— CSS样式

文章目录 概要1.快速了解2.使用方式3. CSS选择器4. 多个属性类联合使用 样式1. 高度和宽度2. 块级和行内标签3. 字体设置4. 文字对齐方式5. 浮动6. 内边距7.外边距8. 内容居中9.body标签10. hover(伪类)11. 设置透明度12. after(伪类)13. position14. 边框border15. 背景色back…

SSH服务详解

1 SSH服务 1.1 SSH服务协议 SSH 是 Secure Shell Protocol 的简写,由 IETF 网络工作小组(Network Working Group )制定;在进行数据传输之前,SSH先对联机数据包通过加密技术进行加密处理,加密后在进行数据传输。确保…

机器学习集成学习——Adaboost分离器算法

系列文章目录 机器学习之SVM分类器介绍——核函数、SVM分类器的使用 机器学习的一些常见算法介绍【线性回归,岭回归,套索回归,弹性网络】 机器学习相关概念思维导图 文章目录 系列文章目录 前言 Adaboost算法的简单介绍 Adaboost算法相…

如何将Chrome浏览器重置为默认设置?

如何将Chrome浏览器重置为默认设置? 将 Chrome 设置重置为默认设置 您可随时在 Chrome 中恢复您的浏览器设置。如果所安装的应用或扩展程序在您不知情的情况下更改了设置,那么您可能需要这样做。不过,您保存的书签和密码不会被清除或更改。 …

数据库 期末复习(4) 概念数据库的设计

参考资料 :邹老师数据库课件 程老师数据库课件 战老师数据库课件 第一部分 为啥要引入概念数据库 感觉只有一个重点 实体联系模型----ER模型 第二部分-----实体联系模型 这个例子可以全看完之后再来看 举个例子:根据COMPANY数据库的需求来构造数据库模式:The com…

工业控制系统的设备如何加密防勒索病毒

场景描述 信息化时代发展迅速,数据防泄露一词也频繁的出现在我们身边。无论企业或政府单位,无纸化办公场景越来越多,数据泄露的时间也层出不穷。例如:世界最大职业中介网站Monster遭到黑客大规模攻击,黑客窃取在网站注…

Flume的安装和使用

安装Flume 1.1访问Flume的官网(http://flume.apache.org/download.html),下载Flume安装apache-flume-1.9.0-bin.tar.gz。或者下载我的百度网盘资源。把安装文件解压缩到windows操作“D:\”目录下,然后执行如下命令测试是否安装成…

JavaEE Servlet的API详解

Servlet的API详解O(∩_∩)O~: 文章目录 JavaEE & Servlet的API详解1. HttpServlet抽象类1.1 init方法1.2 destroy方法1.3 service方法 2. HttpRequest接口2.1 在浏览器上显示请求首行2.2 在浏览器上显示请求header2.3 getParameter方法 - 最常用的API之一2.4 js…

【MAC】nvm安装和使用

傻瓜式使用教程如下,不用担心443 和 mac的文件夹权限问题 ! 1.将nvm包clone下来并克隆到nvm 文件夹中 打开终端后执行: git clone https://gitee.com/mirrors/nvm.git ~/.nvm2.激活nvm sudo source ~/.nvm/nvm.sh接着就可以通过nvm ls命令…

2023/6/1总结

学习CSS 动画: 2023-05-31 21-48-43-504 效果图: 2023-06-01 13-58-26-168 3D转换 3D移动: transform:translateX() 在x轴移动 transform:translateY() 在y轴移动 transform:translateZ() 在z轴移动 transform:translate3d(x,y,z); …

程序设计综合实习(C语言):链表的创建

一、目的 1.掌握单向链表的概念 2.掌握单向链表的创建、查找、删除方法 二、实习环境 Visual Stdio 2022 三、实习内容、步骤与要求 1.创建一个单向链表,存放10个学生的学号,姓名,并输出这种10个学生的信…

分布式锁框架-Redisson

分布式锁框架-Redisson 一、Redisson介绍二、在SpringBoot中使用Redisson三、Redisson工作原理四、Redisson使用扩展4.1、Redisson单机连接4.2、Redisson集群连接4.3、Redisson主从连接 五、分布式锁总结5.1、分布式锁特点5.2、锁的分类5.3、Redission的使用 基于Redis看门狗机…

chatgpt赋能python:Python以图搜图:如何用Python优化SEO?

Python以图搜图:如何用Python优化SEO? 随着搜索引擎算法的普及,优化您的SEO策略需要更多的创意和技巧。一种方法是使用Python以图搜图,具有该技能可以使您的网站上升到搜索结果列表的顶部。在这篇文章中,我们将探讨Py…

在外部编译器中使用pyqgis

pyqgis_dragonzoebai的博客-CSDN博客 升级后整理 例如在vscode当中添加qgis提供的python解释器,那么就可以使用qgis.core等库 批量处理gdb文件夹,导出对应文件夹目录的geojson文件。 我的gdb文件均没有坐标系,因此需要自己设置正确的坐标系…