机器学习深度学习——注意力分数(详细数学推导+代码实现)

news2025/1/13 14:04:31

👨‍🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——机器翻译(序列生成策略)
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助

这篇文章实际上应该要接着上次讲过的注意力提示和注意力池化讲的,但是由于学到后面感觉有点不对劲,因为我跳过了一些基础的东西,所以导致transformer的思想有点疑惑,所以暂停了注意力机制的学习,而去实现了机器翻译。
这边接着上次的注意力机制的内容进行学习:
机器学习&&深度学习——注意力提示、注意力池化(核回归)

注意力分数(详细数学推导+代码实现)

  • 数学思维推导
    • 注意力分数
    • 高维拓展
    • 注意力分数设计
      • Additive Attention(key和query不等长时)
      • Scaled Dot-Product Attention(key和query等长时)
    • 总结
  • 实现复杂注意力机制
    • 遮蔽softmax操作
    • 加性注意力
    • 缩放点积注意力
  • 小结

数学思维推导

注意力分数

在之前使用了高斯核来对查询和键之间的关系建模,而高斯核指数部分可以视为注意力评分函数(简称评分函数),然后把这个函数的输出结果输入到softmax函数中进行运算。通过上述步骤,将得到与键对应的值的概率分布(也就是注意力权重)。最后,注意力池化的输出就是基于这些注意力权重的值的加权和。
下图说明了如何将注意力池化的输出计算称为值的加权和,其中a表示注意力评分函数。由于注意力权重是概率分布,因此加权和其本质上是加权平均值。
在这里插入图片描述
我们可以回顾一下注意力池化的函数f:
f ( x ) = ∑ i α ( x , x i ) y i = ∑ i = 1 n s o f t m a x ( − 1 2 ( x − x i ) 2 ) y i 其中, α ( x , x i ) 指的就是注意力权重,而 a = − 1 2 ( x − x i ) 2 指的是注意力分数 f(x)=\sum_iα(x,x_i)y_i=\sum_{i=1}^nsoftmax(-\frac{1}{2}(x-x_i)^2)y_i\\ 其中,α(x,x_i)指的就是注意力权重,而a=-\frac{1}{2}(x-x_i)^2指的是注意力分数 f(x)=iα(x,xi)yi=i=1nsoftmax(21(xxi)2)yi其中,α(x,xi)指的就是注意力权重,而a=21(xxi)2指的是注意力分数

高维拓展

现在让我们拓展到高维度,假设:
q u e r y   q ∈ R q , m 对 k e y − v a l u e ( k 1 , v 1 ) , . . . , ( k m , v m ) ,这里 k i ∈ R k , v i ∈ R v query \ q∈R^q,m对key-value(k_1,v_1),...,(k_m,v_m),这里\\ k_i∈R^k,v_i∈R^v query qRqmkeyvalue(k1,v1),...,(km,vm),这里kiRkviRv
那么注意力池化层就可以表示为:
f ( q , ( k 1 , v 1 ) , . . . , ( k m , v m ) ) = ∑ i = 1 m α ( q , k i ) v i ∈ R v α ( q , k i ) = s o f t m a x ( a ( q , k i ) ) = e x p ( a ( q , k i ) ) ∑ j = 1 m e x p ( a ( q , k j ) ) ∈ R f(q,(k_1,v_1),...,(k_m,v_m))=\sum_{i=1}^mα(q,k_i)v_i∈R^v\\ α(q,k_i)=softmax(a(q,k_i))=\frac{exp(a(q,k_i))}{\sum_{j=1}^mexp(a(q,k_j))}∈R f(q,(k1,v1),...,(km,vm))=i=1mα(q,ki)viRvα(q,ki)=softmax(a(q,ki))=j=1mexp(a(q,kj))exp(a(q,ki))R

注意力分数设计

也就是上面的a函数的设计方式

Additive Attention(key和query不等长时)

