【TGRS 2023】RingMo: A Remote Sensing Foundation ModelWith Masked Image Modeling

news2024/11/16 9:31:06

RingMo: A Remote Sensing Foundation Model With Masked Image Modeling, TGRS 2023

论文:https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9844015

代码:https://github.com/comeony/RingMo

MindSpore/RingMo-Framework (gitee.com)

解读:自监督论文阅读笔记 RingMo: A Remote Sensing Foundation Model with Masked Image Modeling-CSDN博客

【代码复现】RingMo:应用于遥感领域的图像掩码建模技术_lalula1999的博客-CSDN博客

摘要

深度学习方法促进了遥感 (RS) 图像解释的快速发展。最广泛使用的是利用ImageNet预训练模型来处理指定任务的 RS 数据。然而,存在自然场景与RS场景之间的领域差距,以及 RS模型泛化能力差 等问题。开发具有通用 RS 特征表示的基础模型是有意义的。由于有大量未标记的数据可用,自监督方法在遥感方面比全监督方法具有更大的发展意义。然而,目前大多数自监督方法都使用 对比学习,其性能对数据增强、附加信息以及正负对的选择很敏感。

本文利用 生成式自监督学习 对 RS 图像的优势,提出一个名为 RingMo 的遥感基础模型框架,它由两部分组成。首先,通过从卫星和航空平台收集200万张RS图像,构建一个大规模的数据集,覆盖全球多个场景和物体。其次,提出了一种 针对 复杂 RS 场景中 密集和小型物体的  RS 基础模型训练方法

本文展示了使用 RingMo 方法在本文的数据集上训练的基础模型在四个下游任务的八个数据集上实现了最先进的模型,证明了所提出框架的有效性。通过深入探索,本文鼓励人们对RS拥抱生成式自监督学习,并利用其通用的表示能力 来加速 RS 应用程序的开发。

引言

 动机

RS图像解译 的一般范式是 加载ImageNet预训练的权重,然后在指定的任务中使用RS数据和标签信息进行模型训练。尽管这些方法在特定任务中 针对特定目标 表现良好,但仍存在以下问题

  • 由于自然场景与RS场景的差异,ImageNet预训练的权值存在域偏差,这在一定程度上影响了RS模型的性能。
  • 模型泛化能力较弱,有限的RS训练数据只能提供少量的场景和对象,难以适应其他任务。
  • 为了扩展到更多的目标和任务,研究人员需要不断地注释数据,这降低了实际应用的效率。

 综上,使用海量的 RS 数据 并设计具有通用 RS 特征表示的基础模型 有望解决上述问题。现有的 RS 基础模型训练方法 包括 监督学习 和 自监督学习 两大主流。

  • 监督学习,需要大量标记数据来训练基础模型。尽管效果好,但获取标记数据需要专家花费大量时间。由于数据标注和数据采集速度之间的巨大差距,大量RS数据无法使用。
  • 自监督方法,可以利用大量未标记的数据,并且是 RS 基础模型的主流方法对比学习 [MoCo,SimCLR v1,v2 ] 在 RS 自监督方法领域占据主导地位。 

自监督方法可以利用大量的未标记数据,是RS大模型建模的主流方法。自监督常用方法:

  • 对比学习

对比学习的核心思想是缩小相似样本的特征距离。由于没有标签,研究者对同一张图像进行数据增强,形成正对,而其他图像作为该图像的负对。此外,一些研究者通过仔细搜索或匹配,引入RS特有的附加信息(地理信息、时间序列数据、音频数据等),并鼓励模型在数据增强的同时,学习附加信息的相似性。对比学习方法取得了很好的效果,但是它们默认将不同的图像作为负对,并且RS数据集中的不同图像经常包含同一类的实例。

  • 掩码图像建模(MIM)

它的目的是重建被遮挡的像素点,并学习数据分布中的一般特征表示。一方面,它避免引入额外的信息,从而更容易利用大量数据。随着建模数据的增多,基础模型对不同场景的适应能力增强。另一方面,它的目标是从原始图像重建像素级信号。在不需要任何数据增广和负对的情况下,该模型也可以获得更好的特征表示。因此,生成式自监督方法更适合于RS基础模型的建立

