【深度学习实验】注意力机制(二):掩码Softmax 操作

news2024/12/25 23:58:33

文章目录

  • 一、实验介绍
  • 二、实验环境
    • 1. 配置虚拟环境
    • 2. 库版本介绍
  • 三、实验内容
    • 0. 理论介绍
      • a. 认知神经学中的注意力
      • b. 注意力机制:
    • 1. 注意力权重矩阵可视化(矩阵热图)
    • 2. 掩码Softmax 操作
      • a. 导入必要的库
      • b. masked_softmax
      • c. 实验结果

一、实验介绍

  注意力机制作为一种模拟人脑信息处理的关键工具,在深度学习领域中得到了广泛应用。本系列实验旨在通过理论分析和代码演示,深入了解注意力机制的原理、类型及其在模型中的实际应用。

本文将介绍将介绍带有掩码的 softmax 操作

二、实验环境

  本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 理论介绍

a. 认知神经学中的注意力

  人脑每个时刻接收的外界输入信息非常多,包括来源于视
觉、听觉、触觉的各种各样的信息。单就视觉来说,眼睛每秒钟都会发送千万比特的信息给视觉神经系统。人脑通过注意力来解决信息超载问题,注意力分为两种主要类型:

  • 聚焦式注意力(Focus Attention):
    • 这是一种自上而下的有意识的注意力,通常与任务相关。
    • 在这种情况下,个体有目的地选择关注某些信息,而忽略其他信息。
    • 在深度学习中,注意力机制可以使模型有选择地聚焦于输入的特定部分,以便更有效地进行任务,例如机器翻译、文本摘要等。
  • 基于显著性的注意力(Saliency-Based Attention)
    • 这是一种自下而上的无意识的注意力,通常由外界刺激驱动而不需要主动干预。
    • 在这种情况下,注意力被自动吸引到与周围环境不同的刺激信息上。
    • 在深度学习中,这种注意力机制可以用于识别图像中的显著物体或文本中的重要关键词。

  在深度学习领域,注意力机制已被广泛应用,尤其是在自然语言处理任务中,如机器翻译、文本摘要、问答系统等。通过引入注意力机制,模型可以更灵活地处理不同位置的信息,提高对长序列的处理能力,并在处理输入时动态调整关注的重点。

b. 注意力机制:

  1. 注意力机制(Attention Mechanism):

    • 作为资源分配方案,注意力机制允许有限的计算资源集中处理更重要的信息,以应对信息超载的问题。
    • 在神经网络中,它可以被看作一种机制,通过选择性地聚焦于输入中的某些部分,提高了神经网络的效率。
  2. 基于显著性的注意力机制的近似: 在神经网络模型中,最大汇聚(Max Pooling)和门控(Gating)机制可以被近似地看作是自下而上的基于显著性的注意力机制,这些机制允许网络自动关注输入中与周围环境不同的信息。

  3. 聚焦式注意力的应用: 自上而下的聚焦式注意力是一种有效的信息选择方式。在任务中,只选择与任务相关的信息,而忽略不相关的部分。例如,在阅读理解任务中,只有与问题相关的文章片段被选择用于后续的处理,减轻了神经网络的计算负担。

  4. 注意力的计算过程:注意力机制的计算分为两步。首先,在所有输入信息上计算注意力分布,然后根据这个分布计算输入信息的加权平均。这个计算依赖于一个查询向量(Query Vector),通过一个打分函数来计算每个输入向量和查询向量之间的相关性。

    • 注意力分布(Attention Distribution):注意力分布表示在给定查询向量和输入信息的情况下,选择每个输入向量的概率分布。Softmax 函数被用于将分数转化为概率分布,其中每个分数由一个打分函数计算得到。

    • 打分函数(Scoring Function):打分函数衡量查询向量与输入向量之间的相关性。文中介绍了几种常用的打分函数,包括加性模型、点积模型、缩放点积模型和双线性模型。这些模型通过可学习的参数来调整注意力的计算。

      • 加性模型 s ( x , q ) = v T tanh ⁡ ( W x + U q ) \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{v}^T \tanh(\mathbf{W}\mathbf{x} + \mathbf{U}\mathbf{q}) s(x,q)=vTtanh(Wx+Uq)

      • 点积模型 s ( x , q ) = x T q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{q} s(x,q)=xTq

      • 缩放点积模型 s ( x , q ) = x T q D \mathbf{s}(\mathbf{x}, \mathbf{q}) = \frac{\mathbf{x}^T \mathbf{q}}{\sqrt{D}} s(x,q)=D xTq (缩小方差,增大softmax梯度)

      • 双线性模型 s ( x , q ) = x T W q \mathbf{s}(\mathbf{x}, \mathbf{q}) = \mathbf{x}^T \mathbf{W} \mathbf{q} s(x,q)=xTWq (非对称性)

  5. 软性注意力机制

    • 定义:软性注意力机制通过一个“软性”的信息选择机制对输入信息进行汇总,允许模型以概率形式对输入的不同部分进行关注,而不是强制性地选择一个部分。

    • 加权平均:软性注意力机制中的加权平均表示在给定任务相关的查询向量时,每个输入向量受关注的程度,通过注意力分布实现。

    • Softmax 操作:注意力分布通常通过 Softmax 操作计算,确保它们成为一个概率分布。

