SimMIM: a Simple Framework for Masked Image Modeling

news2025/1/15 16:45:25

论文名称:SimMIM: a Simple Framework for Masked Image Modeling
发表时间:CVPR2022
开源地址: 开源代码
作者及组织:Zhenda Xie, Zheng Zhang, Hu Han等,来自清华,微软亚洲研究院。

前言

 本文提出一种新的自监督视觉预训练方法,是跟MAE同期工作 ,两篇论文有点儿类似,但本文较MAE额外验证MIM在swin-transformer网络上也有效。

1、方法

在这里插入图片描述

  MIM基本成了下游感知任务的预训练标配,结合代码简单说下pipline。

1.1.数据读取

  1)给定一张192*192图像,假设每个遮挡的块size=32,则这张图像能拆成192/32 * 192/32 = 36个块;
  2)然后生成长度为36的全0的mask张量,此时在引入额外的参数(遮挡比例=0.6),然后随机将mask张量中的36*0.6 = 22个位置置为1;
  3)siwn_t会先将图像经过一个kernel_size = stride = 4的conv,将图像变成(1,96,48,48)的张量,由于下采样4倍所以每个遮挡块大小成了32/4=8*8。
  4)之后将mask中每个元素广播成(8,8)大小在和图像相乘即可,当然这里将mask区域替换成可学习张量。

class MaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
        self.input_size = input_size
        self.mask_patch_size = mask_patch_size  
        self.model_patch_size = model_patch_size   # 即4中的kernel = stride = 4
        self.mask_ratio = mask_ratio
        
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        
        self.rand_size = self.input_size // self.mask_patch_size
        self.scale = self.mask_patch_size // self.model_patch_size
        
        self.token_count = self.rand_size ** 2
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
        
    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1
        
        mask = mask.reshape((self.rand_size, self.rand_size))
        # 广播成(48,48)的0/1块张量
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)  
        
        return mask
# 大卷积核处理图像
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        return x
class SwinTransformerForSimMIM(SwinTransformer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        assert self.num_classes == 0

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) # 可学习mask_token
        trunc_normal_(self.mask_token, mean=0., std=.02)

    def forward(self, x, mask):
        x = self.patch_embed(x)

        assert mask is not None
        B, L, _ = x.shape

        mask_tokens = self.mask_token.expand(B, L, -1)
        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) 
        x = x * (1. - w) + mask_tokens * w     #   用可学习mask_token替换遮挡区域

1.2.模型

 后续数据经过swin_transformer,在接一个简单decoder直接回归 被遮挡块 的RGB像素值即可。