目前,大多数 MIM 方法都是 基于自然图像进行训练的。与自然场景相比,RS图像存在以下难点

  • 分辨率和方向范围大。受遥感传感器的影响,图像具有多种空间分辨率。此外,具有固定方向的自然图像不同,RS 图像中的对象 从鸟瞰角度 具有较大的角度分布范围。因此,由于尺度和角度的多样性,同一物体在不同的RS图像中具有不同的特征
  • 许多密集和小的目标。自然图像一般仅有少量目标且多在前景区域,RS图像中存在许多小目标,通常分布较密,在一定程度上影响了物级解译精度。
  • 复杂的背景。RS图像中包含了大量的背景信息,导致图像的信噪比较低。物体的边界和背景模糊,干扰了物体的分类。RS影像容易受到天气、光线、云雾等外部因素的干扰,影响成像质量。

基于上述差异,使用自然图像训练的生成式自监督模型在 RS 场景上表现不佳。因此,有必要设计一个适合RS数据的基础模型。

本文提出了一个遥感基础模型框架RingMo,它由两部分组成:一个大规模的RS数据集 一个适合 多任务解译的基础模型训练方法

有了更多的 RS 数据,基础模型可以学习更多的通用特征表示。为了覆盖 具有 复杂背景 和 大量目标 的 RS 数据,本文在无监督条件下收集了 200 万张图像的数据集。 RS 图像的分辨率范围 从 0.1m 到 30m,覆盖六大洲的多个场景和物体,从卫星和空中平台捕获。

基于该数据集,本文设计 RingMo 训练方法,将 MIM 方法应用于训练 RS 基础模型。由一个用于提取隐特征表示的 相对较重的编码器 一个 用于输出重建结果的解码器 组成。

现有的 随机掩码策略 往往会丢失 RS 图像中的密集和小目标。为了解决这个问题,本文根据 RS 场景和目标的属性 设计了一种 Patch Incomplete Mask (PIMask) 策略。 RS目标在 复杂的遥感场景中 通过采用 局部不完全掩码 在保持整体掩码率的情况下 进行保留。

本文使用 RingMo 在本文的数据集上 进行训练 以获得基础模型,并且训练的特征表示 可以应用于广泛的 RS 下游任务。本文在四个 RS 任务上评估RingMo。使用 UCM、AID、NWPU-RESISC45进行 场景识别实验,使用 FAIR1M和 DIOR  进行 目标检测实验,使用 iSAID和 ISPRS Postdam 2 进行 语义分割实验,和 LEVIR-CD 用于 变化检测实验。实验结果表明,本文的基础模型可以在八个数据集上实现最先进的 (SOTA) 性能。此外,由于本文的框架不需要手动标注任何标签,还讨论了不同数量的 RS 数据对基础模型训练的影响,增加 RS 数据的数量可以获得更好的实验结果。

贡献

  • 提出了遥感领域首个生成式自监督基础模型框架(RingMo)。该框架利用大量的遥感数据来获得通用特征表示,提高了各种遥感解译任务的准确性。
  • 为了增强基础模型对RS数据的处理能力,论文根据RS图像的特性设计了一种自监督方法,改善了以往掩码策略在复杂RS场景中可能忽略密集小目标的情况。
  • 在没有人为监督的情况下,论文收集了200万张图像的RS数据集,这些图像来自卫星和航空平台,涵盖了六大洲不同的物体和场景,包含了大量不同的RS图像,提高了基础模型对不同场景的适应性。
  • 在收集到的数据集上使用RingMo训练方法得到基础模型后,在四个典型的RS任务上对其进行微调。实验表明,本文方法在8个下游数据集上实现了SOTA,并验证了RS基础模型在各种应用中的有效性和泛化性。

RingMo框架

本文提出了一种遥感基础模型开发框架RingMo,旨在 以生成的自监督学习方式 从收集的数据集中 获得训练的基础模型。在本节中,首先从 统计信息、捕获源 和 场景分布方面介绍本文提出的数据集。然后简要介绍了应用 RingMo 训练方法的自监督学习。其中,将详细介绍 掩码策略、模型架构 和 重构目标。

数据集

 本节为遥感领域的自监督学习任务 收集了一个大规模数据集。该数据集全面覆盖了各种遥感场景,具有多源、多时相、多实例的特点。数据集的细节将在后续内容中从统计信息、采集来源、场景分布三个方面进行说明。 

  • 统计信息

