Torch 论文复现:Vision Transformer (ViT)

news2025/1/10 2:41:10

论文标题:An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

从 TPUv3-core-days 可以看到,ViT 所需的训练时间比 ResNet 更短,同时 ViT 取得了更高的准确率

ViT 的基本思想是,把一张图片拆分成若干个 patch (16×16),每个 patch 当作 NLP 中的一个单词,若干个 patch 组成一个句子,用 Transformer 进行处理

ViT 的核心计算模块有:Multihead Attention (torch.nn.MultiheadAttention),Transformer Encoder (torch.nn.TransformerEncoder),Patch Embedding

其中两个可以在 torch.nn 中找到,但是其源代码是由 Python 写的,而且非常冗长。比如: torch.nn.MultiheadAttention 的 forward 函数需要输入 query, key, value,并进行相互间的数值比较,但很多情况下这三者是相等的 (共用一个 tensor),这样的比较显然不必要;而且 torch.nn.TransformerEncoder 又调用了 torch.nn.MultiheadAttention,这也意味着要简化 torch.nn.MultiheadAttention 的话两者都必须重新编写

Multihead Attention

我阅读了 torch 官方的源代码,也参考了其它大佬的代码,整理出了如下的计算流程图。其中 L, B, C 分别表示 Sequence length、Batch size、Channel。N 表示注意力头的个数,并满足 N \cdot C_{head} = C

Multihead Attention 所涉及到的乘法计算有:

  • Input 的线性变换 (Linear [C, 3C])
  • Query 的逐元素乘法
  • Query 和 Key 的矩阵乘法,可看作 Linear [C_{head}, L]
  • Weights 的 softmax 运算
  • Weights 和 Value 的矩阵乘法,可看作 Linear [L, C_{head}]
  • Attention 的线性变换 (Linear [C, C])

其乘法次数可表示为 (含幂运算):

3LBC^2 + LBNC_{head} + L^2BNC_{head} +2L^2BN + L^2BNC_{head} + LBC^2

LB(4C^2 + C)+ 2L^2B(C+N)

虽然 Multihead Attention 的输入通道数 = 输出通道数 (输入输出 shape 相同),但注意力头的个数 N 对乘法次数的影响还是相当大的 (源于 softmax 运算)

class MultiheadAttention(nn.Module):
    ''' n: 注意力头数'''

    def __init__(self, c1, n, drop=0.1):
        super().__init__()
        self.c_head = c1 // n
        assert n * self.c_head == c1, 'c1 must be divisible by n'

        self.scale = self.c_head ** -0.5
        self.qkv = nn.Linear(in_features=c1, out_features=3 * c1, bias=False)
        self.dropout = nn.Dropout(p=drop)
        self.proj = nn.Linear(in_features=c1, out_features=c1)

    def forward(self, x):
        L, B, C = x.shape
        # view: [L, B, C] -> [L, BN, C_head]
        q, k, v = map(lambda t: t.contiguous().view(L, -1, self.c_head), self.qkv(x).chunk(3, dim=-1))
        q, k, v = q.transpose(0, 1), k.permute(1, 2, 0), v.transpose(0, 1)
        # q[BN, L, C_head] × k[BN, C_head, L] = w[BN, L, L]
        # N 对浮点运算量的影响主要在 softmax
        weight = self.dropout((q * self.scale @ k).softmax(dim=-1))
        # w[BN, L, L] × v[BN, L, C_head] = a[BN, L, C_head] -> a[L, B, C]
        attention = (weight @ v).transpose(0, 1).contiguous().view(L, B, C)
        return self.proj(attention)

Transformer Encoder

在参考了 torch 官方的源代码后,我对 LayerNorm 的位置进行了调整,也就是在每次张量与残差相加时才进行层标准化

class TransformerEncoder(nn.Module):
    ''' n: 注意力头数
        e: 全连接层通道膨胀比'''

    def __init__(self, c1, n, e=1., drop=0.1):
        super().__init__()
        self.attn = nn.Sequential(
            MultiheadAttention(c1, n, drop),
            nn.Dropout(p=drop)
        )
        c_ = max([1, round(c1 * e)])
        self.mlp = nn.Sequential(
            nn.Linear(c1, c_),
            nn.GELU(),
            nn.Dropout(p=drop),
            nn.Linear(c_, c1),
            nn.Dropout(p=drop)
        )
        self.norm1 = nn.LayerNorm(c1)
        self.norm2 = nn.LayerNorm(c1)

    def forward(self, x):
        # x[L, B, C]
        x = self.norm1(x + self.attn(x))
        return self.norm2(x + self.mlp(x))

Vision Transformer

