各种注意力评分函数的实现

news2024/9/20 6:30:39

预备知识

本文基于MXNet进行实现,需要对于注意力机制有一定初步了解。也需要对Python有足够了解。

另外这里稍加说明,在注意力机制中,本质上是“注意”的位置,即加权计算后进行Softmax回归的结果。在Nadaraya-Watson核回归中,首先具有一个键值对(key-value),输入称为一个查询(query),对于每个查询,有对应计算,计算查询与键的关系,根据关系的大小,取键所对应的值,通过带权重的值进行预测,这就是Nadaraya-Watson核回归的基本思想。

注意力评分函数

注意力评分函数本质上是对查询和键之间的关系建模,即\hat{y}=\Sigma_i^n \alpha(x,x_i)y_i

在Nadaraya-Watson核回归中,α为查询与键的距离。将注意力评分函数的输出结果输入到softmax函数中进行运算。 通过上述步骤,将得到与键对应的值的概率分布(即注意力权重)。 最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。

准备工作

选择不同的注意力评分函数α会导致不同的注意力汇聚操作。 本节将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。

引入库

import math
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

掩蔽Softmax

为了使注意力机制的实现是有效的,可以采用掩蔽Softmax操作,仅对一定的值纳入注意力汇聚中,而无意义的值则排除掉。

def masked_softmax(X, valid_lens):
    if valid_lens is None:
        return npx.softmax(X)
    else:
        shape = X.shape
        if valid_lens.ndim == 1:
            valid_lens = valid_lens.repeat(shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
                              value=-1e6, axis=1)
        return npx.softmax(X).reshape(shape)

加性注意力机制

对于给定查询q和键k,分别乘以对应权重,连结后输入一个多层感知机,具有一个隐藏层,禁用bias项,对于这一步产生的结果再进行tanh激活函数的操作,最后通过一个权重矩阵W_v输出结果。(这里还使用了Dropout。)

大致可以理解为,输入含有若干特征x,对其进行运算,获得num_hiddens个隐藏单元,又有若干个键key,对其进行运算,获得num_hiddens个隐藏单元,进行连结,再经过tanh运算,最后乘以权重矩阵,获得输出为一个神经元的结果,这个结果是对键和查询的关系进行加权运算的结果。

class AdditiveAttention(nn.Block):
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.w_v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = np.expand_dims(queries, axis=2) + np.expand_dims(
            keys, axis=1)
        features = np.tanh(features)
        scores = np.squeeze(self.w_v(features), axis=-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return npx.batch_dot(self.dropout(self.attention_weights), values)

缩放点积注意力

缩放点击直接将查询和键进行点积操作,之后进行缩放,得到的值进行Softmax回归。显然,直接进行矩阵乘法操作是更加快速的,因此缩放点积注意力的运算效率远远高于加性注意力机制,不过缩放点积注意力对于输入和键的大小是有要求的,要求输入和键具有相同大小,否则不可乘。

我个人认为加性注意力机制更类似于一种一般的深度学习方法,而缩放点积注意力则是一种特殊方法。

实现过程如下,需要注意的是:
    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)

class DotProductAttention(nn.Block):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return npx.batch_dot(self.dropout(self.attention_weights), values)

说明

疑问

加性注意力中的直接学习的机制我是可以理解的,但在缩放点积注意力的点积部分我感到不解,对于既定的若干个键值对,为什么查询和键直接进行点积操作可以有效获得类似于“权重”的结果呢?

对应解答

查询和键的点积操作有效地衡量了它们之间的相关性或匹配程度。这个操作可以理解为测量查询和每个键的“相似度”或“匹配度”。点积较大的结果意味着查询和对应的键在特征空间中更接近,因此它们之间的匹配程度更高。这个相似度分数在经过缩放和 Softmax 后转化为权重,反映了查询对各个键值对的关注程度。最终,这些权重用于加权值(Value),从而产生最终的注意力输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2082907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

问界M7 Pro发布,又做回纯视觉?华为智驾系统这3年到底功夫下在哪?

北京时间8月26日下午14:00,华为又公布了一款新问界车型,时至今日,华为问界家族已有三大款,细分9个系列车型(从动力方面看,各自都分为增程和纯电两种版本)。 1个多小时的发布会上,除…

Zabbix和Prometheus

1.Zabbix 1.1 Zabbix监控获取数据的方式 zabbix-agent 适用于服务器,主机监控 SNMP协议 适用于网络设备(交换机、路由器、防火墙) IPMI协议 适用于监控硬件设备信息(温度、序列号) JMX协议 适用于Java应用监控 1.2 …

基于SSM+微信小程序的跑腿平台管理系统(跑腿3)(源码+sql脚本+视频导入教程+文档)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM微信小程序的跑腿平台管理系统实现了管理员、接单员及用户三个角色。 1、管理员实现了首页、个人中心、管理员管理、基础数据管理、接单详情、跑腿任务管理等。 2、接单员实现了…

C++ TinyWebServer项目总结(14. 多线程编程)

早期Linux不支持线程,直到1996年,Xavier Leroy等人开发出第一个基本符合POSIX标准的线程库LinuxThreads,但LinuxThreads效率低且问题多,自内核2.6开始,Linux才开始提供内核级的线程支持,并有两个组织致力于…

离线环境下的 Prometheus 生态部署攻略

一、前言 在当今高度数字化的世界中,监控系统的稳定性和可靠性对于确保业务连续性和性能优化至关重要。特别是在网络隔离或无互联网接入的局域网环境下,离线部署监控解决方案成为了一种必要且挑战性的任务。本文将深入探讨如何在离线环境中成功部署 Pro…

