从零实现深度学习框架——注意力机制

news2024/9/21 4:31:06

引言

本着“凡我不能创造的,我就不能理解”的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导。
💡系列文章完整目录: 👉点此👈
要深入理解深度学习,从零开始创建的经验非常重要,从自己可以理解的角度出发,尽量不适用外部框架的前提下,实现我们想要的模型。本系列文章的宗旨就是通过这样的过程,让大家切实掌握深度学习底层实现,而不是仅做一个调包侠。

本文我们来了解注意力机制。在前面几篇文章讨论的seq2seq模型,非常依赖上下文向量。而上下文向量只是最后一个时间步的隐藏状态。它可能会成为一个瓶颈,因为输入序列的长度可能是任意的,我们期望一个固定大小的隐藏向量保存任意长度的信息是不可能的。

注意力机制就为了解决这个瓶颈,它允许解码器可以从编码器所有的隐藏状态中获取信息,而不仅仅是最后一个隐藏状态。

注意力机制

注意力机制是论文NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE提出来的,建议去看原文,也可以看笔者尝试的翻译 :[论文翻译]NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE。

image-20230704153143022

传统的seq2seq模型,编码器读取了整个输入序列后,会为每个时间步都生成一个隐藏状态,但只是取最后一个隐藏状态作为固定大小的上下文向量传递给解码器。

而在注意力机制中,上下文向量通过一个函数生成,该函数读编码器所有时刻的隐藏状态:
c = f ( h 1 e , ⋯   , h n e ) (1) \pmb c= f(\pmb h_1^e,\cdots,\pmb h_n^e) \tag 1 c=f(h1e,,hne)(1)
这里用 c \pmb c c表示这个上下文向量。因为输入序列的长度 n n n是不固定的,所以我们一般还是只能用一个向量来表示。

那为什么叫注意力呢?因为就像我们人类注意一副图片一样。

0_1

比如说上面这幅图片,博主第一眼注意到的是里面的人。随着眼球的移动又看到了电线杆啥的。

对于注意力机制来说,也是类似的,虽然编码器输出了 n n n个隐藏状态,但我们只会注意到少数部分,就像我们每次都看到图片的某个部分一样。怎么强调注意到少数部分呢?实际上通过权重参数来实现的。就是给这几个少数部分增加较大的权重,其他剩下的部分很小的权重。

对于解码器正要生成的token来说,通过为编码器生成的每个隐藏状态赋予不同的权重来仅注意此时感兴趣的相关的(编码器的)输入,然后编码器器的每个隐藏状态都会有一个权重,对它们计算加权和就得到了我们要的上下文向量 c \pmb c c:
c i = ∑ j = 1 n α i j h j e (2) \pmb c_i = \sum_{j=1}^n \alpha_{ij} \pmb h_j^e \tag 2 ci=j=1nαijhje(2)
由于是与解码器的位置相关的, c i \pmb c_i ci表示在解码步骤 i i i重新生成的,从公式可以看出,它考虑了所有编码器的隐藏状态。然后,在解码过程中计算当前时间步解码器隐藏状态时就可以使用此上下文:
h i d = g ( y ^ i − 1 , h i − 1 d , c i ) (3) \pmb h_i^d = g(\hat y_{i-1}, \pmb h_{i-1}^d, \pmb c_i) \tag{3} hid=g(y^i1,hi1d,ci)(3)

这个公式和我们上篇文章中看到的公式很像:
h i d = g ( y ^ i − 1 , h i − 1 d , c ) \pmb h_i^d = g(\hat y_{i-1}, \pmb h_{i-1}^d , \pmb c) hid=g(y^i1,hi1d,c)
不过在公式 ( 3 ) (3) (3)中的上下向量 c i \pmb c_i ci不再是静态的。

从公式 ( 3 ) (3) (3)也可以看出重要的一点,即先后问题,先有 c i \pmb c_i ci后才有的 h i d \pmb h_i^d hid,说明无法通过当前位置解码器的隐藏状态来计算 c i \pmb c_i ci,而是通过上一位置的隐藏状态 h i − 1 d \pmb h_{i-1}^d hi1d计算的!

只要能通过某种机制生成注意力权重,那么在解码器不同的时间步上可以注意到编码器不同的输入,得到的上下文向量也是不同的。相当于此时的上下文向量不再是固定的,而是解码器动态生成的感兴趣的。

现在的问题是,这个注意力权重要如何生成。

