5、MAE:探索视觉预训练模型

news2025/1/9 17:49:18

目录

1、论文

2、背景与动机

3、回答的问题

4、创新与卖点

5、实现细节

模型框架

具体步骤

简单代码示例

6、一些资料


1、论文

Masked Autoencoders Are Scalable Vision Learnersicon-default.png?t=N7T8https://arxiv.org/pdf/2111.06377.pdf

2、背景与动机

        在深度学习和计算机视觉的领域中,预训练模型已经成为了提高下游任务性能的重要手段。传统上,许多预训练模型如ResNet、VGG等都是在大规模数据集(如ImageNet)上通过监督学习训练得到的。然而,监督学习需要大量的标记数据,这在成本和可扩展性上都是一个不小的挑战。

        最近,自监督学习作为一个新兴研究领域,提供了一种无需手工标注数据的解决方案。自监督学习的一个关键点是设计预测任务,通过这些任务模型可以从输入数据本身学习到有用的表示。在自然语言处理(NLP)领域,BERT通过掩码语言模型(MLM)任务表现出色,这激发了计算机视觉领域对类似方法的探索。

        MAE (Masked Autoencoder) 正是从这样的背景和动机出发,它将自监督学习中的掩码预测任务引入到视觉领域,致力于从图像数据中以无监督的方式学习高效的特征表示。

3、回答的问题

        论文中回答了一个问题。为什么自监督在CV领域的发展要滞后于NLP呢?论文中给了两个解释:

(1)NLP主流方法是Transformer,视觉里CNN是主流方法,结构差异让视觉很难构造类似于“masked autoencoding”的任务。但是ViT的提出解决了这个问题;

(2)语言和视觉的信息密度(information density)差异巨大,前者是强语义的,高信息密度的(highly semantic and information-dense),在NLP中即使只mask一个token,对模型来说可能都是很难的任务,因此模型可以通过学习获得复杂的语言理解能力(sophisticated language understanding),但是对视觉图像来说,信息是高度冗余的,缺失一个patch,可能并不会让模型产生多少困惑,模型可以通过周围的像素信息进行推断

        所以MAE做的一件事就是mask很高比例的patches,制造高难度的学习任务,方法简单但是极其有效

4、创新与卖点

MAE 的核心创新在于其独特的自监督预训练方法。不同于之前的自监督视觉模型通常需要对比学习或复杂的数据增强,MAE 提出了一种简洁高效的方法:

  1. Masking 策略,并且mask比例非常高:MAE 对输入图像进行随机遮蔽,只露出一小部分像素,模型的任务是预测被遮蔽部分的原始像素。这种策略减少了模型需要处理的数据量,同时迫使模型学习丰富的上下文信息来重建图像。

  2. 编码器-解码器架构:MAE 采用了一个不对称的编码器-解码器架构,其中编码器只对未被遮蔽的部分进行处理,大幅减少了计算量。解码器则负责图像的重建工作,它的结构相对简单,因为其主要任务是理解编码器提供的特征。

  3. 预训练与微调:MAE 的预训练阶段不依赖于标签,这使得模型可以在非常大的数据集上进行训练。一旦预训练完成,MAE 可以通过微调在各种下游任务上实现优异的性能,包括分类、检测和分割等。

5、实现细节

模型框架

具体步骤

  1. 数据遮掩:首先,在输入图像或序列数据中随机选择一定比例的区域进行遮掩,将其替换为特定的遮掩标记(如0或[MASK])。

  2. 编码阶段:编码器实际上就是ViT,将input image切分为不重叠的patches之后,执行linear projection,再加上positional embeddings (the sine-cosine version) ,然后送入transformer blocks。

  3. 解码器:同样使用ViT,将mask tokens和encoded visible patches作为输入,加上位置编码 (the sine-cosine version) 。decoder的最后一层是linear projection,输出通道数量和一个patch内的pixel数量相同(方便重构),然后再reshape,重构image。损失函数使用MSE,损失函数只对masked patches计算(和BERT相同)。同时作者也尝试了normalization的方式,即计算一个patch内像素值的均值和标准差,然后对patch执行normalization,此时encoder的重构任务发生了一些变化,需要重构normalized pixel values,实验表明这种方式效果更好一点

            MAE中decoder的设计并不重要,因为预训练结束之后,只保留encoder,decoder只需要完成预训练时的图像重构任务。但是作者也表示decoder决定了latent representations的语义级别

  4. 损失函数:使用L1或L2距离作为损失函数,衡量预测的像素值或词向量与原始未遮掩数据之间的差异。

  5. 预训练与微调:经过大规模无标签数据上的预训练后,可以将模型参数迁移到特定的下游任务中进行微调,进一步提升任务性能。

简单代码示例

import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    # 用于添加位置信息的模块,通常在Transformer结构中使用
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class Encoder(nn.Module):
    def __init__(self, embed_dim, num_layers, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),
                                      dropout=drop_rate, attention_dropout=attn_drop_rate, bias_qkv=qkv_bias)
            for _ in range(num_layers)])

    def forward(self, src, mask=None):
        output = src
        for layer in self.layers:
            output = layer(output, src_key_padding_mask=mask)
        return output

