【AI】计算机视觉VIT文章(Transformer)源码解析

news2025/1/24 8:33:10

论文:Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020

源码的Pytorch版:https://github.com/lucidrains/vit-pytorch

0.前言

Transformer提出后在NLP领域中取得了极好的效果,其全Attention的结构,不仅增强了特征提取能力,还保持了并行计算的特点,可以又快又好的完成NLP领域内几乎所有任务,极大地推动自然语言处理的发展。
VIT这篇文章就是将Transformer模型应用在了CV领域,它将图像处理成Transformer模型可以应用的形式,沿用NLP领域中Transformer的方法,直接验证了其精度可以和ResNet不相上下,展示了在计算机视觉中使用纯Transformer结构的可能,为Transformer在CV领域的应用打开了大门。

直接读文章通常比较抽象,英文的原文更能劝退一大部分人,但对于程序员来说,代码是通行于世界的语言,理解起来就比较简单,结合源码理解论文中的结构,就比较事半功倍。

在这里插入图片描述

上图是VIT文章中的结构,我们看图提问题,从数据的流向来看:图像怎么切分重排的?Linear Projection of Flattened Patches对图像作了什么,怎么让图像变成Transformer能够输入的格式?
Position Embedding是怎么做图像位置编码的,为什么会多出来一个0的位置编码?Transformer Encoder中的各个结构分别代表什么,是怎么实现的?输出的类别是什么?

1.图像是如何适配Transformer输入的?

Transformer的输入是一个一维的向量,而我们的图像是二维的,需要把图像拉伸成一维的,最简单的方式就是沿着x轴展开,将所有的行拼接在一起,也就是Flatten的操作,但是这样处理会导致向量维度比较大,而且同一张图片也只能生成一个Embedding,不能适配Multi-head Embedding(这种解释有点牵强,有点拿结果去解释原因)。

1.1 图像切分重排

如结构图所示,VIT采取了图像切分重排的方式,将一个完整的图像按照行列的方向切分成小块,然后再进行后续处理。切分重排是怎么实现的,我们看代码:

self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), #图片切分重排
            nn.Linear(patch_dim, dim), # Linear Projection of Flattened Patches
        )

这里的Rearrange使用到了einops的库,相关的介绍可以查看文档
这里主要是对b c (h p1) (w p2) -> b (h w) (p1 p2 c)表达式的理解,b是batchsize,c是chanel,h和w是图像的高度和宽度,表达式前边的(h p1)表示将输入的h重新拆分成一个h*p1的向量,同理(w p2)是将输入的w拆分成w*p2的向量,这里的p1和p2是模型定义的patch_height和patch_width,可以理解为切分后小图的高和宽;
表达式右边表示输出向量的维度,(p1 p2 c)表示这3个数相乘,表示的是切分后一个小图的一维向量大小,(h w)则表示总共有多少个切分后的小图所生成的向量。这里的h和w的值跟输入的h和w值不同,表示的是原有h和w除以patch_height或patch_width的值,也就是在高和宽上各能切分出几个小图。
通过这样的处理,就将输入的c个channel的h*w的二维向量,转换成小图展开后的一维向量。

1.2 构建patch0

图像在输入Transformer之前,还连接了一个patch0,这一步的操作可以认为是延续了nlp的操作,在VIT这边的操作,可以认为是将分类类别拼接上去。

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)

1.3 Positional Embedding

为了在模型中引入位置信息,VIT引入了位置编码的形式,也就是图中的0、1、2、3…

x += self.pos_embedding[:, :(n + 1)]

然后将position_embedding与图像的Embedding相加。

2.Transformer Encoder中的各个结构分别代表什么,是怎么实现的?

结构中的Norm可以认为是归一化处理,MLP是多个全连接层,理解起来都比较简单。其实最主要的结构就是Multi-head Attention。
在这里插入图片描述

2.1 Attention 定义

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        #得到qkv
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

2.2 Attention 的理解

VIT不同于之前RNN的地方就是引入了Attention机制,其本质可以认为是相似度计算,是计算每一个输入值与其他输入值的相似度,然后带入到计算中,如图所示,q是输入的查询值,k是关键词,v是计算值,计算每一个q与其他k的相似度,然后再带入计算中去。
在这里插入图片描述

