【深度学习实验】注意力机制(四):点积注意力与缩放点积注意力之比较

news2025/1/8 0:28:16

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
    • 3. 打分函数——加性注意力模型
    • 3. 打分函数——点积注意力与缩放点积注意力
      • a. 缩放点积注意力模型
      • b. 点积注意力模型
      • c. 模拟实验
      • d. 模型比较与选择
      • e. 代码整合

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍将介绍带有掩码的 softmax 操作

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

在这里插入图片描述

2. 掩码Softmax 操作

【深度学习实验】注意力机制(二):掩码Softmax 操作
在这里插入图片描述

3. 打分函数——加性注意力模型

  • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

【深度学习实验】注意力机制(三):打分函数——加性注意力模型

3. 打分函数——点积注意力与缩放点积注意力

  • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

  • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

a. 缩放点积注意力模型

class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

  1. 初始化方法 (__init__):

    • 参数:
      • dropout: Dropout 正则化的概率。
    • 说明: 初始化方法定义了模型的组件,仅包含了一个 Dropout 正则化层。
  2. 前向传播方法 (forward):

    • 参数:
      • queries: 查询张量,形状为 (batch_size, num_queries, d)
      • keys: 键张量,形状为 (batch_size, num_kv_pairs, d)
      • values: 值张量,形状为 (batch_size, num_kv_pairs, value_size)
      • valid_lens: 有效长度张量,形状为 (batch_size,)(batch_size, num_queries)
    • 返回值: 加权平均后的值张量,形状为 (batch_size, num_queries, value_size)
  3. 实现细节:

    • 计算缩放点积得分:通过张量乘法计算 querieskeys 的点积,然后除以 d \sqrt{d} d 进行缩放,其中 d d d 是查询或键的维度。
    • 使用 masked_softmax 函数计算注意力权重,根据有效长度对注意力进行掩码。
    • 将注意力权重应用到值上,得到最终的加权平均结果。
    • 使用 Dropout 对注意力权重进行正则化。

b. 点积注意力模型

class DotProductAttention2(nn.Module):
    """点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention2, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens=None):
        # P195:(8.3),(8.4)
        # 在计算得分时不进行缩放操作(即不再除以sqrt(d))
        self.scores = torch.bmm(queries, keys.transpose(1, 2))
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

  • 区别:
    • 计算点积得分:通过张量乘法计算 querieskeys 的点积。

c. 模拟实验

  • 模型数据
queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])
  • 模型应用
# 创建缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
# 使用模型进行前向传播
attention(queries, keys, values, valid_lens)

# 创建点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
# 使用模型进行前向传播
attention2(queries, keys, values, valid_lens)

在这里插入图片描述

  • 权重可视化
      为了直观地展示模型在不同输入下的注意力分布,使用前文 show_heatmaps 函数,通过热图的形式展示注意力权重。
# 可视化缩放点积注意力权重
show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
              xlabel='Keys', ylabel='Queries')

# 可视化点积注意力权重
show_heatmaps(attention2.attention_weights.reshape((1, 1, 2, 10)),
              xlabel='Keys', ylabel='Queries')

在这里插入图片描述

d. 模型比较与选择

  • 缩放点积注意力模型

    • 适用于处理高维度的查询和键。
    • 通过缩放操作有助于防止点积得分的方差过大。
  • 点积注意力模型

    • 适用于处理相对较低维度的查询和键。
    • 更方便地利用矩阵乘积提高计算效率。
      在这里插入图片描述

e. 代码整合

# 导入必要的库
import math
import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l
from torch.utils import data


def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)


def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5), cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6)

class DotProductAttention(nn.Module):
    """缩放点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        print(queries)
        d = queries.shape[-1]
        print(d)
        self.scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        print(self.scores)
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


class DotProductAttention2(nn.Module):
    """点积注意力"""

    def __init__(self, dropout, **kwargs):
        super(DotProductAttention2, self).__init__(**kwargs)
        # 使用暂退法进行模型正则化
        self.dropout = nn.Dropout(dropout)

    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)

    def forward(self, queries, keys, values, valid_lens=None):
        # P195:(8.3),(8.4)
        # 在计算得分时不进行缩放操作(即不再除以sqrt(d))
        self.scores = torch.bmm(queries, keys.transpose(1, 2))
        print(self.scores)
        self.attention_weights = masked_softmax(self.scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


queries, keys = torch.normal(0, 1, (2, 1, 2)), torch.ones((2, 10, 2))
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])