class MaskedAutoencoder(nn.Module):
    def __init__(self, image_size, patch_size, num_channels, embed_dim, num_layers, num_heads, mlp_ratio, num_classes):
        super(MaskedAutoencoder, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_patches = (image_size // patch_size) ** 2
        self.encoder = nn.Sequential(
            nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
            nn.LayerNorm(embed_dim),
        )
        self.pos_embed = PositionalEncoding(embed_dim)
        self.transformer_encoder = Encoder(embed_dim, num_layers, num_heads, mlp_ratio)
        self.decoder = nn.Sequential(          # 这里做简单了,实际解码器还是用的vit
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, num_channels * patch_size ** 2),
            nn.PixelShuffle(patch_size),
        )
        self.to_patch_embedding = nn.Sequential(
            nn.Unflatten(dim=1, unflattened_size=(num_patches, embed_dim)),
            nn.Dropout(p=0.1),
        )

    def forward(self, x, mask_ratio=0.75):
        B, C, H, W = x.shape
        assert H == W, "Input image must be square"
        x = self.encoder(x)
        x = self.pos_embed(x)
        
        # 随机掩码
        rand_mask = torch.rand(B, self.num_patches, 1, 1, device=x.device) < mask_ratio
        masked_x = x.clone()
        masked_x[rand_mask] = 0.

        # 编码
        encoded_patches = self.transformer_encoder(self.to_patch_embedding(masked_x))

        # 解码
        reconstructed_image = self.decoder(encoded_patches)

        return reconstructed_image

# 初始化模型
model = MaskedAutoencoder(image_size=224, patch_size=16, num_channels=3, embed_dim=768, num_layers=12, num_heads=12, mlp_ratio=4., num_classes=0)

# 假设我们有输入数据x
x = torch.randn((10, 3, 224, 224))

# 计算重构后的图像
reconstruction = model(x)

6、一些资料

MAE(Masked Autoencoders) - 知乎简介MAE(Masked Autoencoders)是用于CV的自监督学习方法,优点是扩展性强的(scalable),方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。MAE基于两个核心设计:(1)不对称的(…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/446761025

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1377048.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【实用技巧】Steam Wallpaper Engine 壁纸引擎向手机导入壁纸方法

一、内容简介 本文介绍如何使用电脑上的 Wallpaper Engine &#xff08;Steam 平台中的壁纸引擎&#xff09;向安卓手机导入并使用壁纸。 二、所需原材料 安卓手机&#xff08;以笔者使用的华为荣耀50为例&#xff09;、安装有Steam以及Wallpaper Engine的电脑 三、导入方法…

美摄视频SDK,卓越的视频解决方案

视频已经成为企业传播信息、展示品牌形象的重要工具。然而&#xff0c;高质量的视频制作并不容易&#xff0c;需要专业的技术和设备支持。这就是我们的美摄科技视频SDK发挥作用的地方。作为一家专注于视频技术开发的公司&#xff0c;我们的目标是为企业提供最优质的视频解决方案…

缓存学习实战篇

缓存练习题&#xff08;用户查询操作&#xff09; public List<ShopType> queryAllType() throws JsonProcessingException {//从缓存中查数据String shopTypeJson stringRedisTemplate.opsForValue().get("cache:shopType");//如果缓存命中&#xff0c;if (S…

Linux内存管理:(八)页面迁移

文章说明&#xff1a; Linux内核版本&#xff1a;5.0 架构&#xff1a;ARM64 参考资料及图片来源&#xff1a;《奔跑吧Linux内核》 Linux 5.0内核源码注释仓库地址&#xff1a; zhangzihengya/LinuxSourceCode_v5.0_study (github.com) 1. 可迁移页面 页面迁移机制支持两…

VUE+bpmn.js实现工作流

1、安装bpmn.js npm install bpmn-js7.3.1 // 我安装的版本是7.3.1npm install bpmn-js-properties-panel0.37.2npm install bpmn-moddle7.1.3 npm install --save camunda-bpmn-moddle 2、配置axios&#xff0c;在main.js中引入axios import axios from axiosVue.proto…

Dcoker构建部署Java项目过程

目录 前言 一、打包 二、Docker File文件编写 一个简单的Docker File文件 三、上传文件 四、构建镜像 五、运行 六、端口开放 前言 使用Dcoker构建部署Java项目&#xff0c;发布到服务器 一、打包 我这里打包的是item-service这个module&#xff0c;clean-cpmpile-pa…

【深度学习每日小知识】Logistic Loss 逻辑回归

逻辑回归的损失函数 线性回归的损失函数是平方损失。逻辑回归的损失函数是对数损失&#xff0c;定义如下&#xff1a; L o g L o s s ∑ ( x , y ) ∈ D − y log ⁡ ( y ′ ) − ( 1 − y ) log ⁡ ( 1 − y ′ ) LogLoss\sum_{(x,y)\in D}-y\log(y)-(1-y)\log(1-y) LogLoss…

测试老鸟汇总,接口测试总结与用例编写,一文策底概全...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 1、为什么要做接口…

Java SE入门及基础(9)

if选择结构 1. 基本if选择结构 语法 if ( 条件 ){ // 如果条件满足&#xff0c;则执行代码块 //代码块 } 案例 从控制台输入一个整数&#xff0c;如果该数字小于 10 &#xff0c;则输出 10 与该数字的差值。 流程图 代码实现 public class Example1 { public s…

Dart 空感知操作符:??

示例 写了如下代码&#xff1a; var str1 "hello"; var str2 "world"; var result str1 ?? str2.toUpperCase(); //如果str1不为空&#xff0c;则执行后面的语句 print(result); 代码可以正常执行&#xff0c;但是报了如下错误&#xff1a; Warnin…

【网络安全】Nessus部署自动更新和端口权限开放

文章目录 Nessus 自动更新配置Nessus服务端口开放Nessus profession 版本需要开放端口Sensor ProxyTenable Security Center (TSC)Tenable OT Security (TOT)Tenable OT Security Enterprise Manager (IEM)Tenable OT Security Industrial Core Platform (ICP)Tenable OT Secur…

kafka除了作为消息队列还能做什么?

Kafka 最初是为大规模处理日志而构建的。它可以保留消息直到过期&#xff0c;并让各个消费者按照自己的节奏提取消息。 与其之前的竞品不同&#xff0c;Kafka 不仅仅是一个消息队列&#xff0c;它还是一个适用于各种情况的开源事件流平台。 让我们回顾一下流行的 Kafka 用例。 …

基于集成学习算法XGBoost农作物产量可视化分析预测系统

文章目录 基于集成学习算法XGBoost农作物产量可视化分析预测系统一、项目简介二、开发环境三、项目技术四、功能结构五、功能实现模型构建封装类用于网格调参训练模型系统可视化数据请求接口模型评分 0.5*mse 六、系统实现七、总结 基于集成学习算法XGBoost农作物产量可视化分析…

中间人攻击如何进行防护

中间人攻击&#xff08;Man-in-the-Middle Attack&#xff0c;简称 MITM 攻击&#xff09;是一种常见的网络攻击方式&#xff0c;攻击者通过截获两个通信实体之间的通信数据&#xff0c;并在此基础上进行篡改、窃取或伪造等恶意行为。这种攻击方式因其攻击手段的隐蔽性和难以防…

2024 年1月12日最热NLP大模型论文:Transformers are Multi-State RNNs

揭秘Transformer的无限可能&#xff0c;Meta研究发现Transformer其实是多状态RNN 引言&#xff1a;重新定义Transformer的视角 在自然语言处理&#xff08;NLP&#xff09;的领域&#xff0c;Transformer架构自2017年提出以来&#xff0c;已经成为了一种主流的模型&#xff0…

呼吸道病毒感染后,为何会引发细菌性肺炎?气道和肠道微生物组改变是关键

谷禾健康 病毒-细菌合并或继发感染 引起呼吸道感染的病毒是导致全世界高发病率和死亡率的原因&#xff0c;数十年来通常发生在冬季。在冬天&#xff0c;空气干燥&#xff0c;那些可能含有病毒的飞沫可以在空气中停留更长时间&#xff0c;并可以进一步传播。此外人的免疫力在冬季…

“Frontiers”系列多本期刊分区下跌,1本SCI被踢,2本SCI升为Top,还可投吗?

近期&#xff0c;2023年中科院分区正式发布&#xff0c;不少学者都很关心期刊变动情况。此次分区更新中&#xff0c;Frontiers出版社旗下的医学期刊表现让人大跌眼镜。 据汇总来看&#xff0c;32本大类医学SCI期刊中&#xff0c;Frontiers of Hormone Research直接从原来的医学…

照片模糊如何变清晰不妨试试这款软件吧

很多人希望能把模糊的图片或照片变得很清晰&#xff0c;或者把一个只有几十KB的小图变成有几M大小的高清大图。一般来说&#xff0c;一张模糊或打了马赛克的图片本身很多细节信息就没有或被删除了&#xff0c;就像一本书缺了很多页&#xff0c;我们是可能百分百的还原出它原来的…

云服务器ECS_GPU云服务器_AIGC_弹性计算-阿里云

阿里云高性能云服务器60%单实例最大性能提升&#xff0c;35Gbps内网带宽&#xff0c;网络增强&通用型云服务器、本地SSD型云服务器、大数据型云服务器、GPU异构型云服务器&#xff0c;阿里云百科aliyunbaike.com分享阿里云高性能云服务器&#xff1a; 阿里云高性能云服务器…

小红书年终“礼物营销”玩法:种拔一体,实现品效破圈

恰逢年末&#xff0c;用户送礼需求旺盛&#xff0c;小红书推出“礼物季”&#xff0c;品牌们纷纷入局&#xff0c;话题上线18天浏览量破9亿。“礼物营销”覆盖全年营销节点&#xff0c;贯穿始终&#xff0c;礼赠场景下用户消费决策链路缩短&#xff0c;种拔一体&#xff0c;帮助…