【深度学习实验】注意力机制(三):打分函数——加性注意力模型

news2024/11/16 18:04:22

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
    • 3. 打分函数——加性注意力模型
      • 1. 初始化
      • 2. 前向传播
      • 3. 内部组件
      • 4. 实现细节
      • 5. 模拟实验
      • 6. 代码整合

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍打分函数——加性注意力模型

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

在这里插入图片描述

2. 掩码Softmax 操作

【深度学习实验】注意力机制(二):掩码Softmax 操作
在这里插入图片描述
PS:记录~一天两上热搜榜

3. 打分函数——加性注意力模型

s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)

AdditiveAttention 类实现了加性注意力机制。下面对该类的主要组件和功能进行详细介绍:

1. 初始化

  • 参数:
    • key_size: 键的维度。
    • query_size: 查询的维度。
    • num_hiddens: 隐藏层的维度。
    • dropout: Dropout 正则化的概率。
  • 说明: 定义了模型的各个组件,包括线性变换层和 Dropout 层。

2. 前向传播

  • 参数:
    • queries: 查询张量,形状为 (batch_size, num_queries, query_size).
    • keys: 键张量,形状为 (batch_size, num_kv_pairs, key_size).
    • values: 值张量,形状为 (batch_size, num_kv_pairs, value_size).
    • valid_lens: 有效长度张量,形状为 (batch_size,).
  • 返回值: 加权平均后的值张量,形状为 (batch_size, num_queries, value_size)

3. 内部组件

  • W_q: 将查询进行线性变换的层。
  • W_k: 将键进行线性变换的层。
  • w_v: 将特征进行线性变换的层,该层输出的形状为 (batch_size, num_queries, num_kv_pairs, 1),并通过 squeeze(-1) 操作得到注意力分数。
  • dropout: Dropout 正则化层。

4. 实现细节

  • 使用线性变换将查询和键映射到相同的隐藏维度。
  • 通过广播方式计算注意力得分。
  • 使用 tanh 激活函数将得分映射到 (-1, 1) 范围。
  • 计算注意力分数,并通过 masked_softmax 函数计算注意力权重。
  • 对注意力权重进行 Dropout 正则化。
  • 将注意力权重应用到值上,得到最终的加权平均结果。

  此加性注意力实现的目标是通过学习注意力权重,根据输入的查询和键对值进行加权平均,用于处理序列数据中的关联信息。

5. 模拟实验

  模拟输入数据并使用上述AdditiveAttention模型进行前向传播。

# 创建模拟的输入数据
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
  • queries: 一个形状为 (2, 1, 20) 的张量,表示两个查询。这个张量的第一个维度是批量大小,第二个维度是查询的个数,第三个维度是查询的特征维度。
    在这里插入图片描述

  • keys: 一个形状为 (2, 10, 2) 的张量,表示两个样本,每个样本包含10个键值对。每个键值对的维度为2。
    在这里插入图片描述

# 创建模拟的values数据
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
  • values: 一个形状为 (2, 10, 4) 的张量,表示两个样本,每个样本包含10个值。每个值的维度为4。
    在这里插入图片描述
# 创建模拟的有效长度数据
valid_lens = torch.tensor([2, 6])
  • valid_lens: 一个形状为 (2,) 的张量,表示两个样本的有效长度。在注意力机制中,这用于指定每个查询关注的键值对数量。
# 创建 AdditiveAttention 模型
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
  • 创建了一个 AdditiveAttention 的实例,指定了键的维度 key_size、查询的维度 query_size、隐藏层的维度 num_hiddens 和 Dropout 的概率 dropout
# 使用模型进行前向传播
attention(queries, keys, values, valid_lens)
  • 调用 attention 模型的前向传播方法,传入查询 queries、键 keys、值 values 和有效长度 valid_lens
  • 模型返回加权平均后的值,形状为 (2, 1, 4),表示两个查询的输出。

在这里插入图片描述

  • 绘制矩阵热图:
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
              xlabel='Keys', ylabel='Queries')

在这里插入图片描述