Additive Attention也叫做加性注意力
定义三个可学的参数:
W k ∈ R h × k , W q ∈ R h × q , v ∈ R h W_k∈R^{h×k},W_q∈R^{h×q},v∈R^h WkRh×kWqRh×qvRh
此时我们需要和上面一样,把k和q(key和query)结合起来,计算注意力分数:
a ( k , q ) = t a n h ( W k k + W q q ) a(k,q)=tanh(W_kk+W_qq) a(k,q)=tanh(Wkk+Wqq)
这时候就很容易看出前面的两个可学参数的意义了,是为了将k和q拉回到同一个维度,方便他们进行计算。
计算出来后的式子一定会是属于Rh的。
那么此时我们计算注意力权重的方式为:
α ( k , q , v ) = v T a ( k , q ) = v T t a n h ( W k k + W q q ) α(k,q,v)=v^Ta(k,q)=v^Ttanh(W_kk+W_qq) α(k,q,v)=vTa(k,q)=vTtanh(Wkk+Wqq)
那么我们最终会得到一个固定的值,这个值就是注意力权重了。
这里的意义我们可以想象得到,这就等价于将key和value合并起来后放到一个隐藏大小为h,输出大小为1的单隐藏层MLP

Scaled Dot-Product Attention(key和query等长时)

Scaled Dot-Product Attention也叫缩放点积注意力
当query和key都是相同的长度,也就是:
q , k i ∈ R d q,k_i∈R^d q,kiRd
那么可以:
a ( q , k i ) = < q , k i > / d a(q,k_i)=<q,k_i>/{\sqrt{d}} a(q,ki)=<q,ki>/d
也就是说key和query等长时无须再通过可学习的参数把他们拉回到同一纬度,直接计算点击即可。而除以根号d的用意是为了防止其对于长度过于敏感。
向量化版本:
Q ∈ R n × d , K ∈ R m × d , V ∈ R m × v 注意力分数: a ( Q , K ) = Q K T / d ∈ R n × m 注意力池化: f = s o f t m a x ( a ( Q , K ) ) V ∈ R n × v Q∈R^{n×d},K∈R^{m×d},V∈R^{m×v}\\ 注意力分数:a(Q,K)=QK^T/{\sqrt{d}}∈R^{n×m}\\ 注意力池化:f=softmax(a(Q,K))V∈R^{n×v} QRn×dKRm×dVRm×v注意力分数:a(Q,K)=QKT/d Rn×m注意力池化:f=softmax(a(Q,K))VRn×v

总结

1、注意力分数是query和key的相似度,注意力权重是分数的softmax结果
2、两种常见的分数计算:
(1)将query和key合并起来进入一个单输出单隐藏层的MLP
(2)直接将query和key做内积

实现复杂注意力机制

接下来将用上面的两个流行评分函数来实现更复杂的注意力机制:

import math
import torch
from torch import nn
from d2l import torch as d2l

遮蔽softmax操作

softmax操作用于输出一个概率分布作为注意力权重,但在某些情况下,并非所有的值都应该被纳入到注意力池化中,例如某些文本序列被填充了没有意义的特殊词元等。为了仅将有意义的词元作为值来获取注意力池化,可以指定一个有效序列长度(即词元的个数),以便在计算softmax时过滤掉超出指定范围的位置:

#@save
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                              value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

接下来我们可以验证一下,考虑2个2×4矩阵表示的样本,这两个样本的有效长度分别为2和3,那么超出的部分的softmax值都会被置为0:

print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])))

运行结果:
在这里插入图片描述
可以看出没啥问题,第一个矩阵都是样本的前两列算,第二个矩阵都是样本的前三列算,当然也满足加起来和为1。
同样也可以使用二维张量,为矩阵样本的每一行都指定有效长度:

print(masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])))

运行结果:
在这里插入图片描述

加性注意力

#@save
class AdditiveAttention(nn.Module):
    """加性注意力"""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后,
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)

缩放点积注意力

使用点积可以得到计算效率更高的评分函数,但是点积操作要求查询和键具有相同的长度d。下面的缩放点积注意力的实现使用了dropout进行模型正则化。

#@save
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

小结

1、将注意力池化的输出计算可以作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力池化操作。
2、当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/885633.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【实战】十一、看板页面及任务组页面开发(一) —— React17+React Hook+TS4 最佳实践,仿 Jira 企业级项目(二十三)