1. 注意力权重矩阵可视化(矩阵热图)

【深度学习实验】注意力机制(一):注意力权重矩阵可视化(矩阵热图heatmap)

2. 掩码Softmax 操作

  掩码Softmax操作的用处在于在处理序列数据时,对于某些位置的输入可能需要进行忽略或者特殊处理。通过使用掩码张量,可以将这些无效或特殊位置的权重设为负无穷大,从而在进行Softmax操作时,使得这些位置的输出为0。
  这种操作通常在序列模型中使用,例如自然语言处理中的文本分类任务。在文本分类任务中,输入是一个句子或一个段落,长度可能不一致。为了保持输入的统一性,需要进行填充操作,使得所有输入的长度相同。然而,在经过填充操作后,一些位置可能对应于填充字符,这些位置的权重应该被忽略。通过使用掩码Softmax操作,可以确保填充位置的输出为0,从而在计算损失函数时不会对填充位置产生影响。

a. 导入必要的库

import torch
from torch import nn
import torch.nn.functional as F
from d2l import torch as d2l

b. masked_softmax

  带有掩码的 softmax 操作主要用于处理变长序列,其中填充的元素不应该对 softmax 操作的结果产生影响。

def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
        

参数解释

  • X: 一个三维张量,表示输入的 logits。

  • valid_lens: 一个一维或二维张量,表示每个序列的有效长度。如果是一维张量,它会被重复到匹配 X 的第二维。

函数流程

  1. 如果 valid_lensNone,则直接应用标准的 softmax 操作,返回 nn.functional.softmax(X, dim=-1)

  2. 如果 valid_lens 不是 None,则进行以下步骤:

    • 获取 X 的形状 shape

    • 如果 valid_lens 是一维张量,将其重复到匹配 X 的第二维,以便与 X 进行逐元素运算。

    • X 重塑为一个二维张量,形状为 (-1, shape[-1]),这样可以在最后一个轴上进行逐元素操作。

    • 使用 d2l.sequence_mask 函数,将有效长度外的元素替换为一个很大的负数(-1e6)。这样,这些元素在经过 softmax 后的输出会趋近于零。

    • 将处理后的张量重新塑形为原始形状,然后应用 softmax 操作。最终输出是带有掩码的 softmax 操作结果。

c. 实验结果

masked_softmax(torch.rand(3, 8, 5), torch.tensor([2, 2, 2]))
  • 随机生成了一个形状为 (3, 8, 5) 的 3D 张量,其中有效长度全为 2。

在这里插入图片描述

masked_softmax(torch.rand(3, 8, 5), torch.tensor([1, 2, 3]))

