Vision Transformer(VIT)论文解读及实现

news2024/9/22 11:35:53

1 论文解读

paper:VIT

1.1 VIT模型架构如下图所示:

  • 图片原始输入维度 H * W * C
  • 在H和W按像素P切分,则H 、W可分割为 NPP, N=HW/(PP),N为输入transform序列的长度。
  • x ∈ R H ∗ W ∗ C = > x ∈ R N ∗ P 2 ∗ C x \in R^{H*W*C} => x\in R^{N*P^2*C} xRHWC=>xRNP2C
  • 固定每层的维度D不变,The Transformer uses constant latent vector size D through all of its layers, so we flatten the patches and map to D dimensions with a trainable linear projection
  • 在N序列长度的基础上,增加一个Class token,类似bert用于分类任务学习
  • 增加位置信息,使用拉长后的一维数据作为位置编码信息。(使用图片的二维坐标位置,模型效果没有明显改善)
    VIT模型架构

VIT模型公式

输入 x ∈ N ∗ p 2 ∗ C 输入 x \in N*p^2*C 输入xNp2C
x p 1 ∈ P 2 ∗ C x_p^1 \in P^2*C xp1P2C
E ∈ ( P 2 ∗ C ) ∗ D E \in (P^2*C) *D E(P2C)D
其中E对序列N中的每一个xi都是一样的,z0的维度为(N+1)* D
公式(2)MSA(多头注意力)不改变z0的维度
公式(3)经过MLP层后与原始z相加,类似残差网络
公式(4)只取z的第一个值(之前在第一个位置手动添加了一个class标识)用于分类任务,进行模型学习
在这里插入图片描述

2 代码实现

2.1 embedding 层

  • 模型输入x.shape=[16,3,224,224] #16为batch_size
  • x输入patch_embedding 后,shape =[16,768,14,14]
  • 将上面的patch_embedding最后两位(H,W)拉平后,与channel调换位置,shape=[16,196,768]
  • 然后与手动的cls_token拼接 shape=[16,197,768]
  • 加入位置信息后,即可得到embdeeing的输出,shape=[16,197,768]
self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=16,
                                       stride=16)
  • cls_token shape=[1,1,768]
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
#备注:n_patches=14*14   ,config.hidden_size=768
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))

2.2 block层

  • 输入为Embedding层输入的x ;shape=[16,197,768]
  • 通过layer_norm层,,shape不变
  • 通过attn层,构建多头注意力,query,key,value的shape都为shape=[16,12,197,64]
  • 加上原始的x,纪委multi-head的输出,shape=[16,197,768]
  • 再经过layer_norm和全连接层,加上上层x,即为block的输出,shape=[16,197,768]

layer_norm层

  self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)

2.3Encoder层

经过L个Block层,输出结果即为encoder层,shape=[16,197,768]

2.4 模型输出

  • transform最后的输出层为 shape=[16,197,768]
  • 取序列197的第一个作为输出x,x shape=[16,768]
  • 输出x,经过全连接层,shape=[16,num_class]
  • 模型loss为交叉熵损失

3 transformer 结构

  (embeddings): Embeddings(
    (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (layer): ModuleList(
      (0): Block(
        (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (ffn): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attn): Attention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (proj_dropout): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
      )
... 省略10层Block
      (11): Block(
        (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (ffn): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (attn): Attention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (proj_dropout): Dropout(p=0.0, inplace=False)
          (softmax): Softmax(dim=-1)
        )
      )
    )
    (encoder_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  )
)

3 代码总览

3.1 Embedding类

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
                                         width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        print(x.shape)
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        print(cls_tokens.shape)
        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)
        print(x.shape)
        x = x.flatten(2)
        print(x.shape)
        x = x.transpose(-1, -2)
        print(x.shape)
        x = torch.cat((cls_tokens, x), dim=1)
        print(x.shape)

        embeddings = x + self.position_embeddings
        print(embeddings.shape)
        embeddings = self.dropout(embeddings)
        print(embeddings.shape)
        return embeddings

3.2 Block层

class Block(nn.Module):
def init(self, config, vis):
super(Block, self).init()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)

