Pytorch从零开始实现Vision Transformer (from scratch)

news2024/12/24 8:28:08

Pytorch从零开始实现Vision Transformer

  • 前言
  • 一、Vision Transformer架构介绍
    • 1. Patch Embedding
    • 2. Multi-Head Attention
    • 3. Transformer Block
      • Feed Forward
  • 二、预备知识
    • 1. Einsum
    • 2. Einops
  • 三、Vision Transformer代码实现
    • 0. 导入库
    • 1. Patch Embedding
    • 2. Residual & Norm
    • 3. Multi-Head Attention & FeedForward
    • 4. Transformer Encoder
    • 6. Vision Transformer
    • 7. Test Code
    • 模型参数量计算
      • 1. 卷积核参数量计算
      • 2. 全连接层参数量计算
      • 3. ViT参数量计算
  • 总结
  • 日志
  • 参考文献


前言

Transformer在NLP领域大放异彩,而实际上NLP(Natural Language Processing,自然语言处理)领域技术的发展都要先于CV(Computer Vision,计算机视觉),那么如何将Transformer这类模型也能适用到图像数据上呢?
在2017年Transformer发布后,历经3年时间,Vision Transformer于2020年问世。与Transformer相同,Vision Transformer也是由Google Brain和Google Research团队开发,然而并不是同一批人(除了Jakob Uszkoreit)。
值得一提的是,Vision Transformer并不是第一个将Transformer应用到CV上的,因为这些巨头的存在(如Google,FaceBook),论文的名气也自然会更大,而且从如今ViT的泛用程度来看也是,大家对其认可度更高纷纷follow。和这些巨头庞大资源比,高校产出的论文光芒显得黯淡了许多。而在大模型时代更是如此,都是“大力出奇迹”的结果。可大模型大数据训练就是AI的最终形态了吗,我觉得不然……或许在AI真正具有“智能”时,深度学习的模型也并不需要这么大吧,因为人脑正是有了联想推理才能拥有知识和技能,而不完全单靠记忆。


一、Vision Transformer架构介绍

在这里插入图片描述

1. Patch Embedding

2. Multi-Head Attention

3. Transformer Block

如图,(a) 是最初Transformer的Encoder结构图, (b)则是ViT的。可以明显看出,Transformer是在multi-head attention和feedforward模块后进行残差操作(即Add)和Norm(标准化),而ViT则是在这些模块前使用Norm操作。

Feed Forward

ViT的Feed Forward模块使用两层全连接层(Linear)和GeLU激活函数。而Transformer使用的是ReLu激活函数。
GeLu于2016年被提出,见于Bridging Nonlinearities and Stochastic Regularizers with Gaussian Error Linear Units,后来经过论文修改改名为“Gaussian Error Linear Units (GELUs)”。论文给出了ReLu和GeLu的图示:
在这里插入图片描述
ReLu确实好用,但缺点也很明显,其在输入值小于0时都会输出0,这样“一刀切”的策略势必会丢掉信息,累计error。因此后来出现了GeLu、LeakyReLu等一系列激活函数来解决神经元”死亡“问题,让输入值小于0时输出不总是0。


二、预备知识

本节的两个操作都是为了方便编程人员更好对tensor进行操作,且让代码更具可读性。

1. Einsum

Einsum即爱因斯坦和,torch.einsum即可调用。

2. Einops

是大牛 受Einsum启发所开发的一个库,主要用于张量的变形等操作。


三、Vision Transformer代码实现

这次代码并不是直接取用某一份代码,而是参考包括Pytorch官方的代码库、网上博客、github项目综合出的一份Vision Transformer代码,尽可能还原ViT又兼顾代码可读性以便读者学习理解。此处引用比ViT原论文更加具体的ViT模型图:
ViT流程图
此图出自论文Vision Transformers for Remote Sensing Image Classification。

0. 导入库

import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

1. Patch Embedding

class PatchEmbedding(nn.Module):
    def __init__(self, embed_size=768, patch_size=16, channels=3, img_size=224):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        # Version 1.0
        # self.patch_projection = nn.Sequential(
        #     Rearrange("b c (h h1) (w w1) -> b (h w) (h1 w1 c)", h1=patch_size, w1=patch_size),
        #     nn.Linear(patch_size * patch_size * channels, embed_size)
        # )

        # Version 2.0
        self.patch_projection = nn.Sequential(
            nn.Conv2d(channels, embed_size, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)),
            Rearrange("b e (h) (w) -> b (h w) e"),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
        self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, embed_size))

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.patch_projection(x)
        # prepend the cls token to the input
        cls_tokens = repeat(self.cls_token, "() n e -> b n e", b=batch_size)
        x = torch.cat([cls_tokens, x], dim=1)
        # add position embedding
        x += self.positions
        return x

