Vision Transformer (ViT)及各种变体

news2024/12/23 14:16:24

目录

0.Vision Transformer介绍

1.ViT 模型架构

1.1 Linear Projection of Flattened Patches

1.2 Transformer Encoder

1.3 MLP Head

1.4 ViT架构图

1.5 model scaling

2.Hybrid ViT

4.其他Vision Transformer变体

5.Vit代码

6.参考博文


0.Vision Transformer介绍

2017年Vaswani等人在发表的《Attention Is All You Need》中提出Transformer模型,是第一个完全依靠自注意力计算其输入和输出的模型,从此在自然语言处理领域大获成功。

2021年Dosovitskiy等人将注意力机制的思想应用于计算机视觉领域,提出了Vision Transformer(ViT )模块。在大规模数据集的支持下,ViT模型可以达到与CNNs模型相当的精度,如下图所示为ViT的不同版本与ResNet和EfficientNet在不同数据集下的准确率对比。

论文名称:《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》 

论文地址:https://arxiv.org/abs/2010.11929

1.ViT 模型架构

作者在文中给出ViT模型如下架构图,其中主要有三个部分组成:

1)Linear Projection of Flattened Patches(Embedding层,将子图映射为向量);

2)Transformer Encoder(编码层,对输入的信息进行计算学习);

3)MLP Head(用于分类的层结构);

1.1 Linear Projection of Flattened Patches

在标准的Transformer模块中,输入是Token(向量)序列,即二维矩阵[num_token, token_dim]。而图像数据格式为[H, W, C]的三维数据,因此需要将图像数据经过Embedding层进行变换,转换为Transformer模块能够输入的数据类型。

 以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到( 224 / 16 ) * ( 224 / 16 ) =196个Patches。接着通过线性映射(Linear Projection)将每个Patch映射到一维向量中。

Linear Projection:使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现线性映射,这个卷积操作产生shape变化为[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平(Flattened Patches)即可,shape变化为([14, 14, 768] -> [196, 768]),此时正好变成了一个二维矩阵,符合Transformer输入的需求。其中,196表征的是patches的数量,将每个Patche数据shape为[16, 16, 3]通过卷积映射得到一个长度为768的向量(后面都直接称为token)。

在输入Transformer Encoder之前注意需要加上[class]token以及Position Embedding

1)[class]token:原文中,作者参考了Bert模型,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。

2)Position Embedding:Position Embedding采用的是一个可训练的一维位置编码(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。

自注意力是所有的元素两两之间去做交互,所以是没有顺序的,但是图片是一个整体,子图patches是有自己的顺序的,在空间位置上是相关的,所以要给patch embedding加上了positional embedding这样一组位置参数,让模型自己去学习patches之间的空间位置相关性。

对于Position Embedding作者也有做一系列对比试验,虽然没有位置嵌入的模型和有位置嵌入的模型的性能有很大差距,但是不同的位置信息编码方式之间几乎没有差别,由于Transformer编码器工作在patch级别的输入上,相对于pixel级别,如何编码空间信息的差异不太重要,结果如下所示:

1.2 Transformer Encoder

Transformer Encoder 主要由以下几部分组成:

1)Layer Norm:Transformer中使用Layer Normalization进行归一化操作,能够加快训练的速度,提高训练的稳定性;

2)Multi-Head Attention:与Transformer中的一样,详见:Transformer-《Attention Is All You Need》_HM-hhxx!的博客-CSDN博客

3)Dropout/DropPath:在原论文的代码中是直接使用的Dropout层,在但实现的代码中使用的是DropPath;

4)MLP Block:全连接+GELU激活函数+Dropout组成, 第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768],MLPBlock结构如下图所示

 

Encoder结构如下图所示,左侧为实际结构,右侧为论文中结构,省去了Dropout/DropPath层,其实就是重复堆叠如下图所示的Encoder Block L次,MLP Block结构如上图所示:

   

1.3 MLP Head

在经过Transformer Encoder时,输入的shape和输出的shape保持不变。在论文中,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。在Transformer Encoder后还有一个Layer Norm,结构图中并没有给出,如下图所示:

这里我们只是需要Transformer Encoder中的分类信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768],因为self-attention计算全局信息的特征,这个[class]token其中已经融合了其他token的信息。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。

1.4 ViT架构图

1.5 model scaling

论文中,作者根据Bert模型设计了‘Base’和‘Large’模型,并增加了一个‘Huge’模型。并对名称进行了解释,例如ViT-L / 16表示具有16 × 16输入patch size的" Large "变体。需要注意的是Transformer的序列长度与patch size的平方成反比,因此patch size较小的模型计算开销较大。因此在ViT源码中,除了patch size为16×16的,还有32×32的。

