【论文及代码详解】BEIT: BERT Pre-Training of Image Transformers

news2024/9/20 14:56:02

记录下论文《BEIT: BERT Pre-Training of Image Transformers》,这是一篇将Transformer应用于图像领域,并使用自监督方法进行参数初始化的文章。

论文链接

整体概要

在这里插入图片描述

由于网络整体流程图没有标注好模型的运行过程,结合论文的描述:第一阶段为训练一个变分自动编码器,即图中的 Encoder和Decoder部分,然后保留自动编码器的Encoder部分,即图中的Tokenizer,用于生成文中提及的Token;第二个阶段为自监督训练过程,将图像分块,然后进行一定比例的随机mask(文中说遮挡40%的图像内容),然后将经过遮挡处理的图像送入Transformer进行学习,学习的目标就是完整图像送入Tokenizer得到的特征图,即让Transformer拥有一定的特征学习能力(本质上是得到具有一定学习能力的Transformer模型,即相当于对模型参数进行了合理的初始化);第三个阶段会将经过预训练后的Transformer模型进行下游任务的学习(图像分类、图像分割等),在此过程中,仅用添加下游任务相关的网络头部即可,也就是仅微调网络头部参数。

方法产生原因及存在的问题

1 文中没有直接使用Tranformer对masked图像的像素进行重构/恢复,作者认为这将会另Transformer模型更倾向关注短距离相关性(即过度关注局部特征,限制全局特征的学习)或高频去噪能力(因为模型中经过掩码后,与周围的图像形成强烈的反差或明显边缘,类似加了高频信息)。而我们一般使用深度神经网络模型一般是希望模型能够学习到抽象的特征,而不是针对某种特定任务的模型,所以这里使用自动编码器学习到的特征作为目标,训练Transformer,让其在输入有遮挡的情况下,仍然能学习到目标特征,一定程度上可认为经过预训练后的Transformer为更强大的编码器,能够对输入生成抽象的高维特征。
2 个人认为存在的问题在于文中在自监督预训练过程中,使用了随机遮挡的方法,文中提到是大约40%的图像内容被遮挡,作者并未讨论为什么大约40%内容被遮挡?此外作者在预训练模型过程中,输入端使用了一个特殊的Token [S] 并未解释该Token的用途,可能是进一步扩展模型学习的余量?

官方关键代码

官方代码链接,可以使用 git 或 svn 进行下载,比较快。
在这里插入图片描述
在这里插入图片描述
这个代码里其实看起来有些乱,但是文件命名整体上还相对较规范,可以说明一些问题。下面将结合和整体的模型相关的关键代码来说一下。

1 modeling_pretrain.py

这个文件中基本包含了模型相关的代码,整体上可以辅助我们理解模型的整体运行流程。整体上包含:PatchEmbed -> Relative position encoding -> Transformer Block

import math
import torch
import torch.nn as nn
from functools import partial

from modeling_finetune import Block, _cfg, PatchEmbed, RelativePositionBias
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_ as __call_trunc_normal_


def trunc_normal_(tensor, mean=0., std=1.):
    __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)


__all__ = [
    'beit_base_patch16_224_8k_vocab', 
    'beit_large_patch16_224_8k_vocab', 
]


class VisionTransformerForMaskedImageModeling(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,
                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, init_std=0.02, **kwargs):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        if use_abs_pos_emb:
            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        else:
            self.pos_embed = None
        self.pos_drop = nn.Dropout(p=drop_rate)

        if use_shared_rel_pos_bias:
            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
        else:
            self.rel_pos_bias = None

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
                attn_head_dim=attn_head_dim,
            )
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        self.init_std = init_std
        self.lm_head = nn.Linear(embed_dim, vocab_size)

        if self.pos_embed is not None:
            trunc_normal_(self.pos_embed, std=self.init_std)
        trunc_normal_(self.cls_token, std=self.init_std)
        trunc_normal_(self.mask_token, std=self.init_std)
        trunc_normal_(self.lm_head.weight, std=self.init_std)
        self.apply(self._init_weights)
        self.fix_init_weight()

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=self.init_std)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            trunc_normal_(m.weight, std=self.init_std)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def get_num_layers(self):
        return len(self.blocks)

    def forward_features(self, x, bool_masked_pos):
        x = self.patch_embed(x, bool_masked_pos=bool_masked_pos)
        batch_size, seq_len, _ = x.size()

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        mask_token = self.mask_token.expand(batch_size, seq_len, -1)

        # replace the masked visual tokens by mask_token
        w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
        x = x * (1 - w) + mask_token * w

        x = torch.cat((cls_tokens, x), dim=1)
        if self.pos_embed is not None:
            x = x + self.pos_embed
        x = self.pos_drop(x)

        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
        for blk in self.blocks:
            x = blk(x, rel_pos_bias=rel_pos_bias)

        return self.norm(x)

    def forward(self, x, bool_masked_pos, return_all_tokens=False):
        x = self.forward_features(x, bool_masked_pos=bool_masked_pos)
        x = x[:, 1:]
        if return_all_tokens:
            return self.lm_head(x)
        else:
            # return the masked tokens
            return self.lm_head(x[bool_masked_pos])