其对应的计算公式为:
在这里插入图片描述
相似度计算是用的点积的形式,除以根号下dk是为了抑制极端值,保证softmax之后数值不至于丧失梯度。对应的代码如下:

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

然后过一个softmax

attn = self.attend(dots) # self.attend = nn.Softmax(dim = -1)
attn = self.dropout(attn)

然后跟value值相乘

out = torch.matmul(attn, v)

2.3 Multihead Attention

多头注意力机制可以认为是定义多个Attention(每个attention关注的重点不同)分别来对数据进行处理,这里从源码中的循环结构可以体现出来:

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

3.使用示例

这里可以参考官方的readme文档

$ pip install vit-pytorch

使用示例

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1338972.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

绝缘电阻测试仪的测量范围有多少?它的测量方法是什么?

绝缘电阻测试仪广泛应用于设备检测和故障排除。它广泛应用于电力检测行业。甚至可以说,电力设备离不开绝缘电阻测试仪设备。对于许多经验丰富的电力测试工人来说,绝缘电阻测试仪的常规测量范围和方法应该非常清楚。在本文中,我们将向一些新的…

分享免费视频素材网站,第三弹

今天继续给大家分享免费视频素材网站,整理不易,觉得内容不错的话,可以点赞收藏一下哦~ 1.livelybg 一个免费视频素材小站,资源很少很少,但均为很少见的科幻、超现实、赛博朋克风格素材!&#xf…

Linux账号和权限管理

目录 前言 一、管理用户账号 1、Linux系统中用户账号类型 2、用户标识UID的分类 3、用户账号文件 4、用户账号的初始配置文件 5、用户账号的管理命令 5.1 useradd 5.2 usermod 5.3 passwd 5.4 userdel 二、管理组账号 1、Linux系统中组账号类型 2、组标识号GID的…

blackbox黑盒监控部署(k8s内)tensuns专用

一、前言 部署在k8s中需要用到deployment、configmap、service服务 二、部署 创建存放yaml的目录 mkdir /opt/blackbox-exporter && cd /opt/blackbox-exporter 编辑blackbox配置文件,使用configmap挂在这 vi configmap.yaml apiVersion: v1 kind: Confi…

移动app软件开发50个创意

那些需要努力工作才能赚钱的日子已经一去不复返了。现在,可以通过软件app赚钱。好吧,人类完全依赖智能手机这并不是谎言,无论是订购项目还是呼叫某人提供各种服务。因此,投资移动软件app开发是一个好主意。然而,移动软…

【51单片机系列】DS1302时钟模块

本文是关于DS1302时钟芯片的相关介绍。 文章目录 一、 DS1302时钟芯片介绍二、DS1302的使用2.1、DS1302的控制寄存器2.2、DS1302的日历/时钟寄存器2.3、片内RAM2.4、DS1302的读写时序 三、SPI总线介绍四、DS1302使用示例 一、 DS1302时钟芯片介绍 DS1302是DALLAS公司推出的涓流…

【教程】使用ipagurd打包与混淆Cocos2d-x的Lua脚本

文章目录 摘要引言正文1. 准备工作2. 使用ipaguard处理Lua文件3. 运行ipagurd进行混淆代码加密具体步骤测试和配置阶段IPA 重签名操作步骤4. IPA重签名与发布 总结 摘要 本文将介绍如何使用ipagurd工具对Cocos2d-x中的Lua脚本进行打包与混淆,以及在iOS应用开发中的…

IDEA2023创建web项目

一、新建项目 点击File->New->Project...,如果是第一次创建项目则单击New Project 二、添加Web Application 建好的样子 把web移动到main目录下同时改名为webapp 三、不存在Add Framework Support添加Web Application 如何存在Add Framework Support&#…

React快速入门之交互性

响应事件 创建事件处理函数 处理函数名常以handle事件名命名 function handlePlayClick() {alert(Playing);}传递事件处理函数 函数名、匿名两种方式&#xff01; function PlayButton() {function handlePlayClick() {alert(Playing);}return (<Button handleClick{handl…

信息科技成“新课标”重点,家长必须要懂!