数据集中的部分图像选自遥感领域的公开数据集,其他图像来自中国高分二号卫星。然后将不同平台捕获的图像裁剪为448 × 448像素,得到多种分辨率和场景的图像2,096,640幅数据集的复杂特征可以使基础模型学习到更鲁棒的特征表示,更适合和特定于遥感领域的下游任务。我们的数据集图像数量最多,数据集的分辨率变化范围非常大,从0.3m到30m不等。

  • 捕获来源

所利用的图像来自不同的传感器,具有不同的成像参数。与现有的其他遥感基础模型训练任务数据集相比,该数据集的图像同时来自航空和卫星平台:JL-1、GF-2、Ganfen satellites、GeoEye、WorldView、QuickBird、IKONOS、SPOT等。航空影像与卫星影像之间存在差异,使基础模型表现更好。卫星图像的拍摄角度往往是垂直的,而航空图像的拍摄角度则包含了倾斜的情况。航空图像的图像质量通常高于卫星图像,但卫星图像的覆盖范围比航空图像大。

  • 场景分布

利用的图像覆盖了亚洲、欧洲、北美、南美、非洲、大洋洲六大洲的不同场景。在不同的季节和时间拍摄的图像也包括在内。图像之间的光照情况也不同,如光照强度、拍摄时间造成的差异、阴影造成的阴影等。上述所有情况都在下游任务中 引入了类内变化和类间相似,这对获得高度泛化的模型 提出了严峻的挑战。因此,利用尽可能多的不同特征的图像 进行预训练过程,学习更好的统一的特征表示。

基础模型训练方法RingMo

RingMo 训练方法通过生成式自监督学习来学习遥感表示。这种建模是一种典型的自动编码方法,它从原始信号的部分观察中重建。为了避免丢失小物体的特征信息,本文设计了PIMask策略。给定输入图像,PIMask 实现 区域选择 和 掩码生成。本文方法有一个编码器,提取掩码图像的隐表示,然后用于重建掩码区域的原始信号。学习到的编码器应该对各种光学遥感下游任务有用。在这项工作中,主要考虑了两种经典的视觉 Transformer 架构:ViT和 Swin Transformer。重建目标 指定要预测的原始信号的形式,L1回归损失 用于计算重建结果与像素值的差异。所提出的框架架构如图所示。

PIMask策略

大多数MIM方法常用的掩蔽策略是随机掩码,如图4所示。随机选取一定比例的图像斑块,然后进行完全掩码。这种方法在自然图像中很有用。然而,在遥感图像的应用中存在一些问题。特殊的成像机制使得背景更复杂,目标尺寸更小随机掩码策略很容易忽略许多整体的小物体。如图 4 右侧红色块所示,随机掩模策略完全丢失了掩码patches中的小目标信息,这影响了基础模型重建小目标,增加了图像重建的难度。

因此,论文设计了一种新的掩码策略PIMask来解决这个问题。 如图 4 中左边的红色块所示,本文没有完全屏蔽图像块,而是 在屏蔽块中随机保留一些像素。采用这种掩码策略,可以有效地保留小目标的部分像素信息。就像图 4 中的蓝色块一样,本文增加了掩码块的数量以保持总掩码比率不变

此外,为了更好地利用这些保留像素,采用 多层卷积 实现 patch embedding。一些相关研究证明,将早期卷积层 添加到 vision transformer 中可以帮助模型更好地学习图像特征。具体来说,在卷积过程中,本文让卷积核只在每个patch内部进行计算,不破坏模型的mask约束。并且不同于传统的embedding结果,多层卷积后的所有token都有特征信息,进一步提高了encoder的学习效率。

PIMask 策略的数学计算过程如下所示。首先,本文将图像分割成许多不重叠的块的集合

本文以 被掩蔽的patch的比例α 和 patch的总数len(Cpatches) 作为输入,通过RTG(Randomly Label Generator)函数 确定每个patch的mask标签。因此,输出 mtag 是一个一维向量,其中包括patches的所有掩码标签。 

 

得到patchi的 mask标签mtag后,就可以判断patch是否被部分屏蔽了。如公式2所示,以mask inside ratio β 和 patch的维度 dim(patchi))为输入,通过RTG函数生成每个patch的mask张量PIMaski。如公式 3 所示,直接将 patchi 与 I - PIMaski相乘,其中 I 表示单位矩阵,× 表示逐元素相乘。

模型架构

编码器:Vision Transformer / Swin Transformer
解码器:仅用于图像重建的预训练。本文用线性层进行实验。

