Transformer(Vit+注意力机制)

news2024/11/25 21:30:54

文献基本信息:

Encoder-Decoder:

Transformer的结构:

  • 输入
  • 编码器
  • 解码器
  • 输出

Transformer的工作流程:

  • 获取输入句子的每一个单词的表示向量X,X由单词的embedding(embedding是一种将高维特征映射到低维的技术)和单词位置的embedding相加得到

  • 将得到的单词表示向量矩阵(如下图所示),每一行是一个单词的表示,传入Encoder中,经过6个Encoder block后可以得到句子所有单词的编码信息矩阵C(如下图所示);单词向量矩阵用X_n*d表示,n是句子中单词的个数,d是表示向量的维度(论文中d=512);每一个Encoder block输出的矩阵维度与输入维度一致

  • 将Encoder输出的编码信息矩阵C传递到Decoder中,Decoder依次会根据当前翻译过的单词1~i翻译下一个单词i+1(如下图所示);在使用的过程中,翻译到单词i+1的时候需要通过Mask(掩盖)操作遮盖住i+1之后的单词

编码器结构:

如图,红色框是编码器结构;N表示这一层有几个编码器;编码器由一个多头注意力机制、一个残差网络、LayerNorm(层归一化)和前馈神经网络组成。

信息进入的完整流程:

  1. 多头注意力机制:处理输入特征的全局关系
  2. 输出加上输入经过:残差连接
  3. 结果经过LayerNorm:进行归一化
  4. 进入前馈神经网络:进行非线性特征变换
  5. 再次经过残差连接+ LayerNorm

论文信息:

从论文来读是Encoder有N=6层,每层包含两个子层(sub-layers)且都有多头自注意力机制,允许模型在处理序列时同时关注序列的不同部分和全连接前馈网络,对每个位置的向量进行相同的操作。

解码器结构:

由一个decoder由两个多头注意力机制构成,全连接神经网络FNN构成

Decoder block(下图红色框)与Encoder block相似,但存在一些区别

与Encoder 区别:

  • 第一个Multi-Head Attention层采用了Masked操作,使得 decoder 不能看见未来的信息,对于一个序列,在 time_step 为 t 的时刻,解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此要把 t 之后的信息给隐藏起来。
  • 第二个Multi-Head Attention层的K、V矩阵使用Encoder的编码信息矩阵C进行计算,而Q使用上一个Decoder block 的输出计算
  • 最后使用softmax层计算下一个翻译单词的概率

第一个 Multi-Head Attention:

采用了 Masked 操作原因:

防止模型在生成输出时查看未来的信息

第二个Multi-Head Attention:

Decoder block第二个Multi-Head Attention变化不大,主要区别在于其中的Self-Attention的K、V矩阵不是使用上一个Decoder block的输出计算的,而是使用Encoder的编码信息矩阵计算的;

输入来源:

根据Encoder的输出C计算得到K、V,根据上一个Decoder block的输出Z计算Q(若第一个Decoder block则使用输入矩阵X进行计算),后续计算方法于之前描述的一致

作用:

在Decoder的每一位单词都可以利用到Encoder所有单词的信息(之前的信息不需要Mask)

Input输入:

使用词嵌入算法将每个词转换为一个词向量;论文中,词嵌入向量的维度是512

词嵌入算法(Word Embedding):

将文本中的单词或短语转换为连续向量的技术

Positional Encoding:

表示输入序列中各个元素位置的一种技术;一般在输入嵌入之后,将位置编码与词嵌入结合做加法操作

添加位置编码原因:
  • 一句话中的同一个词,若词语出现的位置把不同,意思也不同;而Treansformer使用的self-attention不能获取词语的位置信息;若不添加位置编码,那么无论单词在什么位置,它的注意力分数都是确定的
  • 为了理解单词顺序,Transformer为每个输入的词嵌入添加了一个向量,这样能够更好的表达词与词之间的关系
获取方式:

论文中的Transformer使用的是正余弦位置编码;位置编码通过使用不同频率的正弦、余弦函数生成,然后和对应的位置的词向量相加,位置向量为度必须和词向量的维度一致;

位置编码公式:

Pos:单词在句子中的绝对位置,pos=0,1,2…;例如Jerry在“Tom chase Jerry”中的pos=2

dmodel:表示词向量的维度,在这里dmodel=512

i:表示词向量中的第几维;2i和2i+1用来表示奇偶性;例如dmodel=512,故i=0~255

将positional encoding与词向量相加,而不是拼接的原因:

拼接和相加都可以,只是本身词向量的维度512维就已经很大了,若再拼接512维的位置向量变成1024维,训练速度会变慢,从而影响效率

注意力机制:

自注意力机制:

Self-Attention结构:

Q、K、V计算:

若输入序列是Thinking Machines,x1、x2就是对应“Thinking”和“Machines”添加过位置编码之后的词向量,将x1通过embedding高维映射成为a1,将x2通过embedding高维映射成为a2;然后词向量通过三个权值矩阵Wq、Wk、Wv(对于所有的a都是共享的,可训练参数)转变成为计算Attention值所需的Query、Keys、Values

QKV含义:

Query:接下来用来匹配每一个Key

Keys:会被Q进行match的

Values:从a中学习到的有用信息

在transformer中可以实现并行化,因为是矩阵乘法

步骤:
  • 计算得分:通过所有输入句子的单词的键向量与“Thinking”的查询向量相点积来计算,分数决定了在编码单词“Thinking”的过程中有多重视句子的其它部分

  • 相关性得分归一化:对相似度矩阵除以根下dk,dk为K的维度大小,这个除法被称为scale;对于输入序列中每个单词之间的相关性得分进行归一化,归一化的目的主要是为了训练时梯度能够稳定,使方差变小

  • Softmax函数计算:通过softmax函数,将每个单词之间的得分向量转换为[0,1]之间的概率分布,同时更凸显单词之间的关系;使所有单词的分数归一化,得到的分数都是正值且和为1

  • Soft分数与矩阵相乘:根据每个单词之间的概率分布,然后乘上对应的values值
  • 对加权值向量求和:得到自注意力层再该位置的输出

整体计算过程图:

最终结果:

融合了输入序列中每个元素与其他元素的关系,从而生成了上下文相关的表示;表示了该单词与上下文之间的联系。

自注意力机制的输入是整个序列的特征,输出是考虑了序列中所有元素之间上下文关系的序列表示

多头注意力机制:

过程:

将生成的q、k、v分为两个头(多个头),按照Head的个数采用均分(线性映射)的方式;例如将q1分为q11、q12;将第二个索引相同的qkv的归类为head1,例如q11、k11、v11、q21、k21、v21归类为head1

对每一head执行self-attention

得到的向量进行concat拼接

为什么要乘W矩阵?

W使用线性投影,将拼接后的输出映射回原始维度;每个线性变化中都有W

步骤:

  • 输入原始的Value、Key、Query矩阵,后进入到一个线性层,投影到比较低的维度
  • 做上面的注意力机制做8次(假设使用了8个注意力头),得到8个输出
  • 将 8 个头的输出拼接在一起,形成一个新的向量
  • 线性变换(将拼接后的输出映射回原始的嵌入维度)

前馈层只需要一个矩阵,则把得到的8个矩阵拼接在一起,然后用一个附加的权重矩阵W与它们相乘

多头注意力机制的优点:

  • 扩展了模型专注于不同位置的能力
  • 有多个Query、Key、Value权重矩阵集合,(Transformer使用9个注意力头)并且每一个都是随机初始化的;和过程中所提到的一样,用矩阵X*Wq、Wk、Wv来产生Query、Key、Value矩阵
  • Self-attention只使用了一组Wq、Wk、Wv来进行变换得到Query、Key、Value矩阵,而Multi-Head Attention使用多组Wq、Wk、Wv得到多组Query、Key、Value矩阵,然后每个组分别计算得到一个Z矩阵

基于Encoder-Decoder 的Multi-Head Attention:

  • Encoder中的Multi-Head Attention是基于Self-Attention的
  • Decoder中的第二个Multi-Head Attention就只是基于Attention,它的输入Query来自于Masked Multi-Head Attention的输出,Keys和Values来自于Encoder中最后一层的输出。

Add&Normalize:

在经过多头注意力机制得到矩阵Z之后,并没有直接传入全连接神经网络,而是经过了一步Add&Normalize。

Add:

在Z的基础上加了一个残差块X,加入残差块的目的是为了防止在深度神经网络的训练过程中发生退化的问题,退化意为深度神经网络通过增加网络的层数,Loss逐渐减小,然后区域稳定达到饱和,然后再继续增加网络层数,Loss反而会增加

ResNet残差神经网络:

引入ResNet残差神经网络,神经网络退化指的是达到最优网络层数之后,神经网络还在继续训练导致Loss增大,对于多余的层,我们需保证多出来的网络进行恒等映射;只有进行了恒等映射后才能保证多出来的神经网络不会影响到模型的效果

残差块:

X为输入值,F(X)经过第一层线性变换后并且使用Relu激活函数输出,在第二层线性变换和激活函数之前,F(X)加入这一层的输入值X,完成相加操作后再进行Relu激活

恒等映射:

为防止过拟合,需要进行恒等映射;需要将F(X)=0;X经过线性变换(随机初始化权重一般偏向于0),输出值会明显偏向于0,经过激活函数Relu会将负数变为0,过滤了负数的影响;这样输出使左边那部分趋于0,输出就更接近于右边的数。

作用:这样当网络自己决定哪些网络层为冗余层时,使用ResNet的网络很大程度上解决了学习恒等映射的问题,用学习残差F(x)=0更新该冗余层的参数来代替学习h(x)=x更新冗余层的参数。

Normalize:

作用:

  • 加快训练的速度
  • 提高训练的稳定性

前馈网络(Feed-Forward Networks):

全连接层是一个两层的神经网络,先线性变换,然后Relu非线性变换,再进行一个线性变换

作用:

将输入的Z映射到更加高维的空间中,然后通过非线性函数Relu进行筛选,筛选后再变回原来的维度

经过6个encoder后输入到decoder中

Output输出:

Output如图所示,首先经过一次线性变换,然后softmax得到输出的概率分布,然后通过词典,输出概率最大的对应单词作为预测输出

VIT模型:

工作流程:

  • 将输入的图像进行patch的划分
  • Linear Projection of Flatted patches,将patch拉平并进行线性映射
  • 生成CLS token(用向量有效地表示整个输入图像的特征)特殊字符“*”,生成Position Embedding,用Patch+Position Embedding相加作为inputs token
  • Transformer Encoder编码,特征提取
  • MLP Head进行分类输出结果

图片分块嵌入:

输入图片大小为224x224,将图片切分为固定大小16x16的patch,则每张图像会生成224x224/16x16=196个patch,输出序列为196;

196个patch(196x3)通过线性投射层(3x768)之后的维度为196x768,即一共有196个token,每个token的维度是768。

Learnable embedding(可学习的嵌入):

它是一个可学习的向量,它用来表示整个输入图像的信息;这个向量的长度是 768。

Position Embedding:

如图,编号有0-9的紫色框表示各个位置的position embedding,而紫色框旁边的粉色框则是经过linear projection之后的flattened patch向量。这样每一个 token 既包括图像信息又包括了位置信息

ViT的Position Embedding采用的是一个可学习/训练的 1-D 位置编码嵌入,是直接叠加在tokens上的

Transformer Encoder:

Encoder输入的维度为[197, 768],输出的维度为[197, 768],可以把中间过程简单的理解成为特征提取的过程

MLP Head:

定义:

MLP Head 是指位于模型顶部的全连接前馈神经网络模块,用于将提取的图像特征表示转换为最终的分类结果或其他预测任务输出。MLP Head 通常跟在 Transformer Encoder 的输出之后,作为整个模型的最后一层。

MLP Block,如图下图所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]

Vision Transformer维度变换:

  • 输入图像的input shape=[1,3,224,224],1是batch_size,3是通道数,224是高和宽
  • 输入图像经过patch Embedding,其中Patch大小是14,卷积核是768,则经过分块后,获得的块数量为196,每个块的维度被转换为768,即得到的patch embedding的shape=[1,196,768]
  • 将可学习的[class] token embedding拼接到patch embedding前,得到shape=[1,197,768]
  • 将position embedding加入到拼接后的embedding中,组成最终的输入嵌入,最终的输入shape=[1,197,768]
  • 输入嵌入送入到Transformer encoder中,shape并不发生变化
  • 最后transformer的输出被送入到MLP或FC中执行分类预测,选取[class] token作为分类器的输入,以表示整个图像的全局信息,假设分类的类目为K KK,最终的shape=[1,768]*[768,K]=[1,K]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2219976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