@register_model
def beit_base_patch16_224_8k_vocab(pretrained=False, **kwargs):
    model = VisionTransformerForMaskedImageModeling(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def beit_large_patch16_224_8k_vocab(pretrained=False, **kwargs):
    model = VisionTransformerForMaskedImageModeling(
        patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.load(
            kwargs["init_ckpt"], map_location="cpu"
        )
        model.load_state_dict(checkpoint["model"])
    return model

2 PatchEmbed 方法,该方法在 modeling_finetune.py 中。从代码中可以看出,假设输入为224x224的话,每个patch的尺寸为 16x16,总共包含的 num_patches 为 14x14,最后模型通过一层核大小为 16,步长为16的卷积得到维度为 (768,14,14)的嵌入层,即将每个 16x16的patch映射为了14x14的768维向量。
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x, **kwargs):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x
3 RelativePositionBias相对位置编码,这个相对位置编码与SwinTransformer中的不太一样,其实就是比Swin中的相符位置编码在长和宽上增加了一维数据,即文中提及的 特殊Token [S] 的表示。
class RelativePositionBias(nn.Module):

    def __init__(self, window_size, num_heads):
        super().__init__()
        self.window_size = window_size
        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
        # cls to token & token 2 cls & cls to cls

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = \
            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        relative_position_index[0, 0:] = self.num_relative_distance - 3
        relative_position_index[0:, 0] = self.num_relative_distance - 2
        relative_position_index[0, 0] = self.num_relative_distance - 1

        self.register_buffer("relative_position_index", relative_position_index)

        # trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self):
        relative_position_bias = \
            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
                self.window_size[0] * self.window_size[1] + 1,
                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
4 代码中包含有Mask的操作,我们可以看下 masking_generator.py 这段代码,下面这段代码其实就是文中关于Mask部分的描述,文中说的 Blockwise mask 通过下面的算法伪代码,可以发现,该操作有最低的块数限制:16-0.4*N 块。

结合说下Mask的实现逻辑:首先确定 num_patchs 为原图划分patch后的数量,然后设置最小块数min_num_patches和最大块数max_mask_patches;然后从这个区间内随机获取需要mask的块数,当然在变化之前,还对随机选取得块数(target_area)进行了横纵比的变化,得到mask的 h和w,最后在原图中对 h,w的区域进行 mask 。此外代码中还对 mask 操作进行了限制,当目标区域全部mask为1后,就循环结束,代码中进行了10次的循环尝试,只要该过程中有mask则退出 mask的生成。
Mask代码是在数据集处理的部分调用的

在这里插入图片描述

import random
import math
import numpy as np
class MaskingGenerator:
    def __init__(
            self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None,
            min_aspect=0.3, max_aspect=None):
        if not isinstance(input_size, tuple):
            input_size = (input_size, ) * 2
        self.height, self.width = input_size

        self.num_patches = self.height * self.width
        self.num_masking_patches = num_masking_patches

        self.min_num_patches = min_num_patches
        self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches

        max_aspect = max_aspect or 1 / min_aspect
        self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))

    def __repr__(self):
        repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
            self.height, self.width, self.min_num_patches, self.max_num_patches,
            self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
        return repr_str

    def get_shape(self):
        return self.height, self.width

    def _mask(self, mask, max_mask_patches):
        delta = 0
        for attempt in range(10):
            target_area = random.uniform(self.min_num_patches, max_mask_patches)
            aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))
            if w < self.width and h < self.height:
                top = random.randint(0, self.height - h)
                left = random.randint(0, self.width - w)

                num_masked = mask[top: top + h, left: left + w].sum()
                # Overlap
                if 0 < h * w - num_masked <= max_mask_patches:
                    for i in range(top, top + h):
                        for j in range(left, left + w):
                            if mask[i, j] == 0:
                                mask[i, j] = 1
                                delta += 1

                if delta > 0:
                    break
        return delta

    def __call__(self):
        mask = np.zeros(shape=self.get_shape(), dtype=np.int)
        mask_count = 0
        while mask_count < self.num_masking_patches:
            max_mask_patches = self.num_masking_patches - mask_count
            max_mask_patches = min(max_mask_patches, self.max_num_patches)

            delta = self._mask(mask, max_mask_patches)
            if delta == 0:
                break
            else:
                mask_count += delta

        return mask