本文将 可见patches 和掩码patches 连同位置嵌入的投影一起馈送到编码器中。具体来说,ViT 接收转换后的token嵌入序列 作为输入。与 BERT类似,ViT 预先为嵌入的token序列 准备了一个可学习的类嵌入,旨在表示 Transformer 编码器输出的全局类。 ViT 使用标准的可学习一维位置嵌入。 Transformer 编码器包括 交替的多层感知器 和 多头自注意力层。在每个部分之前使用分层模块和残差连接。 ViT 中特定于视觉的归纳偏差小于卷积神经网络 (CNN) 中的归纳偏差。

对于 Swin Transformer,它首先将图像裁剪成没有重叠的块。每个patches 都被视为一个“token”,其特征被设置为原始像素信息的聚合。投影层用于将每个patches 标记为随机维度(表示为 C)。patches token 上使用了几个指定的 Swin Transformer 块。这些块一起生成具有与经典 CNN 相同分辨率的分层表示。 Swin Transformer 是通过将 Transformer 块中的标准多头自注意力更改为基于移动窗口的模块来构建的,其他层保持不变。因此,Swin Transformer 可以方便地替换以前方法中的骨干网络来完成多项任务。

重建目标

由于像素值在光学空间中是连续的,我们的方法通过回归预测被遮挡区域的原始像素来重建输入。损失函数计算重建图像和原始图像之间的 L1 距离

其中 x,y \in R^{3HW\times1} 表示原始像素值和重建像素值; Ω 是元素的数量; M 表示掩码像素的集合。与其他 MIM 方法类似,RingMo只计算遮盖区域的损失

实验 

遥感基础模型

        本文使用 RingMo 训练方法用本文收集的数据集训练遥感基础模型。为了验证遥感数据对基础模型的影响,本文使用大规模遥感数据以随机掩码策略 训练基础模型,并评估不同数据量和训练epochs 的影响。此外,为了验证本文提出的创新方法,使用 PIMask 来训练基础模型。随后的实验表明,所提出的RS基础模型框架可以有效地在遥感图像上实现,更重要的是,它可以有效地提高各种下游解释任务的性能,而无需复杂的操作。 

图 5:可视化了 使用 PIMask 策略重建基础模型的图像。它表明,在遥感数据集上预训练的生成式自监督模型基本上是 根据保留块的上下文信息 重建被掩蔽的区域。此外,如图5所示,这些重建结果与原始图像略有不同,表明 本文的模型 真正基于未掩蔽区域重建图像,而不是简单地记忆原始图像,并具有泛化能力


   在对基础模型进行预训练后,本文将特定任务的头 附加到预训练的主干上,并对下游任务的参数进行微调。本文在遥感图像解译中的常见任务上验证了本文的基础模型,例如场景分类、目标检测、语义分割和变化检测。实验结果验证了本文在构建的遥感数据集上预训练的生成式自监督模型的有效性。使用所提出的 RingMo 框架,可以直接在其他遥感解译任务中利用 pre-training-then-fine-tuning 的范式。 

遥感场景分类

遥感目标检测

遥感语义分割

遥感变化检测 

关键代码

ringmo.py

# https://github.com/comeony/RingMo/blob/master/ringmo_framework/arch/ringmo.py

# ============================================================================
"""ringmo of ringmo-framework"""
from mindspore import nn
from mindspore import ops as P
from mindspore import dtype as mstype

from ringmo_framework.loss.loss import L1Loss
from ringmo_framework.models.backbone.vit import Vit
from ringmo_framework.models.backbone.swin_transformer import SwinTransformer


class SwinTransformerForRingMo(SwinTransformer):
    """swim transformer for ringmo"""
    def __init__(self, **kwargs):
        super(SwinTransformerForRingMo, self).__init__(**kwargs)
        assert self.num_classes == 0
        dp = self.parallel_config.data_parallel
        self.reshape = P.Reshape()
        self.transpose = P.Transpose().shard(((dp, 1, 1),))
        self.add_pos = P.Add().shard(((dp, 1, 1), (1, 1, 1)))
        self.sub = P.Sub().shard(((), (dp, 1, 1)))
        self.multi = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
        self.hw = int(self.final_seq ** 0.5)

    def construct(self, x, mask):
        """construct of SwinTransformerForRingMo"""
        # pylint: disable=W0221
        x = self.multi(x, self.sub(1, mask))
        x = self.patch_embed(x)

        if self.ape:
            x = self.add_pos(x, self.absolute_pos_embed)

        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        x = self.transpose(x, (0, 2, 1))
        x = self.reshape(x, (x.shape[0], x.shape[1], self.hw, self.hw))
        return x

    def no_weight_decay(self):
        return super().no_weight_decay() | {'mask_token'}