opencv出错以及解决技巧

opencv配置 一开始&#xff0c;include的路径是<opencv4/opencv2/…> 这样在using namespace cv的时候导致了报错&#xff0c; 所以在cmakelist中需要对cmake的版本进行升级。 set(CMAKE_CXX_FLAGS “-stdc14 -O0 -Wall”)-O0 表示在编译过程中不进行任何优化 对应的pac…

Linux操作系统如何制作U盘启动盘

在麒麟系统中有一款U盘启动器软件&#xff0c;它是用于制作系统启动U盘的工具&#xff0c;方便无光驱的电脑安装操作系统&#xff0c;也可以反复使用一个U盘&#xff0c;避免光盘的浪费。下面对该U盘启动器使用方法做详细讲解。 1.准备需要安装的系统镜像文件。 图 1 2.准备1…

Node-RED开源项目的modbus通信(TCP)

一、Modbus 通信协议 Modbus是一种串行通信协议&#xff0c;是Modicon公司&#xff08;现在的施耐德电气 Schneider Electric&#xff09;于1979年为使用可编程逻辑控制器&#xff08;PLC&#xff09;通信而发表。Modbus已经成为工业领域通信协议的业界标准&#xff08;De fact…

Redis高阶篇之Redis单线程与多线程

文章目录 0 前言1. 为什么Redis是单线程&#xff1f;1.1 Redis单线程1.2 为什么Redis3时代单线程快的原因1.3 使用单线程原因 2.为什么逐渐加入多线程呢&#xff1f;2.1 如何解决 3.redis6/7的多线程特性和IO多路复用入门3.1主线程和IO线程怎么协作完成请求处理的3.2 Unix网络编…

政府采购合同公告明细数据(1996-2024年)

透明度成为了公众对政府活动的基本要求之一。特别是在政府采购领域&#xff0c;透明度不仅关系到公共资源的合理分配&#xff0c;更是维护市场公平竞争的重要保障。政府采购合同公告制度正是为了满足这一需求而设立的。 1996-2024年政府采购合同公告明细数据&#xff08;dta文…

Perl打印9x9乘法口诀

本章教程主要介绍如何用Perl打印9x9乘法口诀。 一、程序代码 1、写法① use strict; # 启用严格模式&#xff0c;帮助捕捉变量声明等错误 use warnings; # 启用警告&#xff0c;帮助发现潜在问题# 遍历 1 到 9 的数字 for my $i (1..9) {# 对于每个 $i&#xff0c;遍历 1…

Javascript 脚本查找B站限时免费番剧

目录 前言 脚本编写 脚本 前言 B站的一些番剧时不时会“限时免费”&#xff0c;白嫖党最爱&#xff0c;主打一个又占到便宜的快乐。但是在番剧索引里却没有搜索选项可以直接检索“限时免费”的番剧&#xff0c;只能自己一页一页的翻去查看&#xff0c;非常麻烦。 自己找限…

Git极速入门

git初始化 git -v git config --global user.name "" git config --global user.email "" git config --global credential.helper store git config --global --list省略(Local) 本地配置&#xff0c;只对本地仓库有效–global 全局配置&#xff0c;所有…

spring boot yml文件中引用*.properties文件中的属性

1、首先在*.properties文件中加入一个属性&#xff0c;如&#xff1a; 2、然后再application.yml文件中通过${jdbc.driver}来引用&#xff0c;如&#xff1a; 3、然后再创建一个资源配置类&#xff0c;通过PropertySource来引入这个*.properties文件&#xff0c;如&#xff1…

JDK中socket源码解析

目录 1、Java.net包 1. Socket通信相关类 2. URL和URI处理类 3. 网络地址和主机名解析类 4. 代理和认证相关类 5. 网络缓存和Cookie管理类 6. 其他网络相关工具类 2、什么是socket&#xff1f; 3、JDK中socket核心Api 4、核心源码 1、核心方法 2、本地方法 3、lin…

