Vision Transformer结构解析

news2024/10/24 0:22:13

Vision Transformer结构解析

  • ViT简介
  • ViT三大模块
    • ViT图像预处理模块——PatchEmbed
    • 多层Transformer Encoder模块
    • MLP(FFN)模块
  • 基本的Transformer模块
  • Vision Transformer类的实现
  • Transformer知识点

ViT简介

Vision Transformer。transformer于2017年的Attention is all your need提出,该模型最大的创新点就是将transformer应用于cv任务。

论文题目:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
论文链接:https://arxiv.org/pdf/2010.11929.pdf
代码地址:https://github.com/google-research/vision_transformer

ViT模型整体结构图如下:
在这里插入图片描述

ViT三种不同尺寸模型的参数对比:

Panda

ViT三大模块

ViT主要包含三大模块:PatchEmbed、多层Transformer Encoder、MLP(FFN),下面用结构图和代码解析这第三大模块。

ViT图像预处理模块——PatchEmbed

VIT划分patches的原理:
输入图像尺寸(224x224x3),按16x16的大小进行划分,共(224x224) / (16x16) = 196个patches,每个patch的维度为(16x16x3),为满足Transformer的需求,对每个patch进行投影,[16, 16, 3]->[768],这样就将原始的[224, 224, 3]转化为[196, 768]。

Panda

代码实现如下:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding,二维图像patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)  # 图片尺寸224*224
        patch_size = (patch_size, patch_size)  #下采样倍数,一个grid cell包含了16*16的图片信息
        self.img_size = img_size
        self.patch_size = patch_size
        # grid_size是经过patchembed后的特征层的尺寸
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1] #path个数 14*14=196

        # 通过一个卷积,完成patchEmbed
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 如果使用了norm层,如BatchNorm2d,将通道数传入,以进行归一化,否则进行恒等映射
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape  #batch,channels,heigth,weigth
        # 输入图片的尺寸要满足既定的尺寸
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # proj: [B, C, H, W] -> [B, C, H,W] , [B,3,224,224]-> [B,768,14,14]
        # flatten: [B, C, H, W] -> [B, C, HW] , [B,768,14,14]-> [B,768,196]
        # transpose: [B, C, HW] -> [B, HW, C] , [B,768,196]-> [B,196,768]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

多层Transformer Encoder模块

该模块的主要结构是Muti-head Attention,也就是self-attention,它能够使得网络看到全局的信息,而不是CNN的局部感受野。

self-attention的结构示例如下:

Panda
class Attention(nn.Module):
    """
    muti-head attention模块,也是transformer最主要的操作
    """
    def __init__(self,
                 dim,   # 输入token的dim,768
                 num_heads=8, #muti-head的head个数,实例化时base尺寸的vit默认为12
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads  #平均每个head的维度
        self.scale = qk_scale or head_dim ** -0.5  #进行query操作时,缩放因子
        # qkv矩阵相乘操作,dim * 3使得一次性进行qkv操作
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim) 
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim] 如 [bactn,197,768]
        B, N, C = x.shape  # N:197 , C:768

        # qkv进行注意力操作,reshape进行muti-head的维度分配,permute维度调换以便后续操作
        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] 如 [b,197,2304]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] 如 [b,197,3,12,64]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # qkv的维度相同,[batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale  #矩阵相乘操作
        attn = attn.softmax(dim=-1) #每一path进行softmax操作
        attn = self.attn_drop(attn)

        # [b,12,197,197]@[b,12,197,64] -> [b,12,197,64]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # 维度交换 transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)  #经过一层卷积
        x = self.proj_drop(x)  #Dropout
        return x

MLP(FFN)模块

一个MLP模块的结构如下:

Panda
class Mlp(nn.Module):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 act_layer=nn.GELU,  # GELU是更加平滑的relu
                 drop=0.):
        super().__init__()
        out_features = out_features or in_features  #如果out_features不存在,则为in_features
        hidden_features = hidden_features or in_features #如果hidden_features不存在,则为in_features
        self.fc1 = nn.Linear(in_features, hidden_features) # fc层1
        self.act = act_layer() #激活
        self.fc2 = nn.Linear(hidden_features, out_features)  # fc层2
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

基本的Transformer模块

由Self-attention和MLP可以组合成Transformer的基本模块。Transformer的基本模块还使用了残差连接结构。
一个Transformer Block的结构如下:

Panda
class Block(nn.Module):
    """
    基本的Transformer模块
    """
    def __init__(self,
                 dim,
                 num_heads,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super(Block, self).__init__()
        self.norm1 = norm_layer(dim)  #norm层
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        # 代码使用了DropPath,而不是原版的dropout
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim) #norm层
        mlp_hidden_dim = int(dim * mlp_ratio)  #隐藏层维度扩张后的通道数
        # 多层感知机
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))  # attention后残差连接
        x = x + self.drop_path(self.mlp(self.norm2(x)))   # mlp后残差连接
        return x

Vision Transformer类的实现

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_c (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
            distilled (bool): model includes a distillation token and head as in DeiT models
            drop_ratio (float): dropout rate
            attn_drop_ratio (float): attention dropout rate
            drop_path_ratio (float): stochastic depth rate
            embed_layer (nn.Module): patch embedding layer
            norm_layer: (nn.Module): normalization layer
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes  #分类类别数量
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1  #distilled在vit中没有使用到
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) #层归一化
        act_layer = act_layer or nn.GELU  #激活函数

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))  #[1,1,768],以0填充
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_ratio)

        # 按照block数量等间距设置drop率
        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)  # layer_norm

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s),分类头,self.num_features=768
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init,权重初始化
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # cls_token类别token [1, 1, 768] -> [B, 1, 768],扩张为batch个cls_token
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 196, 768]-> [B, 197, 768],维度1上的cat
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)  #添加位置嵌入信息
        x = self.blocks(x)  #通过attention堆叠模块(12个)
        x = self.norm(x)  #layer_norm
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])  #返回第一层特征,即为分类值
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        # 分类头
        x = self.forward_features(x) # 经过att操作,但是没有进行分类头的前传
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x)
        return x

Transformer知识点

论文:Attention Is All You Need
论文地址:https://arxiv.org/pdf/1706.03762.pdf

Transformer由Attention和Feed Forward Neural Network(也称FFN)组成,其中Attention包含self Attention与Mutil-Head Attention。

网络结构如下:

Panda

attention和multi-head-attention结构:

Panda

计算过程:

Panda

计算复杂度对比:

Panda

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1498106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【操作系统概念】 第7章:死锁

文章目录 0.前言7.1 系统模型7.2 死锁特征7.2.1 必要条件7.2.2 资源分配图 7.3 死锁处理方法7.4 死锁预防(deadlock prevention)7.4.1 互斥7.4.2 占有并等待7.4.3 非抢占7.4.4 循环等待 7.5 死锁避免(deadlock-avoidance)7.5.1 安…

银行数字化转型导师坚鹏:银行数字化转型案例研究

银行数字化转型案例研究 课程背景: 数字化背景下,很多银行存在以下问题: 不清楚银行科技金融数智化案例? 不清楚银行供应链金融数智化案例? 不清楚银行普惠金融数智化案例? 不清楚银行跨境金融数智…

Visual Studio如何进行类文件的管理(类文件的分离)

大家好: 衷心希望各位点赞。 您的问题请留在评论区,我会及时回答。 一、问题背景 实际开发中,类的声明放在头文件中,给程序员看类的成员和方法。比如:Dog.h(类的声明文件) 类的成员函数的具体…

[LeetCode][239]【学习日记】滑动窗口最大值——O(n)单调队列

题目 239. 滑动窗口最大值 难度:困难相关标签相关企业提示 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回滑动窗口中的最大值。 示例 1…

devc++8x8取模软件

这几天在搞arduino nano和单个max7219模块,涉及到16进制的取模,在网上转了一圈,没找到合适的取模软件,于是自己做了一个,试过,可以用,按esc退出并生成16进制的取模结果 源代码: #i…

Unity 动画(旧版-新版)