def forward(self, x):
    print(x.shape)
    h = x
    x = self.attention_norm(x)
    print(x.shape)
    x, weights = self.attn(x)
    x = x + h
    print(x.shape)

    h = x
    x = self.ffn_norm(x)
    print(x.shape)
    x = self.ffn(x)
    print(x.shape)
    x = x + h
    print(x.shape)
    return x, weights

3 encoder层

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))
 
    def forward(self, hidden_states):
        print(hidden_states.shape)
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

attention 层

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        print(new_x_shape)
        x = x.view(*new_x_shape)
        print(x.shape)
        print(x.permute(0, 2, 1, 3).shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        print(hidden_states.shape)
        mixed_query_layer = self.query(hidden_states)
        print(mixed_query_layer.shape)
        mixed_key_layer = self.key(hidden_states)
        print(mixed_key_layer.shape)
        mixed_value_layer = self.value(hidden_states)
        print(mixed_value_layer.shape)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        print(query_layer.shape)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        print(key_layer.shape)
        value_layer = self.transpose_for_scores(mixed_value_layer)
        print(value_layer.shape)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        print(attention_scores.shape)
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        print(attention_scores.shape)
        attention_probs = self.softmax(attention_scores)
        print(attention_probs.shape)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)
        print(attention_probs.shape)

        context_layer = torch.matmul(attention_probs, value_layer)
        print(context_layer.shape)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        print(context_layer.shape)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        print(context_layer.shape)
        attention_output = self.out(context_layer)
        print(attention_output.shape)
        attention_output = self.proj_dropout(attention_output)
        print(attention_output.shape)
        return attention_output, weights

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/732956.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第三章 SSD存储介质:闪存 3.1

