《动手学深度学习 Pytorch版》 10.3 注意力评分函数

news2024/12/22 23:42:17

上一节使用的高斯核的指数部分可以视为注意力评分函数(attention scoring function),简称评分函数(scoring function)。

后续把评分函数的输出结果输入到softmax函数中进行运算。最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。该过程可描述为下图:

在这里插入图片描述

用数学语言描述为:

f ( q , ( k 1 , v 1 ) , … , ( k m , v m ) ) = ∑ i = 1 m α ( q , k i ) v i ∈ R v f(\boldsymbol{q},(\boldsymbol{k}_1,\boldsymbol{v}_1),\dots,(\boldsymbol{k}_m,\boldsymbol{v}_m))=\sum^m_{i=1}{\alpha(\boldsymbol{q},\boldsymbol{k}_i)\boldsymbol{v}_i}\in\R^v f(q,(k1,v1),,(km,vm))=i=1mα(q,ki)viRv

其中查询 q \boldsymbol{q} q 和键 k i \boldsymbol{k}_i ki 的注意力权重(标量)是通过注意力评分函数 a a a 将两个向量映射成标量,再经过softmax运算得到的:

α ( q , k i ) = s o f t m a x ( a ( q , k i ) ) = a ( q , k i ) ∑ j = 1 m exp ⁡ a ( q , k i ) ∈ R \alpha(\boldsymbol{q},\boldsymbol{k}_i)=\mathrm{softmax}(a(\boldsymbol{q},\boldsymbol{k}_i))=\frac{a(\boldsymbol{q},\boldsymbol{k}_i)}{\sum^m_{j=1}{\exp{a(\boldsymbol{q},\boldsymbol{k}_i)}}}\in\R α(q,ki)=softmax(a(q,ki))=j=1mexpa(q,ki)a(q,ki)R

import math
import torch
from torch import nn
from d2l import torch as d2l

以下介绍的是两个流行的评分函数。

10.3.1 遮蔽 softmax 操作

并非所有的值都应该被纳入到注意力汇聚中。下面的 masked_softmax 函数实现了这样的掩蔽softmax操作(masked softmax operation),其中任何超出有效长度的位置都被掩蔽并置为0。

#@save
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))  # 两样本有效长度分别为 2 和 3
print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])))  # 也可以给每一行指定有效长度
tensor([[[0.4297, 0.5703, 0.0000, 0.0000],
         [0.6186, 0.3814, 0.0000, 0.0000]],

        [[0.2413, 0.3333, 0.4254, 0.0000],
         [0.4165, 0.2801, 0.3034, 0.0000]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3277, 0.4602, 0.2121, 0.0000]],

        [[0.5026, 0.4974, 0.0000, 0.0000],
         [0.2684, 0.2599, 0.2613, 0.2103]]])

10.3.2 加性注意力

当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。加性注意力(additive attention)的评分函数为:

a ( q , k i ) = w v T tanh ⁡ ( W q q + W k k ) ∈ R a(\boldsymbol{q},\boldsymbol{k}_i)=\boldsymbol{\mathrm{w}}_v^T\tanh{(\boldsymbol{\mathrm{W}}_q\boldsymbol{q}+\boldsymbol{\mathrm{W}}_k\boldsymbol{k})}\in\R a(q,ki)=wvTtanh(Wqq+Wkk)R

参数字典:

  • q ∈ R q \boldsymbol{q}\in\R^q qRq 表示查询

  • k ∈ R k \boldsymbol{k}\in\R^k kRk 表示键

  • W q ∈ R h × q \boldsymbol{\mathrm{W}}_q\in\R^{h\times q} WqRh×q W k ∈ R h × k \boldsymbol{\mathrm{W}}_k\in\R^{h\times k} WkRh×k W v ∈ R h \boldsymbol{\mathrm{W}}_v\in\R^h WvRh 均为可学习参数。

#@save
class AdditiveAttention(nn.Module):
    """加性注意力"""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)  # 使用了暂退法进行模型正则化

    def forward(self, queries, keys, values, valid_lens):
        # 初始 q 和 k 的形状如下,不好直接加
        # queries 的形状:(batch_size,查询的个数,num_hidden)
        # key 的形状:(batch_size,“键-值”对的个数,num_hiddens)
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后,
        # queries 的形状:(batch_size,查询的个数,1,num_hidden)
        # key 的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)  # 优雅,实在优雅 使用广播方式进行求和
        features = torch.tanh(features)
        # self.w_v 仅有一个输出,因此从形状中移除最后那个维度。
        # scores 的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)  # 把最后一个维度去掉
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))  # 查询、键和值的形状为(批量大小,步数或词元序列长度,特征大小)
# values的小批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
                              dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)  # 注意力汇聚输出的形状为(批量大小,查询的步数,值的维度)
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),  # 本例子中每个键都是相同的,所以注意力权重是均匀的,由指定的有效长度决定。
                  xlabel='Keys', ylabel='Queries')


