自然语言处理---Self Attention自注意力机制

news2024/10/7 8:22:24

Self-attention介绍

Self-attention是一种特殊的attention,是应用在transformer中最重要的结构之一。attention机制,它能够帮助找到子序列和全局的attention的关系,也就是找到权重值wi。Self-attention相对于attention的变化,其实就是寻找权重值的wi过程不同。

  • 为了能够产生输出的向量yi,self-attention其实是对所有的输入做了一个加权平均的操作,这个公式和上面的attention是一致的。
  • j代表整个序列的长度,并且j个权重的相加之和等于1。值得一提的是,这里的 wij并不是一个需要神经网络学习的参数,它是来源于xi和xj的之间的计算的结果(这里wij的计算发生了变化)。它们之间最简单的一种计算方式,就是使用点积的方式。
  • xi和xj是一对输入和输出。对于下一个输出的向量yi+1,有一个全新的输入序列和一个不同的权重值。

  • 这个点积的输出的取值范围在负无穷和正无穷之间,所以要使用一个softmax把它映射到[0,1] 之间,并且要确保它们对于整个序列而言的和为1。
  • 以上这些就是self-attention最基本的操作。

Self-attention和Attention使用方法

根据他们之间的重要区别,可以区分在不同任务中的使用方法:

  • 在神经网络中,通常来说会有输入层(input),应用激活函数后的输出层(output),在RNN当中会有状态(state)。如果attention (AT) 被应用在某一层的话,它更多的是被应用在输出或者是状态层上,而当使用self-attention(SA),这种注意力的机制更多的实在关注input上。
  • Attention (AT) 经常被应用在从编码器(encoder)转换到解码器(decoder)。比如说,解码器的神经元会接受一些AT从编码层生成的输入信息。在这种情况下,AT连接的是**两个不同的组件**(component),编码器和解码器。但是如果用**SA**,它就不是关注的两个组件,它只是在关注应用的**那一个组件**。那这里就不会去关注解码器了,就比如说在Bert中,使用的情况,就没有解码器。
  • SA可以在一个模型当中被多次的、独立的使用(比如说在Transformer中,使用了18次;在Bert当中使用12次)。但是,AT在一个模型当中经常只是被使用一次,并且起到连接两个组件的作用。
  • SA比较擅长在一个序列当中,寻找不同部分之间的关系。比如说,在词法分析的过程中,能够帮助去理解不同词之间的关系。AT却更擅长寻找两个序列之间的关系,比如说在翻译任务当中,原始的文本和翻译后的文本。这里也要注意,在翻译任务重,SA也很擅长,比如说Transformer。
  • AT可以连接两种不同的模态,比如说图片和文字。SA更多的是被应用在同一种模态上,但是如果一定要使用SA来做的话,也可以将不同的模态组合成一个序列,再使用SA。
  • 其实有时候大部分情况,SA这种结构更加的general,在很多任务作为降维、特征表示、特征交叉等功能尝试着应用,很多时候效果都不错。

Self-attetion实现步骤

  • 这里实现的注意力机制是现在比较流行的点积相乘的注意力机制
  • self-attention机制的实现步骤
    • 第一步: 准备输入
    • 第二步: 初始化参数
    • 第三步: 获取key,query和value
    • 第四步: 给input1计算attention score
    • 第五步: 计算softmax
    • 第六步: 给value乘上score
    • 第七步: 给value加权求和获取output1
    • 第八步: 重复步骤4-7,获取output2,output3

1. 准备输入

# 这里随机设置三个输入, 每个输入的维度是一个4维向量
import torch
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
]
x = torch.tensor(x, dtype=torch.float32)

2. 初始化参数

# 每一个输入都有三个表示,分别为key(橙黄色),query(红色),value(紫色)。
# 每一个表示,希望是一个3维的向量。由于输入是4维,所以参数矩阵为 4*3 维。

# 为了能够获取这些表示,每一个输入(绿色)要和key,query和value相乘

# 在例子中,使用如下的方式初始化这些参数。
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("w_key: \n", w_key)
print("w_query: \n", w_query)
print("w_value: \n", w_value)

3. 获取key,query和value

# 使用向量化获取keys的值
                    [0, 0, 1]
[1, 0, 1, 0]    [1, 1, 0]    [0, 1, 1]
[0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0]
[1, 1, 1, 1]    [1, 1, 0]    [2, 3, 1]

# 使用向量化获取values的值
                    [0, 2, 0]
[1, 0, 1, 0]    [0, 3, 0]    [1, 2, 3] 
[0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0]
[1, 1, 1, 1]    [1, 1, 0]    [2, 6, 3]