文章目录 一、项目起航&#xff1a;项目初始化与配置二、React 与 Hook 应用&#xff1a;实现项目列表三、TS 应用&#xff1a;JS神助攻 - 强类型四、JWT、用户认证与异步请求五、CSS 其实很简单 - 用 CSS-in-JS 添加样式六、用户体验优化 - 加载中和错误状态处理七、Hook&…

docker安装redis7-主从模式

说明 系统版本&#xff1a;CentOS7.9 redis版本&#xff1a;7.0.5镜像 此模式为1主2从,主节点端口为6379&#xff0c;从节点端口为6380、6381以下所有的示例以redis7.0.5为例 下载镜像 docker pull redis:7.0.5 创建挂载路径 所有节点的数据、配置文件以及日志都挂载到宿…

隧道人员定位方案

针对隧道环境的人员定位方案&#xff0c;UWB定位技术同样可以提供高精度和可靠的定位服务。以下是一个可行的方案&#xff1a; 部署基站网络&#xff1a;在隧道内建立一个基站网络&#xff0c;基站需要均匀分布在各个关键位置&#xff0c;以确保全方位的覆盖。由于隧道的特殊环…

CMake:检测外部库---使用pkg-config

CMake:检测外部库---使用pkg-config 导言ZMQ安装项目结构CMakeLists.txt相关源码 导言 前面几篇内容的学习&#xff0c;我们基本上了解了如何链接一个三方库的方法。本篇以及下一篇将补充两个检测外部库的方法。 目前为止&#xff0c;我们已经学习了两种检测外部依赖关系的方…

文件的导入与导出

文章目录 一、需求二、分析1. Excel 表格数据导出2. Excel 表格数据导入一、需求 在我们日常开发中,会有文件的导入导出的需求,如何在 vue 项目中写导入导出功能呢 二、分析 以 Excel 表格数据导出为例 1. Excel 表格数据导出 调用接口将返回的数据进行 Blob 转换,附: 接…

iPhone删除的照片能恢复吗?不小心误删了照片怎么找回?

iPhone最近删除清空了照片还能恢复吗&#xff1f;大家都知道&#xff0c;照片对于我们来说是承载着美好回忆的一种形式。它记录着我们的平淡生活&#xff0c;也留住了我们的美好瞬间&#xff0c;具有极其重要的纪念价值。 照片不小心误删是一件非常难受的事&#xff0c;那么iP…

【智慧工地源码】:人工智能、BIM技术、机器学习在智慧工地的应用

智慧工地云平台是专为建筑施工领域所打造的一体化信息管理平台。通过大数据、云计算、人工智能、BIM、物联网和移动互联网等高科技技术手段&#xff0c;将施工区域各系统数据汇总&#xff0c;建立可视化数字工地。同时&#xff0c;围绕人、机、料、法、环等各方面关键因素&…

Unity用NPOI创建Exect表,保存数据,和修改删除数据。以及打包后的坑——无法打开新创建的Exect表

先说坑花了一下午才找到解决方法解决&#xff0c; 在Unity编辑模式下点击物体创建对应的表&#xff0c;获取物体名字与在InputText填写的注释数据。然后保存。创建Exect表可以打开&#xff0c;打包PC后&#xff0c;点击物体创建的表&#xff0c;打不开文件破损 解决方法&#…

AIGC绘画:kaggle部署stable diffusion项目绘画

文章目录 kaggle介绍项目部署edit my copy链接显示 结果展示 kaggle介绍 Kaggle成立于2010年&#xff0c;是一个进行数据发掘和预测竞赛的在线平台。从公司的角度来讲&#xff0c;可以提供一些数据&#xff0c;进而提出一个实际需要解决的问题&#xff1b;从参赛者的角度来讲&…

【Docker】Docker network之bridge、host、none、container以及自定义网络的详细讲解

&#x1f680;欢迎来到本文&#x1f680; &#x1f349;个人简介&#xff1a;陈童学哦&#xff0c;目前学习C/C、算法、Python、Java等方向&#xff0c;一个正在慢慢前行的普通人。 &#x1f3c0;系列专栏&#xff1a;陈童学的日记 &#x1f4a1;其他专栏&#xff1a;CSTL&…

建筑工地的水泥分配和料场选址问题(Cplex求解线性规划模型+粒子群搜索算法)【Java实现】

