Visual Transformer开端——ViT及其代码实现

news2025/1/4 15:45:05

深度学习知识点总结

专栏链接:
https://blog.csdn.net/qq_39707285/article/details/124005405

此专栏主要总结深度学习中的知识点,从各大数据集比赛开始,介绍历年冠军算法;同时总结深度学习中重要的知识点,包括损失函数、优化器、各种经典算法、各种算法的优化策略Bag of Freebies (BoF)等。


从RNN到Attention到Transformer系列

专栏链接:
https://blog.csdn.net/qq_39707285/category_11814303.html

此专栏主要介绍RNN、LSTM、Attention、Transformer及其代码实现。


YOLO系列目标检测算法

专栏链接:
https://blog.csdn.net/qq_39707285/category_12009356.html

此专栏详细介绍YOLO系列算法,包括官方的YOLOv1、YOLOv2、YOLOv3、YOLOv4、Scaled-YOLOv4、YOLOv7,和YOLOv5,以及美团的YOLOv6,还有PaddlePaddle的PP-YOLO、PP-YOLOv2等,还有YOLOR、YOLOX、YOLOS等。


Visual Transformer

专栏链接:
https://blog.csdn.net/qq_39707285/category_12184436.html

此专栏详细介绍各种Visual Transformer,包括应用到分类、检测和分割的多种算法。


本章目录

  • 1. 简介
  • 2. 模型
    • 2.1 输入图片2D转1D
    • 2.2 [class] token
    • 2.3 位置嵌入
    • 2.4 Inductive bias
    • 2.5 Hybrid Architecture
  • 3. 代码实现
    • 3.1 定义参数
    • 3.2 图片编码
    • 3.3 加入class token
    • 3.4 位置编码
    • 3.5 Transformer
    • 3.6 MLP层
    • 3.7 总体代码


ViT
ViT是Visual Transformer的开端之作,第一次应用Transformer到CV领域。论文:《AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》。

1. 简介

  虽然Transformer架构已成为自然语言处理任务的基本标准,但其在计算机视觉中的应用仍然有限。在视觉上,注意力要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持其整体结构不变。本文表明,这种对神经网络的依赖是不必要的,直接应用图像patch序列的纯Transformer可以很好地执行图像分类任务。当对大量数据进行预训练并将其迁移到多个中型或小型图像识别基准(ImageNet、CIFAR-100、VTAB等)时,与最先进的卷积网络相比,Visual Transformer(ViT)获得了优异的结果,同时训练所需的计算资源大大减少。

2. 模型

  模型总体图如下所示:
在这里插入图片描述
模型总体图。将图像分割成固定大小的patchs,线性嵌入每个patch,添加位置嵌入,并将生成的矢量序列提供给标准的Transformer编码器。为了进行分类,在序列中添加额外可学习的“分类标记class token”。

2.1 输入图片2D转1D

  标准的Transformer输入的是1D的序列,为了处理2D的图片,把图片x-(H×W×C)reshape成一系列拉平的2D patchs x p s h a p e : ( N × ( P 2 ⋅ C ) ) x_p shape:(N×(P^2·C)) xpshape:(N×(P2C)),其中 H 、 W H、W HW是原始图片的高和宽, C C C是通道数, ( P , P ) (P,P) (P,P)是每一个图片patch的分辨率, N = H W / P 2 N=HW/P^2 N=HW/P2是最终的patchs的总数,也是Transformer输入序列的长度。Transformer在其所有层中使用固定大小的向量D,因此需要将patch拉平,并使用可训练的线性投影映射到D维度(公式1)。将此投影的输出称为patch embeddings(patch嵌入)。

在这里插入图片描述

2.2 [class] token

  和BERT的[class]token类似,在嵌入patch序列 ( z 0 0 = x c l a s s ) (z^0_0=x_{class}) z00=xclass中添加了一个可学习的嵌入,其在Transformer编码器输出端的状态 ( z L 0 ) (z^0_L) zL0用作图像的表示y(公式4)。在预训练和微调期间, z L 0 z^0_L zL0上都安装了一个分类头。分类头在预训练时由具有一个隐藏层的MLP实现,在微调时由单个线性层实现。

2.3 位置嵌入

  位置嵌入被添加到patch嵌入以保留位置信息。本文使用标准的可学习1D位置嵌入,因为没有观察到使用更高级的2D感知位置嵌入带来的显著性能提高。添加后所得的嵌入向量序列用作编码器的输入。
Transformer编码器由multi-head self attention(MSA)和MLP块的交替组成。在每个块之前应用Layernorm(LN),在每个块之后应用残差连接。