在论文中,作者用四个等式表述了 ViT 的计算过程 (先不考虑 Batch size),其中的符号意义为:

  • N:一幅图像所包含的 patch 的数量
  • x_{class}:可训练的 embedding,shape 为 [D]
  • x_p^i:第 i 个 patch 的特征图
  • P:每一个 patch 的边长 
  • E:二维卷积核 (in_channels=C, out_channels=D, k_size=[P, P], stride=[P, P]),可将特征图 x_p^i\ [C, P, P] 变换为 x_p^i E\ [D, 1, 1] \rightarrow [D]
  • E_{pos}:可训练的 embedding,表征每一个 patch 在图像中的位置
  • z_l:第 i 个 Transformer Encoder 的输出,shape 为 [N+1, D]z_l^0 = z_l[0] 的 shape 为 [D]

ViT 所完成的操作如下 (其中 B 为 Batch size):

  • 用 torch.nn.Conv2d 把图像分割成若干个 patch,每个 patch 用一个向量表示 (可看作 NLP 中的单词),展平后得到 shape 为 [B, N, D] 的“句子”
  • 拼接 x_{class} 之后将 shape 变为 [B, N+1, D],并与 E_{pos} 相加
  • transpose 将 shape 变为 [N+1, B, D],输入若干个 Transformer Encoder 之后取 z_l^0 输出
