d2l 注意力评分函数 --附加mask_softmax讲解

news2024/11/17 20:36:43

本章节tensor处理操作也不少,逐个讲解下:

目录

1.mask_softmax

1.1探索源码d2l.sequence_mask

2.加性注意力

3.缩放注意力


1.mask_softmax

  dim=-1表示对最后一个维度进行softmax
  .dim()返回的是维度数
  对于需要mask的数,要用绝对值非常大的负数替换,不能用0,因为0进行softmax时exp=1,返回值不会约等于0.

#@save
def masked_softmax(X, valid_lens):
    """通过在最后⼀个轴上掩蔽元素来执⾏softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后⼀轴上被掩蔽的元素使⽤⼀个⾮常⼤的负值替换,从⽽其softmax输出为0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
                                value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

验证:2个2×4矩阵样本,指定两个样本的有效长度分别为2和3

masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

'''
tensor([[[0.4265, 0.5735, 0.0000, 0.0000],
         [0.6215, 0.3785, 0.0000, 0.0000]],

        [[0.2043, 0.3346, 0.4611, 0.0000],
         [0.3598, 0.2352, 0.4050, 0.0000]]])
'''

指定二维张量,len中的形状为(2,2),第一个表示每个指哪个样本,第二个维度里面的表示指定每个样本的每一行的有效长度

masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]]))

'''
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4087, 0.3961, 0.1952, 0.0000]],

        [[0.6028, 0.3972, 0.0000, 0.0000],
         [0.1992, 0.2031, 0.3061, 0.2915]]])
'''

1.1探索源码d2l.sequence_mask

庐山真面目:

# @save
def sequence_mask(X, valid_len, value=0):
    """在序列中屏蔽不相关的项"""
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None]
    X[~mask] = value
    return X

送进去X(bs,T)与valid_len(bs),返回的是(bs,T),且valid_len后全为0

 

 

 将两项全部广播成(bs,T),然后挨个比较再反向赋值

 最终返回的能够对上len长度

注意~mask是取反,对False设置value值

2.加性注意力

公式:

用在query与key向量长度不同时,使用两个权重相乘让他们相同,再做内积,等价于将二者合并后送入MLP

实现:

#@save
class AdditiveAttention(nn.Module):
    """加性注意⼒"""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # 在维度扩展后,
        # queries的形状:(batch_size,查询的个数,1,num_hidden)
        # key的形状:(batch_size,1,“键-值”对的个数,num_hiddens)
        # 使⽤⼴播⽅式进⾏求和
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # self.w_v仅有⼀个输出,因此从形状中移除最后那个维度。
        # scores的形状:(batch_size,查询的个数,“键-值”对的个数)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # values的形状:(batch_size,“键-值”对的个数,值的维度)
        return torch.bmm(self.dropout(self.attention_weights), values)

  最终得到的是(bs,q,values),对每个query都会拿到长为values维度的向量。
  重点在forward里面的广播,对于query(bs,q,h)与key(bs,k-v,h),将两个扩充为(bs,q,1,h)与(bs,1,k-v,h),再通过广播相加,最后再激活,通过最后Linear变成(bs,q,k-v,1)。注意,最后维度为1,所以可以压缩掉最后的维度。
  scores里面mask的解读:先将score处理成valid_len后的值替换成-1e6(很小的数),再进行softmax,使得valid_len后面的得分都是0
  bmm里面的权重为(bs,q,k-v),values为(bs,k-v,values),进行bmm矩阵乘法最终得到(bs,q,values)

### 验证一下,可看到queries中为(bs,q,q_size),keys(bs,k-v,k_size),values(bs,k-v,values)为(2,10,4)
### 最终得到为(bs,q,values)

queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# values的⼩批量,两个值矩阵是相同的
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
                        2, 1, 1)
    
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
                            dropout=0.1)

attention.eval()
attention(queries, keys, values, valid_lens)

'''
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)
'''

3.缩放注意力

公式:

  用于query与key长度相同,故可做转置后相乘--内积。 

  注意transpose,是进行将k转置再与queries做内积

#@save
class DotProductAttention(nn.Module):
    """缩放点积注意⼒"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        
    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)

