​注意力机制中的掩码详解

news2024/12/22 7:10:04

注意力机制的掩码允许我们发送不同长度的批次数据一次性的发送到transformer中。在代码中是通过将所有序列填充到相同的长度,然后使用“attention_mask”张量来识别哪些令牌是填充的来做到这一点,本文将详细介绍这个掩码的原理和机制。

我们先介绍下如果不使用掩码,是如何运行的。这里用GPT-2每次使用一个序列来执行推理,因为每次只有一个序列,所以速度很慢:

 from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
 tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
 gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
 
 context = tokenizer('It will rain in the', return_tensors='pt')
 
 prediction = gpt2.generate(**context, max_length=10)
 tokenizer.decode(prediction[0])
 # prints 'It will rain in the morning, and the rain'

在显存允许的情况下,使用批处理输入的速度更快,因为我们在一次推理的过程可以同时处理多个序列。对许多样本执行推理要快得多,但也稍微复杂一些,下面是使用transformer库进行推理的代码:

 tokenizer.padding_side = "left"
 tokenizer.pad_token = tokenizer.eos_token
 
 sentences = ["It will rain in the",
             "I want to eat a big bowl of",
             "My dog is"]
 inputs = tokenizer(sentences, return_tensors="pt", padding=True)
 
 output_sequences = gpt2.generate(**inputs)
 
 for seq in output_sequences:
     print(tokenizer.decode(seq))

transformer库帮我们处理了很多细节,我们现在详细的介绍它里面到底做了什么。

我们将令牌输入到语言模型中,如GPT-2和BERT,作为张量进行推理。张量就像一个python列表,但有一些额外的特征和限制。比如说,对于一个2+维的张量,该维中的所有向量必须是相同的长度。例如,

 from torch import tensor
 
 tensor([[1,2], [3,4]])  # ok
 tensor([[1,2], [3]])   # error!

当我们对输入进行标记时,它将被转换为序列的张量,每个整数对应于模型词表中的一个项。以下是GPT-2中的标记化示例:

如果我们想在输入中包含第二个序列:

因为这两个序列有不同的长度,所以不能把它们组合成一个张量。这时就需要用虚拟标记填充较短的序列,以便每个序列具有相同的长度。因为我们想让模型继续向序列的右侧添加,我们将填充较短序列的左侧。

这就是注意力掩码的一个应用。注意力掩码告诉模型哪些令牌是填充的,在填充令牌的位置放置0,在实际令牌的位置放置1。现在我们理解了这一点,让我们逐行查看代码。

 tokenizer.padding_side = "left"

这一行告诉标记器从左边开始填充(默认是右边),因为最右边标记的logits将用于预测未来的标记。

 tokenizer.pad_token = tokenizer.eos_token

这一行指定将使用哪个令牌进行填充。选择哪一个并不重要,这里我们选择的是“序列结束”标记。

 sentences = ["It will rain in the",
             "I want to eat a big bowl of",
             "My dog is"]

上面这三个序列在标记时都有不同的长度,我们使用下面的方法填充:

 inputs = tokenizer(sentences, return_tensors="pt", padding=True)

在进行表计划和添加填充后,得到了以下的结果:

 {'input_ids': tensor([
     [50256, 50256, 50256,  1026,   481,  6290,   287,   262],
     [   40,   765,   284,  4483,   257,  1263,  9396,   286],
     [50256, 50256, 50256, 50256, 50256,  3666,  3290,   318]
   ]),
 'attention_mask': tensor([
     [0, 0, 0, 1, 1, 1, 1, 1],
     [1, 1, 1, 1, 1, 1, 1, 1],
     [0, 0, 0, 0, 0, 1, 1, 1]
   ])}

可以看到,第一个和第三个序列在开始时进行了填充,并且attention_mask参数标记了这个填充的位置。

现在让我们将这个输入传递给模型来生成新的文本:

 output_sequences = gpt2.generate(**inputs)

如果你不熟悉函数调用的**kwargs语法,它是将输入字典作为命名参数传入,使用键作为参数名,并使用值作为相应的实参值。

我们只需要循环遍历每个生成的序列并以人类可读的形式打印出结果,使用decode()函数将令牌id转换为字符串。

 for seq in output_sequences:
     print(tokenizer.decode(seq))