深圳保障房、商品房、小产权房子类型对比

摘要: 整理了我认知以内的深圳房子类型,有安居房,可售人才房,共有产权房、配售型保障房、商品房、统建楼、农民房的区别。如果数据存疑,可以多方对比论证,我也主要靠百度。 我发现我很多同事是非深户&#…

秋招突击——算法练习——8/26——图论——200-岛屿数量、994-腐烂的橘子、207-课程表、208-实现Trie

文章目录 引言正文200-岛屿数量个人实现 994、腐烂的橘子个人实现参考实现 207、课程表个人实现参考实现 208、实现Trie前缀树个人实现参考实现 总结 引言 正文 200-岛屿数量 题目链接 个人实现 我靠,这道题居然是腾讯一面的类似题,那道题是计算最…

《分析模式》2024中译本-前言-01(加红色标注)

写在前面 今天开始,我们逐渐发布一些《分析模式》2024中译本的译文。 红色字体标出的文字,表示我认为之前的译本可能会让读者产生误解的地方。 感兴趣的读者,可以对照之前译本以及原文,捉摸一下为什么要标红。 主要原因当然是…

基于SpringBoot+Vue+MySQL的小区物业管理系统

系统背景 在当今信息化高速发展的时代,小区物业管理正经历着从传统模式向智能化、高效化转型的深刻变革。这一转变的核心驱动力,正是小区物业管理系统的全面智能化升级。该系统不仅极大地提升了物业管理的效率与精确度,还深刻重塑了物业与业主…

数分基础(03-1)客户特征分析

文章目录 客户特征分析1. 数据集2. 思路与步骤2.1 特征工程2.2 识别方法2.3 可视化 3. 分析准备3.1 读取数据集3.2 识别不同客户群体3.2.1 使用K-Means聚类进行初步细分3.2.2 关于聚类方法(1)特征缩放1)平衡特征对模型的影响力,避…

通过ICMP判断网络故障

一、ICMP协议 Internet控制消息协议ICMP(Internet Control Message Protocol)是IP协议的辅助协议。 ICMP协议用来在网络设备间传递各种差错和控制信息,对于收集各种网络信息、诊断和排除各种网络故障等方面起着至关重要的作用。 TypeCode描述备注00Echo Replyping…

C++从入门到起飞之——list使用 全方位剖析!

​ 🌈个人主页:秋风起,再归来~🔥系列专栏:C从入门到起飞 🔖克心守己,律己则安 目录 1、迭代器 2、push_back与emplace_back 3、list成员函数sort与库sort比较 4、merge 5、uniqu…

2024117读书笔记|《李煜词(果麦经典)》——一壶酒,一竿身,快活如侬有几人?一片芳心千万绪,人间没个安排处

2024117读书笔记|《李煜词(果麦经典)》——一壶酒,一竿身,快活如侬有几人?一片芳心千万绪,人间没个安排处 《李煜词(果麦经典)》李煜的词很美,插图也不错,很值…

基于粒子群优化算法的六自由度机械臂三维空间避障规划

摘要:本研究旨在解决机械臂在复杂环境中避障路径规划的问题。本文提出了一种利用粒子群优化算法(PSO)进行机械臂避障规划的方法,通过建立机械臂的运动模型,将避障问题转化为优化问题。PSO算法通过模拟群体中个体的社会…

ggml 简介

ggml是一个用 C 和 C 编写、专注于 Transformer 架构模型推理的机器学习库。该项目完全开源,处于活跃的开发阶段,开发社区也在不断壮大。ggml 和 PyTorch、TensorFlow 等机器学习库比较相似,但由于目前处于开发的早期阶段,一些底层…

8月28c++

c手动封装顺序表 #include <iostream>using namespace std; using datatype int;//类型重命名struct SeqList { private:datatype *data;//顺序表数组int size0;//数组大小int len0;//顺序表实际长度 public:void init(int s);//初始化函数bool empty();//判空函数bool …

python有主函数吗

python和C/Java不一样&#xff0c;没有主函数一说&#xff0c;也就是说python语句执行不是从所谓的主函数main开始的。 当运行单个python文件时&#xff0c;如运行a.py&#xff0c;这个时候a的一个属性__name__是__main__。 当调用某个python文件时&#xff0c;如b.py调用a.p…

HDD介绍

HDD是“Hard Disk Drive”的缩写&#xff0c;意为“硬盘驱动器”&#xff0c;是计算机中用于存储数据和程序的主要设备之一。 硬盘有机械硬盘(Hard Disk Drive&#xff0c;HDD)和固态硬盘(SSD)之分。机械硬盘即是传统普通硬盘&#xff0c;主要由&#xff1a;盘片&#xff0c;磁…

2024年华侨生联考英语真题全析:难度变化与备考策略

导读 在前面我们和大家一起分享了2024年华侨生联考各科真题的难度情况。今天我们就来和大家具体的看一下2024年港澳台华侨生联考英语真题试卷具体分析哈。 听力部分 今年的听力和去年的听力总体难度差别不大&#xff0c;一段听力材料对应一道听力题目&#xff08;简称一对一…

谐波电抗器选择的最佳方法

选择谐波电抗器的最佳方法取决于系统的具体要求和条件。 以下是选择谐波电抗器时需要考虑的关键因素和方法&#xff1a; 1、确定系统谐波频率 谐波分析&#xff1a;使用谐波分析仪测量系统中的谐波频率&#xff0c;确定主要的谐波频率和幅值。谐波电抗器的选择需要针对这些谐…