验证一下,仍是得到(bs,q,values)

queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
attention(queries, keys, values, valid_lens)

'''
tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])
'''

该方法实现简单,但是可学习的参数少,几乎没有。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/445363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FreeRTOS 任务创建与删除实验

本实验主要实现 FreeRTOS 使用动态方法创建和删除任务&#xff0c;本实验设计了四个任务&#xff0c;这四 个任务的功能如下表所示&#xff1a; 软件设计 1. 程序流程图 本实验的程序流程图&#xff0c;如下图所示&#xff1a; 2. FreeRTOS 函数解析 (1) 函数 xTaskCreate…

spring框架基础知识和基于XML的Bean对象的管理回顾

什么是spring框架&#xff1f; spring基本功能所必须的jar包就是这些 如何获取bean&#xff1f; IOC原理 上面耦合度太高了 改进使用工厂模式 上面并没有把耦合度降低到最低&#xff0c;使用反射 spring实现IOC的两种方式 BeanFactory和ApplicationContext IOC如何管理Bea…

5个方法,帮助你快速提高团队管理效率

团队中&#xff0c;大家看起来都很忙&#xff0c;但最终交付的结果却总是差强人意。会议那么多&#xff0c;但有效的却很少越管理&#xff0c;但偏偏有时候越管理越乱......相信以上这些问题&#xff0c;很多管理者都有遇到过&#xff0c;团队管理是一个项目中最关键的一环。好…

如何打造全流程数字化零工场景,实现零工管理一体化?

近年来&#xff0c;零工市场发展迅速&#xff0c;不仅为企业提供更低成本、更便捷的用工方式&#xff0c;也为劳动者就业提供更低门槛更灵活形式&#xff0c;发挥了就业「蓄水池」的重要作用。但由于零工经济模式下的用工形式非常灵活&#xff0c;企业想要管好零工并不容易。 …

短视频平台-小说推文(知乎)推广任务详情

知乎会员 知乎日结内测中&#xff0c;可能暂只对部分优质会员开放! 2023/03/29通知: 知乎拉新项目&#xff0c;由于内部测试转化较低&#xff0c;暂时下线&#xff0c;原有关键词出单不受影响。 1、关键词 1.1 选择会员文 在知乎【首页】或者【会员】里面选取&#xff0c;需…

PEIS体检系统全套源代码,C# 源码

医院体检信息系统PEIS源码,C# 源码&#xff0c;PEIS源码源码 文末获取联系&#xff01; 系统概述 医院体检信息系统是专门针对医院体检中心的日常业务运作的特点和流程&#xff0c;结合数字化医院建设要求进行设计研发的一套应用系统。该系统覆盖体检中心的所有业务&#xff0…

使用nvm替换nvmw作为nodejs的版本切换(亲测)

之前的文章&#xff1a;同时使用vue2.0和vue3.0版本的采坑记录 安装的nvmw&#xff0c;今天想要用nvmw切换时&#xff0c;居然给我报错了&#xff1a; 然后我就走上了使用nvm替换nvmw之路。。 1.安装 nvm-windows下载 下载release版 中Assets中的包&#xff0c;window10&…

APIs -- DOM正则表达式

1. 介绍 正则表达式(Regular Expression)是用于匹配字符串中字符组合的模式。在JavaScript中&#xff0c;正则表达式也是对象通常用来查找、替换那些符合正则表达式的文本&#xff0c;许多语言都支持正则表达式。正则表达式在JavaScript中的使用场景: 例如验证表单:用户名表单…

数据库的实际操作

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、关系模型二、数据库的操作 创建数据库查看数据库选择数据库删除数据库三、MySQL 数据库命名规范总结 一、关系模型 关系数据库是建立在关系模型上的。而关系模…

flutter学习之旅(一)

