文章目录
- 一、诞生背景
- 1.1 自监督学习的趋势
- 2.2 ViT 的出现
- 二、模型
- 2.1 模型架构
- 2.1.1 数据shape变化
- 2.1.2 模型架构流程图
- 2.1.3 PyTorch 代码示例(核心部分)
- 2.2 位置信息
- 2.3 非对称的编码器-解码器结构
- 2.4图片重构
- 三、实验
- 3.1 主实验
- 3.2 消融实验
- 3.3 总结
- 四、总结
论文 📝 MAE:Masked Autoencoders Are Scalable Vision Learners
Masked Autoencoders Are Scalable Vision Learners (He et al., 2021)。
一、诞生背景
1.1 自监督学习的趋势
在自然语言处理(NLP)中,BERT 使用了“掩码语言建模(MLM)”策略,通过掩盖输入序列的一部分并预测它们,从海量文本中学习通用语言表示。这种方法不依赖人工标注,大幅推动了 NLP 的发展。
而在视觉领域:
- 早期主流是 对比学习(如 MoCo, SimCLR),强调“相似-不相似”之间的对比。
- 但对比学习存在:
- 对 batch size 敏感;
- 训练难度高;
- 依赖复杂的负样本设计。
Yann LeCun(图灵奖得主、Meta 首席 AI 科学家)的一个非常经典的演讲或访谈中:
“如果机器学习是一块蛋糕,那么:
- 无监督学习是这块蛋糕的主体;
- 监督学习只是蛋糕上的奶油;
- 强化学习也许只是蛋糕上的樱桃;
- 而像 GAN 这样花哨的生成模型,是那几颗漂亮的水果装饰。”
—— Yann LeCun
蛋糕体 = 无监督学习(Unsupervised/Self-supervised)
- 最大份量:海量的数据(图像、文本、视频等)是未标注的,无监督学习能利用它们;
- 最核心:掌握“世界的结构”是智能的关键,不靠标签,仅靠模式本身;
- 未来趋势:LeCun 强调世界模型(world model)的学习是实现通用智能的基础,而这必须靠无监督/自监督。
奶油 = 有监督学习(Supervised)
- 很有效但需要人工标签;
- 数据依赖重:每个任务都要手工标数据,很昂贵;
- 不是 scalable 的方向:尤其在医学、遥感、视频等领域,标签严重不足。
樱桃 = 强化学习(RL)
- 在很多 AI 系统中是点缀;
- 很难训练,需求结构明确的环境;
- 在实际任务中用得不多,但“显得很酷”。
水果 = GAN 等生成模型
- GAN、Diffusion 之类生成模型非常吸睛;
- 能做图像生成、艺术创作、深度伪造等;
- 但更像是展示能力而非“认知核心”。
2.2 ViT 的出现
Vision Transformer (ViT) 将图像表示为 patch tokens,借助全局自注意力学习图像上下文,为图像应用 NLP 的技术打开了大门。
因此,自然的想法是:是否可以像 BERT 一样在图像领域做掩码建模?
二、模型
2.1 模型架构
2.1.1 数据shape变化
如同ViT一样做patch:
(B, 196, 768)
Mask 操作(通常 mask 掉 75%)
→ 随机选择 25% 的 patch 作为输入
→ kept_patches.shape = (B, 49, 768) # 196 × 0.25 = 49
编码器输出
encoder_out.shape = (B, 49, D) # D 是维度,通常仍是 768
解码器输入(插入 learnable mask tokens)
→ 加入 147 个 mask token,拼接回去变成:
decoder_in.shape = (B, 196, D)
解码器输出:重建所有 patch
decoder_out.shape = (B, 196, patch_dim) # patch_dim = 768
然后计算 MSE loss,仅在被 mask 的 75% 上:
loss = MSE(decoder_out[masked], original_patch[masked])
2.1.2 模型架构流程图
┌────────────────┐
│ Input: x │
│(B, 3, 224,224) │
└─────┬──────────┘
↓
┌────────────────────┐
│ Split into patches │
│→ (B, 196, 768) │
└─────┬──────────────┘
↓
┌────────────────────────────┐
│ Random Masking (75%) │
│ → keep 25%: (B, 49, 768) │
└───────┬────────────────────┘
↓
┌───────────────────────────┐
│ Encoder: Transformer │
│ Input: (B, 49, 768) │
│ Output: (B, 49, 768) │
└───────┬───────────────────┘
↓
┌────────────────────────────────────┐
│ Decoder Input = Encoder output │
│ + learnable mask tokens (147 个) │
│ → (B, 196, 768) │
└──────┬─────────────────────────────┘
↓
┌────────────────────────────┐
│ Decoder: Transformer │
│ Output: (B, 196, 768) │
└──────┬─────────────────────┘
↓
┌─────────────────────────────────────────────┐
│ Prediction = reconstruct patch pixel values │
│ 仅在被 mask 的 patch 上计算 MSE loss │
└─────────────────────────────────────────────┘
2.1.3 PyTorch 代码示例(核心部分)
class MAE(nn.Module):
def __init__(self, encoder, decoder, mask_ratio=0.75):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.mask_ratio = mask_ratio
self.patch_embed = nn.Linear(16*16*3, 768)
self.mask_token = nn.Parameter(torch.zeros(1, 1, 768))
def forward(self, x):
B, C, H, W = x.shape
x = self.patchify(x) # (B, 196, 768)
# === masking ===
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
latent = self.encoder(x) # (B, 49, 768)
# === decode ===
full_tokens = self.restore_with_mask(latent, ids_restore) # (B, 196, 768)
x_rec = self.decoder(full_tokens) # reconstruct pixels
return x_rec, mask
def patchify(self, imgs):
# divide into patches and flatten
p = 16
B, C, H, W = imgs.shape
patches = imgs.reshape(B, C, H//p, p, W//p, p)
patches = patches.permute(0,2,4,3,5,1).reshape(B, -1, p*p*C)
return self.patch_embed(patches)
def random_masking(self, x, mask_ratio):
B, L, D = x.shape
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(B, L)
ids_shuffle = torch.argsort(noise, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
# 返回还原顺序 index,mask 向量等
return x_masked, ..., ...
def restore_with_mask(self, latent, ids_restore):
B, N, D = latent.shape
mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1] - N, 1)
x_ = torch.cat([latent, mask_tokens], dim=1)
x_full = torch.gather(x_, 1, ids_restore.unsqueeze(-1).expand(-1, -1, D))
return x_full
2.2 位置信息
Transformer 是 序列模型,本身不具备位置感知能力。而图像是二维的,有强烈的空间结构。如果 Transformer 不知道某个 patch 来自图像的哪一块,它根本没法理解图像结构。
MAE 如何保留位置信息?
Patch embedding 之前或之后(但一定是在mask之前),加入位置编码
- 对每个 patch 位置
i
∈
[
0
,
N
)
i \in [0, N)
i∈[0,N)
生成一个 learnable 向量 pos i ∈ R D \text{pos}_i \in \mathbb{R}^D posi∈RD
将它加到 patch embedding 上:
x i ′ = patch i + pos i x_i' = \text{patch}_i + \text{pos}_i xi′=patchi+posi - 编码器(Encoder)只看部分 patch,但这些patch 带着它们的位置信息进入 encoder,所以 encoder 能知道“我正在看的是图的哪一部分”
注意:虽然 patch 是被乱序选择的,但位置编码是 绝对位置编码(absolute positional embedding),可以标记每个 token 的二维坐标(如 第 5 行第 2 列)
解码器如何知道 token 顺序?
MAE 中的设计:ids_restore 机制
- Mask 之前:每个 patch 是有顺序的,从 0 到 195
- Mask 后:encoder 只接收其中一部分,例如 patch 7, 31, 80…
- decoder 需要恢复原来的顺序,包括被遮挡的 patch,所以用一个 index 向量
ids_restore
:
x_ = torch.cat([visible_tokens, mask_tokens], dim=1)
x_full = torch.gather(x_, dim=1, index=ids_restore)
这样 decoder 的输入顺序就和原图一致了,可以:
- 按照顺序进行解码
- 正确叠加位置编码(decoder 通常也有自己的 pos embedding)
2.3 非对称的编码器-解码器结构
在 MAE 中,所谓“不对称”,主要体现在两个方面:
组件 | Encoder | Decoder |
---|---|---|
输入 | 仅可见 patch(25%) | 全部 patch(含 mask token) |
计算量 | 少得多(仅处理 49/196) | 多,但结构浅(轻量) |
网络结构 | 深 + 大(ViT-B/L) | 浅 + 小(3~4 层小 ViT) |
输出 | 表征(用于下游任务) | 重建像素 |
Encoder 要小输入 + 大容量
-
输入小(只看 25% patch)
- 节省计算:ViT 的注意力复杂度是 O(N^2),patch 多了算力爆炸;
- 只处理可见 patch(~49 个 token),减少 75% 的负担。
-
模型大(ViT-B、ViT-L、ViT-H)
- 虽然只看了图的一角,但模型有足够的 capacity 来理解上下文;
- 学到的是高质量的图像表征,能迁移到分类、检测、分割等任务。
Decoder 要大输入 + 小容量
-
输入大(加上了 mask token,还原全图)
- decoder 要 reconstruct 被遮挡的 patch,因此需要“填补回全部 196 个位置”。
-
模型小(浅层 Transformer)
- 目标是像素级重建,不是深层语义理解;
- decoder 只是帮 encoder 完成训练目标(像素恢复),不是核心;
- 如果 decoder 太强,反而容易“帮 encoder 猜答案”,encoder 就没学到好的表征(信息泄露问题);
优点如下:
优点 | 原因 |
---|---|
高效 | encoder 只处理 1/4 的 patch,计算量 ↓3~4 倍 |
表征纯净 | encoder 没被迫 reconstruct,focus on语义学习 |
分工明确 | encoder 学表征,decoder 做辅助重建 |
易扩展 | 可以用 ViT-B/L/H 做 encoder,decoder 保持轻量 |
2.4图片重构
MAE在这一点上做的十分简单
Decoder Output:
(B, 196, 768)
↓ reshape
(B, 14, 14, 16, 16, 3)
↓ 合并 patch
(B, 224, 224, 3)
↓ permute
(B, 3, 224, 224) ← final reconstructed image
MAE 的 decoder 输出每个 patch 的像素值(flatten 后),通过 unpatchify 把它 reshape 回图像结构,最终拼成整张图,用于训练中的重建误差计算。
三、实验
3.1 主实验
实验任务 | 模型设置 | 数据集 | 结果指标 | 性能 | 结论 |
---|---|---|---|---|---|
图像分类 | ViT-B / ViT-L / ViT-H + MAE | ImageNet-1K | Top-1 Accuracy | 83.6% / 85.9% / 86.9% | MAE 可训练出大规模 ViT,性能超 ResNet |
目标检测 | ViT-L + MAE + Mask R-CNN | COCO | box mAP | 50.9 | MAE 预训练可迁移至目标检测 |
语义分割 | ViT-L + MAE + UPerNet | ADE20K | mIoU | 55.4 | MAE 表征对分割也有提升 |
3.2 消融实验
问题 | 对比设置 | 最佳设置 | 结论 |
---|---|---|---|
最佳 mask 比例? | 0%、50%、75%、90% | 75% | 足够稀疏才能激发全局建模,太多则信息不足 |
Decoder 深度影响? | 1、4、8 层 | 4 层 | 太浅重建不准,太深信息泄露 |
Decoder 宽度(dim)? | 512、768、1024 | 768 | 与 encoder 一致效果最佳,decoder 无需变大 |
预训练时长? | 400、800、1600 epoch | 1600 epoch | 长训练时间带来稳定性能增益 |
loss 应该计算哪里? | 仅 masked、全部 patch | 仅 masked | 避免 decoder 影响 visible 部分,提升表征纯度 |
目标类型(预测什么)? | 原始像素、CNN 中间特征、视觉 token | RGB 像素 | 最简单的重建目标效果最佳 |
Decoder 架构复杂度? | 不同 decoder 宽度 / 层数 | 轻量、浅层 | decoder 太强会抢 encoder 的任务 |
Patch 尺寸(未系统化,但有提及) | 16×16 vs 其他 | 16×16 | 平衡了细粒度与序列长度,训练稳定 |
3.3 总结
原则 | 解释 |
---|---|
轻 decoder,重 encoder | encoder 专注学语义,decoder 只是辅助拼图 |
高 mask 比例是关键 | 掩码越多,模型越要“理解图”而不是“记忆图” |
不要复杂预测目标 | 原始像素 MSE 最简单,效果也最稳定 |
长训练周期 + 大模型 | 大 ViT + 足够训练时间是成功前提 |
仅 mask 区域监督 | 避免 visible 部分 loss 污染 encoder 学习 |
四、总结
总的来说,MAE 是一个特别“干净利落”的自监督学习方法。它的核心思路就是:把图像拆成小块(patch),随机遮掉大部分,然后让模型用剩下的一点点信息去“脑补”整张图。整个过程不依赖任何标签,完全靠模型自己学。
MAE 之所以成功,是因为它做到了三个字:简单、有用、高效。它把 encoder 和 decoder 分工明确——encoder 专注提取语义特征,decoder 只是个工具人,帮忙还原像素。这样不仅训练快、计算省,而且学出来的表示还能很好地迁移到下游任务,比如分类、检测、分割等等。
论文里也做了很多消融实验,验证了比如“75% 掩码比例最合适”、“decoder 不要太深”、“loss 只在被遮挡的地方算”等等这些设计选择确实有用。可以说,这些细节上的坚持,才成就了 MAE 的整体效果。
更重要的是,MAE 的成功也带动了后续一大批基于遮挡重建的视觉自监督方法,像 SimMIM、CAE、MaskFeat 等都受到了它的启发。
所以如果用一句话总结 MAE,那就是:
它用最简单的方式,把图像看懂了。