# 使用向量化获取querys的值
                    [1, 0, 1]
[1, 0, 1, 0]    [1, 0, 0]    [1, 0, 2]
[0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2]
[1, 1, 1, 1]    [0, 1, 1]    [2, 1, 3]

# 将query key  value分别进行计算
keys = x @ w_key
querys = x @ w_query
values = x @ w_value
print("Keys: \n", keys)
print("Querys: \n", querys)
print("Values: \n", values)

4. 给input1计算attention score

# 获取input1的attention score,使用点乘来处理所有的key和query,包括自己的key和value。
# 这样就能够得到3个key的表示(因为有3个输入),就获得了3个attention score(蓝色)
                [0, 4, 2]
[1, 0, 2] x [1, 4, 3] = [2, 4, 4]
                [1, 0, 1]

# 注意: 这里只用input1举例。其他的输入的query和input1做相同的操作.

attn_scores = querys @ keys.T
print(attn_scores)

5. 计算softmax

from torch.nn.functional import softmax

attn_scores_softmax = softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
attn_scores_softmax = [
  [0.0, 0.5, 0.5],
  [0.0, 1.0, 0.0],
  [0.0, 0.9, 0.1]
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)

softmax([2, 4, 4]) = [0.0, 0.5, 0.5]

6. 给value乘上score

使用经过softmax后的attention score乘以它对应的value值(紫色),这样就得到了3个weighted values(黄色)

1: 0.0 * [1, 2, 3] = [0.0, 0.0, 0.0]
2: 0.5 * [2, 8, 0] = [1.0, 4.0, 0.0]
3: 0.5 * [2, 6, 3] = [1.0, 3.0, 1.5]

weighted_values = values[:,None] * attn_scores_softmax.T[:,:,None]
print(weighted_values)

7. 给value加权求和获取output1

把所有的weighted values(黄色)进行element-wise的相加。

   [0.0, 0.0, 0.0]

+ [1.0, 4.0, 0.0]

+ [1.0, 3.0, 1.5]

------------------------

= [2.0, 7.0, 1.5]

得到结果向量[2.0, 7.0, 1.5](深绿色)就是ouput1的和其他key交互的query representation

8. 重复步骤4-7,获取output2,output3