问题 问题一求解 求解思路 该问题可以直接建立一个线性规划模型&#xff0c;然后使用cplex求解器来求解 模型 决策变量 x i j &#xff1a;第 i 个料场向第 j 个工地运送的水泥吨数&#xff0c;其中 1 ≪ i ≪ m &#xff1b; 1 ≪ j ≪ n 其中 x i j 的取值范围是 [ 0 , d…

prisma的增删改查

目录 一、单表1.增自增问题2.查询所有信息3.查询以l开头的数据4.查询限定数据5.查询唯一的数据6.分页查询7.改8.删 二、联表1.新增文章2.将文章和用户关联3.查询用户的同时查询用户的文章4.关联查询&#xff08;级联操作&#xff0c;链式调用&#xff09; 一、单表 模型 mode…

腾讯云轻量服务器测评:2核 2G 4M

腾讯云轻量2核2G4M服务器&#xff0c;4M带宽下载速度可达512KB/秒&#xff0c;系统盘为50GB SSD盘&#xff0c;300GB月流量&#xff0c;地域节点可选上海、广州和北京&#xff0c;腾讯云百科分享腾讯云2核2G4M轻量应用服务器配置性能表&#xff1a; 目录 腾讯云轻量2核2G4M服…

Spring MVC 中的常见注解的用法

目录 认识 Spring MVC什么是 Spring MVCMVC 的定义 Spring MVC 注解的运用1. Spring MVC 的连接RequestMapping 注解 2. 获取参数获取单个参数获取多个参数传递对象表单传参后端参数重命名RequestBody 接收 JSON 对象PathVariable 获取 URL 中的参数上传文件 RequestPart获取 C…

最小生成树,Kruskal算法

最小生成树&#xff08;Minimum Spanning Tree&#xff0c;简称 MST&#xff09;是一个连通图的子图&#xff0c;它包含图中的所有节点&#xff0c;并且是一个树&#xff08;无环连通图&#xff09;&#xff0c;同时保证连接所有节点的边的权重之和最小。 在一个带权重的连通图…

R语言实现非等比例风险生存资料分析(1)

#非等比例风险的生存资料分析 ###1 生成模拟数据### library(flexsurv) set.seed(123) # 生成样本数量 n <- 100 # 生成时间数据 time <- sample(1:1000,n,replaceF) # 调整shape和scale参数以控制生存曲线形状 # 生成事件数据&#xff08;假设按比例风险模型&#xff0…

【SpringBoot】中的ApplicationRunner接口 和 CommandLineRunner接口

1. ApplicationRunner接口 用法&#xff1a; 类型&#xff1a; 接口 方法&#xff1a; 只定义了一个run方法 使用场景&#xff1a; springBoot项目启动时&#xff0c;若想在启动之后直接执行某一段代码&#xff0c;就可以用 ApplicationRunner这个接口&#xff0c;并实现接口…

YB2416是支持高电压输入的同步降压电源管理芯片

简介&#xff1a; YB2416是支持高电压输入的同步降压电源管理芯片&#xff0c;在 4~30V 的宽输入电压范围内可实现3A的连续电流输出。通过调节 FB 端口的分压电阻&#xff0c;可以输出1.8V到28V的稳定电压。YB2416具有优秀的恒压/恒流(CC/C)特性。YB2416 采用电流模式的环路控制…

UI自动化测试常见的Exception

一. StaleElementReferenceException&#xff1a; - 原因&#xff1a;引用的元素已过期。原因是页面刷新了&#xff0c;此时当然找不到之前页面的元素。- 解决方案&#xff1a;不确定什么时候元素就会被刷新。页面刷新后重新获取元素的思路不变&#xff0c;这时可以使用python的…

【云原生】【k8s】从小白到大神之路之学习运维第82天-------基于Prometheus监控Kubernetes集群

第四阶段 时 间&#xff1a;2023年8月17日 参加人&#xff1a;全班人员 内 容&#xff1a; 基于Prometheus监控Kubernetes集群 目录 一、Prometheus简介 &#xff08;一&#xff09;Prometheus的基本原理 &#xff08;二&#xff09;Prometheus优势 &#xff08;三&a…