自注意力简介

news2024/9/21 18:37:11

在注意力机制中,每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q,键k和值v都来自于同一组输入,那么这个注意力就被称为是自注意力(self-attention)。自注意力这部分理论,我觉得台大李宏毅老师的课程讲得最好。

自注意力就是输入一堆向量,假设称为a1,a2,a3,a4,那么这四个向量都会参与自注意力机制的运算,得到的结果仍然是四个输出,这四个输出再去做全连接运算。而每一个自注意力机制的输出都用到了a1~a4四个向量来进行运算,也就是说每个输出都是观察了所有的输入之后才得到的。

首先,输入a1需要和a2,a3,a4分别计算相关性,这个相关性可以由缩放点积方式来计算,也就称作缩放点积注意力,也可以由两个两个输入向量相加后再做非线性处理得到,称为加性注意力。

缩放点积的计算方法如下:

输入一个向量v1和一个向量v2,v1去乘上一个可训练矩阵Wq得到q,v2去乘上一个可训练矩阵Wk得到k,再把这个q和k做一个点积运算,得到的就是α,类似于相似度。

回到前面的例子中,a1这里既作为q,又作为k,又作为v。其中,Wq*a1就是q1,Wk*a1就是k1,q1和k1的点积就是α11,相当于a1自己和自己的相似度,同样的,a1和a2,a3,a4分别计算得到α12,α13,α14,然后将α11,α12,α13,α14经过softmax得到最终的四个输出,如下图所示:

然后再用一个可训练矩阵Wv去乘以a1得到v1,用计算得到的相似度α'11去乘以v1,得到一个值temp11;同样的,用可训练矩阵Wv去乘以a2得到v2,用计算得到的相似度α‘12去乘以v2,得到temp12;类似的得到temp13,temp14,然后把temp11+temp12+temp13+temp14得到b1,这个b1就是自注意力机制的第一个输出。

我们刚刚是以a1的视角做的运算,得到b1,同样可以以a2,a3,a4的视角做运算,得到b2,b3,b4。这次就得到了自注意力机制的输出。光看最后这个结构图,有点类似全连接,只是里面的运算过程比全连接要复杂。

下面,我们来看一下如何用代码实现自注意力的计算。

import torch  
import torch.nn as nn  
import torch.nn.functional as F  