初学Flutter flutter官网和中文开发手册 安装flutter - windows 官方文档-windows flutter_windows_3.7.9-stable.zip 编辑环境变量 在 用户变量 一栏中&#xff0c;检查是否有 Path 这个条目&#xff1a; 如果存在这个条目&#xff0c;以 ; 分隔已有的内容&#xff0c;加入 f…

物联网能源能耗之场景控制原理

物联网能源能耗系统利用物联网技术&#xff0c;可帮助企业构建能耗分布&#xff0c;帮助操作人员实时监控各类关键参数&#xff0c;计算关键环节的能耗指标&#xff0c;和既定的能耗基线进行对比&#xff0c;得出能耗差距。 对于制造企业而言&#xff0c;物联网能源能耗不仅能…

商业策划的基本功:竞品分析

商业策划的基本功&#xff1a;竞品分析 商业的三个视角&#xff1a;用户&#xff0c;竞争&#xff0c;自己 有方法会更有效 趣讲大白话&#xff1a;磨刀不误砍柴工 【趣讲信息科技138期】 **************************** 世界上如果只有一种矿泉水 就不会竞争 就不会有农夫山泉这…

由世纪互联运营的Microsoft Teams携创新功能正式发布,夯实“企业数字中枢”

2023年4月18日&#xff0c;北京——今日&#xff0c;微软宣布由世纪互联运营的Microsoft Teams推出一系列创新功能&#xff0c;围绕企业数字核心能力&#xff0c;赋能数字化协作空间、智能化协作体验、整合生产力工具和工作流、安全合规、构建团队文化等五大落地场景&#xff0…

基于重要抽样技术的非序贯蒙特卡洛法(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

前端学习:HTML响应式设计、计算机代码、语义元素

目录 HTML响应Web设计 一、什么是响应式Web设计&#xff1f; 1.流体网格 2.媒体查询 3.响应媒体 4.视口元标记 二、使用Bootstrap HTML计算机代码元素 HTML 语义元素 一、什么是语义元素 二、HTML5中的新的语义元素 HTML响应Web设计 一、什么是响应式Web设计&…

现在学习Java,还有出路吗?

当然有出路&#xff0c;Java一直都是市场占有率最高的编程语言&#xff0c;我们生活涉及到的方方面面都有Java的身影&#xff0c;Java基本也覆盖了所有的行业。同时Java自身也是不断在升级更新&#xff0c;平均一年半左右进行一次&#xff0c;而未来的发展还会更加的强势。 随…

Mysql安装步骤

1、解压服务端Mysql安装包 解压之后的目录就是以上这样的。 2.复制改变my.ini文件 把my.ini文件添加到目录中去 [mysql] # 设置mysql客户端默认字符集 default-character-setutf8 [mysqld] #设置3306端口 port 3306 # 设置mysql的安装目录 basedirE:/mysql/mysql-8.0.18-wi…

Spring AOP核心概念与操作示例

AOP 核心概念 还记得我们Spring有两个核心的概念嘛&#xff1f;一个是IOC/DI&#xff0c;另一个是AOP咯。 先来认识两个概念&#xff1a; AOP(Aspect Oriented Programming)面向切面编程&#xff1b;作用&#xff1a;在不惊动原始设计的基础上为其进行功能增强&#xff0c;类…

Linux命令行操作/选项介绍,文件分类/内容与属性/绝对相对路径,隐藏文件与整个目录结构

Linux的命令行操作介绍 Linux操作的特点&#xff1a;纯命令行&#xff0c;当然Linux它也有图形化界面或桌面版。Windows也有命令行&#xff0c;也有图形化界面。不过它是面向普通客户的操作系统&#xff0c;所以必须得是好用好玩的&#xff0c;所以图形化界面那是必然。无论是…

PCL点云库(2) - IO模块

目录 2.1 IO模块接口 2.2 PCD数据读写 &#xff08;1&#xff09; PCD数据解析 &#xff08;2&#xff09;PCD文件读写示例 2.3 PLY数据读写 &#xff08;1&#xff09;PLY数据解析 &#xff08;2&#xff09;PLY文件读写示例 2.4 OBJ数据读写 &#xff08;1&#xff…