class VisionTransformerForRingMo(Vit):
    """vision transformer for ringmo"""
    def __init__(self, **kwargs):
        super(VisionTransformerForRingMo, self).__init__(**kwargs)

        assert self.num_classes == 0
        dp = self.parallel_config.data_parallel
        self.reshape = P.Reshape()
        self.transpose = P.Transpose().shard(((dp, 1, 1),))
        self.add_pos = P.Add().shard(((dp, 1, 1), (1, 1, 1)))
        self.sub = P.Sub().shard(((), (dp, 1, 1)))
        self.multi = P.Mul().shard(((dp, 1, 1), (dp, 1, 1)))
        self.hw = int(self.num_patches ** 0.5)

        self.slice = P.Slice().shard(((dp, 1, 1),))

    def construct(self, x, mask):
        """construct of VisionTransformerForRingMo"""
        # pylint: disable=W0221
        x = self.multi(x, self.sub(1, mask))
        x = self.patch_embed(x)

        batch, seq, channel = x.shape
        cls_tokens = self.tile(self.cls_tokens, (batch, 1, 1))
        x = self.cat((cls_tokens, x))
        if self.pos_embed is not None:
            x = self.add_pos(x, self.pos_embed)

        x = self.dropout(x)

        if self.rel_pos_bias:
            rel_pos_bias = self.rel_pos_bias()
            x = self.encoder(x, self.encoder_input_mask, rel_pos_bias=rel_pos_bias)
        else:
            x = self.encoder(x, self.encoder_input_mask)
        x = self.norm(x)
        x = self.slice(x, (0, 1, 0), (batch, seq, channel))  # x = x[:, 1:]
        x = self.transpose(x, (0, 2, 1))
        x = self.reshape(x, (x.shape[0], x.shape[1], self.hw, self.hw))
        return x


class RingMo(nn.Cell):
    """RingMo"""
    def __init__(self, encoder, encoder_stride, use_lbp=False, parallel_config=None):
        super(RingMo, self).__init__()
        self.encoder = encoder
        self.encoder_stride = encoder_stride
        self.use_lbp = use_lbp
        if parallel_config:
            dp = parallel_config.data_parallel
        else:
            dp = 1

        self.decoder = nn.Conv2d(
            in_channels=self.encoder.num_features,
            out_channels=self.encoder_stride ** 2 * 3,
            kernel_size=1, has_bias=True, pad_mode='pad'
        )

        # encoder output -> [B,C,H,W]
        self.decoder.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
        self.decoder.bias_add.shard(((dp, 1, 1, 1), (1,)))

        self.decoder_lbp = nn.Conv2d(
            in_channels=self.encoder.num_features,
            out_channels=self.encoder_stride ** 2 * 3,
            kernel_size=1, has_bias=True, pad_mode='pad'
        )

        # encoder output -> [B,C,H,W]
        self.decoder_lbp.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
        self.decoder_lbp.bias_add.shard(((dp, 1, 1, 1), (1,)))

        self.pixelshuffle = P.DepthToSpace(self.encoder_stride).shard(((dp, 1, 1, 1),))
        self.in_chans = self.encoder.in_chans
        self.patch_size = self.encoder.patch_size
        self.l1_loss = L1Loss(reduction='none', parallel_config=parallel_config)

        self.expand_dim = P.ExpandDims().shard(((dp, 1, 1),))
        self.cast = P.Cast()
        self.div = P.Div().shard(((), ()))
        self.multi = P.Mul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))

        self.sum = P.ReduceSum().shard(((dp, 1, 1, 1),))
        self.add = P.Add().shard(((), ()))

    def ringmo_loss(self, x, x_rec, lbp=None, lbp_rec=None, mask=None):
        """ringmo loss"""
        x = self.cast(x, mstype.float32)
        x_rec = self.cast(x_rec, mstype.float32)
        mask = self.cast(mask, mstype.float32)
        loss_ori_recon = self.l1_loss(x, x_rec)
        loss_ori_mask = self.mean(loss_ori_recon, mask)
        loss_lbp_mask = 0.
        if self.use_lbp:
            loss_lbp_recon = self.l1_loss(lbp, lbp_rec)
            loss_lbp_mask = self.mean(loss_lbp_recon, mask)
        loss = self.add(loss_ori_mask, loss_lbp_mask)
        return loss

    def mean(self, loss, mask):
        mul_a = self.multi(loss, mask)
        div_a = self.sum(mul_a)
        sum_b = self.sum(mask)
        div_b = self.add(sum_b, 1e-5)
        loss_mask = self.div(div_a, div_b)
        loss_mask = self.div(loss_mask, self.in_chans)
        return loss_mask

    def _check_input(self, inputs):
        if not self.use_lbp:
            return inputs[0], None, inputs[1]

        return inputs[0], inputs[1], inputs[2]

    def construct(self, *inputs):
        """construct of RingMo"""
        x_in, lbp_in, mask_in = self._check_input(inputs)

        # x -> [B,L,C]
        z = self.encoder(x_in, mask_in)
        # z -> [B,C,H,W]
        x_rec = self.decoder(z)
        # self.summary_4d("decoder_conv2d", self.decoder.weight)
        # z -> [B,C,H,W]
        x_rec = self.pixelshuffle(x_rec)

        lbp_rec = None
        if lbp_in is not None:
            lbp_rec = self.decoder_lbp(z)
            lbp_rec = self.pixelshuffle(lbp_rec)

        sim_loss = self.ringmo_loss(x_in, x_rec, lbp_in, lbp_rec, mask_in)

        return sim_loss

    def no_weight_decay(self):
        if hasattr(self.encoder, 'no_weight_decay'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay()}
        return {}

    def no_weight_decay_keywords(self):
        if hasattr(self.encoder, 'no_weight_decay_keywords'):
            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
        return {}


