Transformer的PyTorch实现之若干问题探讨(二)

news2025/1/15 20:06:59

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])


    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的
        dec_inputs: [batch_size, tgt_len]
        enc_inpus: [batch_size, src_len]
        enc_outputs: [batch_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量
        dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵
        # 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为False
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +
                                       dec_self_attn_subsequence_mask), 0).cuda()
        # 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]

        for layer in self.layers:
            dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

        return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]],
        [[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],
        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        # 1) 计算注意力分数QK^T/sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]
        # 2)  进行 mask 和 softmax
        # mask为True的位置会被设为-1e9
        scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9
        attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]
        # 3) 乘V得到最终的加权和
        context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]
        '''
        得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,
        换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列

        返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,
        只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来
        '''
        return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1442582.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电商小程序06用户审核

目录 1 创建自定义应用2 显示待办数量3 创建审核页面4 开发审核功能5 搭建布局6 最终效果总结 上一篇我们讲解了用户注册的功能,用户注册之后状态是待审核,需要管理员进行审核。通常给管理员提供一套PC端的软件进行相关的操作,在低代码中&…

ChatGPT高效提问—prompt常见用法(续篇五)

ChatGPT高效提问—prompt常见用法(续篇五) 1.1 种子词 ​ 种子词(seed word)通常指的是在对话中使用的初始提示或关键词,用于引导ChatGPT生成相关回复。种子词可以是一个词、短语或句子,通常与对话的主题…

问题:老年人心理健康维护与促进的原则为________、________、发展原则。 #媒体#知识分享

问题:老年人心理健康维护与促进的原则为________、________、发展原则。 参考答案如图所示

肯尼斯·里科《C和指针》第12章 使用结构和指针(1)链表

只恨当时学的时候没有读到这本书,,,,,, 12.1 链表 有些读者可能还不熟悉链表,这里对它作一简单介绍。链表(linked list)就一些包含数据的独立数据结构(通常称为节点)的集…

第5章 数据库操作

学习目标 了解数据库,能够说出数据库的概念、特点和分类 熟悉Flask-SQLAlchemy的安装,能够在Flask程序中独立安装扩展包Flask-SQLAlchemy 掌握数据库的连接方式,能够通过设置配置项SQLALCHEMY_DATABASE_URI的方式连接数据库 掌握模型的定义…

来自谷歌的新年礼物!速来免费领取2个月谷歌Gemini Advanced会员!价值280元!对标ChatGPT Plus!

大家好,我是木易,一个持续关注AI领域的互联网技术产品经理,国内Top2本科,美国Top10 CS研究生,MBA。我坚信AI是普通人变强的“外挂”,所以创建了“AI信息Gap”这个公众号,专注于分享AI全维度知识…

Vue3.3新特新和Vue3-Pinia

文章目录 1.Vue3.3新特性 - defineOptionsVue3.3新特性 - defineModel3.Pinia快速入门4.手动添加Pinia到Vue项目5.Vue3 - Pinia的基本语法6.action的异步实现7.Vue3-Pinia-storeToRefs方法8.Pinia持久化插件安装用法 1.Vue3.3新特性 - defineOptions 背景说明 有<script se…

项目02《游戏-11-开发》Unity3D

基于 项目02《游戏-10-开发》Unity3D &#xff0c; 任务&#xff1a;飞行坐骑 首先创建脚本&#xff0c; 绑定脚本&#xff0c; using UnityEngine; public class Dragon : MonoBehaviour{ [SerializeField] private float speed 10f; public Transfo…

redis-sentinel(哨兵模式)

目录 1、哨兵简介:Redis Sentinel 2、作用 3、工作模式 4、主观下线和客观下线 5、配置哨兵模式 希望能够帮助到大家&#xff01;&#xff01;&#xff01; 1、哨兵简介:Redis Sentinel Sentinel(哨兵)是用于监控redis集群中Master状态的工具&#xff0c;其已经被集成在re…

【Makefile语法 01】程序编译与执行