日新月异的当下&#xff0c;人工智能无疑是与生活最为密切相关的核心词语。或许在不远的将来&#xff0c;技术含量低的重复性工作将会被机器取代。甚至有人认为&#xff0c;现在的小学生&#xff0c;大概多数会在未来从事目前尚未发明出来的工作。 近年来&#xff0c;我国的教…

微信消息撤回拦截:x64dbg反汇编实现揭秘

在数字世界中&#xff0c;信息传递的速度快如闪电&#xff0c;但也常常伴随着一些遗憾。微信作为我们日常生活中最常用的通讯工具之一&#xff0c;其撤回功能让许多人在发出信息后有了后悔的机会。然而&#xff0c;有时候我们却希望能够拦截这些即将被撤回的信息。通过x64dbg反…

YOLOv7+Pose姿态估计+tensort部署加速

YOLOv7是一种基于深度学习的目标检测算法&#xff0c;它能够在图像中准确识别出不同目标的位置和分类。而姿态估计pose和tensort则是一种用于实现人体姿态估计的算法&#xff0c;可以对人体的关节位置和方向进行精准的检测和跟踪。 下面我将分点阐述YOLOv7姿态估计posetensort…

用 Unity 实现的安检模拟小游戏源码,通过安检设备 (扫描仪) 检查乘客的随身物品 根据禁止名单对乘客做出判断是否允许通行

介绍 用 Unity 实现的安检模拟小游戏 软件版本 Unity 2019.4.9f1 (64-bit) Visual Studio 2019 游戏玩法 在游戏中你将扮演一名安全检查员 通过安检设备 (扫描仪) 检查每位乘客的随身物品 根据禁止名单对乘客做出判断&#xff1a;允许通行或者下令逮捕 游戏效果 游戏截图…

maven工具的搭建以及使用

文章目录 &#x1f412;个人主页&#x1f3c5;JavaEE系列专栏&#x1f4d6;前言&#xff1a;&#x1f380;首先进行maven工具的搭建&#x1f993;1.[打开下载 maven 服务器官网](http://maven.apache.org)&#x1fa85;2.解压之后&#xff0c;配置环境变量&#x1f3e8;3.打开设…

【SAM系列】I-MedSAM: Implicit Medical Image Segmentation with Segment Anything

论文链接&#xff1a;https://arxiv.org/abs/2311.17081 比较有趣的点&#xff1a;frequency adapter

数据结构课设迷宫问题

迷宫问题 【问题描述】以一个mn的长方阵表示迷宫&#xff0c;0和1分别表示迷宫中的通路和障碍。设计一个程序&#xff0c;对任意设定的迷宫&#xff0c;求出一条从入口到出口的通路&#xff0c;或得出没有通路的结论。 【基本要求】编写一个求解迷宫的非递归程序。求得的通路…

微信小程序获取手机号

1、先新建vue页面 打开看到页面是下图 在method定义方法 源码&#xff1a; <template><view><button type"primary" open-type"getPhoneNumber" getphonenumber"getPhoneNumberFn">登录</button><view class"&…

Explain分析-Mysql索引优化(三)

欢迎大家关注我的微信公众号&#xff1a; 传送门&#xff1a;Explain分析——索引优化实践 传送门&#xff1a;Explain分析-Mysql索引优化&#xff08;二&#xff09; 目录 分页查询优化 Join关联查询优化 in和exsits优化 count(*)查询优化 分页查询优化 示例表&am…

115基于matlab的用于铣削动力学建模的稳定性叶瓣图分析(stablity lobe)

基于matlab的用于铣削动力学建模的稳定性叶瓣图分析(stablity lobe)&#xff0c;程序已调通&#xff0c;可直接运行。 115matlab铣削动力学 (xiaohongshu.com)

HarmonyOS的功能及场景应用

一、基本介绍 鸿蒙HarmonyOS主要应用的设备包括智慧屏、平板、手表、智能音箱、IoT设备等。具体来说&#xff0c;鸿蒙系统是一款面向全场景(移动办公、运动健康、社交通信、媒体娱乐等)的分布式操作系统&#xff0c;能够支持手机、平板、智能穿戴、智慧屏、车机等多种终端设备…