手撕Transformer(二)| Transformer掩码机制的两个功能,三个位置的解析及其代码

news2024/10/6 2:25:18

文章目录

  • 1 掩码的两个功能
    • 1.1 功能1 输入掩码,统一长度
    • 1.2 功能2 遮挡未预测信息
  • 2 掩码存在的三个位置
  • 3 代码实现

Transformer的掩码机制是非常重要的

三个位置,两个功能

1 掩码的两个功能

1.1 功能1 输入掩码,统一长度

实现不同的长度句子同时训练,统一长度,计算注意力机制的时候方便并行计算

假设训练中最长的句子为10个tokens

那么这两句话

“我喜欢猫猫“

“我喜欢打羽毛球”

在不足十个token的情况下,填充pad直到10个,这样可以使得所有序列的长度保持一致都是十个

那么问题来了,模型怎么知道哪些是pad,哪些是真正的token

就通过掩码机制,两个句子的掩码分别为

(0,0,0,0,0,1,1,1,1,1)

(0,0,0,0,0,0,0,1,1,1)

这时候模型对1的位置进行替换处理,使得最后对于计算结果没有影响

具体代码实现中,你会发现在encoder和decoder两个重要组件中里都有padding mask,位置是在softmax之前,

我们在这些位置上补一些无穷小(负无穷)的值,经过softmax操作,这些值就成了0,就不在影响全局概率的预测。

padding的另一种解释,我觉得也不错,放在这里供大家参考

一句文本输入:[1, 2, 3, 4, 5]
input size: 1* 8
加padding:[1, 2, 3, 4, 5, 0, 0, 0]

padding 引入的问题:padding填充数量不一致,导致均值计算偏离
原始均值:(1 + 2 + 3 + 4 + 5) / 5 = 3
padding后的均值: (1 + 2 + 3 + 4 + 5) / 8 = 1.875

引入mask,解决padding的缺陷:

1.2 功能2 遮挡未预测信息

在解码器部分,我们在模型训练的过程中不能让模型知道未来时间步的信息,否则话就相当于告诉了模型最终的答案是什么。

什么意思呢,比如对于翻译任务

我们的编码器输入的是源语言的完整信息,最终编码一个特征memory给到解码器,比如“我喜欢可爱的猫猫”

那么解码器输入的是来自编码器的编码后的整体特征memory还有目标语言的已经预测的信息,比如“I like cute cats”

对于cute这一个词的预测就要给模型memory和I 和 like这两个token的信息

对于cats 这一个词的预测就要给模型memory和I 和 like和 cute 这三个token的信息

那么对于未预测的信息不能提前给模型看~

那么现在问题来了,看样子我们需要模型一个词一个词的预测,那这样多慢啊,这不是有点循环神经网络的意思了吗,但是我在这篇博客手撕Transformer(一)| 经典Positional_encoding 用法and代码详解-CSDN博客里面讲到,深度学习一定要有矩阵思维,这里的训练过本质上也是这个过程,通过矩阵使得这些词的预测同步进行,怎么做到的?这一点非常重要,要充分理解

Transformer不是一个词一个词预测的。他是一批都一次性预测,然后每一个和之后的进行比较~矩阵的并行计算

1用编码器的输出memory结合第1 个token预测第二个tokens( 第二步)

2用编码器的输出memory结合第1,2 个token预测第三个token( 第三步)

3用编码器的输出memory结合第1,2,3 个token 预测第四个token (第四步)

以此类推

以上这些是同时进行的,就是通过掩码机制

我们最终会生成这样一个矩阵

[0,1,1

0,0,1

0,0,0 ]

第一行代表模型最终只能看到第一个token,此时第二个和第三个都看不到,这时预测第二个

第二行代表模型最终只能看到第一,二个token,此时第三个看不到,这时预测第三个

依此类推

2 掩码存在的三个位置

存在于三个位置

位置1,2的实现的功能一,位置3实现的功能2

但位置1,2 虽然功能差不多,但输入有差别,因为位置1实现的是自注意力,位置2是交叉注意力

或者可以理解为什么,最后对最后的Q,K乘积的注意力矩阵的遮蔽,为0的部分计算注意力,为1的部分不计算