旧版 旧版-动画组件:Animation 窗口-动画 动画文件后缀: .anim 将制作后的动画拖动到Animation组件上 旧版的操作 using System.Collections; using System.Collections.Generic; using UnityEngine;public class c1 : MonoBehaviour {// Start is called before…

Latex公式太长换行标号

Latex中公式太长换行,且编号,可以采用align,不编号行公式用\nonumber,示例如下: \begin{align}\nonumber %第1行公式不编号&a+b+a+b+a+b+a+b+a+b+a+b+a+b+a\\&+c+d=m %第2行公式编号 \end{align}效果如下 原文件链接 公式不同命令的区别 \begin{align} 与 \…

信号处理--卷积残差网络实现单通道脑电的睡眠分期监测

目录 背景 亮点 环境配置 数据 方法 结果 代码获取 参考文献 背景 人类大约花三分之一的时间睡觉,这使得监视睡眠成为幸福感的组成部分。 在本文中,提出了用于端到端睡眠阶段的34层深残留的Convnet架构 亮点 使用深度1D CNN残差架构&#xff0…

高并发服务器模型

高并发服务器模型 1.高并发服务器模型--select2.高并发服务器模型--poll3.epoll模型3.1 epoll原理3.2epoll反应堆 1.高并发服务器模型–select 我们知道实现服务器的高并发,可以用多线程或多进程去实现。但还可以利用多路IO技术:select来实现,它可以同时…

【框架学习 | 第二篇】暴打MyBatis-Plus——MyBatis的升级版本

教程来源链接:https://www.quanxiaoha.com/mybatis-plus/mybatis-plus-tutorial.html 教程作者:犬小哈 文章目录 1.Mybatis Plus介绍1.1Mybatis和Mybatis Plus的区别是什么1.1.1什么是Mybatis?1.1.2区分Mybatis Plus和Mybatis 1.2Mybatis Plus特点1.3支…

【C语言】终の指针(前篇)

个人主页点这里~ 指针初阶点这里~ 指针初阶2.0点这里~ 指针进阶点这里~ 终の指针 一、回调函数二、qsort函数1、整形比较2、结构数据比较①结构体②-> 的使用③结构数据比较 一、回调函数 回调函数就是⼀个通过函数指针调用的函数。 把一个函数的指针作为参数传递给另一…

勾股定理的七种经典证明

据说勾股定理约有500种证明方法,下面介绍几种经典的证明方法。 一、切割重拼法。 顾名思义,就是将图形切割成其他形式的图形,然后通过拼图转换为另一种图形,这个过程中图形的面积是不变的。 “赵爽弦图”是这种方法的经典应用&…

Mysql案例之GROUP_CONCAT函数详解

Hello,大家好,我是灰小猿,一个超会写bug的程序员! 今天这篇文章记录一个最近开发中遇到的mysql实战场景,觉得还挺典型的,就在此做一下记录。 先看一下举例场景: mysql中学生表与学科表通过关…

Linux设备模型(九) - bus/device/device_driver/class

一,设备驱动模型 1,概述 在前面写的驱动中,我们发现编写驱动有个固定的模式只有往里面套代码就可以了,它们之间的大致流程可以总结如下: 实现入口函数xxx_init()和卸载函数xxx_exit() 申请设备号 register_chrdev_r…

首发:鸿蒙面试真题分享【独此一份】

最早在23年华为秋季发布会中,就已经宣布了“纯血鸿蒙”。而目前鸿蒙处于星河版中,加速了各大互联网厂商的合作。目前已经有200参与鸿蒙的原生应用开发当中。对此各大招聘网站上的鸿蒙开发需求,每日都在增长中。 2024大厂面试真题 目前的鸿蒙…

OpenHarmony教程指南—ArkUI中组件、通用、动画、全局方法的集合

介绍 本示例为ArkUI中组件、通用、动画、全局方法的集合。 本示例使用 Tabs容器组件搭建整体应用框架,每个 TabContent内容视图 使用 div容器组件 嵌套布局,在每个 div 中使用 循环渲染 加载此分类下分类导航数据,底部导航菜单使用 TabCont…

LeetCode 2917.找出数组中的 K-or 值:基础位运算

【LetMeFly】2917.找出数组中的 K-or 值:基础位运算 力扣题目链接:https://leetcode.cn/problems/find-the-k-or-of-an-array/ 给你一个下标从 0 开始的整数数组 nums 和一个整数 k 。 nums 中的 K-or 是一个满足以下条件的非负整数: 只有…

如何合理布局子图--确定MATLAB的subplot子图位置参数

确定MATLAB的subplot子图位置参数 目录 确定MATLAB的subplot子图位置参数摘要1. 问题描述2. 计算过程2.1 确定子图的大小和间距2.2 计算合适的figure大小2.3 计算每个子图的position数据 3. MATLAB代码实现3.1 MATLAB代码3.2 绘图结果 4. 总结 摘要 在MATLAB中,使用…

网络编程套接字(1)—网络编程基础

目录 一、为什么需要网络编程? 二、什么是网络编程 三、网络编程中的基本概念 1、发送端和接收端 2、请求和响应 3、客户端和服务端 四、常见的客户端服务端模型 1、一问一答模型 2、一问多答模型 3、多问一答模型 4、多问多答模型 一、为什么需要网络编程? 为什么…

(二十二)从零开始搭建k8s集群——高可用kubernates集群搭建上篇

前言 本节内容分为上、中、下三篇,上篇主要是关于搭建k8s的基础环境,包括服务器基本环境的配置(网络、端口、主机名、防火墙、交换分区、文件句柄数等)、docker环境部署安装配置、镜像源配置等。中篇会介绍k8s的核心组件安装、k8…