在注意力掩码中,我们的输入是0和1,但是在最终的计算时,会将在将无效位置的注意力权重设置为一个很小的值,通常为负无穷(-inf),以便在计算注意力分数时将其抑制为接近零的概率。

这时因为,在计算注意力权重时,需要进行Softmax的计算:

Softmax函数的性质:注意力机制通常使用Softmax函数将注意力分数转化为注意力权重,Softmax函数对输入值进行指数运算,然后进行归一化。当输入值非常小或负无穷时,经过指数运算后会接近零。因此,将掩码设置为负无穷可以确保在Softmax函数计算时,对应位置的注意力权重趋近于零。

排除无效位置的影响:通过将无效位置的注意力权重设置为负无穷,可以有效地将这些位置的权重压低。在计算注意力权重时,负无穷的权重会使对应位置的注意力权重接近于零,从而模型会忽略无效位置的影响。这样可以确保模型更好地关注有效的信息,提高模型的准确性和泛化能力。

但是负无穷并不是唯一的选择。有时也可以选择使用一个很大的负数,以达到相似的效果。具体的选择可以根据具体的任务和模型的需求来确定。

https://avoid.overfit.cn/post/0538d928a1c14940b3861437ea2fcffa

作者:Prudhviraju Srivatsavaya

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/758873.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(简单)设计哈希集合 Java

为了实现哈希集合这一数据结构,有以下几个关键问题需要解决: 哈希函数:能够将集合中任意可能的元素映射到一个固定范围的整数值,并将该元素存储到整数值对应的地址上冲突处理:由于不同元素可能映射到相同的整数值&…

SpringBoot读取配置的方式

读取配置的几种方式 Spring Boot提供了多种方式来读取配置,下面是其中几种常用的方式: 使用application.properties或application.yml文件:在Spring Boot项目的classpath根目录下,可以创建一个名为application.properties或appli…

oc基本控件3

UIButton // // ViewController.m // OcDemoTest // // Created by Mac on 2023/7/14. //#import "ViewController.h"interface ViewController ()endimplementation ViewController- (void)viewDidLoad {[super viewDidLoad];// 1 创建按钮对象UIButton *button…

涂鸦智能打造专业家庭智能生活助手,实现人机交互升级

近年来,智能家居设备的品类不断拓展,同时,人们对AI与智能家居的联动愈发憧憬。自然语言交互是未来人机交互的主要趋势之一,其关键在于使AI具备主动理解信息的能力,让用户的交互更轻松。如何将智能场景的交互变得更“善…

MySQL-DDL-表结构操作-创建-案例

案例 根据页面原型/需求创建表(设计合理的数据类型、长度、约束) 具体操作 在idea中使用可视化图形界面创建 具体操作如下: 在该界面中进行属性的创建,进行属性名称、数据类型、约束、描述等信息的填写最终运行结果如下&…

800V高压电驱动系统盘点

2023年上海车展共有23家厂商的63个电驱动产品,经过梳理,本次展出的800V高压电驱动共有13款,可以说电驱动全面进入高压化。800V电驱动是一个系统性的话题,对于电机而言,挑战的方向主要围绕高速、高压、散热,…

替换空格

替换空格 请实现一个函数&#xff0c;把字符串 s 中的每个空格替换成"%20"。 题目给的测试用例里有以下限制&#xff1a; 0 < s.length < 14。 split() 把字符串分割为子字符串数组 例如&#xff1a; var txt"ABCD EFGH IJKL MNOP QRSTU VWXYZ"; v…

微信小程序下拉选择

微信小程序中下拉框选择一般的交互方式有以下两种 直接下拉选择点击选择框后&#xff0c;弹出浮层进行选择 下边分别介绍两种方式的实现。在微信小程序中&#xff0c;这两种实现都需要修改三个文件 js 文件&#xff1a;下拉选择逻辑的具体实现 wxml 文件&#xff1a;下拉组件…

C#正则表达式校验某个字符串是否是合格的email

C#正则表达式校验某个字符串是否是合格的email 可以借助正则表达式校验某个字符串是否是合规的电子邮箱。对于邮箱的正则表达式有严格的模式&#xff0c;如&#xff1a;^[a-zA-Z0-9_&*-](?:\\.[a-zA-Z0-9_&*-])*(?:[a-zA-Z0-9-]\\.)[a-zA-Z]{2,7}$ 对应的C#实现如下…

TCP编程流程和粘包

