Transformer模型:Decoder的self-attention mask实现

news2025/1/6 18:56:28

前言

        这是对Transformer模型Word Embedding、Postion Embedding、Encoder self-attention mask、intra-attention mask内容的续篇。

视频链接:20、Transformer模型Decoder原理精讲及其PyTorch逐行实现_哔哩哔哩_bilibili

文章链接:Transformer模型:WordEmbedding实现-CSDN博客

                  Transformer模型:Postion Embedding实现-CSDN博客

                  Transformer模型:Encoder的self-attention mask实现-CSDN博客

                  Transformer模型:intra-attention mask实现-CSDN博客


 正文

        首先介绍一下Deoder的self-attention mask,它与前面的两个mask不一样地方在于Decoder是生成一个单词之后,将改单词作为输入给到Decoder中继续生成下一个,也就是相当于下三角矩阵,一次多一个,直到完成整个预测。

        先生成一个下三角矩阵:

tri_matrix = [torch.tril(torch.ones(L, L)) for L in tgt_len]

         这里生成的两个下三角矩阵的维度是不一样的,首先要统一维度:

valid_decoder_tri_matrix = [F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)) for L in tgt_len]

        然后就是将它转为1个3维的张量形式,过程跟先前类似,这里就不一步步拆解了:

valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)),0) for L in tgt_len])

        后续掩码过程还是跟前两篇一样,这里也不多解释了:

invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool)
score2 = torch.randn(batch_size, max_tgt_seg_len, max_tgt_seg_len)
mask_score3 = score2.masked_fill(mask_decoder_self_attention, -1e9)
prob3 = F.softmax(mask_score3, -1)

 代码

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# 句子数
batch_size = 2

# 单词表大小
max_num_src_words = 10
max_num_tgt_words = 10

# 序列的最大长度
max_src_seg_len = 12
max_tgt_seg_len = 12
max_position_len = 12

# 模型的维度
model_dim = 8

# 生成固定长度的序列
src_len = torch.Tensor([11, 9]).to(torch.int32)
tgt_len = torch.Tensor([10, 11]).to(torch.int32)

# 单词索引构成的句子
src_seq = torch.cat(
    [torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max_src_seg_len - L)), 0) for L in src_len])
tgt_seq = torch.cat(
    [torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max_tgt_seg_len - L)), 0) for L in tgt_len])

# Part1:构造Word Embedding
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

# 构造Pos序列跟i序列
pos_mat = torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2) / model_dim)

# Part2:构造Position Embedding
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

# 构建位置索引
src_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max_position_len), 0) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

# Part3:构造encoder self-attention mask
valid_encoder_pos = torch.unsqueeze(
    torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
score = torch.randn(batch_size, max_src_seg_len, max_src_seg_len)
mask_score1 = score.masked_fill(mask_encoder_self_attention, -1e9)
prob1 = F.softmax(mask_score1, -1)

# Part4:构造intra-attention mask
valid_encoder_pos = torch.unsqueeze(
    torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seg_len - L)), 0) for L in src_len]), 2)
valid_decoder_pos = torch.unsqueeze(
    torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_tgt_seg_len - L)), 0) for L in tgt_len]), 2)

valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)
mask_score2 = score.masked_fill(mask_cross_attention, -1e9)
prob2 = F.softmax(mask_score2, -1)

