《动手学深度学习 Pytorch版》 10.5 多头注意力

news2024/9/20 10:49:19

多头注意力(multihead attention):用独立学习得到的 h 组不同的线性投影(linear projections)来变换查询、键和值,然后并行地送到注意力汇聚中。最后,将这 h 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。

对于 h 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。

在这里插入图片描述

10.5.1 模型

用数学语言描述多头注意力:

h i = f ( W i ( q ) q , W i ( k ) k , W i ( v ) v ) ∈ R p \boldsymbol{h}_i=f(\boldsymbol{W}_i^{(q)}\boldsymbol{q},\boldsymbol{W}_i^{(k)}\boldsymbol{k},\boldsymbol{W}_i^{(v)}\boldsymbol{v})\in\R^p hi=f(Wi(q)q,Wi(k)k,Wi(v)v)Rp

参数字典:

  • f f f 表示注意力汇聚函数

  • q ∈ R d q \boldsymbol{q}\in\R^{d_q} qRdq k ∈ R d k \boldsymbol{k}\in\R^{d_k} kRdk v ∈ R d v \boldsymbol{v}\in\R^{d_v} vRdv 分别是查询、键和值

  • W i ( q ) ∈ R p d × d q \boldsymbol{W}_i^{(q)}\in\R^{p_d\times d_q} Wi(q)Rpd×dq W i ( k ) ∈ R p k × d k \boldsymbol{W}_i^{(k)}\in\R^{p_k\times d_k} Wi(k)Rpk×dk W i ( v ) ∈ R p v × d v \boldsymbol{W}_i^{(v)}\in\R^{p_v\times d_v} Wi(v)Rpv×dv 均为可学习参数

多头注意力的输出需要经过另一个线性转换:

y = [ h 1 ⋮ h h ] ∈ R p o y= \begin{bmatrix} \boldsymbol{h}_1\\ \vdots\\ \boldsymbol{h}_h \end{bmatrix} \in\R^{p_o} y= h1hh Rpo

import math
import torch
from torch import nn
from d2l import torch as d2l

10.5.2 实现

在实现过程中通常选择缩放点积注意力作为每一个注意力头。为了避免计算代价和参数代价的大幅增长,设定 p q = p k = p v = p o / h p_q=p_k=p_v=p_o/h pq=pk=pv=po/h。值得注意的是,如果将查询、键和值的线性变换的输出数量设置为 p q h = p k h = p v h = p o p_qh=p_kh=p_vh=p_o pqh=pkh=pvh=po,则可以并行计算 h 个头。在下面的实现中, p o p_o po 是通过参数 num_hiddens 指定的。

MultiHeadAttention 类将使用下面定义的两个转置函数,transpose_output 函数反转了 transpose_qkv 函数的操作。转来转去是为了避免 for 循环。

#@save
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