目录 1、TCP编程流程 2、粘包 1、TCP编程流程 socket() 是创建套接字&#xff0c;返回值为监听套接字描述符&#xff0c;有了套接字才能通过网络进行数据的传输。创建套接字的参数要指定服务类型&#xff0c;TCP协议使用的是流式服务&#xff08;SOCK_STREAM&#xff09;。 b…

用Matlab听音乐 - 动态频谱

文章目录 高帧率版本效果: 定时器版本music_play主函数&#xff1a;定时器回调函数&#xff1a;效果: 高帧率版本 由于matlab这款科学计算软件本身庞大略显笨重&#xff0c;执行代码的速度受当前系统影响&#xff0c;很难做到严格定时仿真&#xff08;造成音画不同步&#xff…

互联网行业真的不行了吗?

文章目录 前言一、起因二、互联网真的完了吗&#xff1f;三、是不是要转行&#xff1f;四、十年磨一剑五、统一回复 前言 英雄算法联盟 - 七月集训 已经开始 16 天&#xff0c;八月算法集训 将于 08月01日 正式开始&#xff0c;目前已经提前开始报名&#xff0c;报名方式参见&a…

英国24所顶尖大学撤销禁令,更新AI使用规定!

自从ChatGPT展现了其高超的AI技术后&#xff0c;备受全球年轻人的喜爱。ChatGPT功能多样化&#xff0c;可以节省查阅复杂文献的时间、编写简单的Python代码、辅助学生理解知识点... 同时&#xff0c;ChatGPT引发的学术不诚信问题也让各大院校头疼不已。 连续数月以来&#xff…

js 浮点位数超过17位乘以10^18,精度丢失问题

我有一个浮点型 var num 9.963407954080194743 用num * (10 ** 18) 计算得出的结果是9963407954080195000, 但是我想要得到的结果是9963407954080194743 问ChatGPT问题得以解决&#xff1a; GPT提供的代码&#xff1a; import Big from big.js;const num1 new Big(9.9634…

从输入URL到页面渲染的整个过程

从输入URL到页面渲染的整个过程 1.DNS解析&#xff0c;把url中的域名解析成对应的IP地址。如果本地DNS缓存没有响应的记录&#xff0c;则会向DNS发送请求&#xff0c;获取相应的IP地址。 2.浏览器使用获取到的目标服务器的IP地址&#xff0c;通过TCP/IP协议与服务器建立连接&a…

python-web开发(Djaongo)课程基本内容

python-web开发&#xff08;Djaongo&#xff09;课程基本内容及其前置技术 参考内容&#xff1a; 【最新Python的web开发全家桶&#xff08;django前端数据库&#xff09;】 https://www.bilibili.com/video/BV1rT4y1v7uQ/?share_sourcecopy_web&vd_source84fd4883bb478d0…

CDA数据分析系01 anaconda

简介 数据处理集成包&#xff0c;不局限于python 创建一个新的environment conda create --name python34 python3.4 激活一个environment activate python34 # for windows conda的package管理 类似pip&#xff0c;conda install xxxx 查看已安装的python包 conda list…

利用技术优势:程序员如何通过互联网自媒体项目实现财务自由?

作为程序员&#xff0c;通过互联网自媒体项目实现财务自由是一个很好的选择。以下是一些技术优势的利用方法&#xff1a; 选择适合的自媒体平台&#xff1a;在互联网上有许多不同类型的自媒体平台&#xff0c;如博客、YouTube、Podcast等。选择适合你技术背景和兴趣的平台&…

手机忘记密码怎么办? 帮你快速解锁手机的十大软件请收好

有许多不同类型的手机锁&#xff0c;这些锁对于手机的用户或所有者来说可能非常烦人和恼人。这些锁可称为手机锁、SIM 锁、主锁或运营商锁。这些锁实际上是手机的实际限制。 为了仅在有限的国家/地区阻止电话访问&#xff0c;该区域之外的任何其他人都无法使用。 手机解锁如何…

极速上手k8s,Kubernetes 从入门到摸鱼系列-实践篇

大家好&#xff0c;我是比特桃。本文为《极速上手k8s&#xff0c;Kubernetes 从入门到摸鱼系列》的实战篇&#xff0c;旨在快速上手k8s。如没有阅读过k8s相关理论的朋友&#xff0c;可以先阅读理论篇。 1. 实践环境 k8s 的意义在于分布式大规模容器编排&#xff0c;所以如果我…