下表中的Layers就是Transformer Encoder中重复堆叠Encoder Block的次数,Hidden Size就是对应通过Embedding层后每个token的dim(向量的长度),MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍),Heads代表Transformer中Multi-Head Attention的heads数。

2.Hybrid ViT

论文在4.1章节中介绍模型的缩放后,对模型的混合模型进行了介绍,即将传统的CNN特征提取和Transformer进行结合。文中将以ResNet50作为特征提取器的混合模型,但这里的R50的卷积层采用的StdConv2d不是传统的Conv2d,然后将所有的BatchNorm层替换成GroupNorm层。在原Resnet50网络中,stage1重复堆叠3次,stage2重复堆叠4次,stage3重复堆叠6次,stage4重复堆叠3次,但在这里的R50中,把stage4中的3个Block移至stage3中,所以stage3中共重复堆叠9次。

通过R50 Backbone进行特征提取后,得到的特征矩阵shape是[14, 14, 1024],接着再输入Patch Embedding层,注意Patch Embedding中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和前面的ViT结构一样。

 下表是论文中对比ViT、ResNet及R-ViT模型的效果,通过对比发现,在训练epoch较少时hybrid模型效果优于ViT,但在epoch增加时ViT效果更好。

4.其他Vision Transformer变体

在论文《A review of convolutional neural network architectures and their optimizations》中,作者指出一些研究表明ViT模型与CNN相比缺乏可优化性,这是由于ViT缺乏空间归纳偏差。因此,在ViT模型中使用卷积策略来削弱这种偏差,可以提高其稳定性和性能。并列出如下Vit变体:

1)LeVit(2021):映入主义偏向的思想来结合位置信息。

论文名称:《LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference》

论文地址:https://arxiv.org/abs/2104.01136

2)PVT(2021):金字塔vit,将transformer融入到CNNs中,在图像的密集分区上进行训练,以实现输出高分辨率。克服了transformer对于密集预测任务的缺点。

 论文名称:《Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions》

论文地址:https://arxiv.org/abs/2102.12122

3)T2T-ViT(2021):通过递归地将相邻的token聚合为一个token,图像最终被逐步结构化为token;提供了具有更深更窄的高效backbone;将图像结构化为token。

论文名称:《Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet》

论文地址:https://arxiv.org/abs/2101.11986

4)MobileVit(2021):将mobilenet v2连接vit,效果明显优于其他轻量级网络。结合逆残差(inverse residual)和Vit。

论文名称:《MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer》

论文地址:https://arxiv.org/abs/2110.02178

5)VTs(2021):Visual Transformers,通过词法分析将特征图转换为一系列视觉token,然后通过投影仪( Wu等2020b)将处理后的视觉令牌投影到原始地图和原始图像上。实验表明,VTs在使用较少的FLOPs和参数的情况下,将ImageNet top - 1的ResNet精度提高了4.6 ~ 7个点。通过映射将图像输入transformer。

论文名称:《Visual Transformers: Token-based Image Representation and Processing for Computer Vision》

论文地址:https://arxiv.org/abs/2006.03677

6)Conformer(2021):结合CNN与Transformer的优点,并行的通过特征耦合单元(Feature Coupling Unit,FCU)与每个阶段的局部和全局特征进行交互,从而兼具CNN和Transformer的优点。将CNNs和Transformer模块并行组合。

论文名称:《Conformer: Local Features Coupling Global Representations for Visual Recognition》

论文地址:https://arxiv.org/abs/2105.03889

7)BoTNet(2021):通过在ResNet的最后三个瓶颈块中用全局自注意力替换空间卷积显著改善了基线。用全局自注意力代替空间卷积。

论文名称:《Bottleneck Transformers for Visual Recognitiont》

论文地址:CVPR 2021 Open Access Repository

8)CoAtNets(2021):并认为CNNs的深层结构和注意力机制可以通过简单的相对注意力联系起来。此外还认为叠加卷积层和Transformer encoder可以产生好的效果。堆叠卷积层和Transformer编码器。

论文名称:《CoAtNet: Marrying Convolution and Attention for All Data Sizes》

论文地址:https://arxiv.org/abs/2106.04803

9)Swin Transformer(2021):通过移动窗口将自注意力计算限制在不重叠的局部窗口,同时允许跨窗口连接,在Imagenet - 1K上达到了87.3 %的准确率。将自注意力限制在不重叠的局部窗口并将其进行连接。

论文名称:《Swin Transformer: Hierarchical Vision Transformer using Shifted Windows》