class SelfAttention(nn.Module):
    # embed_size代表输入的向量维度,heads代表多头注意力机制中的头数量
    def __init__(self, embed_size, heads): 
        super(SelfAttention,self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads # 每个头的维度
        # 用assert断言机制判断
        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads" 
        # 没有偏置项,其实这个线性层本质上就是为了计算值Wv*a = V
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        # 最后的全连接操作,输出仍是输入的向量维度,也就是说大小是不变的
        self.fc_out = nn.Linear(heads*self.head_dim, self.embed_size)
        
    def forward(self, values, keys, query, mask):
        # 这个mask也很关键,它用于控制模型在处理序列数据时应该关注哪些部分,以及忽略哪些部分
        N = query.shape[0] # 获取输入的批量个数
        print("N:",N)
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 获取输入序列的长度
        # Split the embedding into self.heads different pieces  
        # 把k,q,v都切分为多个组
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        # 计算k,q,v
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        energy = torch.einsum("nqhd,nkhd->nhqk",[queries, keys]) # 格式转化
        print("queries.shape:", queries.shape)
        print("keys.shape:", keys.shape)
        print("energy.shape:", energy.shape)
        
        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20"))
            
        attention = torch.softmax(energy/(self.embed_size**(1/2)), dim=3) # softmax内部是缩放点积
        print("attention.shape:", attention.shape)
        print("values.shape:", values.shape)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        out = self.fc_out(out)
        return out
embed_size = 512
heads = 8
attention = SelfAttention(embed_size, heads)

# batch size 1, seq length 60
values = torch.rand(1,60,embed_size)
keys = torch.rand(1,60,embed_size)
queries = torch.rand(1,60,embed_size)
mask = None # 假设没有mask

out = attention(values, keys, queries, mask)
print(out.shape)

# 输出
N: 1
queries.shape: torch.Size([1, 60, 8, 64])
keys.shape: torch.Size([1, 60, 8, 64])
energy.shape: torch.Size([1, 8, 60, 60])
attention.shape: torch.Size([1, 8, 60, 60])
values.shape: torch.Size([1, 60, 8, 64])
torch.Size([1, 60, 512])

通过这个程序,我们可以看到,自注意力机制是不改变输入和输出的形状的,输入的Q,K,V格式是[1,60,512],输出的结果的仍然是[1,60,512]。

下面是几点说明:

1. 这里的embed_size代表的是输入到自注意力层中的每个元素的向量维度。在Transformer模型中,输入数据首先会被转换成一个固定长度的向量,这个向量的长度就称为embed_size。

2. mask表示的是模型在处理序列数据时,应该忽略掉哪部分,我这里设置为None,也就是全部参与计算。

3. einsum,称为爱因斯坦求和,起源是爱因斯坦在研究广义相对论时,需要处理大量求和运算,为了简化这种繁复的运算,提出了求和约定,推动了张量分析的发展。einsum 可以计算向量、矩阵、张量运算,如果利用得当,sinsum可完全代替其他的矩阵计算方法。

例如,C = einsum('ij,jk->ik', A, B),就相当于两个矩阵求内积:cik = Σj AijBjk。

通过输出可以看到,在计算前queries的形状是[1,60,8,64],keys的形状是[1,60,8,64],在表达式"nqhd,nkhd->nhqk"中,n=1,q=60,h=8,d=64,k=60,两个矩阵进行内积,因此得到的结果是nhqk,也就是[1,8,60,60]。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1916218.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一位互联网公司项目经理繁忙的一天

早晨:准备与计划 7:00 AM - 起床与准备 项目经理起床后,快速洗漱并享用早餐。之后花几分钟查看手机上的邮件和消息,确保没有紧急事务需要立即处理。 7:30 AM - 通勤时间 前往公司。在通勤途中,通过手机或平板电脑查看当天的会议…

硅谷甄选运营平台-vue3组件通信方式

vue3组件通信方式 vue2组件通信方式: props:可以实现父子组件、子父组件、甚至兄弟组件通信自定义事件:可以实现子父组件通信全局事件总线$bus:可以实现任意组件通信pubsub:发布订阅模式实现任意组件通信vuex:集中式状态管理容器,实现任意组件通信ref:父…

简单了解下安全测试!

一、基本概念 安全测试是在软件产品开发基本完成时,验证产品是否符合安全需求定义和产品质量标准的过程。它主要检查系统对非法侵入渗透的防范能力,旨在通过全面的脆弱性安全测试,发现系统未知的安全隐患并提出相关建议,确保系统…

星申刹车盘平衡机:精准与高效兼备

星申动双工位刹车盘动平衡机是一款具备测量和切削两个工位的高精度设备。该机器由平衡测量单元、内部搬运系统、切削校正模块及电气控制部分组成,专为满足自动化生产需求设计。其主要功能特点包括: 1. 实现全自动刹车盘平衡测试,精度高达30gm…

vue3 ts 报错:无法找到模块“../views/index/Home.vue”的声明文件

解决办法: env.d.ts 新增代码片段: declare module "*.vue" {import type { DefineComponent } from "vue";// eslint-disable-next-line typescript-eslint/no-explicit-any, typescript-eslint/ban-typesconst component: Define…

ExtruOnt——为工业 4.0 系统描述制造机械类型的本体

概述 论文地址 :https://arxiv.org/abs/2401.11848 原文地址:https://ai-scholar.tech/articles/ontology/ExtruOnt 在工业 4.0 应用场景中,以机器可解释代码提供的、语义丰富的制造机械描述可以得到有效利用。然而,目前显然还缺…

使用nvm安装node包后,安装vue提示“vue不是内部或外部命令,也不是可运行的程序或批处理命令”

前言 使用 npm 安装了 vue-cli 后,输入 "vue -V" 查询vue版本命令提示: “vue不是内部或外部命令,也不是可运行的程序或批处理命令”。 解决方法 第一步:首先,查看 C 盘下有没有 npm 文件夹。 目录类似于&#xff1…

六、数据可视化—Echars(爬虫及数据可视化)

六、数据可视化—Echars(爬虫及数据可视化) Echarts应用 Echarts Echarts官网,很多图表等都是我们可以 https://echarts.apache.org/zh/index.html 是百度自己做的图表,后来用的人越来越多,捐给了orange组织&#xf…

网络渗透CTF实践:获取靶机Web Developer 文件/root/flag.txt中flag

实验目的:通过对目标靶机的渗透过程,了解CTF竞赛模式,理解CTF涵盖的知识范围,如MISC、PPC、WEB等,通过实践,加强团队协作能力,掌握初步CTF实战能力及信息收集能力。熟悉网络扫描、探测HTTP web服…

交易员需要克服的十大心理问题

撰文:Koroush AK 编译:Chris,Techub News 本文来源香港Web3媒体:Techub News 一个交易者在交易上所犯下的最大的错误可能更多来自于心态的失衡而并非技术上的失误,类似的情况已经发生在了无数交易者身上。作为交易者…

如何压缩pdf文件大小,怎么压缩pdf文件大小

在数字化时代,pdf文件因其稳定的格式和跨平台兼容性,成为了工作与学习中不可或缺的一部分。然而,随着pdf文件内容的丰富,pdf文件的体积也随之增大,给传输和存储带来了不少挑战。本文将深入探讨如何高效压缩pdf文件大小…

C++入门 模仿mysql控制台输出表格

一、 说明 控制台输出表格&#xff0c;自适应宽度 二、 源码 #include <iostream> #include <map> #include <string> #include <vector>using namespace std;void printTable(vector<vector<string>> *pTableData) {int row pTableDa…

前端八股文 箭头函数和普通函数的区别

箭头函数是匿名函数&#xff0c;不能作为构造函数&#xff0c;不能使用new箭头函数不绑定 arguments &#xff0c;取而代之用 rest 参数...解决箭头函数不绑定 this &#xff0c;会捕获其所在的上下文的this值&#xff0c;作为自己的this值箭头函数通过 call() 或 apply() 方法…

【学术会议征稿】第四届新材料与化学工程国际学术会议(AMCE 2024)

第四届新材料与化学工程国际学术会议&#xff08;AMCE 2024&#xff09; 2024 4th International Conference on Advanced Materials and Chemical Engineering 为了促进我国新材料与化学工程领域绿色、规范、持续、健康发展,提升科技创新能力&#xff0c;推进学科交叉融合和…

苹果笔记本电脑能玩哪些游戏 苹果电脑可以玩的单机游戏推荐

苹果笔记本有着优美的外观和强大的性能。用户不仅可以使用苹果笔记本办公、剪辑&#xff0c;越来越多的用户开始关注苹果笔记本在游戏领域的表现&#xff0c;尤其是在大型游戏方面。本文将为你详细介绍苹果笔记本都能玩什么游戏&#xff0c;以及为你推荐苹果电脑可以玩的单机游…

浙江宁波G761-3005B穆格伺服阀 有货

G761-3005B穆格伺服阀是一种用于流体控制的阀门。 宁波秉圣与各国的多家进口产品维护服务提供商建立了紧密的合作关系。我们售出的产品提供1年的质保&#xff0c;并且都经过了严格的测试和认证。公司的优势品牌如下&#xff1a;德国MOOG、美国 PARKER&#xff08;派克&#xf…

Sqli-labs合集之环境搭建

Sqli-labs的搭建 搭建第一个SQL注入学习靶场环境&#xff1a; 软件&#xff1a;sqli-labs 安装过程&#xff1a; 1.源码地址&#xff1a;GitHub - Audi-1/sqli-labs: SQLI labs to test error based, Blind boolean based, Time based.&#xff1b; 2.将压缩包解压到phpst…

特征值究竟体现了矩阵的什么特征?

特征值究竟体现了矩阵的什么特征&#xff1f; 简单来说就是x经过矩阵A映射后和自己平行 希尔伯特第一次提出eigenvalue,这里的eigen就是自己的。所以eigenvalue也称作本征值 特征值和特征向量刻画了矩阵变换空间的特征 对平面上的任意向量可以如法炮制&#xff0c;把他在特征…

【UE5】调用ASR接口,录制系统输出。录制音频采样率不匹配

暂时测出window能用。阿里的ASR接口当前仅支持8000和16000。UE默认采样44100。

CSS技巧专栏:一日一例 1.纯CSS实现 会讨好的热情按钮 特效

题外话: 从今天开始,我准备开设一个新的专栏,专门写 使用CSS实现各种酷炫按钮的方法,本专栏目前准备写40篇左右,大概会完成如下按钮效果: 今天,我来介绍第一个按钮的实现方法:会讨好的热情按钮。为什么我给它起这样的名字呢?你看它像不像一个不停摇尾巴的小黄?当你鼠…