在这里插入图片描述

  • 使用二维张量,为矩阵样本中的每一行指定有效长度
masked_softmax(torch.rand(2, 2, 5), torch.tensor([[1, 3], [2, 4]]))

  • 对于形状为 (2, 2, 5) 的 3D 张量
    • 第一个二维矩阵的第一个序列的有效长度为 1,第二个序列的有效长度为 3。
    • 第二个二维矩阵的第一个序列的有效长度为 2,第二个序列的有效长度为 4。
      在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1227218.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单线程的JS中Vue导致的“线程安全”问题

目录 现象分析原因 浏览器中Js是单线程的,当然不可能出现线程安全问题。只是遇到的问题的现象与多线程的情况十分相似,导致对不了解Vue实现的我怀疑起了人生… 现象 项目中用到了element-plus中的加载组件,简单封装了一下,用来保…

一、MySQL-Replication(主从复制)

1.1、MySQL Replication 主从复制(也称 AB 复制)允许将来自一个MySQL数据库服务器(主服务器)的数据复制到一个或多个MySQL数据库服务器(从服务器)。 根据配置,您可以复制数据库中的所有数据库&a…

男子遗失30万天价VERTU唐卡手机,警察2小时“光速”寻回

今天,一则“男子丢失30万元手机女子捡到一位老年机”的新闻迅速冲上热搜第一,引发全网热议。据宿城公安消息:近日,江苏省宿迁市市民王先生在购物时不慎失落了一部价值30万元的全球知名奢侈品VERTU手机,被民警2个多小时…

Linux驱动开发——块设备驱动

目录 一、 学习目标 二、 磁盘结构 三、块设备内核组件 四、块设备驱动核心数据结构和函数 五、块设备驱动实例 六、 习题 一、 学习目标 块设备驱动是 Linux 的第二大类驱动,和前面的字符设备驱动有较大的差异。要想充分理解块设备驱动,需要对系统…

两栏布局:左侧固定,右侧自适应

左侧宽度固定&#xff0c;右侧宽度自适应剩余空间 方法一&#xff1a;float margin 方法二&#xff1a;flex布局 相关HTML代码 <div class"container"><div class"left"></div><div class"main"></div> </d…

C++--哈希表--散列--冲突--哈希闭散列模拟实现

文章目录 哈希概念一、哈希表闭散列的模拟实现二、开散列(哈希桶)的模拟实现数据类型定义析构函数插入查找删除 哈希概念 unordered系列的关联式容器之所以效率比较高&#xff0c;是因为其底层使用了哈希结构。 顺序结构以及平衡树中&#xff0c;元素关键码与其存储位置之间没…

【Linux进阶之路】动静态库

文章目录 回顾一. 静态库1.代码传递的方式2.简易制作3.原理 二. 动态库1.简易制作2.基本原理 尾序 回顾 前面在gcc与g的使用中&#xff0c;我们简单的介绍了动态库与静态库的各自的优点与区别&#xff1a; 动态链接库&#xff0c;也就是所有的程序公用一份代码,虽然方便省空间&…

ACWSpring1.3

首先,前端写ajax写上我们的访问路径(就在我们前端的源代码里面),我们建了两个包pkController用于前端页面url映射过来一层一层找到我们的RestController返回bot1里面有键值,返回的这就是一个session对象bot1这个map.前端拿到我们bot1里的两个值给到我们前端显示出来 1准备页面:…

《Fine-Grained Image Analysis with Deep Learning: A Survey》阅读笔记

论文标题 《Fine-Grained Image Analysis with Deep Learning: A Survey》 作者 魏秀参&#xff0c;南京理工大学 初读 摘要 与上篇综述相同&#xff1a; 细粒度图像分析&#xff08;FGIA&#xff09;的任务是分析从属类别的视觉对象。 细粒度性质引起的类间小变化和类内…