3 代码实现

要理解这部分的代码可以核心看

【深度学习】Transformer中的mask机制超详细讲解_transformer mask-CSDN博客

分别介绍了三个位置的代码的实现,每个位置介绍了两种代码实现的方法

我觉得原作者写的非常好,我就不重复造轮子了,如果之后有时间基于他的我会进一步完善

参考

[transformer掩码-CSDN博客](https://blog.csdn.net/Arctic_Beacon/article/details/122416407#:~:text=transformer掩码 1 一、padding mask 数据输入模型的时候长短不一,为了保持输入一致,通过加padding将input转成固定tensor。 一句文本输入: [1%2C 2%2C,scaled_dot_product_attention ( q%2C k%2C v%2C mask )%3A )

4.4 Transformer编码器部分实现–掩码张量 - 知乎 (zhihu.com)

图片基于李沐老师的动手深度学习进行标记

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1469503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用C++和SFML库创建2D游戏

FML(Simple and Fast Multimedia Library)是一个跨平台的C库,用于开发2D游戏和多媒体应用程序。它提供了许多功能,包括图形、声音、网络、窗口管理和事件处理等。 ———————————————不怎么完美的分割线——————…

使用命令符用cd切换不了

bug:cd 切换不进去 解决办法: 在cd后面加 /d

(HAL)STM32F103C6T8——软件模拟I2C驱动0.96寸OLED屏幕

一、电路接法 电路接法参照江科大视频。 二、相关代码及文件 说明:代码采用hal库,通过修改江科大代码实现。仅OLED.c文件关于引脚定义作了hal库修改,并将宏定义OLED_W_SCL(x)、OLED_W_SDA(x)作了相关修改。 1、OLED.c void OLED_I2C_Init(voi…

32单片机基础:旋转编码器计次

接线图如上图所示。 我们初始化一下PB0和PB1两个GPIO口外设中断,当然,这里只初始化一个外部中断也能完成功能的对于编码器而言,下图所示为正转的波形。如果把一相的下降沿用作触发中断,在中断时刻读取另一相的电平,正…

普中51单片机学习(红外通信)

红外通信 红外线系统的组成 外线遥控器已被广泛使用在各种类型的家电产品上,它的出现给使用电器提供了很多的便利。红外线系统一般由红外发射装置和红外接收设备两大部分组成。红外发射装置又可由键盘电路、红外编码芯片、电源和红外发射电路组成。红外接收设备可由…

Spring Boot中的@Scheduled注解:定时任务的原理与实现

1. 前言 本文将详细探讨Spring Boot中Scheduled注解的使用,包括其原理、实现流程、步骤和代码示例。通过本文,读者将能够了解如何在Spring Boot应用中轻松创建和管理定时任务。 2. Scheduled注解简介 在Spring框架中,Scheduled注解用于标记…

面试经典150题【21-30】

文章目录 面试经典150题【21-30】6.Z字形变换28.找出字符串中第一个匹配项的下标68.文本左右对齐392.判断子序列167.两数之和11.盛最多水的容器15.三数之和209.长度最小的子数组3.无重复字符的最长子串30.串联所有单词的子串 面试经典150题【21-30】 6.Z字形变换 对于“LEETC…

五种多目标优化算法(MOBA、NSWOA、MOJS、MOAHA、MOPSO)性能对比(提供MATLAB代码)

一、5种多目标优化算法简介 多目标优化算法是用于解决具有多个目标函数的优化问题的一类算法。其求解流程通常包括以下几个步骤: 1. 定义问题:首先需要明确问题的目标函数和约束条件。多目标优化问题通常涉及多个目标函数,这些目标函数可能存在冲突,需要在不同目标之间进…

udp服务器【Linux网络编程】

目录 一、UDP服务器 1、创建套接字 2、绑定套接字 3、运行 1)读取数据 2)发送数据 二、UDP客户端 创建套接字: 客户端不用手动bind 收发数据 处理消息和网络通信解耦 三、应用场景 1、服务端执行命令 2、Windows上的客户端 3…

DiceCTF 2024 -- pwn