def ringmo_vit_base_p16(**kwargs):
    encoder = VisionTransformerForRingMo(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
    return RingMo(encoder=encoder, encoder_stride=16)


def ringmo_vit_large_p16(**kwargs):
    encoder = VisionTransformerForRingMo(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
    return RingMo(encoder=encoder, encoder_stride=16)


def ringmo_swin_tiny_p4_w6(**kwargs):
    encoder = SwinTransformerForRingMo(
        image_size=192, patch_size=4, embed_dim=96, depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24], window_size=6, mlp_ratio=4, **kwargs)
    return RingMo(encoder=encoder, encoder_stride=32)


def ringmo_swin_tiny_p4_w7(**kwargs):
    encoder = SwinTransformerForRingMo(
        image_size=224, patch_size=4, embed_dim=96, depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24], window_size=6, mlp_ratio=4, **kwargs)
    return RingMo(encoder=encoder, encoder_stride=32)


def ringmo_swin_base_p4_w6(**kwargs):
    encoder = SwinTransformerForRingMo(
        image_size=192, patch_size=4, embed_dim=128, depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32], window_size=6, mlp_ratio=4, **kwargs)
    return RingMo(encoder=encoder, encoder_stride=32)


def ringmo_swin_base_p4_w7(**kwargs):
    encoder = SwinTransformerForRingMo(
        image_size=224, patch_size=4, embed_dim=128, depths=[2, 2, 18, 2],
        num_heads=[4, 8, 16, 32], window_size=7, mlp_ratio=4, **kwargs)
    return RingMo(encoder=encoder, encoder_stride=32)


