视觉Transformers中的位置嵌入 - 研究与应用指南

news2024/12/22 18:35:11

视觉 Transformer 中位置嵌入背后的数学和代码简介。

自从 2017 年推出《Attention is All You Need》以来,Transformer 已成为自然语言处理 (NLP) 领域最先进的技术。 2021 年,An Image is Worth 16x16 Words² 成功地将 Transformer 应用于计算机视觉任务。从那时起,人们提出了许多基于Transformer的计算机视觉架构。

本文[1]研究了为什么位置嵌入是视觉Transformer的必要组成部分,以及不同的论文如何实现位置嵌入。它包括位置嵌入的开源代码以及概念解释。所有代码都使用 PyTorch 包。

为什么使用位置嵌入?

Attention is All You Need 指出,Transformer由于缺乏递归或卷积,无法学习有关一组标记顺序的信息。如果没有位置嵌入,Transformer对于标记的顺序是不变的。对于图像,这意味着可以对图像的补丁进行加扰,而不会影响预测的输出。

让我们看一下 Luis Zuno 的像素艺术《黄昏山》中补丁顺序的示例。原始艺术作品已被裁剪并转换为单通道图像。这意味着每个像素都有一个介于 0 和 1 之间的值。单通道图像通常以灰度显示;但是,我们将以紫色配色显示它,因为它更容易看到。

mountains = np.load(os.path.join(figure_path, 'mountains.npy'))