baby-talk 题目给了 Dockerfile,但由于笔者 docker 环境存在问题启动不起来,所以这里用虚拟机环境做了(没错,由于不知道远程 glibc 版本,所以笔者远程也没打通)笔者本地环境为 glibc 2.31-0ubuntu9.9。然后…

Shell grep命令练习题

目录 1含有“48“字符串的行的总数 2显示含有“48“字符串的所有行的行号 3精确匹配只含有“48”字符串的行 4抽取代码为484和483的城市位置 5显示使行首不是4或8 6显示含有九月份(Sept)的行 7显示以K开头,以D结尾的所有代码 8显示头两个是大写字母,中间…

规则持久化(Sentinel)

规则持久化 基于Nacos配置中心实现推送 引入依赖 <dependency><groupId>com.alibaba.csp</groupId><artifactId>sentinel-datasource-nacos</artifactId> </dependency> 流控配置文件 [{"resource":"/order/flow",…

《Docker 简易速速上手小册》第6章 Docker 网络与安全(2024 最新版)

文章目录 6.1 Docker 网络概念6.1.1 重点基础知识6.1.2 重点案例&#xff1a;基于 Flask 的微服务6.1.3 拓展案例 1&#xff1a;容器间的直接通信6.1.4 拓展案例 2&#xff1a;跨主机容器通信 6.2 配置与管理网络6.2.1 重点基础知识6.2.2 重点案例&#xff1a;配置 Flask 应用的…

【Java程序设计】【C00302】基于Springboot的校园失物招领管理系统(有论文)

基于Springboot的校园失物招领管理系统&#xff08;有论文&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 这是一个基于Springboot的校园失物招领网站&#xff0c;本系统有管理员以及用户二种角色权限&#xff1b; 系统整体功能有&#xff1a;操作日志管理、…

Escalate_Linux靶机详解(1)

Escalate_Linux靶机详解&#xff08;1&#xff09; 一&#xff0c;信息收集 首先扫描存活主机 目标地址&#xff1a;192.168.236.131 使用nmap扫描保存为linux.nmap 2&#xff0c;HTTP探测 发现开放了80端口http 打开站点是apache的默认站点 默认页面&#xff0c;尝试对w…

PyTorch概述(六)---View

Tensor.view(*shape)-->Tensor 返回一个新的张量同之前的张量具有相同的数据&#xff0c;但是具有不同的形状&#xff1b;返回的张量同之前的张量共享相同的数据&#xff0c;必须具有相同数目的元素&#xff0c;可能具有不同的形状&#xff1b;对于经过view操作的张量&…

Android 内存优化内存泄漏处理

一:匿名内部类/非静态内部类 匿名内部类的泄漏原因&#xff1a;匿名内部类会隐式地持有外部类的引用.当外部类被销毁时&#xff0c;内部类并不会自动销毁&#xff0c;因为内部类并不是外部类的成员变量&#xff0c; 它们只是在外部类的作用域内创建的对象&#xff0c;所以内部…

《艾尔登法环 黄金树幽影》是什么?Mac电脑怎么玩《艾尔登法环》艾尔登法环下载

全体起立&#xff0c;《艾尔登法环 》最新DLC《黄金树幽影》将在6月21日发布&#xff0c;steam售价198元&#xff0c;现在就可以预订了。宫崎英高在接受FAMI通的采访时表示&#xff0c;新DLC的体量远超《黑暗之魂》和《血源诅咒》资料片。好家伙&#xff0c;别人是把DLC续作&am…

28V270V航空交直流线缆:满足飞机对高质量电气连接的需求

28V/270V航空交直流线缆&#xff1a;航空业的“神经系统” 在现代航空业中&#xff0c;无论是飞机、直升机还是其他飞行器&#xff0c;都离不开一种重要的设备&#xff0c;那就是航空28V/270V航空交直流线缆。航空28V/270V航空交直流线缆是飞行器上的电气系统的重要组成部分&am…

【操作系统】

计算机操作系统 计算机是如何让用户得到好的体验什么是操作系统&#xff08;OS&#xff09;操作系统如何管理 计算机是如何让用户得到好的体验 计算机系统是由计算机硬件和软件组成的。用户使用计算机&#xff0c;比如在文本文件填写内容&#xff0c;通过邮箱发送邮件&#xf…