2.4 Inductive bias

  注意到,Vision Transformer比CNN具有更少的图像特定感应偏置。在神经网络中,整个模型的每个层的局部性、二维邻域结构和平移不变性都能体现。在ViT中,只有MLP层是局部的并且是平移不变的,而self-attention层是全局的。二维邻域结构的使用非常谨慎:在模型开始时,通过将图像切割成小块,并在微调时调整不同分辨率图像的位置嵌入。除此之外,初始化时的位置嵌入不携带关于patch的2D位置的信息,并且必须从头学习patch之间的所有空间关系。

2.5 Hybrid Architecture

  作为原始图像patch的替代,输入序列可以由CNN的特征图形成。在该混合模型中,将patch嵌入投影E(公式1)应用于从CNN特征图提取的patch。作为一种特殊情况,patch可以具有空间大小1x1,这意味着通过简单地展平特征图的空间维度并投影到Transformer维度来获得输入序列。如上所述添加分类输入嵌入和位置嵌入。

3. 代码实现

3.1 定义参数

  • 输入图片尺寸:image_size=256
  • 每个patch尺寸:patch_size=16
  • 输出分类总数:num_classes=1000
  • 图片patch编码维度:dim=1024
  • Transformer编码器深度:depth=6
  • MAS head总数:heads=16
  • MLP维度:mlp_dim=2048
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)
    
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        ...

3.2 图片编码

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64):
        super().__init__()
        
        ...
		self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )
        ...
        
    def forward(self, img):
        x = self.to_patch_embedding(img)
        ...
        

  输入图片 s h a p e = ( B × C × H × W ) shape=(B×C×H×W) shape=(B×C×H×W),首先reshape成 ( B × ( h × w ) × ( p 1 × p 2 × C ) ) (B×(h×w)×(p_1×p_2×C)) (B×(h×w)×(p1×p2×C)),其中H、W是图片原始宽和高, p 1 、 p 2 p_1、p_2 p1p2是图片patch的尺寸, h 、 w h、w hw是图片patch的数量。然后使用线性层转换成指定维度。

3.3 加入class token

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64):
        super().__init__()
        
        ...
		self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        ...
        
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        ...

  首先把class tokens复制B份,B是batchsize。然后联结到patch序列前面。

3.4 位置编码

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64):
        super().__init__()
        
        ...
		self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        ...
        
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        ...
    

3.5 Transformer


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64):
        super().__init__()
        
        ...
		self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        ...
        
    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]

		x = self.transformer(x)
        ...
        

3.6 MLP层

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64):
        super().__init__()
        
        ...
		self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
        ...
        
    def forward(self, img):
        ...
		x = self.transformer(x)

        x = x[:, 0]
        out = self.mlp_head(x)
        return out
    

3.7 总体代码

代码下载地址:
https://download.csdn.net/download/qq_39707285/87405676

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/189808.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

购买和登录Linux云服务器

目录 云服务器的购买 SSH登录云服务器 云服务器的购买 我们以腾讯云为例, 其他的服务器厂商也是类似。 1. 进入腾讯云官方网站:学生云服务器_云校园特惠套餐 - 腾讯云 (tencent.com) 2. 登陆网站(可以使用微信登陆) 3.购买云服务器 购买最低级即可,对于…

36/365 java 类的加载 链接 初始化 ClassLoader

1.类的加载,链接,初始化 注意点: Class对象是在类的加载过程中生成的(类的数据(static,常量,代码)在方法区,Class类对象在堆中),这个Class类对象作为方法区中…

Canvas 实现台球假想球精准定位

1. 前言 台球是一个让人非常着迷的运动项目,充满了各种计算逻辑,十分有趣。 对于初学者,母球、目标球、袋口三者在一条线上的时候,是非常容易进球的,但对于三者不在一条线上时,就是需要假想球的帮助&…

Windows 上安装 Insomnia 代替 Postman

Windows 上安装 Insomnia 代替 PostmanInsomnia 概述官网地址下载和安装 Insomnia使用 InsomniaInsomnia 概述 Insomnia 是一个开源桌面应用程序,它提供了设计、调试和测试API的简单方法。 通过对开发者友好的界面、内置的自动化和可扩展的插件生态系统&#xff0…

自动驾驶中间件:量产落地的关键技术

/ 导读 /对于初入自动驾驶行业的人来说,各色各样的新型传感器、线控系统、芯片域控制器、算法软件似乎是自动驾驶未来实现的重中之重,对于中间件大多数人可能都不太熟悉,有些甚至从未听说过其存在。但中间件却也是极为重要的一环,…

设计模式-创建型模式

目录 4.创建型模式 4.1 单例设计模式 4.1.1 单例模式的结构 4.1.2 单例模式的实现 4.1.3 存在的问题 4.1.4 JDK源码解析-Runtime类 4.2 工厂模式 4.2.1 概述 4.2.2 简单工厂模式 4.2.3 工厂方法模式 4.2.4 抽象工厂模式 4.2.5 模式扩展 4.2.6 JDK源码解析-Collecti…

Kotlin~生成器模式