H = mountains.shape[0]
W = mountains.shape[1]
print('Mountain at Dusk is H =', H, 'and W =', W, 'pixels.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
plt.clim([0,1])
cbar_ax = fig.add_axes([0.95.110.050.77])
plt.clim([01])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountains.png'), bbox_inches='tight')
alt

我们可以将此图像分割成大小为 20 的块。

P = 20
N = int((H*W)/(P**2))
print('There will be', N, 'patches, each', P, 'by', str(P)+'.')
print('\n')

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.clim([0,1])
plt.hlines(np.arange(P, H, P)-0.5-0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5-0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
    plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_patches.png'), bbox_inches='tight')
alt

据称,视觉Transformer将无法区分原始图像和补丁被打乱的版本。

np.random.seed(21)
scramble_order = np.random.permutation(N)
left_x = np.tile(np.arange(0, W-P+120), 3)
right_x = np.tile(np.arange(P, W+120), 3)
top_y = np.repeat(np.arange(0, H-P+120), 5)
bottom_y = np.repeat(np.arange(P, H+120), 5)

scramble = np.zeros_like(mountains)
for i in range(N):
    t = scramble_order[i]
    scramble[top_y[i]:bottom_y[i], left_x[i]:right_x[i]] = mountains[top_y[t]:bottom_y[t], left_x[t]:right_x[t]]
    
fig = plt.figure(figsize=(10,6))
plt.imshow(scramble, cmap='Purples_r')
plt.clim([0,1])
plt.hlines(np.arange(P, H, P)-0.5-0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5-0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(N):
    plt.text(x_text[i], y_text[i], str(scramble_order[i]+1), color='w', fontsize='xx-large', ha='center')
    
i3 = np.where(scramble_order==2)[0][0]
plt.text(x_text[i3], y_text[i3], str(scramble_order[i3]+1), color='k', fontsize='xx-large', ha='center');
#plt.savefig(os.path.join(figure_path, 'mountain_scrambled_patches.png'), bbox_inches='tight')
alt

显然,这是与原始图像非常不同的图像,并且您不希望视觉Transformer将这两个图像视为相同。

排列的注意力不变性

让我们研究一下视觉Transformer对于标记顺序不变的说法。Transformer中对 token 顺序不变的组件是注意力模块。

注意力是根据三个矩阵(查询、键和值)计算得出的,每个矩阵都是通过将token传递到线性层而生成的。生成 Q、K 和 V 矩阵后,将使用以下公式计算注意力。

alt

其中 Q、K、V 分别是查询、键和值; dₖ 是缩放值。为了证明注意力对 token 顺序的不变性,我们将从三个随机生成的矩阵开始来表示 Q、K 和 V。Q、K 和 V 的形状如下:

alt

在此示例中,我们将使用 4 个预计长度为 9 的标记。矩阵将包含整数以避免浮点乘法错误。生成后,我们将交换token 0 和token 2 在所有三个矩阵中的位置。具有交换标记的矩阵将用下标 s 表示。

n_tokens = 4
l_tokens = 9
shape = n_tokens, l_tokens
mx = 20 #max integer for generated matricies

# Generate Normal Matricies
np.random.seed(21)
Q = np.random.randint(1, mx, shape)
K = np.random.randint(1, mx, shape)
V = np.random.randint(1, mx, shape)

# Generate Row-Swapped Matricies
swapQ = copy.deepcopy(Q)
swapQ[[02]] = swapQ[[20]]
swapK = copy.deepcopy(K)
swapK[[02]] = swapK[[20]]
swapV = copy.deepcopy(V)
swapV[[02]] = swapV[[20]]

# Plot Matricies
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(8,8))
fig.tight_layout(pad=2.0)
plt.subplot(321)
mat_plot(Q, 'Q')
plt.subplot(322)
mat_plot(swapQ, r'$Q_S$')
plt.subplot(323)
mat_plot(K, 'K')
plt.subplot(324)
mat_plot(swapK, r'$K_S$')
plt.subplot(325)
mat_plot(V, 'V')
plt.subplot(326)
mat_plot(swapV, r'$V_S$')
alt

注意力公式中的第一个矩阵乘法是 Q·Kᵀ=A,其中得到的矩阵 A 是一个大小等于 token 数量的正方形。当我们用 Qₛ 和 Kₛ 计算 Aₛ 时,得到的 Aₛ 的行 [0, 2] 和列 [0,2] 都与 A 交换。

A = Q @ K.transpose()
swapA = swapQ @ swapK.transpose()
modA = copy.deepcopy(A)
modA[[0,2]] = modA[[2,0]] #swap rows
modA[:, [20]] = modA[:, [02]] #swap cols

fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(8,3))
fig.tight_layout(pad=1.0)
plt.subplot(131)
mat_plot(A, r'$A = Q*K^T$')
plt.subplot(132)
mat_plot(swapA, r'$A_S = Q_S * K_S^T$')
plt.subplot(133)
mat_plot(modA, 'A\nwith rows [0,2] swaped\n and cols [0,2] swaped')
alt

下一个矩阵乘法是 A·V=A,其中生成的矩阵 A 与初始 Q、K 和 V 矩阵具有相同的形状。当我们用 Aₛ 和 Vₛ 计算 Aₛ 时,得到的 Aₛ 的行 [0,2] 与 A 交换。

A = A @ V
swapA = swapA @ swapV
modA = copy.deepcopy(A)
modA[[0,2]] = modA[[2,0]] #swap rows

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(127))
fig.tight_layout(pad=1.0)
plt.subplot(221)
mat_plot(A, r'$A = A*V$')
plt.subplot(222)
mat_plot(swapA, r'$A_S = A_S * V_S$')
plt.subplot(224)
mat_plot(modA, 'A\nwith rows [0,2] swaped')
axs[1,0].axis('off')
alt

这表明,更改注意力层输入中标记的顺序会导致输出注意矩阵的相同标记行发生变化。因为注意力是对标记之间关系的计算。如果没有位置信息,更改token顺序不会改变token的关联方式。

示例

定义位置嵌入

现在,我们可以看看正弦位置嵌入的细节。该代码基于 Tokens-to-Token ViT 的公开可用 GitHub 代码。从功能上来说,位置嵌入是一个与 token 形状相同的矩阵。这看起来像:

alt

正弦位置嵌入公式如下所示

alt

其中 PE 是位置嵌入矩阵,i 是沿着标记的数量,j 是沿着标记的长度,d 是标记长度。代码实现:

def get_sinusoid_encoding(num_tokens, token_len):
    """ Make Sinusoid Encoding Table

        Args:
            num_tokens (int): number of tokens
            token_len (int): length of a token
            
        Returns:
            (torch.FloatTensor) sinusoidal position encoding table
    """


    def get_position_angle_vec(i):
        return [i / np.power(100002 * (j // 2) / token_len) for j in range(token_len)]

    sinusoid_table = np.array([get_position_angle_vec(i) for i in range(num_tokens)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)

让我们生成一个示例位置嵌入矩阵。我们将使用 176 个tokens。每个token的长度为 768,这是 T2T-ViT代码中的默认长度。一旦生成了矩阵,我们就可以绘制它。

PE = get_sinusoid_encoding(num_tokens=176, token_len=768)

fig = plt.figure(figsize=(108))
plt.imshow(PE[0, :, :], cmap='PuOr_r')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens');
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([-11])
plt.colorbar(label='Value of Position Encoding', cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'fullPE.png'), bbox_inches='tight')
alt

放大标记的开头。

fig = plt.figure()
plt.imshow(PE[0, :, 0:301], cmap='PuOr_r')
plt.xlabel('Along Length of Token')
plt.ylabel('Individual Tokens');
cbar_ax = fig.add_axes([0.95.20.050.6])
plt.clim([-11])
plt.colorbar(label='Value of Position Encoding', cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'zoomedinPE.png'), bbox_inches='tight')
alt

具有正弦结构!

将位置嵌入应用于tokens

现在,我们可以将位置嵌入添加到我们的tokens中!我们将使用《Mountain at Dusk》,并具有与上述相同的补丁标记化。这将为我们提供 15 个长度为 20²=400 的token。

fig = plt.figure(figsize=(10,6))
plt.imshow(mountains, cmap='Purples_r')
plt.hlines(np.arange(P, H, P)-0.5-0.5, W-0.5, color='w')
plt.vlines(np.arange(P, W, P)-0.5-0.5, H-0.5, color='w')
plt.xticks(np.arange(-0.5, W+110), labels=np.arange(0, W+110))
plt.yticks(np.arange(-0.5, H+110), labels=np.arange(0, H+110))
x_text = np.tile(np.arange(9.5, W, P), 3)
y_text = np.repeat(np.arange(9.5, H, P), 5)
for i in range(1, N+1):
    plt.text(x_text[i-1], y_text[i-1], str(i), color='w', fontsize='xx-large', ha='center')
plt.text(x_text[2], y_text[2], str(3), color='k', fontsize='xx-large', ha='center')
cbar_ax = fig.add_axes([0.95.110.050.77])
plt.clim([01])
plt.colorbar(cax=cbar_ax);
#plt.savefig(os.path.join(figure_path, 'mountain_patches_w_colorbar.png'), bbox_inches='tight')
alt

当我们将这些补丁转换为token时,它看起来像

tokens = np.zeros((1520**2))
for i in range(15):
    patch = gray_mountains[top_y[i]:bottom_y[i], left_x[i]:right_x[i]]
    tokens[i, :] = patch.reshape(120**2)
tokens = tokens.astype(int)
tokens = tokens/255

fig = plt.figure(figsize=(10,6))
plt.imshow(tokens, aspect=5, cmap='Purples_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([01])
plt.colorbar(cax=cbar_ax)
alt

现在,我们可以以正确的形状进行位置嵌入:

PE = get_sinusoid_encoding(num_tokens=15, token_len=400).numpy()[0,:,:]
fig = plt.figure(figsize=(10,6))
plt.imshow(PE, aspect=5, cmap='PuOr_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([01])
plt.colorbar(cax=cbar_ax)
alt

我们现在准备将位置嵌入添加到标记中。位置嵌入中的紫色区域将使令牌变暗,而橙色区域将使它们变亮。

mountainsPE = tokens + PE
resclaed_mtPE = (position_mountains - np.min(position_mountains)) / np.max(position_mountains - np.min(position_mountains))

fig = plt.figure(figsize=(10,6))
plt.imshow(resclaed_mtPE, aspect=5, cmap='Purples_r')
plt.xlabel('Length of Tokens')
plt.ylabel('Number of Tokens')
cbar_ax = fig.add_axes([0.95.360.050.25])
plt.clim([01])
plt.colorbar(cax=cbar_ax)
alt

您可以从原始token中看到结构,以及位置嵌入中的结构!这两条信息都将被转发到Transformer中。

Reference
[1]

Source: https://towardsdatascience.com/position-embeddings-for-vision-transformers-explained-a6f9add341d5

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1489633.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小迪安全31WEB 攻防-通用漏洞文件上传js 验证mimeuser.ini语言特性

#知识点: 1、文件上传-前端验证 2、文件上传-黑白名单 3、文件上传-user.ini 妙用 4、文件上传-PHP 语言特性 #详细点: 检测层面:前端,后端等 2、检测内容:文件头,完整性,二次渲染…

docker快照备份回滚

1. 安装系统 1.1 vm安装Ubuntu 参考:https://blog.csdn.net/u010308917/article/details/125157774 1.2 其他操作 添加自定义物理卷 –待补充– 1.2.1 查询可用物理卷 fdisk -l 输出如下 Disk /dev/loop0: 73.9 MiB, 77492224 bytes, 151352 sectors Units: sectors of …

Vue 项目重复点击菜单刷新当前页面

需求:“在当前页面点击当前页面对应的菜单时,也能刷新页面。” 由于 Vue 项目的路由机制是路由不变的情况下,对应的组件是不重新渲染的。所以重复点击菜单不会改变路由,然后页面就无法刷新了。 方案一 在vue项目中,…

英特尔/ARM/国产化EMS储能控制器解决方案

新型储能是建设新型电⼒系统、推动能源绿⾊低碳转型的重要装备基础和关键⽀撑技术,是实现碳达峰、碳中和⽬标的重要⽀撑。说到储能,大众首先想到的就是电池,其好坏关系到能量转换效率、系统寿命和安全等重要方面,但储能要想作为一…

[LeetBook]【学习日记】数组内乘积

题目 按规则计算统计结果 为了深入了解这些生物群体的生态特征,你们进行了大量的实地观察和数据采集。数组 arrayA 记录了各个生物群体数量数据,其中 arrayA[i] 表示第 i 个生物群体的数量。请返回一个数组 arrayB,该数组为基于数组 arrayA …

vue点击按钮同时下载多个文件

点击下载按钮根据需要的id调接口拿到返回需要下载的文件 再看返回的数据结构 数组中一个对象,就是一个文件,多个对象就是多个文件 下载函数 // 下载tableDownload(row) {getuploadInventoryDownload({ sysBatch: row.sysBatch, fileName: row.fileName…

STM32 TIM编码器接口

单片机学习! 目录 文章目录 前言 一、编码器接口简介 1.1 编码器接口作用 1.2 编码器接口工作流程 1.3 编码器接口资源分布 1.4 编码器接口输入引脚 二、正交编码器 2.1 正交编码器功能 2.2 引脚作用 2.3 如何测量方向 2.4 正交信号优势 2.5 执行逻辑 三、编码器定时…

前端+php:实现提示框(自动消失)

效果 php部分&#xff1a;只展示插入过程 <?php//插入注册表中$sql_insert "INSERT INTO regist_user(userid,password,phone,email)VALUES (" . $_POST[UserID] . "," . CryptPass($_POST[Password]) . "," . $_POST[Phone] . ",&qu…

“而且,再加上”可以用哪个语法来表示,柯桥考级韩语学习

语法 --는/은/ㄴ 데다가 1.语法&#xff1a;는/은/ㄴ 데다가 2.表示&#xff1a;用于谓词词干和体词谓词形后, 表示在原有的状况上再加上其他情况。 3.添加&#xff1a; 4.例句&#xff1a; 当然&#xff0c;与这个语法含义相近的还有不少语法&#xff0c;有一部分是初级暂时…

【AI视野·今日NLP 自然语言处理论文速览 第八十期】Fri, 1 Mar 2024

AI视野今日CS.NLP 自然语言处理论文速览 Fri, 1 Mar 2024 Totally 67 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computation and Language Papers Loose LIPS Sink Ships: Asking Questions in Battleship with Language-Informed Program Sampling Authors G…

Libevent的使用及reactor模型

Libevent 是一个用C语言编写的、轻量级的开源高性能事件通知库&#xff0c;主要有以下几个亮点&#xff1a;事件驱动&#xff08; event-driven&#xff09;&#xff0c;高性能;轻量级&#xff0c;专注于网络&#xff0c;不如 ACE 那么臃肿庞大&#xff1b;源代码相当精炼、易读…

HTTP常用状态码详解

目录 1xx - 信息性状态码 2xx - 成功状态码 3xx - 重定向状态码 4xx - 客户端错误状态码 5xx - 服务器错误状态码 总结 HTTP&#xff08;Hypertext Transfer Protocol&#xff09;是一种用于传输超文本的应用层协议。在HTTP通信中&#xff0c;服务器和客户端之间会通过状态…

Chrome禁止自动升级

一、关闭计划任务 1、首先我们需要右键点击我的电脑&#xff0c;在打开的选项里选择管理。   2、在打开的对话框中选择任务计划程序。   3、在任务计划程序库中找到两个和chrome自动更新相关的任务计划GoogleUpdateTaskMachineCore与GoogleUpdateTaskMachineUA。     4…

C++模拟揭秘刘谦魔术,领略数学的魅力

新的一年又开始了&#xff0c;大家新年好呀~。在这我想问大家一个问题&#xff0c;有没有同学看了联欢晚会上刘谦的魔术呢&#xff1f; 这个节目还挺有意思的&#xff0c;它最出彩的不是魔术本身&#xff0c;而是小尼老师“念错咒语”而导致他手里的排没有拼在一起&#xff0c;…

Python数据可视化库之bashplotlib使用详解

概要 在数据可视化领域,Python拥有许多优秀的库,如Matplotlib、Seaborn等,它们可以创建漂亮而复杂的图形。但是,有时候我们可能需要在终端中绘制简单的图形,这时候Bashplotlib就派上了用场。Bashplotlib是一个Python库,可以在终端中绘制基本的图形,如条形图、散点图等。…

用Python爬取古诗文网的各类古诗

fetch-gushiwen 用途 可以拿去用于个人知识库、知识图谱的创建等其他学习用途。 使用 输入古诗文网的链接&#xff0c;即可爬取该页面所有诗歌的诗名&#xff0c;作者&#xff0c;朝代&#xff0c;内容&#xff0c;译文&#xff0c;注释&#xff0c;赏析&#xff0c;创作背…

文心一言 VS 讯飞星火 VS chatgpt (208)-- 算法导论15.4 5题

五、设计一个 O( n 2 n^2 n2) 时间的算法&#xff0c;求一个 n 个数的序列的最长单调递增子序列。要写代码的时候&#xff0c;请用go语言。 文心一言&#xff0c;抛panic&#xff1a; 在 Go 语言中设计一个 O(n^2) 时间复杂度的算法来求一个 n 个数的序列的最长单调递增子序列…

【数据结构】20 树的应用(堆及其操作)

堆的定义 堆是一种优先队列&#xff0c;从堆中取出元素时按照元素的优先级大小&#xff0c;而不是元素进入队列的先后顺序。 堆的最常用结构是二叉树&#xff0c;一般为完全二叉树。通常使用数组来实现堆的存储。 最大堆&#xff1a;任一结点的值大于或等于其子节点的值。这一…

【论文阅读】Mamba:选择状态空间模型的线性时间序列建模(一)

文章目录 Mamba:选择状态空间模型的线性时间序列建模介绍状态序列模型选择性状态空间模型动机&#xff1a;选择作为一种压缩手段用选择性提升SSM 选择性SSM的高效实现先前模型的动机选择扫描总览&#xff1a;硬件感知状态扩展 Mamba论文 Mamba:选择状态空间模型的线性时间序列建…

【.NET Core】深入理解IO - 读取器和编写器

【.NET Core】深入理解IO - 读取器和编写器 文章目录 【.NET Core】深入理解IO - 读取器和编写器一、概述二、BinaryReader和BinaryWriter2.1 BinartReader类2.2 BinaryWriter类 三、StreamReader和StreamWriter3.1 StreamReader类3.1 StreamWriter类StreamWriter类构造函数Str…