注意力权重也可以理解为相关性,说到相关性大家应该能想到余弦相似度,对于两个向量 a \pmb a a b \pmb b b来说,它们的余弦相似度就是:
similarity = a ⋅ b ∣ ∣ a ∣ ∣   ∣ ∣ b ∣ ∣ \text{similarity} = \frac{\pmb a \cdot \pmb b}{||\pmb a|| \,||\pmb b||} similarity=∣∣a∣∣∣∣b∣∣ab
分母是为了归一化,得到的是一个标量,那么不难想到一种最简单的方法就是计算这两个向量的点乘,这种方法称为点乘注意力(dot-product attention):
score ( h i − 1 d , h j e ) = h i − 1 d ⋅ h j e (4) \text{score}(\pmb h_{i-1}^d,\pmb h^e_j) = \pmb h^d_{i-1}\cdot \pmb h^e_j \tag{4} score(hi1d,hje)=hi1dhje(4)
这里再次强调一下,上标 d d d表示Decoder,解码器;上标 e e e表示Encoder,编码器。

所以是用解码器第 i − 1 i-1 i1个位置的隐藏状态与编码器第 j j j个位置的隐藏状态做点积得到的标量(得分)作为注意力得分,这个得分的取值范围应该是实数空间。

编码器共有 n n n个输入,我们就可以通过 h i − 1 d \pmb h_{i-1}^d hi1d得到 n n n个注意力得分,即 n n n个标量,这 n n n个标量如何变成权重呢?很简单,通过Softmax即可:
α i j = softmax ( score ( h i − 1 d , h j e ) ) = exp ⁡ ( score ( h i − 1 d , h j e ) ) ∑ k = 1 n exp ⁡ ( score ( h i − 1 d , h k e ) ) (5) \begin{aligned} \alpha_{ij} &= \text{softmax}(\text{score}(\pmb h^d_{i-1},\pmb h^e_j)) \\ &= \frac{\exp(\text{score}(\pmb h^d_{i-1},\pmb h^e_j))}{\sum_{k=1}^n \exp(\text{score}(\pmb h^d_{i-1},\pmb h^e_k))} \end{aligned} \tag 5 αij=softmax(score(hi1d,hje))=k=1nexp(score(hi1d,hke))exp(score(hi1d,hje))(5)
这样得到的 α \alpha α就是0到1之间的一个权重。

再利用公式 ( 2 ) (2) (2)就可以得到当前时间步 i i i需要的上下文向量 c i \pmb c_i ci

最后用一张图片来总结注意力计算的过程:

image-20230705113431666

  • 在计算 h i d h_i^d hid时先用 h i − 1 d h_{i-1}^d hi1d与编码器所有的隐状态计算点积,然后计算Softmax得到权重 α \alpha α
  • 用权重 α \alpha α和编码器所有的隐状态计算加权和,得到上下文向量 c i c_i ci
  • c i , h i − 1 d c_i,h_{i-1}^d ci,hi1d和上一步的输出 y i y_i yi计算当前时间步的隐状态 h i d h_i^d hid
  • 利用 h i d h_i^d hid计算当前时间步的输出 y i + 1 y_{i+1} yi+1

以上就是注意力机制的内容,是不是也挺简单的。

我们上面介绍的是点乘注意力,其实还有很多其他的注意力。

常见注意力方式

为了表示方便,我们用 q \pmb q q k \pmb k k分别表示解码器的隐藏状态和编码器的隐藏状态向量。

q i \pmb q_i qi是某个时刻 i i i的隐藏状态向量,它的形状是batch_size, decoder_hidden_size

k \pmb k k是编码器最顶层输出的所有隐藏状态向量,它的形状是src_len, batch_size, encoder_hidden_size

点积注意力

点积注意力,我们上面介绍的。它要求编码器的
score ( q i , k ) = q i ⋅ k (6) \text{score}(\pmb q_i,\pmb k) = \pmb q_i \cdot \pmb k \tag{6} score(qi,k)=qik(6)
优点是计算效率高,只需要计算点积。

缩放点积注意力

score ( q i , k ) = q i ⋅ k d k (7) \text{score}(\pmb q_i,\pmb k) = \frac{\pmb q_i\cdot \pmb k}{\sqrt{d_k}} \tag{7} score(qi,k)=dk qik(7)

但是点积运算有很大的方差,导致Softmax函数的梯度较小。因此缩放点积注意力除以一项来平滑分数的值,来缓解这个问题。

General注意力

score ( q i , k ) = q i T W k (8) \text{score}(\pmb {q_i},\pmb k) = \pmb q_i^T W\pmb k \tag{8} score(qi,k)=qiTWk(8)

和点积注意相比,引入了一个权重 W W W,使得编码器和解码器的大小可以不一致,在计算相似度时引入了非对称性。