总结

以上为BEiT论文我的一些感受,并结合关键代码进行了说明,希望可以帮到您!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/358198.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

收藏,核心期刊的投稿、审稿、出刊流程详解

学术期刊论文&#xff08;核心和普刊&#xff09;的发表流程总的来说其实是一样的&#xff0c;整个流程包括&#xff1a;1写作-2选择刊物-3投稿-4审稿-5返修或拒稿-6录用-7出刊-8上网检索。 其中1和2其实顺序是可以调换的&#xff0c;可以选择好刊物再写作&#xff0c;根据刊物…

麦克风阵列波束基本概念理解

波束形成 本质上是设计合适的滤波器&#xff0c;对于一类固定滤波器系数的阵列来说&#xff0c;无论输入信号或者噪声信号的统计特征如何&#xff0c;其滤波器系数固定不变&#xff0c;此类波束形成叫Fixed Beamforming&#xff0c;固定波束形成好比传统数字信号处理里面的经典…

TCP并发服务器(多进程与多线程)

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起探讨和分享Linux C/C/Python/Shell编程、机器人技术、机器学习、机器视觉、嵌入式AI相关领域的知识和技术。 TCP并发服务器&#xff08;多进程与多线程&#xff09;1. 多进程并发服务器&#xff08;1&#xff09;…

NPDP认证|2023年,0基础转行产品经理可以吗?

2023年&#xff0c;告别了疫情&#xff0c;各个行业正在快速回暖&#xff0c;很多企业都在高薪招聘产品经理岗位&#xff0c;这让很多其他岗位的朋友也想转行做产品经理&#xff0c;那没有基础&#xff0c;没有经验能转行做产品经理吗&#xff1f; 0基础转行产品经理是可能的&a…

Redis 删除策略

过期数据Redis中的数据特征 Redis是一种内存级数据库&#xff0c;所有数据均存放在内存中&#xff0c;内存中的数据可以通过TTL指令获取其状态XX &#xff1a;具有时效性的数据-1 &#xff1a;永久有效的数据-2 &#xff1a;已经过期的数据 或 被删除的数据 或 未定义的数据数…

Windows出现0xc00d36e5错误怎么办?

当我们使用Windows Media Player来播放视频文件时&#xff0c;可能会出现无法播放&#xff0c;并显示0xc00d36e5错误代码。该错误可能是因为Windows Media Player不支持视频格式、注册表项损坏、系统配置问题、第三方应用程序冲突等。下面将开始介绍0xc00d36e5错误的解决方法&a…

K_A12_014 基于STM32等单片机驱动S12SD紫外线传感器模块 串口与OLED0.96双显示

K_A12_014 基于STM32等单片机驱动S12SD紫外线传感器模块 串口与OLED0.96双显示一、资源说明二、基本参数参数引脚说明三、驱动说明IIC地址/采集通道选择/时序对应程序:数据对比&#xff1a;四、部分代码说明1、接线引脚定义1.1、STC89C52RCS12SD紫外线传感器模块1.2、STM32F103…

【数据结构】P0 三要素与五个特征

三要素与五个特征什么是数据结构数据结构的三要素逻辑结构存储结构数据的运算算法的五个特征时间复杂度什么是数据结构 数据元素之间存在着一种或者多种关系&#xff0c;这种关系称为结构。 数据结构的三要素 数据结构的三要素&#xff1a;逻辑结构、存储结构、数据的运算。 …

【面试题】手写防抖和节流

1. 手写防抖 debounce 首先介绍一个防抖的应用场景。假如需要监听一个输入框在输入文字后触发的change事件&#xff0c;那么通过keyup事件&#xff0c;每次输入文字后都会触发change事件&#xff0c;频繁触发的情况会影响系统的性能。因此可以使用防抖来降低触发频率&#xff…

flutter系列之:在flutter中使用导航Navigator