在这里插入图片描述

10.3.3 缩放点积注意力

使用点积可以得到计算效率更高的评分函数,缩放点积注意力(scaled dot-product attention)评分函数为:

a ( q , k ) = q T k d a(\boldsymbol{q},\boldsymbol{k})=\boldsymbol{q}^T\boldsymbol{k}\sqrt{d} a(q,k)=qTkd

在实践中,我们通常从小批量的角度来考虑提高效率:

s o f t m a x ( Q K T d ) V ∈ R n × v \mathrm{softmax}\left(\frac{\boldsymbol{Q}\boldsymbol{K}^T}{\sqrt{d}}\right)\boldsymbol{V}\in\R^{n\times v} softmax(d QKT)VRn×v

#@save
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
queries = torch.normal(0, 1, (2, 1, 2))  # 点积操作需要查询的特征维度与键的特征维度大小相同
attention = DotProductAttention(dropout=0.5)  # 使用了暂退法进行模型正则化
attention.eval()
attention(queries, keys, values, valid_lens)
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')


在这里插入图片描述

练习

(1)修改小例子中的键,并且可视化注意力权重。可加性注意力和缩放的“点-积”注意力是否仍然产生相同的结果?为什么?

不一样,评分函数不一样,键值不同的话那注意力汇聚肯定不一样的。