def build_ringmo(config):
    """build ringmo"""
    model_type = config.model.backbone
    if model_type == 'swin':
        encoder = SwinTransformerForRingMo(
            parallel_config=config.parallel_config,
            moe_config=config.moe_config,
            batch_size=config.train_config.batch_size * config.device_num
            if config.parallel.parallel_mode == "semi_auto_parallel" else config.train_config.batch_size,
            image_size=config.train_config.image_size,
            patch_size=config.model.patch_size,
            in_chans=config.model.in_chans,
            num_classes=0,
            embed_dim=config.model.embed_dim,
            depths=config.model.depth,
            num_heads=config.model.num_heads,
            window_size=config.model.window_size,
            mlp_ratio=config.model.mlp_ratio,
            qkv_bias=config.model.qkv_bias,
            qk_scale=config.model.qk_scale,
            drop_rate=config.model.drop_rate,
            drop_path_rate=config.model.drop_path_rate,
            ape=config.model.ape,
            patch_norm=config.model.patch_norm,
            patch_type=config.model.patch_type)
        encoder_stride = 32
    elif model_type == 'vit':
        encoder = VisionTransformerForRingMo(
            parallel_config=config.parallel_config,
            moe_config=config.moe_config,
            batch_size=config.train_config.batch_size * config.device_num
            if config.parallel.parallel_mode == "semi_auto_parallel" else config.train_config.batch_size,
            image_size=config.train_config.image_size,
            patch_size=config.model.patch_size,
            in_chans=config.model.in_chans,
            num_classes=0,
            embed_dim=config.model.embed_dim,
            depth=config.model.depth,
            num_heads=config.model.num_heads,
            mlp_ratio=config.model.mlp_ratio,
            drop_rate=config.model.drop_rate,
            drop_path_rate=config.model.drop_path_rate,
            use_abs_pos_emb=config.model.use_abs_pos_emb,
            init_values=config.model.init_values,
            use_rel_pos_bias=config.model.use_rel_pos_bias,
            use_shared_rel_pos_bias=config.model.use_shared_rel_pos_bias,
            patch_type=config.model.patch_type)
        encoder_stride = 16
    else:
        raise NotImplementedError(f"Unknown pre-train model: {model_type}")

    model = RingMo(encoder=encoder, encoder_stride=encoder_stride, parallel_config=config.parallel_config,
                   use_lbp=config.model.use_lbp)

    return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1133233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决:vscode和jupyter远程连接无法创建、删除文件的问题(permission denied)

目录 问题:vscode和jupyter远程连接服务器无法创建、删除文件的问题原因:代码文件的权限不够解决方法:1.ls -l查看目录所在组,权限2.chown修改拥有者和所在组 问题:vscode和jupyter远程连接服务器无法创建、删除文件的…

【兔子王赠书第3期】《案例学Python(进阶篇)》

文章目录 前言推荐图书本书特色本书目录本书样章本书读者对象粉丝福利丨评论免费赠书尾声 前言 随着人工智能和大数据的蓬勃发展,Python将会得到越来越多开发者的喜爱和应用。因为Python语法简单,学习速度快,大家可以用更短的时间掌握这门语…

Spring学习笔记—JDK动态代理

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Spring专栏 ✨特色专栏: M…

Unity中Shader的ShaderLOD

文章目录 前言一、ShaderLOD的使用步骤1、ShaderLOD使用在不同的SubShader中,用于区分SubShader所对应的配置2、在 C# 中使用 Shader.globalMaximumLOD 赋值来选择不同的 SubShader,以达到修改配置对应Shader的效果3、在设置LOD时,是需要和程序讨论统一 …

WebGL笔记:矩阵的变换之平移的实现

矩阵的变换 变换 变换有三种状态:平移、旋转、缩放。当我们变换一个图形时,实际上就是在移动这个图形的所有顶点。解释 webgl 要绘图的话,它是先定顶点的,就比如说我要画个三角形,那它会先把这三角形的三个顶点定出来…

为什么需要山洪灾害监测预警系统?

在山洪高发地区,安装山洪灾害监测预警系统能够通过实时监测,预警山洪信息,对于保障我们的生命财产安全具有重要意义。 监测山洪不仅需要对山体进行监测,还要监测降雨量以及水位上升情况。山洪灾害监测预警系统是由GNSS监测站和水…

linux安装node(含npm命令) 并配置淘宝镜像源

1. 下载压缩包 wget https://nodejs.org/dist/v16.14.0/node-v16.14.0-linux-x64.tar.xz # node14 https://nodejs.org/dist/v14.15.4/node-v14.15.4-linux-x64.tar.xz # 推荐将压缩包放置到/usr/local/node文件夹中安装 mv node-v16.14.0-linux-x64.tar.xz /usr/local/node …

LeetCode217——存在重复元素