文章目录简介flutter中的NavigatorNavigator的使用总结简介 一个APP如果没有页面跳转那么是没有灵魂的&#xff0c;页面跳转的一个常用说法就是Navigator,flutter作为一个最为优秀的前端框架&#xff0c;Navigator肯定是必不可少的&#xff0c;那么在flutter中如何使用Navigat…

自建Git服务器

Gitea - Git with a cup of tea是一个国外团队基于国内一位大牛写的gogs开源项目&#xff08;Go语言开发&#xff09;二次开发的轻量Git社区&#xff0c;其稳定性非常好&#xff0c;而且是非常轻量级在个人亲测在1核1G的centos7主机上1个月不重启依然稳定运行&#xff0c;引用g…

chatgpt怎么去玩?解析各种用途和强大的功能

关于chatgpt怎么玩&#xff1f;他的一些原理以及玩法&#xff0c;相信大家都是挺好奇的吧&#xff0c;毕竟这个新的人工智能和以往我们玩过的&#xff0c;是完全不一样的&#xff0c;它具备更多的可玩性&#xff0c;而且具备有一定的学习能力&#xff0c;这就造就了它的强大&am…

记一次IDE的Docker插件实战(Dockfile篇)

IDEA下使用Docker插件制作镜像、推送及运行 前言 本部分主要根据IDEA的Docker插件实战(Dockerfile篇)_程序员欣宸的博客-CSDN博客_idea编写dockerfile一文所述内容进行实践&#xff0c;并对其中遇到的问题进行解答&#xff0c;从而串接多个知识点。 如何编写Dockfile 在Int…

【elasticsearch】elasticsearch es读写原理

一、前言&#xff1a; 今天来学习下 es 的写入原理。 Elasticsearch底层使用Lucene来实现doc的读写操作&#xff1a; Luence 存在的问题&#xff1a; 没有并发设计 lucene只是一个搜索引擎库&#xff0c;并没有涉及到分布式相关的设计&#xff0c;因此要想使用Lucene来处理海量…

「可信计算」与软件行为学

可信计算组织&#xff08;Ttrusted Computing Group,TCG&#xff09;是一个非盈利的工业标准组织&#xff0c;它的宗旨是加强在相异计算机平台上的计算环境的安全性。TCG于2003年春成立&#xff0c;并采纳了由可信计算平台联盟&#xff08;the Trusted Computing Platform Alli…

亮个相吧小宝贝儿,五款压箱底的软件

今天要给大家推荐5款压箱底的宝贝软件了&#xff0c;百度搜索一下就能找到下载链接了。 1.开源浏览器——Firefox Firefox是一个自由的&#xff0c;开放源码的浏览器&#xff0c;适用于 Windows, Linux 和 MacOS X平台&#xff0c;Mozilla Firefox官方版体积小速度快&#xf…

【项目】Vue3+TS CMS 登录模块搭建

&#x1f4ad;&#x1f4ad; ✨&#xff1a;Vue3 TS   &#x1f49f;&#xff1a;东非不开森的主页   &#x1f49c;: keep going&#x1f49c;&#x1f49c;   &#x1f338;: 如有错误或不足之处&#xff0c;希望可以指正&#xff0c;非常感谢&#x1f609;   Vue3TS一、…

微服务面试题:熔断和降级有什么区别?

文章目录引言1.概念不同1.1 熔断概念1.2 降级概念2.熔断器模型3.种状态之间的转换关系4.熔断策略5.熔断和降级的关系6.降级方式6.1、熔断降级&#xff08;不可用&#xff09;6.2、超时降级6.3、限流降级7.题外话8.总结引言 熔断和降级都是系统自我保护的一种机制&#xff0c;但…

进阶C语言 第四章-------《自定义类型》 (结构体、枚举、联合)知识点+完整思维导图+深入细节+通俗易懂+基本练习题+建议收藏

绪论 书接上回&#xff0c;通过上章的一些函数&#xff0c;我们可以让我们对于一些数值的调整有很大的帮助&#xff0c;本章自定义类型在C语言中同样也有着非常重要的地位&#xff0c;相信只要认真的阅读了本章&#xff0c;一定会对你有很大的帮助。 所以安全带系好&#xff0c…

使用Cmake从源码编译Lua

前置要求&#xff1a;电脑已经设置好了Cmake能够使用 首先下载Lua源码&#xff0c;文件后缀是tar.gz 各版本可以从这里找到&#xff1a;Lua - Version history 解压下载文件至所需目录&#xff0c;文件内容如下图&#xff1a; 解压即可。 在解压的文件夹&#xff08;本例是lua…