介绍 主要作用 逐步构造复杂对象,该对象的属性更多的扩展属性,如Glide的使用。 组成 Builder:提供逐步创建产品的步骤 Director:创建可复用的特定产品(规定Builder规定一系列的步骤创建产品,非必须&…

21新版FL Studio水果电音编曲Daw宿主软件好不好用?

首先是FL Studio(以下简称FL)的逻辑和其它宿主软件都不太一样,FL的逻辑就与众不同。FL的逻辑也可以分为三部分:通道机架、混音台和播放列表。在Live里每个发送轨都可以插入一个乐器以及若干个效果器。你有200个发送轨,…

vcenter 起不来报错VMware ESX 找不到虚拟磁盘“vCenter Server 7.0U3_12.vmdk”。请确认路径有效并重试

针对无快照时丢失.vmdk描述符文件:基础磁盘文件为-flat.vmdk是存在的 那个可以进行恢复操作步骤如下1.确定 flat.vmdk基础磁盘文件的大小(字节)2.创建与flat.vmdk相同大小的新的空虚拟磁盘。3.重命名新创建的.vmdk磁盘的描述符文件匹配原始虚…

如何运行一个py项目

在pycharm中打开项目文件确保安装python环境此时是使用python3.7版本,没有的话需要添加环境:add interpreter在anaconda(安装参考https://blog.csdn.net/m0_67357141/article/details/123633490)中选择基础环境(base&a…

Python中的列表

1.创建列表 使用中括号把要添加的元素括起来,不同元素用逗号隔开。 >>> rhyme [1, 2, 3, 4, 5, "上山打老虎"] >>> print(rhyme) [1, 2, 3, 4, 5, 上山打老虎]2.访问列表中的元素 (1)希望顺序访问列表中的元…

博弈论入门

分类 要素 常见博弈 完全信息静态博弈 纳什均衡 囚徒困境 古诺双寡头模型 古诺双寡头模型的条件 市场中有且仅有两家公司策略为同质商品的量,qiq_iqi​边际成本为c,生产成本就为c*q,在这里我们的边际成本是常数。需求曲线:Pa−b∗…

2009-01-从学校毕业步入社会

在一间坐满学生的教室中,台上同学正在对自己毕业答辩项目进行介绍,台下第一排坐着打分的老师,这群人正在进行计算机专业的毕业答辩,台下人群中一个叫刘文轩的同学紧张又期盼的看着前面正在进行答辩的同学,看着同学们优…

react中useReduer和useEffect

相信很多人对于变成中reduce、reducer命名都存在困惑,为了更好理解useRedecuer,我们不妨先来说说reduce。 如何理解reduce和reducer reduce:函数式编程当中的一个术语,reduce操作被称为Fold折叠 // 通过reduce,数组…

公司内部有奖知识答题活动怎么做

公司年会趣味问答、员工业务知识考核、消防安全、党史等知识测试......公司内部的答题活动已经成了众多管理者、HR日常工作中一部分。如何让组织者更轻松、更公平公正地举办答题活动?如何让员工更积极参与呢?试试答题小博士的有奖答题。有奖答题活动形式…

中晶FileScan 3222扫描仪 Code:-206,卡纸或滚筒出错

中晶FileScan 3222是中晶品牌下的一款扫描仪。 型号 3222 产品类型 平板式+馈纸式 扫描光源 LED

机器人中的数值优化之BFGS(convex and smooth)

本文ppt来自深蓝学院《机器人中的数值优化》 目录 1 Why Quasi-Newton Methods 2 Rate of convergence 3 Quasi-Newton Methods 3.1 Quasi-Newton approximation 3.2 preserve descent direction 3.3 secant condition 3.4 iterate B 3.5 Parsed solution B 4 Cont…

微信小程序学习第2天——模板语法与样式,全局配置与页面配置

文章目录一、WXML模板语法1、数据绑定2、事件绑定3、条件渲染4、列表渲染二、WXSS模板样式rpximport语法导入样式全局和局部样式三、全局配置全局配置文件及常用配置项windowtabBar四、页面配置一、WXML模板语法 1、数据绑定 数据绑定的原则:①在data中定义数据 ②…

亚信安慧携AntDB数据库入选信通院软件供应链厂商和产品名录

日前,中国信息通讯研究院(简称:中国信通院)在其主办的3SCON软件供应链安全大会上,发布了软件供应链厂商和产品名录。中国信通院云计算与大数据研究所副所长栗蔚表示,我国软件供应链安全发展面临制度体系待完…

微服务组件(高并发带来的问题 服务器雪崩效应 Sentinel入门)

高并发带来的问题 服务器雪崩效应 Sentinel入门高并发带来的问题模拟高并发服务器雪崩效应常见容错方案Sentinel入门(常见的容错组件)什么是Sentinel?订单服务集成Sentinel流控规则预热流控等待流控关联流控链路流控降级(提供一个兜底方案)慢调用比例异常比例异常数案例高并发…