Scaled dot-prodect Attention的原理和实现(附源码)

news2024/12/24 0:29:45

文章目录

  • 背景
  • 什么是Attention
  • Attention权重的计算方法
    • 1. 多层感知机法
    • 2. Bilinear方法
    • 3. Dot Product
    • 4. Scaled Dot Product
  • Scaled dot-prodect Attention的源码实现

背景

  要了解深度学习中的Attention,就不得不先谈Encoder-Decoder框架(sequence to sequence模型),因为Attention的最常见用途就是解决Encoder-Decoder中的信息丢失问题(注:Attention可以看作一种通用的思想,本身并不依赖于特定框架)。
  下图是NLP领域里常用的Encoder-Decoder框架的示意图。

  可以把上图看作是由一个句子生成另外一个句子的通用处理模型。对于句子对(Source,Target),我们的目标是给定输入句子Source,期待通过Encoder-Decoder框架来生成目标句子Target。Source和Target可以是同一种语言,也可以是两种不同的语言。
  Encoder的作用是学习并提取Source句子里的信息,生成一个代表Source的hidden向量,Decoder则是根据接收到的hidden向量生成对应的Target。可见hidden是连接Encoder和Decoder的桥梁,而由于其的大小有限,所能承载的Source句子的信息就有限,如果遇到Source过长的情况,就会出现hidden无法有效表示Source句子的情况,此时的Decoder也就无法正确理解Source,更无法准确生成Target。
  此时,Attention的作用便凸显了出来。

什么是Attention

在这里插入图片描述
  如上图,由于一个hidden无法涵盖所有的Source句子信息,故将句子中每个字对应的hidden信息都输入到Attention中,再将Attention作为Decoder的输入,这样就可以防止Source句子信息的丢失。
  而此时又遇到另一个问题,Source句子中的每个字(词)对于Decoder当前时刻要生成的字(词)的影响力不同,我们不能简单的把所有hidden都传入Attention,而是要告诉Decoder哪个hidden对其当前的生成任务更重要(即影响力更大),故需要对所有hidden做weighted sum,如下图:
在这里插入图片描述
  注:其中的S(i-1)是Decoder生成的上一个字,也会作为Attention输入,因其对同样影响当前的生成任务。
  weight即权重,weighted sum即对所有的hidden分配不同的权重后求和,对Decoder当前要生成的字(词)影响力大的hidden,权重大,反之,则权重小。
  那么,如何计算每个字在句子中的权重大小呢?

Attention权重的计算方法

  设Q(query)、K(key)分别为代表两个字的向量。

1. 多层感知机法

在这里插入图片描述
  主要是先将query和key进行拼接,然后接一个激活函数为tanh的全连接层,再与一个网络定义的权重矩阵做乘积。这种方法对于大规模的数据特别有效。

2. Bilinear方法

在这里插入图片描述
  通过一个权重矩阵直接建立q和k的关系映射,比较直接,且计算速度较快。

3. Dot Product

在这里插入图片描述
  这个方法更直接,连权重矩阵都省了,直接建立q和k的关系映射,优点是计算速度更快了,且不需要参数,降低了模型的复杂度。但是需要q和k的维度要相同。

4. Scaled Dot Product

在这里插入图片描述
  上面的点积方法有一个问题,就是随着向量维度的增加,最后得到的权重也会增加,为了提升计算效率,防止数据上溢,对其进行scaling,即图中除以的根号下k的维度。后续的Transformer模型中self-attention也是采用了该计算方法。

Scaled dot-prodect Attention的源码实现

  Scaled dot-prodect Attention定义如下:
在这里插入图片描述
  可以理解为:将Source中的构成元素想象成是由一系列的(Key,Value)数据对构成,此时给定Target中的某个元素Query,通过计算Query和各个Key的相似性或者相关性,得到每个Key对应Value的权重系数,然后对Value进行加权求和,即得到了最终的Attention数值。
  计算过程图示如下:
在这里插入图片描述
  源码如下:

def DotProductAttention(query, key, value, mask, scale=True):
    """Dot product self-attention.
    Args:
        query (numpy.ndarray): 代表q的向量,shape为(L_q by d)
        key (numpy.ndarray): 代表k的向量,shape为(L_k by d)
        value (numpy.ndarray): 代表v的向量,shape为(L_k by d) ,注L_v = L_k
        mask (numpy.ndarray): causal attention标志位,用于判断attention的计算类型
        scale (bool): 是否scale,即是否除以根号下词维度(q的维度和k的维度相同)

    Returns:
        numpy.ndarray: 代表Self-attention的矩阵,shape为(L_q by d)
    """
    # 是否除以维度
    if scale: 
        depth = query.shape[-1]
    else:
        depth = 1

    # 计算q和k的点乘
    dots = np.matmul(query, np.swapaxes(key, -1, -2)) / np.sqrt(depth) 
    
    # 对于causal attention,加上mask
    if mask is not None:
        dots = np.where(mask, dots, np.full_like(dots, -1e9)) 
    
    # 对点乘做softmax计算
    logsumexp = scipy.special.logsumexp(dots, axis=-1, keepdims=True)
    dots = np.exp(dots - logsumexp)

	# 计算attention
    attention = np.matmul(dots, value)
    
    return attention

def dot_product_self_attention(q, k, v, scale=True):
    mask_size = q.shape[-2]
    mask = np.tril(np.ones((1, mask_size, mask_size), dtype=np.bool_), k=0)  
        
    return DotProductAttention(q, k, v, mask, scale=scale)

dot_product_self_attention(q, k, v)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/531862.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

拍立淘API接口说明文档 按图搜索淘宝商品API 实时数据返回

开发背景: 随着电商行业的不断发展,人们的购物需求日益增多。在购买商品时,很多人会通过搜索引擎、社交媒体等手段来获取信息或灵感。但是,在这些渠道中找到想要的商品并不容易,因为其中可能会混杂着一些广告或无关内…

Android内存优化检测工具LeakCanary使用

一、什么是LeakCanary leakCanary是Square开源框架,是一个Android和Java的内存泄露检测库。如果检测到某个activity有内存泄露,LeakCanary就是自动地显示一个通知,所以可以把它理解为傻瓜式的内存泄露检测工具。通过它可以大幅度减少开发中遇…

Java 并发队列详解

一,简介 1,并发队列两种实现 以ConcurrentLinkedQueue为代表的高性能非阻塞队列以BlockingQueue接口为代表的阻塞队列 2,阻塞队列与非阻塞队列的区别 当阻塞队列是空的时,从队列中获取元素的操作将会被阻塞,试图从…

【BFS】华子20230506笔试第三题(动态迷宫问题)Java实现

文章目录 题目链接思路BFS板子我的解答 题目链接 塔子哥的codeFun2000:http://101.43.147.120/p/P1251 测试样例1 输入 3 2 1 0 1 2 2 1 2 0 100 100 100 100 000 100 000 000 001输出 1测试样例2 输入 3 2 1 0 2 0 0 1 2 2 000 000 001 010 101 101 110 010 …

在docker容器中启动docker服务并实现构建多平台镜像的能力

在docker容器中启动docker服务并实现构建多平台镜像的能力 背景 在容器中运行docker,是devops中无法避免的场景,通常被应用于提供统一的镜像构建工具,出于安全考虑,不适合将主机的docker进程暴露给公司的内部人员使用&#xff0…

SpringCloud alibaba微服务b2b2c电子商务平台

