机器学习笔记:注意力机制中多头注意力的实现

news2025/1/9 12:52:31

目录

介绍

模型

代码实现 

引入库

单个注意力头

多个注意力头的实现

测试

思考


介绍

在注意力机制中,单个注意力学到的东西有限,可以通过对不同的注意力进行组合,学到不同的知识,以达到想要的目的。因此采用”多头注意力“的方法进行实现,即有多个注意力”头“,对其进行连结得到输出。

模型

首先,对于我们输入的查询,以及每一个键值对,都有需要学习的一系列权重参数W,另外,注意力汇聚函数f也需要学习得到。 多头注意力的输出需要经过另一个线性转换, 它对应着h个头连结后的结果,因此这里也有一个参数需要进行学习。基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

代码实现 

引入库

首先引入深度学习相关的库。

import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

单个注意力头

这里,我们使用缩放点积注意力,先来对每一个注意力头进行实现。这里首先说明一点,即可以设定p_q=p_k=p_v=\frac{p_o}{h}。如果将查询、键和值的线性变换的输出数量设置为p_qh=p_kh=p_vh=p_o,则可以并行计算h个头。

[注,原文如此,但是我其实完全没有明白它这里在说什么,我不知道为什么这样设置。]

详解

这里定义一个多头注意力类,定义其头的数量,并定义隐藏层数量,以实现缩放点击注意力。在前向计算时,注意queries.shape=(batchSize,queryNum,numHiddens),key.shape=values.shape=(batchSize,k-vNum,numHiddens)。经过变换后,输出的queries,keys,values 的形状:  (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)

class MultiHeadAttention(nn.Block):
    """多头注意力"""
    def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
                 **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
        self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

    def forward(self, queries, keys, values, valid_lens):
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            valid_lens = valid_lens.repeat(self.num_heads, axis=0)
            output = self.attention(queries, keys, values, valid_lens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

多个注意力头的实现

为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens) 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)