2. Residual & Norm

class Residual(nn.Module):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

3. Multi-Head Attention & FeedForward

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super(FeedForward, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.mlp(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, n_heads=8, dropout=0.):
        """
        Args:
            embed_dim: dimension of embeding vector output
            n_heads: number of self attention heads
        """
        super(MultiHeadAttention, self).__init__()

        self.embed_dim = embed_dim  # 768 dim
        self.n_heads = n_heads  # 8
        self.head_dim = self.embed_dim // self.n_heads  # 768/8 = 96. each key,query,value will be of 96d
        self.scale = self.head_dim ** -0.5

        self.attn_drop = nn.Dropout(dropout)
        # key,query and value matrixes
        self.to_qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        """
        Args:
           x : a unified vector of key query value
        Returns:
           output vector from multihead attention
        """
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.n_heads), qkv)

        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        attn = dots.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")

        out = self.to_out(out)
        return out

4. Transformer Encoder

class Transformer(nn.Module):
    def __init__(self, dim=768, depth=12, n_heads=8, mlp_expansions=4, dropout=0.):
        super(Transformer, self).__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, MultiHeadAttention(dim, n_heads, dropout))),
                Residual(FeedForward(dim, dim * mlp_expansions, dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x)
            x = ff(x)
        return x

6. Vision Transformer

class VisionTransformer(nn.Module):
    def __init__(self, dim=768,
                 patch_size=16,
                 channels=3,
                 img_size=224,
                 depth=12,
                 n_heads=8,
                 mlp_expansions=4,
                 dropout=0.,
                 num_classes=0,
                 global_pool='avg'):
        super(VisionTransformer, self).__init__()
        assert global_pool in ('avg', 'token')
        self.global_pool = global_pool
        self.patch_embedding = PatchEmbedding(dim, patch_size, channels, img_size)
        self.transformer = Transformer(dim, depth, n_heads, mlp_expansions, dropout)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        ) if num_classes > 0 else nn.Identity()

    def forward(self, img):
        x = self.patch_embedding(img)
        x = self.transformer(x)
        x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.mlp_head(x)
        return x

7. Test Code

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    images = torch.randn((16, 3, 224, 224)).to(device)
    vit = VisionTransformer(num_classes=4, global_pool="token").to(device)
    output = vit(images)
    print(output)
    torch.save(vit.state_dict(), "model.pth")

模型参数量计算

1. 卷积核参数量计算

对于二维卷积层,其参数量由输入通道数(C)、卷积核的大小(KxK)、卷积核的数量或者说输出通道数(F)、偏置项的数量等因素决定。计算公式为:
( K × K × C + 1 ) × F (K \times K \times C + 1)\times F (K×K×C+1)×F,其中1为偏置项。

2. 全连接层参数量计算

对于某一层全连接层的参数量只由其输入维度和输出维度(是否带偏置项)决定,将全连接层理解为一个映射函数,假设输入为矩阵A(维度为HxW),输出为矩阵C(维度为HxH),那么一层全连接层参数量就来自其所代表的矩阵B根据矩阵乘法其维度应为WxH,即Linear(W,H),输入维度W,输出维度也是H。计算公式易得:
W × H + H × 1 W \times H + H\times 1 W×H+H×1,其中1代表偏置项,需要输出维度个偏置项。

3. ViT参数量计算

