self-attention(transformer)

news2024/10/7 16:18:33

自注意力机制

在传统的CNN中,都是对感受野内部的事情进行关联后理解。

感受野实际上关乎了模型对全局信息的理解。

而本质上,感受野是一种特殊的注意力机制,也就是说感受野是一种受限的、具有特定参数的注意力。

之前的内容如DANet,则更加接近广义的注意力机制。

在这种机制的作用下,像素与像素间的内容会产生相互联系,通过注意力权重矩阵,对图像中全局的信息进行提取和理解。

《attention is all your need》这篇文章中,提出了一种全新的注意力机制,其意义在于:

1. 从原理上解释了CNN和RNN的一般形式
2. 为NLP和CV的大一统模型建立的基础
3. 从后续的swin transformer的刷榜提供了理论和技术支撑

如果总结近10年的人工智能发展,除了alexnet之外,就当属transformer这篇文章。

序列到序列的模型

目前所接触的图像任务,都是单一输入到单一输出的情况。

如果在复杂任务中,则会出现输入一个序列,输出一个序列的场景。


这就是seq2seq模型。

seq2seq模型中,主要是用于NLP领域,例如分词、词性判断、语义问答等等。

显然,一个单词就是一个输入,一句话就组成了一个序列。

在图像中,如何使用序列输入?

2020年经典论文Vision Transformer中,就将这个问题建模为了一种图像patch输入的序列化处理。

其中,一张图像被不重叠地切割成了若干个小块。所有小块按顺序连接起来,组成输入序列。

这也是图像与自然语言处理大一统的关键步骤。

seq2seq模型中天然要解决的问题

输入序列中无法有效的沟通和交流,从而无法建立长效的序列间的通信机制,从而无法实现可靠的处理。
常规的对于序列输入的处理方法(如视频、多帧图像等等),一些常用的思路有什么?

1. 依次输入,依次输出

2. 合并成矩阵进行输入

3. RNN

显然,第一种方法缺少可靠性,因为无法沟通。

第二种方法不够灵活,因为需要固定长度输入。

第三种方法在此之前用的比较多,因为他能够充分沟通前后序列信息


然而,即便采用RNN的方法也存在局限,因为他只能关注前后有限个序列的关联。(长序列产生较高复杂度)

**序列到序列模型的最终目的**

seq2seq的最终目的在于,所输出的内容,都是充分考虑了其他序列中的内容所得到的。

一种通用的思路

通用的思路就如同全连接


显然,全连接并不可靠,计算复杂度太高容易过拟合都是他的缺点。 

这里,就有attention出场了。

self-attention的基本过程


序列输入后,每个序列进入自注意力模块,计算出对应的输出

这些输出综合了所有序列的信息,具有更加广泛的全局性特征。

这些特征可以进而作为下一个attention模块的输入,逐级泛化,从而得到可靠的输出。

上述过程就是transformer的基本过程。transformer采用了6层encoder和decoder

与s2s不同的是:s2s的encoder和decoder是RNN网络,而transformer完全抛弃了RNN

对于self-attention模块,可以简单的视为一个这样的模块:


在这样的模块中,b1的结果,是综合了a1到a4的多个输入的结果
那么如何实现这样的过程? 

**自注意模块的两个过程:** 

1. 计算序列间的相关性
2. 为每个输出加入不同的相关性权重


计算序列间的相关性

先简化到最基本的情况,对于两个样本来说,如何计算他们之间的相关性?

有这么几个步骤:

1. 先变换到一个便于度量的特征空间

2. 用一种方法将二者度量


这里采用第一种方式: 

1. 特征a和b分别经过两个不同的矩阵变换,得到embedding表达
2. embedding相乘之后,得到两者之间的相关性。
在self attention中,将所有的输入元素融合到一起,就可以得到两两间的相关性系数,即为权重系数。