def transpose_qkv(X, num_heads):
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.transpose(0, 2, 1, 3)

    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.transpose(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

测试

下面使用键和值相同的小例子来测试我们编写的MultiHeadAttention类。 多头注意力输出的形状是(batch_sizenum_queriesnum_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

输出结果和我们的想法是一样的:

(2, 4, 100) 

思考

  1. 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2045885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows 安装 Mysql

一、安装Mysql 下载完成后直接双击进行安装 安装一路默认 如下图所示,在MySQL Servers/MySQL Server/MySQL Server 5.7的下方找到MySQL Server 5.7.41 - X64,然后选中它,点击两框之间的第一个箭头,将其移到右边的框中 点击Exe…

接口基础知识8_详解response header(响应头)

课程大纲 一、定义 HTTP响应头(HTTP Response Header):在HTTP协议中用于描述服务器响应的元数据。 它是服务器在响应客户端请求时,发送给客户端的一部分响应信息,包含了服务器的相关配置和响应内容的描述。 二、常见…

[机器学习]--KNN算法(K邻近算法)

KNN (K-Nearest Neihbor,KNN)K近邻是机器学习算法中理论最简单,最好理解的算法,是一个 非常适合入门的算法,拥有如下特性: 思想极度简单,应用数学知识少(近乎为零),对于很多不擅长数学的小伙伴十分友好虽然算法简单,但效果也不错 KNN算法原理 上图是每一个点都是一个肿瘤病例…

【C++深度探索】unordered_set、unordered_map封装

🔥 个人主页:大耳朵土土垚 🔥 所属专栏:C从入门至进阶 这里将会不定期更新有关C/C的内容,欢迎大家点赞,收藏,评论🥳🥳🎉🎉🎉 文章目录…

CSS继承、盒子模型、float浮动、定位、diaplay

一、CSS继承 1.文字相关的样式会被子元素继承。 2.布局样式相关的不会被子元素继承。(用inherit可以强行继承) 实现效果: 二、盒子模型 每个标签都有一个盒子模型,有内容区、内边距、边框、外边距。 从内到外:cont…

基于 Android studio 实现停车场管理系统--原创

目录 一、项目演示 二、开发环境 三、项目页面 四、项目详情 五、项目完整源码 一、项目演示 二、开发环境 三、项目详情 1.启动页 这段代码是一个简单的Android应用程序启动活动(Activity),具体功能如下: 1. **延迟进入登…

计算机网络三级笔记--原创 风远 恒风远博

典型设备中间设备数据单元网络协议物理层中继器、集线器中继器、集线器数据位(bit) binary digit二进 制数据的缩写HUB使用了光纤、 同轴电缆、双绞 线.数据链路层网卡、网桥、交换机网桥、交换机数据帧(Frame)STP、ARQ、 SW、CSMA/CD、 PPP(点对点)、 HDLC、ATM网络层路由器、…

SQL注入(cookie、base64、dnslog外带、搜索型注入)

目录 COOKIE注入 BASE64注入 DNSLOG注入—注入判断 什么是泛解析? UNC路径 网上邻居 LOAD_FILE函数 搜索型注入—注入判断 本文所使用的sql注入靶场为sqli-labs-master,靶场资源文件已上传,如有需要请前往主页或以下链接下载 信安必备…

【漫谈C语言和嵌入式002】嵌入式中的大小端

在计算机科学中,"端序"(Endianness)是指多字节数据类型(如整数或浮点数)在内存中的存储方式。主要分为两种:大端模式(Big-Endian)和小端模式(Little-Endian&am…

星戈瑞FITC-DXMS荧光素标记地塞米松不同方向的应用

FITC-DXMS,全称异硫氰基荧光素-地塞米松,是一种创新的科研试剂。他是由FITC-NH2的(-NH2)氨基与地塞米松的-OH(羟基)结合。它结合了地塞米松的特性和荧光素的高灵敏度标记技术,为医药研究、生物医…

栈与括号匹配——20、636、591、32(简中难难)

20. 有效的括号(简单) 给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效。 有效字符串需满足: 左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭…

springboot的学习(二):常用配置

简介 springboot的各种常用的配置。 springboot 项目是要打成jar包放到服务器上运行的。 打包 idea上使用maven打包的时候,会执行自动测试,可能会对数据库中的数据有影响,先点跳过测试,在点package。 运行 Windows上运行的…

新闻资讯小程序的设计

管理员账户功能包括:系统首页,个人中心,新闻类别管理,新闻信息管理,用户管理,管理员管理,系统管理 微信端账号功能包括:系统首页,新闻信息,我的 开发系统&a…

极市平台 | 如何通俗理解扩散模型?

本文来源公众号“极市平台”,仅用于学术分享,侵权删,干货满满。 原文链接:如何通俗理解扩散模型? 极市导读 还有谁没有看过diffusion的工作,席卷AI圈的diffusion到底是什么?本文作者用尽量通…

tcpdump快速入门及实践手册

tcpdump快速入门及实践手册 1. 快速入门 [1]. 基本用法 基本用法: tcpdump [选项 参数] [过滤器 参数] [rootkysrv1 pwe]# tcpdump -h tcpdump version 4.9.3 libpcap version 1.9.1 (with TPACKET_V3) OpenSSL 1.1.1f 31 Mar 2020 Usage: tcpdump [-aAbdDefhH…

Python爬虫使用实例

IDE:大部分是在PyCharm上面写的 解释器装的多 → 环境错乱 → error:没有配置,no model 爬虫可以做什么? 下载数据【文本/二进制数据(视频、音频、图片)】、自动化脚本【自动抢票、答题、采数据、评论、点…

3.2 实体-关系模型(ER模型)

欢迎来到我的博客,很高兴能够在这里和您见面!欢迎订阅相关专栏: 工💗重💗hao💗:野老杂谈 ⭐️ 全网最全IT互联网公司面试宝典:收集整理全网各大IT互联网公司技术、项目、HR面试真题.…

Keycloak中授权的实现-转载

在Keycloak中实现授权,首先需要了解与授权相关的一些概念。授权,简单地说就是某个(些)用户或者某个(些)用户组(Policy),是否具有对某个资源(Resource&#xf…

基于SpringBoot的餐饮订单系统-计算机毕业设计源码39867

摘 要 随着现代生活节奏的加快和人们对便捷餐饮服务的需求不断增长,基于Spring Boot的餐饮订单系统的设计与实现成为当前研究的关键课题。本研究旨在开发一款包括首页、通知公告、餐饮资讯、餐饮菜单、商城管理等功能模块的系统,旨在提供便捷高效的餐饮订…

了解一下内测系统

内测系统是什么? 在软件或应用程序开发的过程中,供开发人员进行测试和调试的系统。 内测系统的作用是什么? 达到让用户使用游戏或者软件的时候体验感更好、减少风险、方便开发者更好的找到并解决自己软件中的问题。测试好后的app可以将自己的…