模块/变量名计算过程参数量
PatchEmbedding c o n v 2 d + c l s _ t o k e n + p o s t i t i o n s conv2d + cls\_token + postitions conv2d+cls_token+postitions742656
conv2d ( 16 × 16 × 3 + 1 ) × 768 (16\times 16\times 3 + 1)\times 768 (16×16×3+1)×768590592
cls_token 1 × 1 × 768 1\times1\times768 1×1×768768
postitions ( ( 224 ÷ 16 ) 2 + 1 ) × 768 ((224\div 16)^2+1)\times768 ((224÷16)2+1)×768151296
Feedforward ( 768 × ( 768 × 4 ) + ( 768 × 4 ) ) + ( ( 768 × 4 ) × 768 + 768 ) (768\times(768\times4)+(768\times4)) + ((768\times4)\times768+768) (768×(768×4)+(768×4))+((768×4)×768+768)4722432
MultiHeadAttention t o _ q k v + t o _ o u t to\_qkv + to\_out to_qkv+to_out2360064
to_qkv 768 × ( 768 × 3 ) 768\times(768\times3) 768×(768×3)1769472
to_out 768 × 768 + 768 768\times768+768 768×768+768590592
Transformer 12 × ( F e e d f o r w a r d + M u l t i H e a d A t t e n t i o n ) 12\times(Feedforward+MultiHeadAttention) 12×(Feedforward+MultiHeadAttention)84989952
ViT T r a n s f o r m e r + P a t c h E m b e d d i n g + m l p _ h e a d Transformer+PatchEmbedding+mlp\_head Transformer+PatchEmbedding+mlp_head85735684
mlp_head 768 × n u m _ c l a s s e s + n u m _ c l a s s e s ,本文设置 n u m _ c l a s s e s 为 4 768\times num\_classes+num\_classes,本文设置num\_classes为4 768×num_classes+num_classes,本文设置num_classes43076

最终参数量为 85735684 × 4 ( B ) = 342942736 ( B ) 85735684\times 4(B) = 342942736(B) 85735684×4(B)=342942736(B)为什么要乘以4字节呢?
因为这些参数权重默认为float32保存,需要用到32bits即4Bytes,最终通过换算得,
342942736 ( B ) ÷ 1024 ÷ 1024 = 327.055679321 ( M B ) 342942736(B)\div 1024\div 1024 = 327.055679321(MB) 342942736(B)÷1024÷1024=327.055679321(MB)
因为我们在Test code有保存模型权重为model.pth文件,可以查看model.pth属性来验证计算是否准确。
在这里插入图片描述
在字节数上有所偏差,但足以表明计算过程大致是正确的! 偏差可能原因是model.pth不止要保存权重,还会附带一些其他信息,所以实际文件大小会比参数量要略大。


总结

日志

参考文献

https://theaisummer.com/vision-transformer/
https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c
https://www.kaggle.com/code/hannes82/vision-transformer-trained-from-scratch-pytorch
https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632
https://github.com/FrancescoSaverioZuppichini/ViT

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/542626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ethereum以太坊事件日志查询参数

目录 一、Ethereum事件日志查询参数二、需求三、实现四、其他 一、Ethereum事件日志查询参数 addresses:合约地址列表fromBlock:开始区块toBlock:结束区块topics:主题数组blockhash:区块哈希,优先级高于fr…

chatgpt赋能Python-python3_9_1怎么用

Python3.9.1是什么? Python是一种高级、动态、解释型语言,具有优雅简洁、易于学习和阅读、功能丰富的特点。Python 3.9.1是Python编程语言的一个版本,于2020年12月21日正式发布,是Python 3的最新稳定版本。它包含了许多新的特性、…

chatgpt赋能Python-python3_7降级3_6

Python 3.7降级3.6:为什么?如何做? 如果你是一个认真的Python开发者,你可能会对Python 3.7的某些改进感到兴奋。但是,在某些情况下,你可能需要将Python降级到3.6版本。在这篇文章中,我们将探讨…

网吧管理系统修正

文章目录 网吧管理系统修正1. 改动1:上机缴费2. 改动2:下机超时计费3. 改动3:注销强制下机操作4. 改动4:计费标准数据的获取与释放 网吧管理系统修正 Carry文件中,打开文件忘记关闭了,虽然C语言中不会报错…

【STM32G431RBTx】备战蓝桥杯嵌入式→决赛试题→第八届

文章目录 前言一、题目二、模块初始化三、代码实现interrupt.h:interrupt.c:main.h:main.c: 四、完成效果五、总结 前言 学习完了所有模块之后(LIS302考点取消了, 扩展板也找不到了,如果你能找到可能你不是在十四届省赛后买的扩展板), 跟省赛一样,先拿第…

这个屏幕录制太好用了!

哈喽,大家好!今天给各位小伙伴测试了一屏幕录制的小工具——ApowerREC。它是一款专业同步录制屏幕画面及声音的录屏软件。界面简洁,操作简单,支持实时编辑屏幕录像、创建计划任务、录制摄像头高清视频等功能。废话不多说&#xff…

汇编九、直接地址与寄存器、ROM和RAM

1、直接地址与寄存器 1.1、A和ACC (1)A是没有地址的,可以当作寄存器使用。 (2)ACC是有地址的,地址为E0H(查数据手册)。 (3)如果使用压栈操作,要用ACC。因为ACC可当做直接地址。 (4)如果使用位操作,要用ACC。因为ACC可位寻址。…