注意,为了避免尺度不一致,这里需要对一个元素对其他序列元素的相关性的指标进行归一化(softmax)
具体上,如下图所示

加入权重


通过矩阵V来得到对应权重的输入,从而获得更好的表达。 

对于矩阵V来说,是综合了不同的权重的内容

总体框架


一些细节

1. 多头注意力
2. 位置信息

多头注意力机制

多头,指的是产生多组qkv矩阵,从而获得更加个性化的权重矩阵。

位置矩阵

将图像切成若干个patch,并将其随机打乱,得到的结果对于transformer来说没有任何区别

这显然是不正常的。所以,我们需要为每个序列中的输入加入位置信息,来补充信息。
所谓位置信息,就是生成一种和位置相关的特征编码,从而来获得更好的结果。


位置编码position codeing
transformer在较大数据集上性能比CNN好
qkv分别可以这样理解:q当前想知道的,k键值,v-value应该输出得到的值

自注意力与卷积


Transformer 

transformer是一种深度学习框架,在图像领域可以用于图像分类(ViT)以及经典图像处理任务的处理(SwinTransformer)
从功能上,transfomer分为两个部分:

1. encoder
2. decoder



input embedding:可以看做一个查找表,用来获取每个单词的学习向量表示
positional encoding:将位置信息嵌入向量中

编码器层:将所有输入序列映射到一个抽象的连续表示,包含了整个序列的学习信息。有两个子模块:多头注意力和一个全连接网络,两个子模块周围由残差连接,后面一次标准化层

多头注意力应用了一种特定的注意力机制:自注意力self attention

为实现自注意力,将输入分别送入三个全连接层,来创建查询向量q,键向量k,值向量v

自注意力过程:

1. 查询q映射到键k:q和v进行点乘得到一个分数矩阵,分数矩阵确定了一个单词应该怎样关注其他单词,分数越高,关注度越高。

将查询和键的维度开方将分数缩放,为了让梯度更稳定(因为乘法有可能产生梯度爆炸)

2. 然后对缩放后的得分进行softmax(归一化)《其中较高的得分会被加强,较低的得分会被抑制》,得到0-1之间的概率值。

3. 将注意力权重与值v相乘得到输出,输出向量送入线性层处理

4. 将多头注意力输出向量加到原始输入上,即残差连接。连接后输出经过层归一化后送入MLP

多头:即n组q,k,v经过同样的自注意力过程,每个自注意力过程称为一个头,输出经过线性层之前,拼接为一个向量。

理论上,没个头都会学到不同的东西。

编码层所有的操作是为了:将输入编码变为带有注意力信息的连续表示。从而帮助解码器在解码时关注输入的适当词汇信息

解码器层:输入通过嵌入层和位置编码层得到位置嵌入,然后降维位置嵌入送到第一个多头注意力层,计算解码器输入的注意力得分。与上面的多头注意力不同的是:

带有masked:为了防止解码器关注该单词之后出现的单词(即未来的标记)

masked是一个与注意力得分矩阵大小相同的矩阵,有0和负无穷(经过softmax计算得到0,将当前词对后出现的词的注意力得分变为0)填充。

第一个多头注意力层输出(是值)带有掩码的向量,包含模型如何关注解码器的输入信息

第二个多头注意力层,编码器输出q,k,v ,将编码器输入和解码器输入进行匹配,让解码器决定哪个解码器输入是相关焦点。

关于VIT 

vision transformer是用于视觉任务的transformer
代码

import torch
from torch import nn, einsum
import torch.nn.functional as F
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim), 
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)
   
class Attention(nn.Module):              
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout),
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)           # (b, n(65), dim*3) ---> 3 * (b, n, dim)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)          # q, k, v   (b, h, n, dim_head(64))

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

除以dk(特征维度):防止梯度爆炸 ------  缩放点积注意力 

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))
    
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