#@save
def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
#@save
class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状: (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状: (batch_size,) 或 (batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状: (batch_size*num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()
MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
torch.Size([2, 4, 100])

练习

(1)分别可视化这个实验中的多个头的注意力权重。

d2l.show_heatmaps(attention.attention.attention_weights.reshape((2, 5, 4, 6)),
                  xlabel='Keys', ylabel='Queries', figsize=(5,5))


在这里插入图片描述

(2)假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?

不会,略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1135788.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

0024Java程序设计-毕业论文管理系统的设计与实现

文章目录 **目录**系统设计开发环境 随着高校的规模不断扩大,如何合理地利用教学资源、有效地加强教学管理工作,已成为各大高校关注的焦点。目前,在教学管理方面,特别是在学生学籍管理、成绩管理等方面都普遍采用了网络化管理手段,而对于课程设计这一重要的教学环节,则普遍采用…

暴跌5600亿!台积电没有想到,中国5G手机如此要命

自从国产5G手机上市以来,台积电的股价持续下跌,至今已下跌了蒸发770亿美元(约5600亿元),损失规模居亚洲之首,然而更可怕的是这款手机产生的影响正持续扩大,台积电可能面临生存危机。 一、台积电…

Maven项目用jetty在服务器部署与配置

Maven项目用jetty在服务器部署与配置 零.Jetty在服务器部署配置 0.1 修改jetty的默认端口 修改 $JETTY_HOME/etc/jetty.xml 文件, 将jetty.port的值改为指定自己需要的端口号即可, 默认为8080。 如下图 jetty 9 版本中,修改%JETTY_HOME%…

A股风格因子看板 (2023.10 第12期)

该因子看板跟踪A股风格因子,该因子主要解释沪深两市的市场收益、刻画市场风格趋势的系列风格因子,用以分析市场风格切换、组合风格暴露等。 今日为该因子跟踪第12期,指数组合数据截止日2023-09-30,要点如下 近1年A股风格因子检验统…

贝锐蒲公英推出二层组网功能,实现远程工业设备数据互通、扫描发现

工业物联是目前的发展趋势所在,包含人机互动、状态感知、设备监测、数据交互等应用场景,海量的设备需要实现互联网接入与管理能力。 但是,工业设备往往位于分散在各地的制造工厂或是户外,且不同地区通常使用了不同的网络运营商&am…

windows中毒

一.查看系统账户安全 1.查看服务器是否有弱口令、可疑账号、隐藏账号、克隆账号、远 程管理端口是否对公网开放 2.winr 查看他 二.检查异常端口 进程 查看端口 定位exe程序 3.另一种方法 d盾 火绒剑 xuetr 判断可疑进程 三.检查启动项 计划任务 服务 …

分享5款小而精的实用软件

分享是一种神奇的东西,它使快乐增大,它使悲伤减小。分享好用软件给大家的同时,我自己也能获得愉悦的心情。 1.鼠标点击特效——ClickShow ​ ClickShow是一款给鼠标点击加上特效的软件,可以让用户在点击鼠标时显示一层波纹特效,左键&#x…

ArcGIS中如何为跨带数据投影?

北京54、西安80高斯克吕格投影是我国常用的投影坐标系统,它们是一种分带投影方式,有3和6分带,不适合大范围内的投影使用。但是如果有份数据范围较大,跨越了多个度带,该选择哪个坐标系统进行投影转换呢? 在大范围内,常用的坐标系统有Albers等面积投影和Lambert等角投影,…

【C++面向对象】9. 重载

文章目录 【 1. 函数重载 】【 2. 运算符重载 】2.1 可重载运算符 / 不可重载运算符2.2 一元 运算符重载2.3 二元 运算符重载2.4 关系 运算符重载2.5 输入/输出 运算符重载2.6 和-- 运算符重载2.7 赋值 运算符重载2.8 函数调用() 运算符重载2.9 下标[ ] 运算符重载2.10 类成员访…

众和策略可靠吗?股权除息是好是坏?

可靠 股权除息,指的是公司在股息发放前,将公司股票分拆,以减少股东持有的股份,添加每股的股息金额。简略来说,就是将股份拆成更小的比例,每股股息也随之添加。 股权除息关于股东来说,好坏参半…

YB5302是一款工作于2.7V到6.5V的PFM升压型双节锂电池充电控制集成电路

YB5302 锂电输入升压型双节锂电池充电芯片 概述: YB5302是一款工作于2.7V到6.5V的PFM升压型双节锂电池充电控制集成电路。YB5302采用恒流和准恒压模式(Quasi-CVT™)对电池进行充电管理,内部集成有基准电压源,电感电流检测单元,电池电压检测电…

第二证券:企业债转常规后受理审核进入常态化运行阶段

第一批14单项目获受理 拟征集资金估计超550亿元 14单项目是企业债转常规后第一批受理的项目,标志企业债的受理、审理、发行等作业进入常态化运转阶段,企业债的审理透明度与功率将有用改善 沪深北证券生意所网站10月25日宣布的信息显现,14单…

CAN接口的PCB Layout规则要求汇总

随着时代高速发展,控制器局域网(CAN)接口的应用越来越广泛,尤其是在汽车电子、航空航天等领域中发挥着重要作用,为了确保CAN接口的可靠性和稳定性,工程师必须在其PCB Layout方面下功夫,下面来看…

酷开科技 | 酷开系统时时刻刻相伴你左右

作家张小娴曾说过一句话:陪伴,是最长情的告白。每个人都需要别人的陪伴,每个人也都要陪伴别人。无论是亲情、友情还是爱情,陪伴永远是这世间一切感情中最不可或缺的一部分。同样,酷开系统通过各种功能及大内容战略陪伴…

绩效考核有什么好处?除了考核员工外?

绩效考核的真正作用,根本不在”考核员工“!绩效考核的真正作用只有一个—— 辅助企业经营目标的达成。 只不过是因为企业想到达成这个经营目标,光靠老板是不可能的,必须靠”员工“,所以说考核员工只是手段&#xff0…

ITSS信息技术服务运行维护标准符合性证书申请详解及流程

ITSS信息技术服务运行维护标准符合性证书 认证介绍 ITSS(InformationTechnologyServiceStandards,信息技术服务标准,简称ITSS)是一套成体系和综合配套的信息技术服务标准库,全面规范了IT服务产品及其组成要素,用于指导实施标准化…

【文件加密软件】文字+视频超详细解析

文件加密软件是一种用于保护文件安全的工具,可以有效地防止未经授权的访问和数据泄露。随着信息化的不断发展,文件加密软件已成为企业和个人不可或缺的安全保障之一。 一、需求分析 文件加密软件的需求主要包括功能、性能和易用性等方面。 1、软件应具…

「实用技巧」后端如何使用 Eolink Apikit 快速调试接口?

程序员最讨厌的两件事: 写文档 别人不写文档 写文档、维护文档比较麻烦,而且费时,还会经常出现 API 更新了,但文档还是旧的,各种同步不一致的情况,从而耽搁彼此的时间,大多数开发人员不愿意写…

PDF编辑工具Acrobat Pro DC 2023中文

Acrobat Pro DC 2023是一款全面、高效的PDF编辑和管理软件。它提供了丰富的PDF编辑功能,如创建、编辑、合并、分割、压缩、旋转、裁剪等,让用户可以轻松处理各种PDF文档。同时,该软件还具有智能的PDF处理技术,可以自动识别和修复P…

智慧实验室系统云LIS全套源码,满足医院实验室、医院集团、独立实验室、临检中心及其它检验机构的专业化检验需求。

​电子化检验信息平台 智慧实验室系统云LIS全套源码 LIS系统是医院信息管理的重要组成部分之一,集申请、采样、核收、计费、检验、审核、发布、质控、查询、耗材控制等检验科工作为一体的网络管理系统。LIS系统不仅是自动接收检验数据,打印检验报告&…