总结855

学习目标: 月目标:5月(张宇强化前10讲,背诵15篇短文,熟词僻义300词基础词) 周目标:张宇强化前3讲并完成相应的习题并记录,英语背3篇文章并回诵 每日必复习(5分钟&#…

chatgpt赋能Python-python3_9_2怎么安装

Python3.9.2的安装指南 Python是一种高级编程语言,广泛应用于各种领域,从数据科学到机器学习到Web应用程序。Python具有易学性和通用性,因此成为众多开发人员的首选语言。此篇文章将指导您如何安装Python3.9.2版本。 为什么选择Python3.9.2…

Packet Tracer – 对 VLAN 实施进行故障排除 – 方案 1

Packet Tracer – 对 VLAN 实施进行故障排除 – 方案 1 地址分配表 设备 接口 IP 地址 子网掩码 交换机端口 VLAN PC1 NIC 172.17.10.21 255.255.255.0 S2 F0/11 10 PC2 NIC 172.17.20.22 255.255.255.0 S2 F0/18 20 PC3 NIC 172.17.30.23 255.255.255.0…

一星期学mysql day2

文章目录 DQL(数据查询语言)基础查询条件查询聚合查询(聚合函数)分组查询注意事项 排序查询注意事项 分页查询注意事项 编写顺序DQL执行顺序 DCL 管理用户管理用户注意事项 权限控制注意事项 函数字符串函数数值函数日期函数流程函…

车辆管理系统的设计与实现

背景 4S店车辆系统,为用户随时随地查看4S店车辆信息提供了便捷的方法,更重要的是大大的简化了管理员管理4S店车辆信息的方式方法,更提供了其他想要了解4S店车辆信息及运作情况以及挑选方便快捷的可靠渠道。相比于传统的管理方法,…

chatgpt赋能Python-python3_8怎么安装numpy库

Python3.8怎么安装numpy库?一个详细的步骤指南 你是否想要在Python3.8中安装numpy库,但不知道如何开始?没问题,本篇文章将介绍Python3.8的numpy安装步骤。 什么是numpy? numpy是Python中的一个重要的数学计算库&…

chatgpt赋能Python-python3_7怎么保存

Python 3.7 - 新的保存方式 Python 语言是一种广泛使用的编程语言之一,它在数据分析、网络编程、科学计算等领域都得到了广泛的应用。Python 的一个新版本,Python 3.7,增加了一种新的保存方式,这种方式可以提高数据的保存效率&am…

ElasticSearch漫游 (5.RestClient初始化准备)

之前都是在客户端页面各种操作, 但是我们作为一个java开发者,需要使用代码的方式 来操作ES, 所以我们要先从 Rest Client 开始 准备数据 既然是ES 查数据 不整点测试数据 没法下饭是吧 先建个酒店表 搞点数据: (新手…

两款亲测非常优秀的压缩软件

哈喽,大家好。今天给大家带来两款超级好用的压缩软件,一款是老牌的WinRAR,另一款是开源的7-Zip。 这两款都是小编电脑上必备的压缩软件,至于为什么要装两款嘛,可能因为任性吧,哈哈。 一、测试演示参数 演…

nvm-windows安装和配置

1.下载nvm-windows https://github.com/coreybutler/nvm-windows/releases 进入后如下图 选择nvm-setup.exe下载 2.安装 2.1 运行安装包,同意协议,下一步 2.2 选择nvm的安装目录,安装路径名最好不要有空格 2.3 点击Next,设…

【CPP】NULL nullptr

文章目录 NULLnullptr总结Ref. NULL 在C中,NULL实际上是0.因为C中不能把void*类型的指针隐式转换成其他类型的指针(cpp是强类型语言),所以为了解决空指针的表示问题,C引入了0来表示空指针,这样就有了上述代…

Python 爬虫(七):pyspider 使用

1 简介 pyspider 是一个支持任务监控、项目管理、多种数据库,具有 WebUI 的爬虫框架,它采用 Python 语言编写,分布式架构。详细特性如下: 拥有 Web 脚本编辑界面,任务监控器,项目管理器和结构查看器&#…

chatgpt赋能Python-python3_8_6怎么用

Python 3.8.6 全面介绍及使用指南 Python是一种高级编程语言,广泛应用于Web开发、数据科学、人工智能等领域。Python 3.8.6是Python 3.x系列的最新版本,在性能、稳定性、开发效率等方面有了重大的提升。本文将对Python3.8.6进行全面介绍,并给…