outputs = weighted_values.sum(dim=0)
print(outputs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1121404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目总结-商品购买流程

(1)添加购物车 Controller: CartService: 实现类: CartDetail detaildao.queryByCdid(cid,gds.getId()); CartDao: //获取详情对象Select("select * from t_cartdetail where cid#{cid} and gid#{gid…

buu第五页 wp

[RootersCTF2019]babyWeb 预期解 一眼就是sql注入,发现过滤了 UNION SLEEP " OR - BENCHMARK盲注没法用了,因为union被过滤,堆叠注入也不考虑,发现报错有回显,尝试报错注入。 尝试: 1||(updatex…

ubuntu20.04下安装nc

前言 nc在网络渗透测试中非常好用,这里的主要记一下Ubuntu20.04中nc的安装 编译安装 第一种方式是自己编译安装,先下载安装包 nc.zip wget http://sourceforge.net/projects/netcat/files/netcat/0.7.1/netcat-0.7.1.tar.gz/download -O netcat-0.7.…

anyproxy 的安装和抓包使用

简介 AnyProxy是阿里开发的开源的代理服务器,主要特性包括: 基于Node.js,开放二次开发能力,允许自定义请求处理逻辑支持Https的解析提供GUI界面,用以观察请求 安装运行Anyproxy 首先需要电脑由安装 node&#xff0…

H3C SecParh堡垒机 data_provider.php 远程命令执行漏洞

构造poc执行远程命令: /audit/data_provider.php?ds_y2019&ds_m04&ds_d02&ds_hour09&ds_min40&server_cond&service$(id)&identity_cond&query_typeall&formatjson&browsetrue漏洞证明: 文笔生疏&#xff0c…

【大模型应用开发教程】02_LangChain介绍

LangChain介绍 什么是 LangChain1. 模型输入/输出2. 数据连接3. 链(Chain)4. 记忆(Meomory)5. 代理(Agents)6.回调(Callback)在哪里传入回调 ?你想在什么时候使用这些东西呢&#x…

1024常玩到的漏洞(第十六课)

1024常玩到的两个漏洞(第十六课) 漏洞扫描工具 1024渗透OpenVas扫描工具使用(第十四课)-CSDN博客 流程 一 ms12-020漏洞分析 MS12-020漏洞是一种远程桌面协议(RDP)漏洞。在攻击者利用该漏洞之前,它需要将攻击者的计算机连接到受害者的计算机上。攻击者可以通过向受害者计算…

跟着NatureMetabolism学作图:R语言ggplot2转录组差异表达火山图

论文 Independent phenotypic plasticity axes define distinct obesity sub-types https://www.nature.com/articles/s42255-022-00629-2#Sec15 s42255-022-00629-2.pdf 论文中没有公开代码,但是所有作图数据都公开了,我们可以试着用论文中提供的数据…

Linux进程与线程的内核实现

进程描述符task_struct 进程描述符(struct task_struct)pid与tgid进程id编号分配规则内存管理mm_struct进程与文件,文件系统 进程,线程创建的本质 clone函数原型线程创建的实现进程创建的实现 总结 进程描述符task_struct 进程描述符(st…

自动驾驶的商业应用和市场前景

自动驾驶技术已经成为了交通运输领域的一项重要创新。它不仅在改善交通安全性和效率方面具有巨大潜力,还为各种商业应用提供了新的机会。本文将探讨自动驾驶在交通运输中的潜力,自动驾驶汽车的制造商和技术公司,以及自动驾驶的商业模式和市场…

Git GUI工具:SourceTree代码管理

Git GUI工具:SourceTree SourceTreeSourceTree的安装SourceTree的使用 总结 SourceTree 当我们对Git的提交、分支已经非常熟悉,可以熟练使用命令操作Git后,再使用GUI工具,就可以更高效。 Git有很多图形界面工具,这里…

JAVA高级教程Java 泛型(10)

目录 三、泛型的使用泛型类泛型接口泛型方法 四、泛型在集合中的使用1、使用泛型来创建set2、使用泛型来创建HashSet3、使用equals,hashCode的使用Person 类HashSet的使用泛型TreeSet的使用泛型comparator实现定制比较(比较器) 三、泛型的使用 泛型类语法 类名T表示…

散列表:如何打造一个工业级水平的散列表?

文章来源于极客时间前google工程师−王争专栏。 散列表的查询效率并不能笼统地说成是O(1)。它跟散列函数、装载因子、散列冲突等都有关系。如果散列函数设计得不好,或者装载因子过高,都可能导致散列冲突发生的概率升高,查询效率下降。 极端情…

在ESP32上使用Arduino(Arduino as an ESP-IDF component)

目录 前言 原理说明 操作步骤 下载esp-arduino 安装esp-arduino 工程里配置arduino 1、勾选该选项,工程将作为一个标准的arduino程序工作 2、不勾选该选型,工程将作为一个传统的嵌入式项目开发, 前言 Arduino拥有丰富的各类库&#…

一款WPF开发的网易云音乐客户端 - DMSkin-CloudMusic

前言 今天推荐一款基于DMSkin框架开发的网易云音乐播放器:DMSkin-CloudMusic。 DMSkin 框架介绍 DMSkin是一个开源的WPF样式UI框架,可以帮助开发者快速创建漂亮的用户界面。 下载体验 下载地址:https://github.com/944095635/DMSkin-Clou…

如何使用vim粘贴鼠标复制的内容

文章目录 一、使用步骤1.找到要编辑的配置文件2.找到目标文件3.再回到vim编辑器 一、使用步骤 1.找到要编辑的配置文件 用sudo vim /etc/apt/sources.list编辑软件源配置文件 sudo vim /etc/apt/sources.listvim 在默认的情况下当鼠标选中的时候进入的 Visual 模式&#xff…

加法器:如何像搭乐高一样搭电路(上)?

目录 背景 异或门和半加器 全加器 小结 补充阅读 背景 上一讲,我们看到了如何通过电路,在计算机硬件层面设计最基本的单元,门电路。我给你看的门电路非常简单,只能做简单的 “与(AND)”“或&#xff…

UG\NX二次开发 获取用户默认设置中的绘图信息 UF_PLOT_ask_session_job_options

文章作者:里海 来源网站:《里海NX二次开发3000例专栏》 感谢粉丝订阅 感谢 m0_58724732 订阅本专栏,非常感谢。 简介 UG\NX二次开发 获取用户默认设置中的绘图信息 UF_PLOT_ask_session_job_options 效果 代码 #include "me.hp

window11安装Python环境

python环境安装 访问Python官网:https://www.python.org/ 点击downloads按钮,在下拉框中选择系统类型(windows/Mac OS/Linux等) 选择下载最新版本的Python cmd命令如果出现版本号以及>>>则表示安装成功 如果出现命令行中输入python出现如下错误 可能…

【SSA-RFR预测】基于麻雀算法优化随机森林回归预测研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…