# Part5:构造Decoder self-attention mask
valid_decoder_tri_matrix = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones(L, L)), (0, max_tgt_seg_len-L, 0, max_tgt_seg_len-L)),0) for L in tgt_len])
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
mask_decoder_self_attention = invalid_decoder_tri_matrix.to(torch.bool)
score2 = torch.randn(batch_size, max_tgt_seg_len, max_tgt_seg_len)
mask_score3 = score2.masked_fill(mask_decoder_self_attention, -1e9)
prob3 = F.softmax(mask_score3, -1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1925556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JVM实战篇】内存调优:内存泄露危害+内存监控工具介绍+内存泄露原因介绍

文章目录 内存调优内存溢出和内存泄漏内存泄露带来什么问题内存泄露案例演示内存泄漏的常见场景场景一场景二 解决内存溢出的方法常用内存监控工具Top命令优缺点 VisualVM软件、插件优缺点监控本地Java进程监控服务器的Java进程(生产环境不推荐使用) Art…

JavaWeb(三:JDBC 与 MVC)

JavaWeb(一:基础知识和环境搭建)https://blog.csdn.net/xpy2428507302/article/details/140365130?spm1001.2014.3001.5501JavaWeb(二:Servlet与Jsp,监听器与过滤器)https://blog.csdn.net/xpy…

Python蜂窝通信Wi-Fi和GPU变分推理及暴力哈希加密协议图消息算法

🎯要点 🎯图模型和消息传递推理算法 | 🎯消息传递推理和循环消息传递推理算法 | 🎯空间人工智能算法多维姿势估计 | 🎯超图结构解码算法量子计算 | 🎯GPU处理变分推理消息传递贝叶斯网络算法 | &#x1f3…

5G-A通感融合赋能低空经济-RedCap芯片在无人机中的应用

1. 引言 随着低空经济的迅速崛起,无人机在物流、巡检、农业等多个领域的应用日益广泛。低空飞行器的高效、安全通信成为制约低空经济发展的关键技术瓶颈。5G-A通感一体化技术通过整合通信与感知功能,为低空网络提供了强大的技术支持。本文探讨了5G-A通感…

【中国近代史】林则徐虎门销烟(1839年)

中国古代朝代历史经过两周时间(7.03-7.13)的分享已经正式结束,首先感谢大家通过那个专栏点赞收藏关注我,这是我继续创作的动力。 接下来新的专栏就是中国近代史。让我们再次走入近代史的潮流中,去学习去感受先辈们的拼…

计组_多处理器的基本概念

2024.06.26:计算机组成原理多处理器的基本概念学习笔记 第21节 多处理器的基本概念 1. 计算机体系结构1.1 SISD单指令流单数据流(前面几章一直在学习的内容)1.2 SIMD单指令流多数据流1.2.1 改进:向量处理器 1.3 MISD多指令流单数据…

应用帕累托原则学习新的编程语言

在本文中,我将讨论如何应用帕累托原则快速学习一门新的编程语言,并在加深对编程语言的理解的同时开始解决实际问题。 什么是帕累托原则? 帕累托原则,又称 80/20 法则,指出对于许多结果而言,大约 80% 的后…

数据湖仓一体(一) 编译hudi

目录 一、大数据组件版本信息 二、数据湖仓架构 三、数据湖仓组件部署规划 四、编译hudi 一、大数据组件版本信息 hudi-0.14.1zookeeper-3.5.7seatunnel-2.3.4kafka_2.12-3.5.2hadoop-3.3.5mysql-5.7.28apache-hive-3.1.3spark-3.3.1flink-1.17.2apache-dolphinscheduler-3.1.9…

我的智能辅助大师-办公小浣熊

一、基本介绍 随着2022年ChatGPT为代表的AI工具对互联网领域进行第一次冲击后,作为一名对编程领域涉足不算特别深的一名程序员,对AI大模型的接触也真的不能算少了,这是时代的必然趋势。在此之前也曾接触过很多的AI工具,他们都能在…

跨境电商系统如何进行搭建

目前越来越多的商家,他们都在进行跨境电商建站,便于自己在网络上进行营销推广,跨境电商系统的搭建是至关重要的,商家应该先了解跨境电商的模式有哪些,这样才能对跨境电商系统有更好的搭建结果。 跨境电商模式 目前来…

1589. 【中山市第十二届义务教育段学生信息学邀请赛】象战(bishop)(Standard IO)

题目描述 国际象棋的棋盘可以表示为一个 8 行 8 列的格子图,其中每个格子都可以放一枚棋子。我们将第 1 行第 2 列的格子用 (1,2) 来表示,以此类推。 为了帮助妹妹认识国际象棋中的“象”这种棋子,Jimmy 可谓是煞费苦心——他首先教会妹妹&…

【Java】数据类型及类型转换

数据类型 Java语言的数据类型分为两大类: 基础数据类型引用数据类型 基础数据类型 基础数据类型包括以下8种: 类型名称关键字占用内存取值范围区间描述字节型byte1 字节-128~127-27~27-1短整型short2 字节-32768~32767-215~215-1整型int4 字节-2147…

算法学习笔记(8.3)-(0-1背包问题)

目录 最常见的0-1背包问题: 第一步:思考每轮的决策,定义状态,从而得到dp表 第二步:找出最优子结构,进而推导出状态转移方程 第三步:确定边界条件和状态转移顺序 方法一:暴力搜素…

经典元启发式算法的适用范围及基本思想

元启发式算法是针对优化问题设计的一类高级算法,它们具有广泛的适用性,可以解决不同类型的问题。不同的元启发式算法由于其特定的搜索机制和策略,适用的优化问题类型也有所不同。以下是一些常见元启发式算法及其适用范围: 1. 遗传…

OrangePi AIpro 浅上手及搭建卡通图像生成多元化AI服务

前言 很高兴,收到了一份新款 OrangePi AIpro 开发板,这是香橙派第一次与华为昇腾合作,使用昇腾系列 AI 处理器来设计这款高性价比的 AI 开发板。这块开发板不仅性能强大,还支持丰富的硬件接口,为AI开发者提供了一个理…

Nginx的访问限制与访问控制

访问限制 访问限制是一种防止恶意访问的常用手段,可以指定同一IP地址在固定时间内的访问次数,或者指定同一IP地址在固定时间内建立连接的次数,若超过网站指定的次数访问将不成功。 请求频率限制配置 请求频率限制是限制客户端固定时间内发…

代码随想录第十二天|226.翻转二叉树、 101.对称二叉树、104.二叉树的最大深度、111.二叉树的最小深度

文章目录 226.翻转二叉树思路解法一、前序遍历递归解法二、深度优先搜索--迭代 101.对称二叉树思路解法一、递归解法二、迭代 104.二叉树的最大深度解法一、深度优先搜索--递归解法二、广度优先搜索--迭代 111.二叉树的最小深度解法一、深度优先搜索--递归解法二、广度优先搜索…

制作显卡版docker并配置TensorTR环境

感谢阅读 相关概念docker准备下载一个自己电脑cuda匹配的docker镜像拉取以及启动镜像安装cudaTensorRT部署教程 相关概念 TensorRT是可以在NVIDIA各种GPU硬件平台下运行的一个模型推理框架,支持C和Python推理。即我们利用Pytorch,Tensorflow或者其它框架…

2、matlab打开显示保存点云文件(.ply/.pcd)以及经典点云模型数据

1、点云数据简介 点云数据是三维空间中由大量二维点坐标组成的数据集合。每个点代表空间中的一个坐标点,可以包含有关该点的颜色、法向量、强度值等额外信息。点云数据可以通过激光扫描、结构光扫描、摄像机捕捉等方式获取,广泛应用于计算机视觉、机器人…

常用控件(六)

布局管理器 布局管理器垂直布局QHBoxLayoutQGridLayoutQFormLayoutQSpacerItem 布局管理器 之前使⽤ Qt 在界⾯上创建的控件, 都是通过 “绝对定位” 的⽅式来设定的. 也就是每个控件所在的位置, 都需要计算坐标, 最终通过 setGeometry 或者 move ⽅式摆放过去.这种设定⽅式其…