LeetCode217——存在重复元素 1.题目描述: 给你一个整数数组 nums 。如果任一值在数组中出现 至少两次 ,返回 true ;如果数组中每个元素互不相同,返回 false 。 2.Result01(暴力解) public static boolean containsDuplicate(in…

SRAM与DRAM的区别

目录 SRAM 特点 应用场景 DRAM 特点 应用场景 SRAM和DRAM的区别 SRAM SRAM(静态随机存取存储器)是一种用于存储和检索数据的类型的计算机内存。SRAM的存储单元通过触发器(flip-flop)实现,它们可以保持数据的状态…

语雀崩溃7个小时的原因是什么??

1 语雀是什么 语雀是蚂蚁集团旗下的在线文档编辑与协同工具,使用了“结构化知识库管理”,形式上类似书籍的目录。用户量在千万级别,是非常强大的。身边有不少朋友是付费会员,有许多公司也付费在使用语雀作为知识库进行文档的存储…

通过社工进网站后台的渗透测试

目录 通过社工进网站后台 0X0 开始: 0X1 获取icp备案企业: 0X2 通过备案获取姓名: 0X3 通过姓名获取手机号: 0X4 弱口令进后台: 0X5 总结 通过社工进网站后台 记录一次通过简单社工获取信息后进入后台的经过。…

近似熵的计算

我们计算两个函数,一个是henon,另外一个是TriMap: 代码: 构造henon函数: function [x,y]=Henon(x0,y0,a,b,M) M=M+10000; x = zeros(1,M+1); y = zeros(1,M+1); x(1)=x0; y(1)=y0; for i = 1:Mx(i+1) = 1+y(i)-a*x(i)^2;y(i+1) = b*x(i); endx=x(10001:M); y=y(10001:M)…

【人工智能Ⅰ】实验1:谓词表示法与产生式知识表示

实验1 谓词表示法与产生式知识表示 一、实验目的 1、熟悉谓词逻辑表示法; 2、理解和掌握产生式知识表示方法,实现产生式系统的规则库。 二、实验内容 要求通过C/C/python语言编程实现: 1、猴子摘香蕉问题 2、动物识别系统 &#xff08…

C# 基于腾讯云人脸核身和百度云证件识别技术相结合的 API 实现

目录 腾讯云人脸核身技术 Craneoffice.net 采用的识别方式 1、活体人脸核身(权威库): 2、活体人脸比对: 3、照片人脸核身(权威库): 调用成本 百度云身份证识别 调用成本 相关结合点 核心代码 实现调用人脸核身API的示例 实现调用身…

css 三栏布局的实现?

目录 前言 用法 代码 理解 高质量图片 1. 左侧栏 - 导航菜单 2. 中间栏 - 主要内容 3. 右侧栏 - 小部件和广告 布局的响应式设计 三栏布局在前端页面设计中是一个常见的布局方式,通常包含左侧、中间和右侧三个部分。这种布局方式在多种场景中都很受欢迎&am…

企业安全—DevSecOps概述详情

0x00 前言 SDL存在的问题在于体量过于庞大,不利于快速进行适配和进行,所以就有了DevSecOps,实际上是因为敏捷开发也就是DevOps的推进,并且坐上了云服务模式的火车,所以这一系列的东西都开始普及。DevSecOps作为DevOps…

龙迅LT9211 是MIPI/TTL/2 PORT LVDS互转的一颗高性能芯片,支持车规级

龙迅LT9211 描述: Lontium LT9211是一个高性能转换器,可以在MIPI DSI/CSI-2/双端口LVDS和TTL之间互转换,除了24位TTL到24位TTL与同步和DE。LT9211反序列化输入的MIPI/LVDS/TTL视频数据,解码数据包,并将格式化的视频数…

2006-2021年上市公司社会责任报告基本信息数据

2006-2021年上市公司社会责任报告基本信息表数据 1、时间:2006-2021年 2、来源为上市公司社会责任BG 3、指标:股票代码、统计截止日期、GG日期、所属省份、行业代码、行业名称、纳税总额、每股社会贡献值、社会捐赠额、是否经第三方机构审验、审验机构…

Python Opencv实践 - 入门使用Tesseract识别图片中的文字

做车牌识别项目前试一试tesseract识别中文。tesseract的安装使用请参考: Python OCR工具pytesseract详解 - 知乎pytesseract是基于Python的OCR工具, 底层使用的是Google的Tesseract-OCR 引擎,支持识别图片中的文字,支持jpeg, png…

Android 中如何使用 App Links

1. 简介 什么是 App Links呢?App Links 是 Android 6.0 (API 级别23) 引入的新功能,它是基于 DeepLinking,允许应用自动处理网站的 URL,而无需提示用户启动相应的应用。 例如:如果你在手机浏览器中输入了某个网站&am…