论文地址:https://arxiv.org/abs/2103.14030

5.Vit代码

原始Vit代码地址:

pytorch-image-models/vision_transformer.py at main · huggingface/pytorch-image-models · GitHub

model.py:

"""
original code from rwightman:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
from functools import partial
from collections import OrderedDict

import torch
import torch.nn as nn


def drop_path(x, drop_prob: float = 0., training: bool = False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # work with diff dim tensors, not just 2D ConvNets
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + \
        torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],
                          img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(
            in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)  # [B,196,768]
        x = self.norm(x)
        return x


class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5  # 开根号操作
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
                                  self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # make torchscript happy (cannot use tensor as tuple)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # q*k的转置*缩放因子,缩放因子就是根号下dk
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# encoder block


class Block(nn.Module):
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(
            drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(
            img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(
            1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(
            1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[
                      i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(
            self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(
                self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(
                x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        # 图片分类没走if这块
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x


def _init_vit_weights(m):
    """
    ViT weight initialization
    :param m: module
    """
    if isinstance(m, nn.Linear):
        nn.init.trunc_normal_(m.weight, std=.01)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode="fan_out")
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        nn.init.zeros_(m.bias)
        nn.init.ones_(m.weight)


def vit_base_patch16_224(num_classes: int = 1000):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA  密码: eu9f
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_base_patch32_224(num_classes: int = 1000):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg  密码: s5hl
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_large_patch16_224(num_classes: int = 1000):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ  密码: qqt8
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=None,
                              num_classes=num_classes)
    return model


def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    weights ported from official Google JAX impl:
    https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes)
    return model


def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    NOTE: converted weights not currently available, too large for github release hosting.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=14,
                              embed_dim=1280,
                              depth=32,
                              num_heads=16,
                              representation_size=1280 if has_logits else None,
                              num_classes=num_classes)
    return model

6.参考博文

1.深度学习之图像分类(十二): Vision Transformer - 魔法学院小学弟

2. Vision Transformer详解_太阳花的小绿豆的博客-CSDN博客

3.论文《A review of convolutional neural network architectures and their optimizations》 A review of convolutional neural network architectures and their optimizations | SpringerLink

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/592858.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

样本不平衡的解决办法

背景 Focal loss是最初由何恺明提出的,最初用于图像领域解决数据不平衡造成的模型性能问题。本文试图从交叉熵损失函数出发,分析数据不平衡问题,focal loss与交叉熵损失函数的对比,给出focal loss有效性的解释。 交叉熵损失函数…

危机先知:TOOM舆情监控助力风险预警

随着社交媒体和互联网的普及,公众的声音在网络上如洪水般涌现。这些声音传递着情绪、态度和观点,对个人、组织甚至整个社会产生着巨大影响。因此,舆情监控成为了一个不可或缺的工具,帮助企业和组织及时了解公众对其品牌、产品或服…

决策树基本理论知识

目录 1、决策树是一种树模型 2、决策树的训练与测试 3、信息增益(ID3) 3.1、衡量标准-熵 3.2、决策树构造实例 4、决策树算法 ​5、连续值离散化 6、预剪枝 1、决策树是一种树模型: (1)、从根结点开始一步步走…

【C++】哈希表封装unordered系列

文章目录 前言一、哈希表的封装总结 前言 在看本篇文章前大家尽量拿出上一篇文章的代码跟着一步步实现,否则很容易引出大量模板错误而无法解决。 一、哈希表的封装 首先我们要解决映射的问题,我们目前的代码只能映射整形,那么如何支撑浮点数…

Java使用zxing.jar生成二维码

由于时代科学的进步,二维码已经和我们的生活密不可分,在开发过程中往往会涉及到和二维码相关的开发,今天这篇文章就教会大家如何使用zxing.jar包生成二维码 下面这个就是百度上面自带的一个生成二维码的功能,那他是怎么实现这个功…

计算机组成原理与体系结构概述

目录 一、计算机的发展 二、计算机的硬件系统 三、硬件的工作原理 四、计算机系统的层次结构 五、计算机的性能指标 一、计算机的发展 第一代计算机:电子管计算机 第一台电子计算机:ENIAC(1946) 设计目的:计算导弹…

平板触控笔哪种好?主动式电容笔推荐

现在市面上的电容笔分为主动式和被动式电容笔,很多小伙伴都分不清主动式和被动式电容笔的区别。今天给大家介绍一下这两款电容笔的区别。给大家分享几款好用的平替电容笔。 一、主动式电容笔和被动式电容笔的区别: 1.主动式电容笔: 主动式电…

数据结构与算法(九)

红黑树复习 图 图,是一种数据结构 集合只有同属于一个集合;线性结构存在一对一的关系,树形结构一对多的关系,图形结构,多对多的关系。 微信中:许多的用户组成了一个多对多的朋友关系网,这个关…

【C语言】变量

🚩 WRITE IN FRONT 🚩 🔎 介绍:"謓泽"正在路上朝着"攻城狮"方向"前进四" 🔎🏅 荣誉:2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2022博客之星T…

【机器学习】分类问题和逻辑(Logistic)回归算法详解

在阅读本文前,请确保你已经掌握代价函数、假设函数等常用机器学习术语,最好已经学习线性回归算法,前情提要可参考https://blog.csdn.net/weixin_45434953/article/details/130593910 分类问题是十分广泛的一个问题,其代表问题是&…

Android studio 环境安装

1. Java JDK安装 https://download.oracle.com/java/17/latest/jdk-17_windows-x64_bin.exe 下载jdk-17 并安装 安装完成后设置环境变量 #新增环境变量JAVA_HOME C:\Program Files\Java\jdk-17#Path 环境变量添加 %JAVA_HOME%\bin %JAVA_HOME%\jdk\bin#新增环境变量CLASSPAT…

HEVC量化编码介绍

介绍 ● 视频编码中,残差信号经过DCT,变换系数具有较大动态范围,因此对变换系数量化可以有效减小信号取值空间,获得更好的压缩效果; ● 多对一映射机制,所以不可避免的引入失真,这是视频编码中…

Spring(三)对bean的详解

一、引入外部属性文件 首先我们将依赖进行导入&#xff1a; <!--MySQL驱动--><dependency><groupId>mysql</groupId><artifactId>mysql-connector-java</artifactId><version>8.0.22</version></dependency><!--数据…

idea连接Linux服务器

一、 介绍 配置idea的ssh会话和sftp可以实现对linux远程服务器的访问和文件上传下载&#xff0c;是替代Xshell的理想方式。这样我们就能在idea里面编写文件并轻松的将文件上传到linux服务器中。而且还能远程编辑linux服务器上的文件。掌握并熟练使用&#xff0c;能够大大提高我…

烂怂if-else代码优化方案 | 京东云技术团队

0.问题概述 代码可读性是衡量代码质量的重要标准&#xff0c;可读性也是可维护性、可扩展性的保证&#xff0c;因为代码是连接程序员和机器的中间桥梁&#xff0c;要对双边友好。Quora 上有一个帖子&#xff1a; “What are some of the most basic things every programmer s…

日本医疗保健和健康管理公司【Zerospo】申请纳斯达克IPO上市

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 猛兽财经获悉&#xff0c;来自日本的医疗保健和健康管理公司【Zerospo】&#xff0c;近期已向美国证券交易委员会&#xff08;SEC&#xff09;提交招股书&#xff0c;申请在纳斯达克IPO上市&#xff0c;股票代码为&#xff…

感谢海洋一所陈老师用Pospac MMS解算pospac数据及GNSS验潮

非常感谢海洋一所陈老师 帮忙用Pospac MMS解算博主的pospa从数据。解算的结果txt文件大小有2个G&#xff0c;令人非常吃惊&#xff0c;因为原始数据的时长不到1天&#xff0c;打开文件才知道每行位置数据的间隔时间是5ms&#xff0c;5ms正是惯导数据的采样频率。 用抽稀软件按…

短视频矩阵系统源码-开源开发php语言搭建

短视频矩阵系统源码---------- php源码是什么&#xff1f; PHP源码指的就是PHP源代码&#xff0c;源代码是用特定编程语言编写的人类可读文本&#xff0c;源代码的目标是为可以转换为机器语言的计算机设置准确的规则和规范。因此&#xff0c;源代码是程序和网站的基础。 PHP…

【数据结构】插入排序详细图解(一看就懂)

&#x1f4af; 博客内容&#xff1a;【数据结构】插入排序详细图解&#xff08;一看就懂&#xff09; &#x1f600; 作  者&#xff1a;陈大大陈 &#x1f989;所属专栏&#xff1a;数据结构笔记 &#x1f680; 个人简介&#xff1a;一个正在努力学技术的准前端&#xff0c;…

ARM-FS6818-点亮LED灯

点亮LED灯 1.开发板介绍 2.cpu控制硬件原理 六大指令里边&#xff0c;只有内存访问指令能访问cpu之外的内容。那cpu如何控制硬件&#xff1f; *load/store指令-->操作4G内存 任何一个芯片都有一个地址映射表。告诉地址空间是如何映射的&#xff0c;便于我们找到对应的硬件地…