目录 一、编译原理概述 二、编译过程分析 三、编译动静态库 四、执行过程分析 一、编译原理概述 make&#xff1a; 一个GCC工具程序&#xff0c;它会读 makefile 脚本来确定程序中的哪个部分需要编译和连接&#xff0c;然后发布必要的命令。它读出的脚本&#xff08;叫做 …

blender几何节点中样条线参数中的系数(factor)是个什么概念?

一根样条线&#xff0c;通常由两个及以上的控制点构成。 每个控制点的系数&#xff0c;其实相当于该点处位于整个样条线的比值。 如图&#xff0c;一根样条线有十一个控制点。相当于把它分成了十段&#xff0c;那每一段可以看到x、y都是0&#xff0c;唯独z每次增加0.1&#xff…

单片机项目调试中的技巧和常见问题解决

单片机是嵌入式系统中的重要组成部分&#xff0c;在各种电子设备中发挥着重要的作用。在单片机项目开发过程中&#xff0c;调试是至关重要的一环&#xff0c;同时也会遇到一些常见问题。本文将介绍一些单片机项目调试的技巧以及常见问题的解决方法&#xff0c;希望能够对单片机…

跟着cherno手搓游戏引擎【22】CameraController、Resize

前置&#xff1a; YOTO.h: #pragma once//用于YOTO APP#include "YOTO/Application.h" #include"YOTO/Layer.h" #include "YOTO/Log.h"#include"YOTO/Core/Timestep.h"#include"YOTO/Input.h" #include"YOTO/KeyCod…

【开源】JAVA+Vue+SpringBoot实现班级考勤管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 系统基础支持模块2.2 班级学生教师支持模块2.3 考勤签到管理2.4 学生请假管理 三、系统设计3.1 功能设计3.1.1 系统基础支持模块3.1.2 班级学生教师档案模块3.1.3 考勤签到管理模块3.1.4 学生请假管理模块 3.2 数据库设…

深入理解java之多线程(一)

前言&#xff1a; 本章节我们将开始学习多线程&#xff0c;多线程是一个很重要的知识点&#xff0c;他在我们实际开发中应用广泛并且基础&#xff0c;可以说掌握多线程编写程序是每一个程序员都应当必备的技能&#xff0c;很多小伙伴也会吐槽多线程比较难&#xff0c;但因为其实…

春晚刘谦第二个魔术原理讲解

目录 1. 先说一下步骤&#xff1a;2. 原理讲解&#xff1a;2.1 第一步分析2.1 第二步分析2.1 第三步分析2.1 第四步分析2.1 第五步分析2.1 第六步分析2.1 第七步分析2.1 第八步分析2.1 第七步重新分析 小结&#xff1a; 首先&#xff0c;先叠个甲。我本人很喜欢刘谦老师&#x…

H12-821_23

23.网络管理员在某台路由器上查看BGP部居信息,邻居信息如图所示,关于该信息下列说法中正确的是? A.该路由器未从对等体接收到BGP路由前缀 B.本地路由器的Router ID为1.1.1.1 C.与对等体3.3.3 3接收、发送的报文数量分别为10、20个 D.与对等体3.3.3.3之间的邻居关系为IBGP对等体…

使用Arcgis裁剪

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、掩膜提取二、随意裁剪三、裁剪 前言 因为从网站下载的是全球气候数据&#xff0c;而我们需要截取成中国部分&#xff0c;需要用到Arcgis的裁剪工具 一、掩…

批归一化(Batch Normalization,简称BN)层的作用!!

批归一化&#xff08;Batch Normalization&#xff0c;简称BN&#xff09;层在卷积神经网络中的作用主要有以下几点&#xff1a; 规范化数据&#xff1a;批归一化可以对每一批数据进行归一化处理&#xff0c;使其均值接近0&#xff0c;方差接近1。这有助于解决内部协变量偏移&…

龙年定制红包封面赠送第三波

你好啊&#xff0c;今天是大年三十&#xff0c;外面已是爆竹声声&#xff0c;年味十足。祝你&#xff0c;我亲爱的读者&#xff0c;甲辰龙年新春快乐&#xff0c;万事如意&#xff01; 最近一直有读者提出&#xff0c;希望我能再发一波红包封面。 尤其是「经典」版本的&#xf…