上式可以改成为 q i T W k = q i T ( U T V ) k = ( U q i ) T ( V k ) \pmb q_i^T W\pmb k = \pmb q_i^T (U^TV)\pmb k = (U \pmb q_i )^T (V\pmb k) qiTWk=qiT(UTV)k=(Uqi)T(Vk),即分别对 q \pmb q q k \pmb k k进行线性变换之后再计算点积。

加性注意力

score ( q i , k ) = v T tanh ⁡ ( W q i + U k ) (9) \text{score}(\pmb {q_i},\pmb k) = \pmb v^T \tanh(W\pmb q_i + U\pmb k) \tag{9} score(qi,k)=vTtanh(Wqi+Uk)(9)

加性注意力引入了可学习的参数,将 q \pmb q q k \pmb k k映射到不同的空间后进行打分。

各种注意力做的事情其实都是为了在生成输出词时,考虑每个输入词和当前输出词的对齐关系,对齐越好的词,会有越大的权重,对生成当前输出词的影响也越大。

代码实现

class Attention(nn.Module):
    def __init__(self, enc_hid_dim=None, dec_hid_dim=None, method: str = "dot") -> None:
        """

        Args:
            enc_hid_dim: 编码器的隐藏层大小
            dec_hid_dim: 解码器的隐藏层大小
            method:  dot | scaled_dot | general | bahdanau | concat

        Returns:

        """
        super().__init__()
        self.method = method
        self.encoder_hidden_size = enc_hid_dim
        self.decoder_hidden_size = dec_hid_dim

        if self.method not in ["dot", "scaled_dot", "general", "bahdanau", "concat"]:
            raise ValueError(self.method, "is not an appropriate attention method.")

        if self.method == "general":
            self.linear = nn.Linear(self.encoder_hidden_size, self.decoder_hidden_size, bias=False)
        elif self.method == "bahdanau":
            self.W = nn.Linear(self.decoder_hidden_size, self.decoder_hidden_size, bias=False)
            self.U = nn.Linear(self.encoder_hidden_size, self.decoder_hidden_size, bias=False)
            self.v = nn.Linear(self.decoder_hidden_size, 1, bias=False)
        elif self.method == "concat":
            # concat
            self.linear = nn.Linear((self.encoder_hidden_size + self.decoder_hidden_size), self.decoder_hidden_size,
                                    bias=False)
            self.v = nn.Linear(self.decoder_hidden_size, 1, bias=False)

    def _score(self, hidden: Tensor, encoder_outputs: Tensor) -> Tensor:
        """

        Args:
            hidden: (batch_size, decoder_hidden_size)  解码器前一时刻的隐藏状态
            encoder_outputs: (src_len, batch_size, encoder_hidden_size) 编码器的输出(隐藏状态)序列

        Returns:

        """

        src_len, batch_size, encoder_hidden_size = encoder_outputs.shape

        if self.method == "dot":
            # 这里假设编码器和解码器的隐藏层大小一致,如果不一致,不能直接使用点积注意力
            # (batch_size, hidden_size) * (src_len, batch_size, hidden_size) -> (src_len, batch_size, hidden_size)
            # (src_len, batch_size, hidden_size).sum(axis=2) -> (src_len, batch_size)
            return (hidden * encoder_outputs).sum(axis=2)
        elif self.method == "scaled_dot":
            # 和点积注意力类似,不过除以sqrt(batch_size),也可以指定一个自定义的值
            return (hidden * encoder_outputs / math.sqrt(encoder_hidden_size)).sum(axis=2)
        elif self.method == "general":
            # energy = (src_len, batch_size, decoder_hidden_size)
            energy = self.linear(encoder_outputs)
            # (batch_size, decoder_hidden_size)  * (src_len, batch_size, decoder_hidden_size)
            # -> (src_len, batch_size, decoder_hidden_size)
            # .sum(axis=2) -> (src_len, batch_size)
            return (hidden * energy).sum(axis=2)
        elif self.method == "bahdanau":
            # hidden = (batch_size, decoder_hidden_size)
            # encoder_outputs = (src_len, batch_size, encoder_hidden_size)
            # W(hidden) -> (batch_size, decoder_hidden_size)
            # U(encoder_outputs) -> (src_len, batch_size, decoder_hidden_size)
            # energy = (src_len, batch_size, decoder_hidden_size)
            energy = F.tanh(self.W(hidden) + self.U(encoder_outputs))
            # v(energy) ->  (src_len, batch_size, 1)
            # squeeze -> (src_len, batch_size)
            return self.v(energy).squeeze(2)
        else:
            # concat
            # unsqueeze -> (1, batch_size, decoder_hidden_size)
            # hidden = (src_len, batch_size, decoder_hidden_size)
            hidden = hidden.unsqueeze(0).repeat(src_len, 1, 1)
            #  encoder_outputs = (src_len, batch_size, encoder_hidden_size)
            # cat -> (src_len, batch_size,  encoder_hidden_size +decoder_hidden_size)
            # linear -> (src_len, batch_size, decoder_hidden_size)
            energy = F.tanh(self.linear(F.cat((hidden, encoder_outputs), axis=2)))
            # v -> (src_len, batch_size , 1)
            # squeeze -> (src_len, batch_size)
            return self.v(energy).squeeze(2)

    def forward(self, hidden: Tensor, encoder_outputs: Tensor) -> Tuple[Tensor, Tensor]:
        """

        Args:
            hidden: (batch_size, decoder_hidden_size)  解码器前一时刻的隐藏状态
            encoder_outputs: (src_len, batch_size, encoder_hidden_size) 编码器的输出(隐藏状态)序列

        Returns: 注意力权重, 上下文向量

        """
        # (src_len, batch_size)
        attn_scores = self._score(hidden, encoder_outputs)
        # (batch_size, src_len)
        attn_scores = attn_scores.T
        # (batch_size, 1, src_len)
        attention_weight = F.softmax(attn_scores, axis=1).unsqueeze(1)
        # encoder_outputs = (batch_size, src_len, num_hiddens)
        encoder_outputs = encoder_outputs.transpose((1, 0, 2))
        # context = (batch_size, 1, num_hiddens)
        context = F.bmm(attention_weight, encoder_outputs)
        # context = (1, batch_size, num_hiddens)
        context = context.transpose((1, 0, 2))

        return attention_weight, context