class SimMIM(nn.Module):
    def __init__(self, encoder, encoder_stride):
        super().__init__()
        self.encoder = encoder
        self.encoder_stride = encoder_stride
        # decoder 
        self.decoder = nn.Sequential(
            nn.Conv2d(
                in_channels=self.encoder.num_features,
                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(self.encoder_stride),
        )

        self.in_chans = self.encoder.in_chans
        self.patch_size = self.encoder.patch_size

    def forward(self, x, mask):
        z = self.encoder(x, mask)
        x_rec = self.decoder(z)
        # 重新插值回4倍的块大小。
        mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
        loss_recon = F.l1_loss(x, x_rec, reduction='none')
        loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
        return loss

2、实验

 经过预训练+微调在ImageNet1k上MIM取得83.8,但Linear probe效果很差。
在这里插入图片描述

 对比全监督训练方法,都是高的。
在这里插入图片描述

 下游任务效果:比全监督点高。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1392775.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024-01-17复盘和总结

今日复盘 今天是我失业的第一天,心里有点难受,但是没办法,生活需要继续。 1.做了什么? 今天早上9点出发,骑电动车去了闵行区的图书馆,在图书馆里优化了简历,把word版的简历变成了pdf版的简历…

import { ArrowRight } from “@element-plus/icons-vue“;

今天下午快被这个问题折磨疯了 虽然知道这个问题怎么产生的 但项目里那个碍眼的红线就是去不掉 后来才发现 这是插件的锅 我的心情 你知道我想要说什么的 想必能看到这篇文章的 也知道这个问题是怎么产生的 vue3ts使用的时候 默认是需要带上文件名的 但是引入el组件时 …

transbigdata笔记:轨迹切片

1 方法介绍 在transbigdata笔记:轨迹停止点和行程提取-CSDN博客中,已经可以把轨迹点拆分成停止点和行程点,但是行程点只有起止位置,不包含行程轨迹信息为了进一步分析车辆的行驶轨迹,需要从每次行程的时间段中提取轨迹…

Apache Doris (六十四): Flink Doris Connector - (1)-源码编译

🏡 个人主页:IT贫道-CSDN博客 🚩 私聊博主:私聊博主加WX好友,获取更多资料哦~ 🔔 博主个人B栈地址:豹哥教你学编程的个人空间-豹哥教你学编程个人主页-哔哩哔哩视频 目录 1. Flink与Doris版本兼容

探索自适应学习在考试培训系统中的优势

近年来,随着互联网的普及和发展,自适应学习作为一种个性化、灵活的学习方式受到越来越多教育工作者的关注。在考试培训系统中引入自适应学习,可以为学生提供更加有效和高效的学习体验。 自适应学习可以根据学生个体的学习特点和水平&#xff…

大文件的断点续传如何实现

断点续传 断点续传是一种数据恢复技术,主要用于在读取或发送数据时,因为网络问题、磁盘问题等原因导致数据传输中断。断点续传技术允许你在已经传输的数据基础上继续传输,从而节省数据传输时间。 断点续传通常用于文件传输过程中,…

Spring Security的使用条件

Spring Security要求使用Java 8或更高版本的运行时环境。 由于Spring Security旨在以自包含的方式运行,因此您无需在Java运行时环境中放置任何特殊的配置文件。特别是,您无需配置特殊的Java认证和授权服务(JAAS)策略文件&#xf…

OpenHarmony—Docker编译环境

Docker环境介绍 OpenHarmony为开发者提供了两种Docker环境,以帮助开发者快速完成复杂的开发环境准备工作。两种Docker环境及适用场景如下: 独立Docker环境:适用于直接基于Ubuntu、Windows操作系统平台进行版本编译的场景。基于HPM的Docker环…

工业企业能源管理平台,可以帮助企业解决哪些方面的能源问题?

随着全球工业化进程的加快,工业企业在生产经营过程中消耗的能源也越来越庞大。能源成本的上升和环境保护的压力使得工业企业对能源管理的重要性有了深刻的认识。为了提高能源利用效率、降低能源消耗、减少环境污染,工业企业在能源管理方面迫切需要一套规…

ASP.NET Core SingleR:初次体验和简单项目搭建

文章目录 前言应用场景SignalR 网站长什么样?第一个ASP.NET core SignalR程序确定SignalR版本新建MVC项目添加unpkg管理器添加客户端添加ChatHub文件添加SignalR服务添加网页运行测试浏览器Websocket调试type1type6Type为其它时 总结 前言 平常的网页通讯都是基于H…

拷贝 hugging face 仓库到 colab

# 挂在谷歌云硬盘 from google.colab import drive drive.mount(/content/drive) # 转到文件夹 import os os.chdir(/content/drive/MyDrive/) !pwd# 安装并引入包 !pip install gradio bypy huggingface_hub import os import shutil from huggingface_hub import snapshot_do…

PuTTY的ppk密钥与OpenSSH密钥之间的相互转换

几个概念说明:id_rsa、id_rsa.pub、ppk、pem 目前有两个主流的密钥格式:OpenSSH格式的密钥 和 PuTTY格式的密钥。 id_rsa和id_rsa.pub 都是OpenSSH格式的密钥。 id_rsa是OpenSSH格式的SSH私钥。 id_rsa.pub是OpenSSH格式的SSH公钥。ppk文件 ppk文件是P…

【Docker】安装 Nginx 容器并部署前后端分离项目

🎉🎉欢迎来到我的CSDN主页!🎉🎉 🏅我是Java方文山,一个在CSDN分享笔记的博主。📚📚 🌟推荐给大家我的专栏《Docker实战》。🎯🎯 &…

OpenHarmony 应用开发入门 (一、环境搭建及第一个Hello World)

万事开头难。难在迈出第一步。心无旁骛,万事可破。没有人一开始就能想清楚,只有做起来,目标才会越来越清晰。--马克.扎克伯格 前言 2024年1月16日,华为目前开启已HarmonyOS NEXT开发者预览版Beta招募,报名周期为1月15…

【MATLAB】SVMD_LSTM神经网络时序预测算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 SVMD-LSTM神经网络时序预测算法是一种结合了单变量经验模态分解(Singular Value Decomposition,SVD)和长短期记忆神经网络(LSTM&#xff09…

Three.js 学习笔记之模型(学习中1.17更新)

文章目录 模型 几何体 材质模型点模型Points - 用于显示点线模型Line | LineLoop | LineSegments网格模型mesh - 三角形 几何体BufferGeometry缓冲类型几何体BufferGeometry - 没有任何形状的空几何体创建几何体的方式BufferAttribute Types定义顶点法线 geometry.attributes…

感觉捡到宝了!这究竟是哪位大神出的神器?

你们在制作简历时,是不是基本只关注两件事:简历模板,还有基本信息的填写。 当你再次坐下来更新你的简历时,可能会发现自己不自觉地选择了那个“看起来最好看的模板”,填写基本信息,却没有深入思考如何使简历…

Java--业务场景:在Spring项目启动时加载Java枚举类到Redis中

文章目录 前言实现项目启动时加载枚举值到Redis1. 定义EnumInterface接口2. 创建EnumDTO3. 创建ClassUtils工具类4. 创建EnumService接口5. 创建EnumServiceImpl6. 修改枚举类7. 创建ApplicationInit 测试结果 前言 新的一年即将来到,回首2023年,也是学…

GEE中Landsat、Sentinel、Modis主要数据集区别

一、Landsat 1. Collection 1/2 的区别 Collection 2 是Landsat Level 1 数据的又一次重大再处理,显著提高了绝对地理定位精度。 Collection1Collection2时间跨度1972~2021底1972~至今数据等级level 1level1:1972~2021底 level2:1982~至今 …

MyBatisPlus学习笔记四-扩展功能

1、代码生成器 1.1、官方的1 1.3、官方的2-idea插件 1.3、非官方的-idea插件 2、静态工具 先查询,再分组 3、逻辑删除 4、枚举处理器 5、JSON处理器