基于stm32的esp8266的WIFI控制风扇实验

实验案例&#xff37;&#xff29;&#xff26;&#xff29;控制风扇 项目需求 电脑通过esp8266模块远程遥控风扇。 项目框图 ​ 风扇模块封装 #include "sys.h" #include "fan.h"void fan_init(void) {GPIO_InitTypeDef gpio_initstruct;//打开时钟…

4K Mini-LED显示器平民价,一千多的联合创新27M3U到底有多香

哈喽小伙伴们好&#xff0c;我是Stark-C~ 要说前几年买显示器还是普通IPS的天下&#xff0c;那个时候虽说也有MiniLED或者OLED显示器&#xff0c;但是价格那也是真贵啊&#xff0c;毕竟那个时候MiniLED和OLED还没普及&#xff0c;只有一些高档电视或者显示器才会用到此技术。不…

OpenCV高级图形用户界面(18)手动设置轨迹条(Trackbar)的位置函数setTrackbarPos()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 该函数设置指定窗口中指定轨迹条的位置。 注意 [仅 Qt 后端] 如果轨迹条附加到控制面板&#xff0c;则 winname 可以为空。 函数原型 void cv…

三周精通FastAPI:4 使用请求从客户端(例如浏览器)向 API 发送数据

FastAPI官网手册&#xff1a;https://fastapi.tiangolo.com/zh/tutorial/query-params/ 上节内容&#xff1a;三周精通FastAPI&#xff1a;3 查询参数 请求 FastAPI 使用请求从客户端&#xff08;例如浏览器&#xff09;向 API 发送数据。 请求是客户端发送给 API 的数据。响…

国家信息安全水平考试(NISP一级)最新题库-第十六章

目录 另外免费为大家准备了刷题小程序和docx文档&#xff0c;有需要的可以私信获取 1 防火墙是一种较早使用、实用性很强的网络安全防御技术&#xff0c;以下关于防火墙说法错误的是&#xff08;&#xff09; A.防火墙阻挡对网络的非法访问和不安全数据的传递&#xff1b;B.防…

Leecode刷题之路第27天之移除元素

题目出处 27-移除元素-题目描述 题目描述 给你一个数组 nums 和一个值 val&#xff0c;你需要 原地 移除所有数值等于 val 的元素。元素的顺序可能发生改变。然后返回 nums 中与 val 不同的元素的数量。假设 nums 中不等于 val 的元素数量为 k&#xff0c;要通过此题&#x…

C++ | Leetcode C++题解之第491题非递减子序列

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<int> temp; vector<vector<int>> ans;void dfs(int cur, int last, vector<int>& nums) {if (cur nums.size()) {if (temp.size() > 2) {ans.push_back(temp);}return;}if…

【题解】—— LeetCode一周小结42

&#x1f31f;欢迎来到 我的博客 —— 探索技术的无限可能&#xff01; &#x1f31f;博客的简介&#xff08;文章目录&#xff09; 【题解】—— 每日一道题目栏 上接&#xff1a;【题解】—— LeetCode一周小结41 14.鸡蛋掉落 题目链接&#xff1a;887. 鸡蛋掉落 给你 k 枚…

c++迷宫游戏

1、问题描述 程序开始运行时显示一个迷宫地图&#xff0c;迷宫中央有一只老鼠&#xff0c;迷宫的右下方有一个粮仓。游戏的任务是使用键盘上的方向健操纵老鼠在规定的时间内走到粮仓处。 基本要求: 老鼠形象可以辨认,可用键盘操纵老鼠上下左右移动&#xff1b;迷宫的墙足够结…

博弈论学习笔记【施工中】

SG函数 首先定义就不用我讲了吧&#xff0c;还不会的自己看看 传送门 再进一步理解一下吧&#xff1a; 黑色数字是节点编号&#xff0c;红色是 S G SG SG 函数值 看下它的过程&#xff1a; 首先 5 5 5 和 6 6 6 没有后继节点&#xff0c;为必败态&#xff0c;先赋值为 …