上面是各种注意力机制的代码实现,其中点积注意力要求编码器和解码器隐藏层大小一致。

参考

  1. Seq2Seq中常见注意力机制的实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/729542.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

网络数据包的监听与分析——IP数据报文分析

1. 抓包工具下载 x下面是一个IP数据报的抓包软件——IPtool的蓝奏云下载链接 https://wwix.lanzoue.com/iaGpy11klpnc 2. iptool使用 下载解压之后,右击以管理员身份运行,打开该exe文件即可 然后点击绿色运行就开始捕包了 随便点一个包进去进行分析就可…

Java文档

API(Application Programming Interface,应用程序编程接口)是Java提供的基本编程接口(java提供的类还有相关方法)。中文在线文档:Java 8 中文版 - 在线API手册 - 码工具 (matools.com) Java语言提供了大量…

Java项目 仿天猫商城系统(springboot+mybatis+mysql+maven+jsp)

基于springbootmybatismysqlmavenjsp仿天猫商城系统 一、系统介绍1、系统主要功能:2.涉及技术框架: 二、功能展示三、其它系统四、获取源码 一、系统介绍 1、系统主要功能: 项目主要参考天猫商城的购物流程:用户从注册开始&…

浅谈线段树

1.前言 Oi-Wiki上的线段树 同步于 c n b l o g s cnblogs cnblogs发布。 如有错误,欢迎各位 dalao 们指出。 注:本篇文章个人见解较多,如有不适,请谅解。 前置芝士 1.二叉树的顺序储存 2.线段树是什么? 线段树…

“量贩零食”热潮袭来:真风口还是假繁荣?

以前只听过量贩式KTV,现在“量贩零食店”也出现在了大街小巷。 高考结束后,家住武汉的花花频繁逛起了量贩零食店。这类店把各种零食集合在一起销售,用低价来换取高销量,主打一个性价比。店里的散装零食即便按斤售卖,也…

蛋白组学 差异蛋白分析 富集分析 go kegg

生信学习day1-蛋白组分析 蛋白质组差异分析的三个R包 - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/448479536Introduction to DEP (bioconductor.org)http://www.bioconductor.org/packages/release/bioc/vignettes/DEP/inst/doc/DEP.html#lfq-based-dep-analysis浅谈蛋白…

阿里云coluder认证训练营开班!

在这个充满机遇和挑战的时代里,云计算已经成为推动企业创新和发展的关键技术。而作为云计算人才培训领域的领军企业,摩尔狮致力于培养更多优秀的云技术人才, 所以摩尔狮联合阿里云为大家打造了免费的云计算入门课程——Clouder认证集训营&…

全方位了解VR全景展示与制作

引言: 虚拟现实(VR)技术正在以惊人的速度改变我们的生活方式和体验方式。其中,VR全景展示与制作作为虚拟现实的重要应用之一,为用户提供了身临其境的视听体验。 一、了解VR全景展示与制作 1.VR全景展示 VR全景展示是…