2023年【广东省安全员C证第四批(专职安全生产管理人员)】考试题库及广东省安全员C证第四批(专职安全生产管理人员)考试试卷

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 广东省安全员C证第四批&#xff08;专职安全生产管理人员&#xff09;考试题库根据新广东省安全员C证第四批&#xff08;专职安全生产管理人员&#xff09;考试大纲要求&#xff0c;安全生产模拟考试一点通将广东省安…

网络协议入门 笔记一

一、服务器和客户端及java的概念 JVM (Java Virtual Machine) : Java虚拟机&#xff0c;Java的跨平台:一次编译&#xff0c;到处运行&#xff0c;编译生成跟平台无关的字节码文件 (class文件)&#xff0c;由对应平台的JVM解析字节码为机器指令 (010101)。 如下图所示&#xff0…

【数据结构】C语言实现队列

目录 前言 1. 队列 1.1 队列的概念 1.2 队列的结构 2. 队列的实现 2.1 队列的定义 2.2 队列的初始化 2.3 入队 2.4 出队 2.5 获取队头元素 2.6 获取队尾元素 2.7 判断空队列 2.8 队列的销毁 3. 队列完整源码 Queue.h Queue.c &#x1f388;个人主页&#xff1a…

100.相同的树(LeetCode)

关于树的递归问题&#xff0c;永远考虑两方面&#xff1a;返回条件和子问题 先考虑返回条件&#xff0c;如果当前的根节点不相同&#xff0c;那就返回false&#xff08;注意&#xff0c;不要判断相等时返回什么&#xff0c;因为当前相等并不能说明后面节点相等&#xff0c;所以…

BatchNormalization:解决神经网络中的内部协变量偏移问题

ICML2015 截至目前51172引 论文链接 代码连接(planing) 文章提出的问题 减少神经网络隐藏层中的”内部协变量偏移”问题。 在机器学习领域存在“协变量偏移”问题,问题的前提是我们划分数据集的时候,训练集和测试集往往假设是独立同分布(i.i.d)的,这种独立同分布更有利于…

Java面向对象(高级)-- 类的成员之四:代码块

文章目录 一、回顾&#xff08;1&#xff09;三条主线&#xff08;2&#xff09;类中可以声明的结构及作用1.结构2.作用 二、代码块&#xff08;1&#xff09;代码块的修饰与分类1. 代码块的修饰2. 代码块的分类3. 举例 &#xff08;2&#xff09; 静态代码块1. 语法格式2. 静态…

2023年高压电工证考试题库及高压电工试题解析

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2023年高压电工证考试题库及高压电工试题解析是安全生产模拟考试一点通结合&#xff08;安监局&#xff09;特种作业人员操作证考试大纲和&#xff08;质检局&#xff09;特种设备作业人员上岗证考试大纲随机出的高压…

2023年【G1工业锅炉司炉】报名考试及G1工业锅炉司炉理论考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 G1工业锅炉司炉报名考试是安全生产模拟考试一点通生成的&#xff0c;G1工业锅炉司炉证模拟考试题库是根据G1工业锅炉司炉最新版教材汇编出G1工业锅炉司炉仿真模拟考试。2023年【G1工业锅炉司炉】报名考试及G1工业锅炉…

SQL INSERT INTO 语句详解:插入新记录、多行插入和自增字段

SQL INSERT INTO 语句用于在表中插入新记录。 INSERT INTO 语法 可以以两种方式编写INSERT INTO语句&#xff1a; 指定要插入的列名和值&#xff1a; INSERT INTO 表名 (列1, 列2, 列3, ...) VALUES (值1, 值2, 值3, ...);如果要为表的所有列添加值&#xff0c;则无需在SQL…

vscode c++ 报错identifier “string“ is undefined

vscode c 报identifier “string” is undefined 问题 新装了电脑, 装好vsc和g等, 发现报错 但开头并没问题 解决 shiftctrlp选择 C/C Edit:COnfigurations (JSON)自动生成打开 c_cpp_properties.json添加g路径等 "cStandard": "c11","cppStanda…

c盘清除文件

打开设置 搜索存储