# 缩放点积注意力模型
attention = DotProductAttention(0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

# 点积注意力模型
attention2 = DotProductAttention2(0.5)
attention2.eval()
attention2(queries, keys, values, valid_lens)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1234638.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ky10 server aarch64 离线安装openssl3.1.4

离线程序 https://gitcode.net/zengliguang/ky10_aarch64_openssl_install.git 输入下面命令执行离线安装脚本 source openssl_offline_install.sh 安装完成

仪表盘:pyecharts绘制

一、仪表盘 在数据分析中,仪表盘图(dashboard)的作用是以一种简洁、图表化的方式呈现数据的关键指标和核心信息,以帮助用户快速理解数据的情况,并从中提取关键见解。 仪表盘图通常由多个图表、指标和指示器组成&…

南京--ChatGPT/GPT4 科研实践应用

2023年随着OpenAI开发者大会的召开,最重磅更新当属GPTs,多模态API,未来自定义专属的GPT。微软创始人比尔盖茨称ChatGPT的出现有着重大历史意义,不亚于互联网和个人电脑的问世。360创始人周鸿祎认为未来各行各业如果不能搭上这班车…

杨氏矩阵解法

每日一言 「 人生如逆旅,我亦是行人。 」--临江仙送钱穆父-苏轼题目 杨氏矩阵 有一个数字矩阵,矩阵的每行从左到右是递增的,矩阵从上到下是递增的,请编写程序在这样的矩阵中查找某个数字是否存在。 解法思路 法一:…

Vue3 provide 和 inject 实现祖组件和后代组件通信

provide 和 inject 能够实现祖组件和其任意的后代组件之间通信: 一、provide 提供数据 我们在祖组件中使用provide 将数据提供出去。 使用provide 之前需要先进行引入: import { provide } from "vue"; 语法格式如下: provide(&q…

米哈游大数据云原生实践

云布道师 近年来,容器、微服务、Kubernetes 等各项云原生技术的日渐成熟,越来越多的公司开始选择拥抱云原生,并将企业应用部署运行在云原生之上。随着米哈游业务的高速发展,大数据离线数据存储量和计算任务量增长迅速&#xff0c…

SqlServer_idea连接问题

问题描述: sqlServer安装之后可以使用navicat进行连接idea使用账户密码进行登录连接失败 问题解决: 先使用sqlServer管理工具进行登录 使用window认证连接修改账户密码 启用该登录名 这时idea还是无法连接,还需要如下配置 打开sqlserve…

论文阅读:“基于特征检测与深度特征描述的点云粗对齐算法”

文章目录 摘要简介相关工作粗对齐传统的粗对齐算法基于深度学习的粗对齐算法 特征检测及描述符构建 本文算法ISS 特征检测RANSAC 算法3DMatch 算法 实验结果参考文献 摘要 点云对齐是点云数据处理的重要步骤之一,粗对齐则是其中的难点。近年来,基于深度…

java项目之社区互助平台(ssm+vue)

项目简介 社区互助平台实现了以下功能: 1、一般用户的功能及权限 所谓一般用户就是指还没有注册的过客,他们可以浏览主页面上的信息。但如果有中意的社区互助信息时,要登录注册,只有注册成功才有的权限。2、管理员的功能及权限 用户信息的添…

第十一章 目标检测中的NMS

精度提升 众所周知,非极大值抑制NMS是目标检测常用的后处理算法,用于剔除冗余检测框,本文将对可以提升精度的各种NMS方法及其变体进行阶段性总结。 总体概要: 对NMS进行分类,大致可分为以下六种,这里是依…

为什么 Django 后台管理系统那么“丑”?

哈喽大家好,我是咸鱼 相信使用过 Django 的小伙伴都知道 Django 有一个默认的后台管理系统——Django Admin 它的 UI 很多年都没有发生过变化,现在看来显得有些“过时且简陋” 那为什么 Django 的维护者却不去优化一下呢?原文作者去询问了多…

股票指标信息(六)

6-指标信息 文章目录 6-指标信息一. 展示股票的K线图数据,用于数据统计二. 展示股票指标数据,使用Java处理,集合形式展示三. 展示股票目前的最新的指标数据信息四. 展示股票指标数据,某一个属性使用Java处理五. 展示股票的指标数据,用于 Echarts 页面数据统计六. 展示股票指标数…

做好性能测试计划的4个步骤!全都是精华!【建议收藏】

如何做好一次性能测试计划呢?对于性能测试新手来说,也许你非常熟悉Jmeter的使用,也许你清楚的了解每一个系统参数代表的意义,但是想要完成好一次性能测试任务,并不仅仅是简单的写脚本,加压力,再…

RedisConnectionFactory is required已解决!!!!

1.起因🤶🤶🤶🤶 redis搭建完成后,准备启动主程序,异常兴奋,结果报错了!!!! 2.究竟是何原因 😭😭😭&#x1f…

【OpenCV实现图像:在Python中使用OpenCV进行直线检测】

文章目录 概要霍夫变换举个栗子执行边缘检测进行霍夫变换小结 概要 图像处理作为计算机视觉领域的重要分支,广泛应用于图像识别、模式识别以及计算机视觉任务中。在图像处理的众多算法中,直线检测是一项关键而常见的任务。该任务的核心目标是从图像中提…

动态神经网络时间序列预测

大家好,我是带我去滑雪! 神经网络投照是否存在反锁与记忆可以分为静态神经网络与动态神经网络。动态神经网络是指神经网络带有反做与记忆功能,无论是局部反馈还是全局反锁。通过反馈与记忆,神经网络能将前一时刻的数据保留&#x…

高压开关柜无线测温系统

高压开关柜无线测温系统是一种用于监测高压开关柜内部温度的系统。依托电易云-智慧电力物联网,它采用无线通信技术,实现对开关柜内部温度的实时监测和数据传输。下面我将为您介绍高压开关柜无线测温系统的组成、原理、功能以及优势。 一、系统组成 高压开…

springboot项目基于jdk17、分布式事务seata-server-1.7.1、分库分表shardingSphere5.2.1开发过程中出现的问题

由于项目需要,springboot项目需基于jdk17环境开发,结合nacos2.0.3、分布式事务seata-server-1.7.1、分库分表shardingSphere5.2.1等,项目启动过程中出现的问题解决方式小结。 问题一: Caused by: java.lang.RuntimeException: j…

svn文件不显示红色感叹号

如下图所示,受svn版本控制的文件不显示下图中红色感叹号和绿色对号时, 可以试着如下操作 空白处单击右键,具体操作如下图

AIoT智能物联网平台技术架构参考

具体来说,AIoT平台能够实现智能终端设备之间、不同系统平台之间、不同应用场景之间的互融互通,进一步推动万物互联的进程。 AIoT智能物联网平台是结合了人工智能(AI)和物联网(IoT)技术的平台。它旨在通过物…