queries_new, keys_rand = torch.normal(0, 1, (2, 1, 2)), torch.rand((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention_rand = AdditiveAttention(key_size=2, query_size=2, num_hiddens=8,
                              dropout=0.1)
attention_rand.eval()
attention_rand(queries_new, keys_rand, values, valid_lens)

d2l.show_heatmaps(attention_rand.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')


在这里插入图片描述

attention_rand = DotProductAttention(dropout=0.5)
attention_rand.eval()
attention_rand(queries_new, keys_rand, values, valid_lens)

d2l.show_heatmaps(attention_rand.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')


在这里插入图片描述


(2)只使用矩阵乘法,能否为具有不同矢量长度的查询和键设计新的评分函数?

可以想办法把他俩映射到一个长度。


(3)当查询和键具有相同的矢量长度时,矢量求和作为评分函数是否比“点-积”更好?为什么?

不会,略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1136507.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PC浏览器获取京东key和pin码

1。 登录京东网站 京东 2。快捷键按F12 打开开发者模式或者下图打开 3. 找到Cookie 4. 复制 pt_keyAAJlOf5ZADDabccjhaljhdgsTU2gbszktPPPD7my5-QN88OZc4mI3__SYGUDyt8GgpbCkVPk; pt_pinjd_9jghsakdjg7687b8;

【广州华锐互动】VR公司工厂消防逃生演练带来沉浸式的互动体验

在工业生产过程中&#xff0c;安全问题始终是我们不能忽视的重要环节。特别是火灾事故&#xff0c;不仅会造成重大的经济损失&#xff0c;更会威胁到员工的生命安全。传统的消防安全训练方法&#xff0c;如讲座、实地演练等&#xff0c;虽然具有一定的效果&#xff0c;但是无法…

YOLOv5 onnx \tensorrt 推理

一、yolov5 pt模型转onnx code: https://github.com/ultralytics/yolov5 python export.py --weights yolov5s.pt --include onnx二、onnx 推理 import os import cv2 import numpy as np import onnxruntime import timeCLASSES [person, bicycle, car, motorcycle, airpl…

strace跟着-编译和解决sip的bus srror问题记录

1 问题&#xff1a; 我编译了一个开源sip代码&#xff0c;可以确定的是&#xff0c;在nuc980dk61yc、nuc97251y上都可以跑的正常程序, 但在该开发板&#xff08;NUC97261Y&#xff09;上运行&#xff0c;报错bus error&#xff1b; 此文记录了 解决该问题的过程 我手里有一个97…

Dubbo 路由及负载均衡性能优化

作者&#xff1a;vivo 互联网中间件团队- Wang Xiaochuang 本文主要介绍在vivo内部针对Dubbo路由模块及负载均衡的一些优化手段&#xff0c;主要是异步化缓存&#xff0c;可减少在RPC调用过程中路由及负载均衡的CPU消耗&#xff0c;极大提升调用效率。 一、概要 vivo内部Java…

【STM32】STM32中断体系

一、STM32的NVIC和起始代码中的ISP 1.NVIC(嵌套向量中断控制器) (1)数据手册中相关部分浏览 (2)地址映射时0地址映射到Flash或SRAM (3)中断向量表可以被人为重新映射&#xff0c;一般用来IAP中 (4)STM32采用一维的中断向量表 (5)中断优先级设置有点复杂&#xff0c;后面细说 1…

windows10专业版优化记录

用来记录我的windows10专业版配置的所有设置 资源管理器占用CPU资源高 gpedit.msc打开本地组策略管理器 这样时不时资源管理器会占用CPU高。 禁用的service的列表 Problem Reports and Solutions Control Panel Support Diagnostic Policy Service 组件诊断服务 WMI Prov…

如何进行渗透测试以提高软件安全性?

对于各种规模的企业和组织来说&#xff0c;软件安全是一个至关重要的问题。随着网络攻击越来越复杂&#xff0c;软件中的漏洞越来越多&#xff0c;确保你的软件安全比以往任何时候都更重要。提高软件安全性的一个有效方法是渗透测试&#xff08;penetration testing&#xff09…

基于Or-Tools的整数规划问题求解

基于Or-Tools的整数规划问题求解 Or-Tools官网整数规划问题导入线性求解器声明 MIP 求解器定义变量定义约束条件定义目标函数调用 MIP 求解器打印结果完整代码 Or-Tools官网例题:求解大规模问题的数组表示构造数据实例化求解器定义变量定义约束条件定义目标函数调用求解器 完整…

罗马数字转整数------题解报告

题目&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 很简单&#xff0c;感觉没什么可以讲的&#xff0c;就是按照题目要求做判断就好了 public int romanToInt(String s) {char []arg s.toCharArray();int sum 0;for(int i0;i<arg…

【JAVA学习笔记】45 - (35 - 43)第十章作业

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter10/src/com/yinhai/homework10 1.静态属性的共享性质 判断下列输出什么 public class HomeWork01 {public static void main(String[] args) {Car c new Car();//无参构造时改变color为red…

鸿蒙跨平台框架来了ArkUi-X

前言&#xff1a; 各位同学大家有段时间没有给大家更新博客了 之前鸿蒙推出了鸿ArkUi-X 框架所以就写个文章分享一下 效果图&#xff1a; 首先需要下载支持 ArkUI-X 套件的华为开发工具 DevEco &#xff0c;版本为 4.0 以上&#xff0c;目前可以下载预览版进行体验。下载地址…

常见的网络攻击类型及防范措施

网络攻击是指针对计算机网络、系统或数据的恶意行为&#xff0c;旨在破坏、入侵、窃取信息或干扰网络服务。网络攻击的类型多种多样&#xff0c;以下是一些常见的网络攻击类型&#xff1a; DDoS 攻击&#xff08;分布式拒绝服务攻击&#xff09;&#xff1a; 攻击者通过多个受感…

【数据集】指针式圆形表计关键点数据集

指针式圆形表计关键点数据集 数据集简介数据集一览 数据集简介 数据类型&#xff1a;指针式圆形表计ROI区域 数据数量&#xff1a;1069 标注标签&#xff1a;中心点&#xff0c;起点&#xff0c;终点&#xff0c;指针端点 图像质量&#xff1a;高清90%&#xff0c;较模糊10% …

如何利用视频号提取视频,视频号下载视频教程

随着视频号的兴起&#xff0c;越来越多的人开始关注这个平台&#xff0c;并希望能够提取视频号中的精彩视频。然而&#xff0c;视频号并不支持直接下载视频&#xff0c;也不能复制链接。那么&#xff0c;我们如何才能实现视频号提取视频的需求呢&#xff1f;答案就是借助视频下…

苹果将于10月31日举行今秋的第二场发布会

在今日凌晨&#xff0c;苹果宣布&#xff0c;将于北京时间10月31日早上8点举行今秋的第二场发布会&#xff0c;主题为“来势迅猛”。据多方猜测苹果本次活动的核心产品大概率是搭载全新M3芯片的Mac系列产品。 据了解&#xff0c;在苹果的产品线中&#xff0c;搭载M3芯片的Mac系…

【强化学习】09——价值和策略近似逼近方法

文章目录 前言对状态/动作进行离散化参数化值函数近似值函数近似的主要形式Incremental MethodsGradient DescentLinear Value Function ApproximationFeature Vectors特征化状态Table Lookup Features Incremental Prediction AlgorithmsMonte-Carlo with Value Function Appr…

【蓝桥杯选拔赛真题02】C++计算天数 青少年组蓝桥杯C++选拔赛真题 STEMA比赛真题解析

目录 C/C++计算天数 一、题目要求 1、编程实现 2、输入输出 二、算法分析 <

VR全景拍摄市场需求有多大?适用于哪些行业?

随着VR全景技术的成熟&#xff0c;越来越多的商家开始借助VR全景来宣传推广自己的店铺&#xff0c;特别是5G时代的到来&#xff0c;VR全景逐渐被应用在我们的日常生活中的各个方面&#xff0c;VR全景拍摄的市场需求也正在逐步加大。 通过VR全景技术将线下商家的实景“搬到线上”…