class VisionTransformer(nn.Module):
    ''' n: 注意力头数
        l: TransformerEncoder 堆叠数
        e: TransformerEncoder 全连接层通道膨胀比'''

    def __init__(self, c1, c2, n, l, img_size, patch_size, e=1., drop=0.1):
        super().__init__()
        # 校验 img_size 和 patch_size
        self.img_size = (img_size,) * 2 if isinstance(img_size, int) else img_size
        self.patch_size = (patch_size,) * 2 if isinstance(patch_size, int) else patch_size
        assert sum([self.img_size[i] % self.patch_size[i] for i in range(2)]
                   ) == 0, 'img_size must be divisible by patch_size'
        n_patch = math.prod([self.img_size[i] // self.patch_size[i] for i in range(2)])

        self.cls_embed = nn.Parameter(torch.empty(1, 1, c2))
        self.pos_embed = nn.Parameter(torch.empty(n_patch + 1, c2))

        self.patch_embed = nn.Conv2d(c1, c2, kernel_size=patch_size, stride=patch_size)
        assert c2 % n == 0, 'c2 must be divisible by n'
        self.encoders = nn.Sequential(*[TransformerEncoder(c2, n, e, drop) for _ in range(l)])

    def forward(self, x):
        B, C, H, W = x.shape
        # view: [B, C, N_patch] -> [B, N_patch, C]
        x = self.patch_embed(x).flatten(2).transpose(1, 2)
        cls_embed = self.cls_embed.repeat(B, 1, 1)
        x = torch.cat([cls_embed, x], dim=1) + self.pos_embed
        # view: [B, N_patch + 1, C] -> [N_patch + 1, B, C]
        return self.encoders(x.transpose(0, 1))[0]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/188721.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Paddle入门实战系列(四):中文场景文字识别

✨写在前面:强烈推荐给大家一个优秀的人工智能学习网站,内容包括人工智能基础、机器学习、深度学习神经网络等,详细介绍各部分概念及实战教程,通俗易懂,非常适合人工智能领域初学者及研究者学习。➡️点击跳转到网站。…

每日学术速递1.31

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 今天带来的arXiv上最新发表的3篇NLP论文。 Subjects: cs.CL、cs.AI、cs.DB、cs.LG 1.Editing Language Model-based Knowledge Graph Embeddings 标题:编辑基于语言模型的知识图谱嵌入 作…

C语言求幂运算——奇特中文变量命名

写在前面 主要涉及C/C趣味编程应用及解析面向初学者撰写专栏,个人代码原创如有错误之处请各位读者指正,各位可以类比做自己的编程作业请读者评论回复、参与投票,反馈给作者,我会获得持续更新各类干货的动力。 致粉丝:…

【Rust】8. 包、Crate 和 模块管理(公有、私有特性)

8.1 包和 Crate 8.1.1 基本概念 crate 是 Rust 在编译时最小的代码单位;crate 有两种形式:二进制项(可以被编译为可执行程序)和库(没有 main 函数,也不会编译为可执行程序,而是提供一些诸如函…

Selenium+Java+Maven(12):引入Allure作为报告生成器

一、前言 本篇作为SeleniumJava系列的补充,讲了如何使用Allure作为测试报告生成器,来替代TestNG自带的测试报告或ReportNG测试报告,生成更加美观的(领导更喜欢的)测试报表。话不多说,一起来学习吧~ 二、A…

蒙特卡洛算法详解

蒙特卡洛算法是20世纪十大最伟大的算法之一,阿法狗就采用了蒙特卡洛算法。 1、定义 蒙特卡洛方法也称为 计算机随机模拟方法,它源于世界著名的赌城——摩纳哥的Monte Carlo(蒙特卡洛)。 它是基于对大量事件的统计结果来实现一些确定性问题的计算。其实…

什么是独立性?如何提高独立性?

独立是每个人必经的成长阶段,也是实现人生价值最重要的途径。没有独立就不能实现真正意义上的人生。独立是我们克服困难、实现抱负的最重要的精神力量,也是我们收获幸福的保障。1、什么是独立性?独立性是意志指不受他人影响、能够独立解决问题…

迟到两年的求职总结经验分享

迟到两年的求职总结&经验分享 写在前面 ​ 号主于2021年3月-2021年9月断断续续参加了校园招聘,包括但不限于:暑期实习、秋招提前批、秋招正式批。收获offer包括但不限于:某互联网推荐算法工程师、某通讯公司数据挖掘工程师、某金融科技…

docker 安装mysql8

docker 安装mysql8无法远程登录 # 启动容器 docker run \ -p 13306:3306 \ --name mysql \ --privilegedtrue \ --restartalways \ -v /home/mysqldata/mysql:/etc/mysql \ -v /home/mysqldata/mysql/logs:/logs \ -v /home/mysqldata/mysql/data:/var/lib/mysql \ -v /etc/l…

C++11线程间共享数据

C11线程间共享数据 使用全局变量等不考虑安全的方式以及原子变量这里就不进行说明了。 在多线程中的全局变量,就好比现实生活中的公共资源一样,比如你有一个同时只能允许一个人做饭的厨房,那么在你占用期间,你的室友就必须等待。…

synchronized锁的升级

synchronized锁优化的背景 用锁能够实现数据的安全性,但是会带来性能的下降 无锁能够基于线程并行提升程序性能,带来安全性的下降 java5 synchronized默认是重量级锁,java6以后引入偏向锁和轻量锁,java15 逐步废弃了偏向锁 …

机器学习实战(第二版)读书笔记(4)——seq2seq模型注意力机制(BahdanauAttention,LuongAttention)详解

一、Seq2seq模型 机器学习实战(第二版)读书笔记(1)——循环神经网络(RNN) 中详细介绍了RNN如下图1所示,可以发现RNN结构大多数对序列长度比较局限,对于机器翻译等任务(输入输出长度不想等N to M),RNN没办法处理&…

SVN使用:Mac电脑中修改SVN输出信息为英文的方法

前言 作为软件开发人员,关于项目代码管理以及维护想必都不陌生,尤其是在团队协作的时候,多人开发维护同一个项目更是需要代码管理。关于项目代码管理维护工具,常用的就是Git、SVN等管理工具。本篇文章只来分享一下关于SVN的配置设…

C语言学习笔记-常量

“常量”的广义概念是:‘不变化的量’。例如:在计算机程序运行时,不会被程序修改的量。 以上是百度百科上对常量的部分定义。C语言的学习过程中将会接触很多的常量,不同类型的常量其定义、用法等会有所差异。要搞清楚他们的相似与…

如何恢复已删除的文件?5分钟搞定的简单方法。

本文介绍如何使用文件恢复程序恢复已删除的文件。它包括与恢复已删除文件相关的提示。 如何恢复已删除的文件 从硬盘驱动器恢复已删除的文件并不是一件疯狂的事情,但一旦您意识到文件已被删除,就尝试恢复会有所帮助。被删除的文件通常不会被真正删除&am…

终于有人把数据仓库讲明白了

数仓概念 ⚫ 数据仓库(英语:Data Warehouse,简称数仓、DW),是一个用于存储、分析、报告的数据系统。 ⚫ 数据仓库的目的是构建面向分析的集成化数据环境,分析结果为企业提供决策支持(Decision Support&am…

Linux入门教程|| Linux 忘记密码解决方法|| Linux 远程登录

很多朋友经常会忘记Linux系统的root密码,linux系统忘记root密码的情况该怎么办呢?重新安装系统吗?当然不用!进入单用户模式更改一下root密码即可。 步骤如下: 重启linux系统 3 秒之内要按一下回车,出现如…

解决Error: Electron failed to install correctly, please delete......报错的问题

问题 在启动electron项目的时候,报mlgb错 Error: Electron failed to install correctly, please delete node_modules/electron and try installing again 搞了 好久 才解决 原因 升级Electron到7.0.0,提示Electron failed to install correctly, p…

python数据可视化开发(3):使用psutil和socket模块获取电脑系统信息(Mac地址、IP地址、主机名、系统用户、硬盘、CPU、内存、网络)

系列文章目录 python开发低代码数据可视化大屏:pandas.read_excel读取表格python实现直接读取excle数据实现的百度地图标注python数据可视化开发(1):Matplotlib库基础知识python数据可视化开发(2):pandas读取Excel的数据格式处理 文章目录系…

Linux下监控类命令:ps,du,top,df,free详解

Linux下监控类命令top命令top信息解释top参数使用ps命令ps信息解释ps参数使用du和dffree命令top命令 top命令,是Linux下常用的性能分析工具,能够实时显示系统中各个进程的资源占用状况,一般系统资源导致的崩溃问题可以使用top实时监控各进程…