TYD:transformer是一个特征提取器,数据经过transformer,将不同的特征突出或下降,把“平平无奇”的数据变为“波澜起伏”的特征,送给计算机,计算机更喜欢这种“对比鲜明”的特征。

总之,多头注意力是一个模块,用于计算注意力权重并生成一个带有编码信息的输出向量,指示序列中(每个词如何关注其他所以的词)每个小图片如何关注其他小patch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/669382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于webpack开发vue-cli

一、vue-cli开发 1. 项目整体目录 2. package.json {"name": "vue-cli","version": "1.0.0","description": "","main": "index.js","scripts": {"start": "npm …

机器学习常识 23: U-Net

摘要: U-Net 集编码-解码于一体, 是一种常见的网络架构. 图 1. U-Net 例. 如图 1 所示, U-Net 就是 U 形状的网络, 前半部分 (左边) 进行编码, 后半部分 (右边) 进行解码. 编码部分, 将一个图像经过特征提取, 变成一个向量. 前面说过: 深度学习本质上只做件事情, 就是特征提取…

【ESXi 7.x/8.x】ESXi 配置备份与还原

目录 1. 使用 ESXi命令行备份数据(1)将已更改的配置与持久存储同步(2)备份 ESXi 主机的配置数据(3)下载配置文件通过浏览器下载配置文件通过wget命令下载 (4)注意事项 2. 还原 ESXi …

基于Java班主任助理系统设计实现(源码+lw+部署文档+讲解等)

博主介绍: ✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精…

一直报错npm ERR! cb() never called!删除缓存仍然不行

看到npm下载包出错, 通常我们会手动删除node-modules这个文件夹来解决. 但是往往现实很骨感, 然后我们会找网上各种方法来解决, 比如这篇文章 但是当所有方法都尝试了一遍, 仍然还是出错, 这到底是什么原因呢? 可以使用npm config ls 查看一下我们电脑上是否会有一份.npmrc…

前端Vue自定义顶部搜索框 热门搜索 历史搜索 用于搜索跳转使用

前端Vue自定义顶部搜索框 热门搜索 历史搜索 用于搜索跳转使用&#xff0c; 下载完整代码请访问uni-app插件市场地址&#xff1a;https://ext.dcloud.net.cn/plugin?id13128 效果图如下&#xff1a; #### 自定义顶部搜索框 用于搜索跳转使用方法 使用方法 <!-- 自定义顶…

【MySQL新手入门系列四】:手把手教你MySQL数据查询由入门到学徒

SQL语言是与数据库交互的机制&#xff0c;是关系型数据库的标准语言。SQL语言可以用于创建、修改和查询关系数据库。SQL的SELECT语句是最重要的命令之一&#xff0c;用于从指定表中查询数据。在此博客中&#xff0c;我们将进一步了解SELECT语句以及WHERE子句以及它们的重要性。…

PCB设计实验|第一周|2月27日

目录 一、实验原理 二、实验环境 三、实验结果 四、实验总结 一、实验原理 Altium Designer 提供了唯一一款统一的应用方案&#xff0c;其综合电子产品一体化开发所需的所有必须技术和功能。Altium Designer 在单一设计环境中集成板级和FPGA系统设计、基于FPGA和分立处理器的…

Axure基础:中继器与热区

一、中继器 1、中继器的主要作用 中继器就是临时的数据库&#xff0c;在我们需要当前原型图存储和变更一些数据的时候会经常用到。 能用到中继器的一般都是高保真原型&#xff0c;如果不需要大量的数据动态展示&#xff0c;那么几乎用不到或者搞动态面板也可以实现。 下面我…

selenium之元素定位

一、selenium安装 pip3 install selenium 二、安装浏览器驱动 1&#xff1a;下载最新的浏览器驱动 chrome浏览器驱动下载地址&#xff1a; https://registry.npmmirror.com/binary.html?pathchromedriver/ 查看自己电脑上安装的chrome浏览器版本号&#xff0c;驱动和浏览…

【工程实践】python实现多进程

1 多线程与多进程 Python中比较常见的并发方式主要有两种&#xff1a;多线程和多进程。 1-1 多线程 多线程即在一个进程中启动多个线程执行任务。一般来说使用多线程可以达到并行的目的&#xff0c;但由于Python中使用了全局解释锁GIL的概念&#xff0c;导致Python中的多线程并…

Doo Prime 德璞资本:怎么买原油期货?原油期货买卖策略分享

随着中国经济市场的逐步开放&#xff0c;也为了快速和国际金融接轨&#xff0c;我国于2018年3月26日正式挂牌上市原油期货交易。并且我国的原油期货交易合约和美原油期货和布伦特原油期货交易是相互影响的&#xff0c;这让中国投资者可以足不出户的进行原油期货投资。那么在国内…

Jconsole 开启远程连接遇到的一些坑

最近在学习 JVM&#xff0c;其中涉及到性能、内存等指标分析需要使用工具分享&#xff0c;Java 提供了几个可视化工具来监控和管理 Java 应用&#xff0c;比如 Jconsole、JVisual、JMC&#xff0c;他们以图形化的界面实时的监控程序各种性能指标以及内存、CPU 的使用情况。 Jco…

Triton教程 --- 模型管理

Triton教程 — 模型管理 Triton系列教程: 快速开始利用Triton部署你自己的模型Triton架构模型仓库存储代理模型设置优化动态批处理速率限制器 Triton 提供的模型管理 API 是 HTTP/REST 和 GRPC 协议的一部分&#xff0c;也是 C API 的一部分。 Triton 以三种模型控制模式之一…

5.实用干货-你可能没留意的几个生信基础

Reads&#xff1a;高通量测序平台产生的序列。 Raw Reads&#xff1a;原始下机数据称为Raw Reads&#xff08;Raw data&#xff09;。 Clean Reads&#xff1a;通过生物信息的方法&#xff0c;去除一些质量差的reads&#xff08;比如测序错误&#xff0c;长度小于20的reads&a…

TypeScript零基础入门之背景介绍和环境安装

一、什么是TypeScript TypeScript是一种由微软开发和维护的开源编程语言。它是JavaScript的超集&#xff0c;意味着任何JavaScript程序都是一种有效的TypeScript程序。TypeScript添加了静态类型、类、接口、枚举和命名空间等概念&#xff0c;同时支持ES6特性。TypeScript被视为…

Flutter 初探原生混合开发

转载请注明出处&#xff1a;https://blog.csdn.net/kong_gu_you_lan/article/details/131320733?spm1001.2014.3001.5501 本文出自 容华谢后的博客 0.写在前面 现如今跨平台技术被越来越多的开发者提起和应用&#xff0c;从最早的Java到后来的RN、Weex&#xff0c;到现在的Co…

每日学术速递6.11

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.CV 1.Video-ChatGPT: Towards Detailed Video Understanding via Large Vision and Language Model 标题&#xff1a;Video-ChatGPT&#xff1a;通过大型视觉和语言模型实现详细的视频理…

SCI论文插图怎么做?有这一篇文章就够了

SCI插图的整体要求 SCI杂志种类很多&#xff0c;对插图的要求也各有不同&#xff0c;但是以下几条是通用的&#xff1a; 1. 插图尺寸要符合SCI期刊要求 2. 同篇文稿插图中文字须统一字号及字体 3. 须提交SCI期刊指定文件类型的插图 4. 插图文件命名须符合SCI期…

C++基础(15)——STL常用算法(遍历和查找)

前言 本文介绍了C中STL常用遍历和查找算法。 9.1&#xff1a;常用遍历算法&#xff08;for_each、transform&#xff09; 9.1.1&#xff1a;foreach for_each&#xff1a;遍历容器&#xff0c;transform&#xff1a;搬运一个容器中的数据到另一个容器中 for_each中使用普通…