6. 代码整合

# 导入必要的库
import torch
from torch import nn
from d2l import torch as d2l


def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6)


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


class AdditiveAttention(nn.Module):
    """加性注意力"""

    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使用广播方式进行求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v仅有一个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)


queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))

values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
valid_lens = torch.tensor([2, 6])

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1)
attention.eval()
attention(queries, keys, values, valid_lens)
weights = attention.attention_weights
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
              xlabel='Keys', ylabel='Queries')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1234331.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL零基础入门教程,贼拉详细!贼拉简单! 速通数据库期末考!(十一)

COUNT() 计数函数 COUNT() 函数返回匹配指定条件的行数。 语法: 1.返回指定列的字段值条数 SELECT COUNT(column_name) FROM table_name;2.返回整表数据行条数 SELECT COUNT(*) FROM table_name;3.返回指定列去重后的字段值条数 SELECT COUNT(DISTINCT column_…

猫罐头牌子哪个好一点?精选5款口碑好的猫罐头推荐!

猫罐头牌子哪个好一点?选择猫罐头是十分重要的事情,千万不能将就。因为,好的猫罐头不仅可以营养丰富,水分充足,适口性好,还能易吸收。而一旦选择错误,不仅无法达到上述效果,还可能产…

服务号迁移到订阅号流程步骤

服务号和订阅号有什么区别?服务号转为订阅号有哪些作用?首先我们要知道服务号和订阅号有什么区别。服务号侧重于对用户进行服务,每月可推送4次,每次最多8篇文章,发送的消息直接显示在好友列表中。订阅号更侧重于信息传…

Stable Diffusion专场公开课

从SD原理、本地部署到其二次开发 分享时间:11月25日14:00-17:00 分享大纲 从扩散模型DDPM起步理解SD背后原理 SD的本地部署:在自己电脑上快速搭建、快速出图如何基于SD快速做二次开发(以七月的AIGC模特生成系统为例) 分享人简介 July&#…

通过AppLink把拼多多热门榜单商品同步至小红书

上篇说到AppLink当中定时调度方式如何配置,这次来演示一下,如何把热门榜单信息同步至小红书 1.拉取一个定时器作为触发动作,通过配置定时器调度时间将定时策略配置为每天执行一次 2.触发动作完成后通过好单库获取拼多多每日热门榜单&#xf…

循环链表3

插入函数——插入数据,在链表plsit的pos位置插入val数据元素 位置pos(在无特别说明的情况下)是从0开始计数的 要改变链表结构,就要依赖前驱,每个前驱的next存储着下一个数据结点的地址,也就是依靠前驱的ne…

笔尖笔帽检测3:Android实现笔尖笔帽检测算法(含源码 可是实时检测)

目录 1. 前言 2.笔尖笔帽检测方法 (1)Top-Down(自上而下)方法 (2)Bottom-Up(自下而上)方法: 3.笔尖笔帽关键点检测模型训练 4.笔尖笔帽关键点检测模型Android部署 (1) 将Pytorch模型转换ONNX模型 (2) 将ONNX模…

leetcode刷题日记:205. Isomorphic Strings(同构字符串)

205. Isomorphic Strings(同构字符串) 对于同构字符串来说也就是对于字符串s与字符串t,对于 s [ i ] s[i] s[i]可以映射到 t [ i ] t[i] t[i],同时对于任意 s [ k ] s [ i ] s[k]s[i] s[k]s[i]都有 s [ k ] s[k] s[k]映射到 t [ k ] t[k] t[k],则 t [ k ] t [ i …

海外IP代理科普——API代理是什么?怎么用?

随着互联网的不断发展,越来越多的企业开始使用API(应用程序接口)来实现数据的共享和交流。而在API使用中,海外代理IP也逐渐普及。那么,什么是API代理IP呢?它有什么作用?API接口有何用处&#xf…

Spring-IOC-@Value和@PropertySource

1、Book.java PropertySource(value"classpath:配置文件地址") 替代 <context:property-placeholder location"配置文件地址"/> Value("${book.bid}") Value("${book.bname}") Value("${book.price}") <bean id&…

SpringCloud原理-OpenFeign篇(三、FeignClient的动态代理原理)

文章目录 前言正文一、前戏&#xff0c;FeignClientFactoryBean入口方法的分析1.1 从BeanFactory入手1.2 AbstractBeanFactory#doGetBean(...)中对FactoryBean的处理1.3 结论 FactoryBean#getObject() 二、FeignClientFactoryBean实现的getObject()2.1 FeignClientFactoryBean#…

蓝桥杯物联网_STM32L071_1_CubMxkeil5基础配置

CubMx配置&#xff1a; project工程中添加.h和.c文件&#xff1a; keil5配置: 运行&#xff1a; 代码提示与解决中文乱码&#xff1a;

结合两个Python小游戏,带你复习while循环、if判断、函数等知识点

&#x1f490;作者&#xff1a;insist-- &#x1f490;个人主页&#xff1a;insist-- 的个人主页 理想主义的花&#xff0c;最终会盛开在浪漫主义的土壤里&#xff0c;我们的热情永远不会熄灭&#xff0c;在现实平凡中&#xff0c;我们终将上岸&#xff0c;阳光万里 ❤️欢迎点…

Google Chrome 任意文件读取 (CVE-2023-4357)漏洞

漏洞描述 该漏洞的存在是由于 Google Chrome 中用户提供的 XML 输入验证不足。远程攻击者可以创建特制网页&#xff0c;诱骗受害者访问该网页并获取用户系统上的敏感信息。远程攻击者可利用该漏洞通过构建的 HTML 页面绕过文件访问限制&#xff0c;导致chrome任意文件读取。Li…

Redis篇---第九篇

系列文章目录 文章目录 系列文章目录前言一、如果有大量的 key 需要设置同一时间过期,一般需要注意什么?二、什么情况下可能会导致 Redis 阻塞?三、缓存和数据库谁先更新呢?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击…

90天,广告商单43张,小红书AI庭院风视频制作详解教程

今天给大家分享一个目前在小红书很火的AI绘画商单号案例。 首先给大家看看案例视频形态 这类视频内容非常简单&#xff0c;主要展示农家庭院的别致景色。通过AI绘画工具生成图片&#xff0c;再利用剪辑工具将画面增加动态元素&#xff0c;让整个视频逼真鲜活&#xff0c;加上…

Selenium4+python被单独定义<div>的动态输入框和二级下拉框要怎么定位?

今天在做练习题的时候,发现几个问题捣鼓了好久,写下这篇来记录 问题一: 有层级的复选框无法定位到二级目录 对于这种拥有二级框的选项无法定位,也不是<select>属性. 我们查看下HTML,发现它是被单独封装在body内拥有动态属性的独立<div>,当窗口点击的时候才会触发…

【算法设计实验三】动态规划解决01背包问题

请勿原模原样复制&#xff01; 01背包dp具体解释详见链接 ↓ 【算法5.1】背包问题 - 01背包 &#xff08;至多最大价值、至少最小价值&#xff09;_背包问题求最小价值_Roye_ack的博客-CSDN博客 关于如何求出最优物品选择方案&#xff1f; 先在递归求dp公式时&#xff0c;若…

【广州华锐互动】VR溺水预防教育:在虚拟世界中学会自救!

在现代社会中&#xff0c;水上安全和救援行动的重要性不言而喻。尤其在自然灾害、游泳事故或航海事故中&#xff0c;有效的救援行动可以挽救许多生命。然而&#xff0c;传统的救援训练往往存在成本高、风险大、效率低等问题。在这样的背景下&#xff0c;虚拟现实&#xff08;VR…

vue3的两个提示[Vue warn]: 关于组件渲染和函数外部使用

1. [Vue warn]: inject() can only be used inside setup() or functional components. 这个消息是提示我们&#xff0c;需要将引入的方法作为一个变量使用。以vue-store为例&#xff0c;如果我们按照如下的方式使用&#xff1a; import UseUserStore from ../../store/module…