JDK,JRE,JVM的区别

1.JVM JVM,也叫java虚拟机,用来运行字节码文件,可将字节码翻译为机器码,JVM是实现java跨平台的关键,可以让相同的java代码在不同的操作系统上运行出相同的结果。 2.JRE JRE,也叫java运行时环境&#xff…

【JS】javascript学习笔记

step by step. 目录 严格区分大小写 点击事件: JavaScript关键字/语句标识符 数据类型 对象Object 创建对象方法 事件 循环 标签 正则表达式 异常 未定义adddlert-> throw—— 调试工具debugger ​编辑 JS严格模式 表单 严格区分大小写 点击事件&am…

python3 学习笔记

一、注释 1.单行注释:# 开头 2.多行注释: 和 """ 二、缩进 python是使用缩进来表示代码块,不需要使用大括号{} python具有严格的缩进原则,每个缩进一般可以有两个或四个空格组成,也可以是任意数量的…

深度学习常用优化器总结

一、优化器的定义 优化器(optimizer)本质上是一种算法,用于优化深度学习模型的参数,通过不断更新模型的参数来最小化模型损失。在选择优化器时,需要考虑模型的结构、模型的数据量、模型的目标函数等因素。 二、常用…

web前端总结(一)HTML标签

1.语法结构&#xff1a; <标签 属性 “值”>内容</标签> <p align "center">标签内容</P> 2.标签 1.标题标签&#xff1a; **标题标签 <h1> - <h6>&#xff08;重要&#xff09;** 为了使网页更具有语义化&#xff0c;我们…

stm32_<一文通>_cubemx_freertos

文章目录 前言一、任务调度1.1 延时1.1.1 相对延时1.1.2 绝对延时 1.2 挂起和恢复1.2.1 cmsis的挂起和恢复函数1.2.2 freertos的挂起和恢复函数 1.3 删除1.3.1 cmsis的删除任务函数1.3.2 freertos的删除任务函数 二、Freertos任务与中断三、消息队列3.1 写入和读取一个数据3.2 …

6阶高清视频滤波驱动MS1681

MS1681 是一个单通道视频缓冲器&#xff0c;它内部集成6dB 增益的轨到轨输出驱动器和6 阶输出重建滤波器。MS1681 的-3dB 带宽为35MHz&#xff0c;压摆率为160V/us。MS1681 比无源LC 滤波器与外加驱动的解决方案能提供更好的图像质量。它单电源供电范围为2.5V 到5.5V&#xff0…

什么是提示词工程师?

前言 你可能听说过人工智能模型&#xff0c;但你是否知道&#xff0c;背后的神奇之处源自于那些执着于提示设计和优化的专业人员&#xff1f;提示词工程师是引导我们与机器对话的幕后英雄&#xff0c;他们通过精心构造的提示&#xff0c;让模型理解我们的意图、解答问题&#…

React + TypeScript 实践

主要内容包括准备知识、如何引入 React、函数式组件的声明方式、Hooks、useRef<T>、useEffect、useMemo<T> / useCallback<T>、自定义 Hooks、默认属性 defaultProps、Types or Interfaces、获取未导出的 Type、Props、常用 Props ts 类型、常用 React 属性类…

zabbix基础4——自定义监控案例

文章目录 一、监控进程二、监控日志三、监控mysql主从四、监控mysql延迟 一、监控进程 示例&#xff1a;监控客户端上的httpd服务进程&#xff0c;当进程书少于1时&#xff0c;说明服务已经挂掉&#xff0c;需要及时处理。 1.客户端开启自定义监控功能。 vim /usr/local/etc/…

YApi-高效、易用、功能强大的可视化接口管理平台——(一)使用 Docker 本地部署

Docker 本地部署 YApi 安装 Docker安装设置 USTC 镜像启动 Docker Docker 安装 MongoDBDocker 安装 YApi登录 YApi 本内容以虚拟机【系统&#xff1a;Centos7】为例&#xff0c;云服务器步骤相同。使用Docker 的方式搭建 YApi&#xff0c;拉取 MongoDB 镜像和 YApi 镜像即可。 …

SpringBoot学习——追根溯源servlet是啥,tomcat是啥,maven是啥 springBoot项目初步,maven构建,打包 测试

目录 引出追根溯源&#xff0c;过渡衔接servlet是啥&#xff1f;tomcat是啥&#xff1f; 前后端开发的模式1.开发模式&#xff1a;JavaWeb&#xff1a;MVC模型2.Web&#xff1a;Vue&#xff0c;MVVC模型3.后端相关3.1 同步与异步3.2 Controller层3.3 Service层&#xff1a;要加…