1. 涉及平台 平台管理、商家端(PC端、手机端)、买家平台(H5/公众号、小程序、APP端(IOS/Android)、微服务平台(业务服务、系统服务、中间件服务) 2. 核心架构 Spring Cloud、Spring Boot2、My…

飞书开发流程

1、进入飞书并创建一个应用 链接: 创建应用 创建应用成功后需要审核通过,如果你拥有管理权限则可以自己进入管理后台通过审核,否则需要联系管理员通过审核 2、进入开发者后台 链接: 发者后台 3、在该调试平台上测试 以这个订阅审批事件为例 这一步…

DHCP协议简单配置

实验原理 网络中主机需要与外界进行通信时,需要配置自己的IP地址、网关地址、DNS服务器等网络参数信息。手工在每台主机上配置维护成本高,容易出错,而且不利于管理员统一维护。 通过DHCP地址自动配置协议,使终端设备能自动获取地址,实现即插即用且IP地址统一由服务器管理…

springboot+java充电桩充电额维修管理系统

项目介绍 Spring Boot 是 Spring 家族中的一个全新的框架,它用来简化Spring应用程序的创建和开发过程。也可以说 Spring Boot 能简化我们之前采用SSM(Spring MVC Spring MyBatis )框架进行开发的过程。 系统基于B/S即所谓浏览器/服务器模式…

STM32 学习笔记_9 定时器中断:编码器接口模式

TIM编码器接口 之前我们处理旋转编码器,是转一下中断一次,挺消耗资源的。 我们可以利用TIM的编码器功能,隔一段时间取一下旋转器值使得cnt或–,以此判断旋转位置以及计算速度,相比中断节约资源。相当于外接了一个有方…

Kubernetes那点事儿——暴露服务之Service

Kubernetes那点事儿——暴露服务之Service 前言一、Service二、Service与Pod关系三、Service常用类型ClusterIPNodePortLoadBalancer 四、Service代理模式IptablesIPVS修改代理模式 前言 K8s中,我们将应用跑在Pod里。多数情况下是一组Pod,用户如何访问这…

凌恩生物美文分享 | 提升科研有一套 | 宏基因组磷循环分析又出新!

磷是包括微生物在内的所有生命体中不可缺少的元素。在生物大分子核酸、高能量化合物ATP、以及生物体内糖代谢的某些中间体中,都有磷的存在。在自然界中,磷的循环包括可溶性无机磷的同化、有机磷的矿化、不溶性磷的溶解等。微生物分解含磷化合物的作用&am…

操作系统面试相关知识

目录 一、简介1、什么是操作系统2、操作系统主要有哪些功能? 二、操作系统结构1、什么是内核?2、什么是用户态和内核态?3、 用户态和内核态是如何切换的? 三、 进程和线程1、并行和并发有什么区别?2、什么是进程上下文…

无线传感器网络的时钟同步估计问题(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 👨‍💻4 Matlab代码 💥1 概述 随着无线传感器网络的快速发展,其应用领域也越来越广。在诸多的应用环境中都需要大量已同步的传感器节点通过协同作用执行一个…

Python 中IndexError: list assignment index out of range 错误解决

文章目录 Python IndexError:列表分配索引超出范围修复 Python 中的 IndexError: list assignment index out of range修复 IndexError: list assignment index out of range 使用 append() 函数修复 IndexError: list assignment index out of range 使用 insert()…

怎么把文本翻译成英文?安利三个文本翻译方法

在当今全球化的时代,跨国交流和合作已经成为常态。然而,不同语言之间的沟通障碍经常阻碍着信息传递和理解。为了帮助我们更好地进行国际交流,文本翻译英文软件应运而生。这类软件能够将各种语言的文本迅速准确地翻译成英文,使我们…

【起飞】让你电脑速度快到飞起的一些牛逼的设置整理【电脑卡顿反应慢等问题解决】

对于开发来说电脑的反应速度简直影响了思维的速度,要让电脑速度跟上我们的思维,提高工作效率,早点打卡下班回家陪老婆孩子哈哈 这篇文章主要对windows系统做的一些优化,是真的好用,仿佛在访问静态页面一样,…

超实用!年薪40W的项目经理都在用的6个项目管理软件

项目管理软件是帮助团队进行项目计划、任务分配、进度跟踪和团队协作等方面的工具,已经成为了项目经理必不可少的工具之一。 市面上的项目管理软件有很多,这就来分享一下几款我认为好用的项目管理软件! 一、六款好用的项目管理软件 1.简道…

C++开发工具 VTK技术实现三维重建CT医学影像PACS系统

一、信息管理 1、支持对患者、检查项目、申请医生、申请单据、设备等信息进行管理; 2、支持检查病人排队管理功能; 3、支持大屏幕队列显示和语音呼叫; 4、提供预约调整、插队管理和掉队处理等功能; 5、支持急诊申请优先安排。…

美股股指期货重要吗?要注意哪些风险?

美股股指期货是一种以美股股票价格指数作为标的物的金融期货合约。美股股指期货的表现对全球股指市场都有重要的影响力,具体体现在以下方面。 美股股指期货成为全球股指市场风向标 自20世纪70年代起,美国芝加哥商品交易所(CME)推…