3.1 闪存物理结构 闪存芯片从小到大依此是由:cell(单元)、page(页)、block(块)、plane(平面)、die(核心)、NAND flash(闪存芯片&#…

Python find()函数使用详解

「作者主页」:士别三日wyx 「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:小白零基础《Python入门到精通》 find 1、指定检索位置2、参数为负数3、超出范围3、find()和index()的区别&#x…

【Docker】Docker安装MySQL

🚀欢迎来到本文🚀 🍉个人简介:陈童学哦,目前专攻C/C、Python、Java等方向,一个正在慢慢前行的普通人。 🏀系列专栏:陈童学的日记 💡其他专栏:CSTL、蓝桥杯&am…

Win11系统如何安装Oracle数据库(超级详细)

前言:在我们安装Oracle之前我们得理解Oracle数据库的优点是什么: Oracle是一个功能强大、可扩展和全面的数据库平台,具有广泛的功能和企业级能力,适用于处理复杂的企业级应用和大型数据集。 目录 一.下载Oracle数据库软件&…

解决idea只能通过 idea.bat打开的问题

解决:C盘用户下面 有idea的配置文件 ,找到idea64.exe.vmoptions 把 -jetbrain : 配置的 jar路径删除

Tablet vs. eReader: Which Is Better for Ebooks? 平板电脑与电子阅读器:哪个更适合电子书?

eReaders are best if all you want to do is have something as close to a paper book as possible. However, if you need anything more than that, a tablet makes more sense as a general-purpose device that can also read ebooks. 如果您只想拥有尽可能接近纸质书的东…

认识文件操作与IO

文章目录 认识文件文件夹文件路径文件分类 文件操作File类构造方法常用方法 字节流IOInputStream常用方法 FileInputStream构造方法FileInputStream实例 OutputStream方法 FileOutputStream 字符流IO 认识文件 我们平时所说的文件指的是存在硬盘上的文件,我们平时的…

Openlayers实战:回显多点、多线段、多多边形

Openlayers地图中,回显数据是非常重要的。 继上一示例回显点、线、圆形、多边形后。本示例回显多线,多点,多个多边形。用到了MultiPoint,MultiLineString,MultiPolygon。 多个信息的显示可以采用循环的方式,单个显示点、线、面。 但是循环方式是要多次计算的,而MultiPoint…

GUI (java)

GUI 一.GUI概念二.Swing概述三.容器组件四.常用容器1.窗体(1) JFrame类的构造方法(2) JFrame类的常用方法 2.面板(1)JPanel类的构造方法(2)JPanel类的常用方法 五.布局管理器1. FlowLayout 流式布局(1)FlowLayout构造方法 2.BorderLayout 边界布局3.GridLayout 网格布局 六.常用…

LangChain: 大语言模型的新篇章

本文介绍了LangChain框架,它能够将大型语言模型与其他计算或知识来源相结合,从而实现功能更加强大的应用。接着,对LangChain的关键概念进行了详细说明,并基于该框架进行了一些案例尝试,旨在帮助读者更轻松地理解LangCh…

Mycat【Mycat高级特性_搭建双主双从、Mycat分片技术_垂直拆分-分库 】(四)-全面详解(学习总结---从入门到深化)

目录 Mycat高级特性_搭建双主双从 Mycat分片技术_垂直拆分-分库 Mycat高级特性_搭建双主双从 环境准备 创建docker容器 #启动第一台 docker run -d -p 3350:3306 -e MYSQL_ROOT_PASSWORD123456 --namemaster1 mysql:5.7#启动第二台 docker run -d -p 3360:3306 -e MYSQL_R…

Qt自定义控件之动画文本

文章目录 前言一、动画文本的效果二、具体实现定义动画对象设置动画时长的实现设置text函数实现绘制代码设置字体函数 三、高级部分操作代码总结 前言 在 Qt 中,自定义控件可以让我们实现丰富的用户界面效果和交互体验。其中,动画文本是一种常见的效果&…

电路分析基础学习(上)第4章

李瀚荪版电分第二版 ----------------------------------------------------------------------------------------------------------------------------- 求单口网络的VCR 两大基本方法: 1.外接电流源求电压; 2.外接电压源求电流; ---…

Netty序列化算法参数调优

目录 一、扩展序列化算法 1、Java 2、Json 二、参数调优 1、CONNECT_TIMEOUT_MILLIS 2、SO_BACKLOG 3、ulimit-n 4、TCP_NODELAY 5、SO_SNDBUF & SO_RCVBUF 6、ALLOCATOR 7、RCVBUF_ALLOCATOR 一、扩展序列化算法 1、Java 我们先写Java中jdk的序列方式&#x…

RabbitMQ系列(28)--RabbitMQ使用Federation Queue(联邦队列)解决异地访问延迟问题

前言: 联邦队列可以在多个Broker节点(或者集群)之间为单个队列提供均衡负载的功能。一个联邦队列可以连接一个或者多个上游队列(upstream queue),并从这些上游队列中获取消息以满足本地消费者消费消息的需求。 1、Federation Queue工作原理图 2、添加策…

Oracle19c默认用户名system密码不正确不能登录问题解决

Oracle19c默认用户名system密码不正确不能登录问题解决 1、oracle 命令乱码问题 oracle乱码问题一般是由于oracle字符集设置和操作系统字符集设置不一致造成的。 查看oracle字符集方式如下: 1.进入sqlplus 命令: sqlplus /nolog2.以系统管理员身份连…

【Git原理与使用】-- 企业级开发模型

目录 引入 系统开发环境 Git 分支设计规范 master 分支 release 分支 develop 分支 feature 分支 hotfix 分支 开发场景 - 基于git flow模型的实践 DevOps研发平台 修复测试环境 Bug 修改预发布环境 Bug 修改正式环境 Bug 紧急修复正式环境 Bug 拓展实践 都说&a…

java的RSA加密解密示例

RSA算法是一种非对称加密算法,公钥和私钥都可以用于加密和解密操作。在RSA算法中,公钥用于加密数据,私钥用于解密数据。 具体来说,使用公钥加密的数据只能使用相应的私钥进行解密。而使用私钥加密的数据则可以使用相应的公钥进行…

【云原生|云计算系列】云计算基础概念

欢迎来到云原生专题的云计算系列第一篇博客,我们将探索云计算的基础知识,以帮助您深入了解这个迅速发展的领域。在前一篇博客中,我们介绍了云原生的概念和重要性,强调了它作为云计算的核心理念和实践的关键角色。本篇博客将进一步…

基于单片机智能水杯 保温杯 定时提醒喝水 温度控制的设计与实现

功能介绍 以51单片机作为主控系统;LCD1602液晶显示当前水温,定时提醒,水量变化DS18B20检测当前水体温度;水位传感器检测当前水位;继电器驱动加热片进行水温加热;定时提醒喝水,蜂鸣器报警&#x…