【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

news2024/12/23 7:58:08

解码器组件

在这里插入图片描述

解码器部分:

  • 由N个解码器层堆叠而成
  • 每个解码器层由三个子层连接结构组成
  • 第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
  • 第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
  • 第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接

解码器层的作用:
作为解码器的组成单元,每个解码器层根据给定的输入向目标方向进行特征提取操作,即解码过程。

解码器层代码

解码器曾主要由三个子层组成,这里面三个子层还用之前构建Encoder时的代码,详情请看:【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        """

        :param size: 词嵌入维度
        :param self_attn: 多头自注意力层 Q=K=V
        :param src_attn: 多头注意力层 Q!=K=V
        :param feed_forward: 前馈全连接层
        :param dropout: 置0比率
        """
        super(DecoderLayer, self).__init__()

        # 传参到类中
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.dropout = dropout

        # 按照解码器层的结构图,使用clones函数克隆3个子层连接对象
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, source_mask, target_mask):
        """构建出三个子层:多头自注意力子层、普通的多头注意力子层、前馈全连接层

        :param x: 上一层输入的张量
        :param memory: 编码器的语义存储张量(K=V)
        :param source_mask: 源数据的掩码张量
        :param target_mask: 目标数据的掩码张量
        :return:一层解码器的解码输出
        """
        m = memory

        # 第一步,让x进入第一个子层(多头自注意力机制的子层)
        # 采用target_mask,将解码时未来的信息进行遮掩。
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, target_mask))

        # 第二步,让x进入第二个子层(常规多头注意力机制的子层,Q!=K=V)
        # 采用source_mask,遮掩掉已经判定出来的对结果信息无用的数据(减少对无用信息的关注),提升计算效率
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, source_mask))

        # 第三步,让x进入第三个子层(前馈全连接层)
        return self.sublayer[2](x, self.feed_forward)

示例

# 定义参数
size = d_model = 512
head = 8
d_ff = 64
dropout = 0.2

self_attn = src_attn = MultiHeadedAttention(head, d_model, dropout)     # 定义多头注意力层
ff = PositionwiseFeedForward(d_model, d_ff, dropout)                    # 定义前馈全连接层
x = pe_res
memory = enc_res    # 将之前编码器实例中的enc_res结果赋值给memory作为K和V
mask = torch.zeros(2, 4, 4)
source_mask = target_mask = mask    # 简单示范,都先给同样的mask

dl = DecoderLayer(size, self_attn, src_attn, ff, dropout)
dl_res = dl(x, memory, source_mask, target_mask)
print(f"dl_res: {dl_res}\n shape:{dl_res.shape}")


dl_res: tensor([[[-2.7233e+01,  3.7782e+01,  1.7257e+01,  ...,  1.2275e+01,
          -4.7017e+01,  1.7687e+01],
         [-2.6276e+01,  1.4660e-01,  5.5642e-02,  ..., -2.5157e+01,
          -2.8655e+01, -3.8758e+01],
         [ 1.0419e+00, -2.7726e+01, -2.3628e+01,  ..., -7.7137e+00,
          -5.7320e+01,  4.6977e+01],
         [-3.3436e+01,  3.2082e+01, -1.6754e+01,  ..., -2.5161e-01,
          -4.0380e+01,  4.7144e+01]],
        [[-5.3706e+00, -2.4270e+01,  2.1009e+01,  ...,  6.5833e+00,
          -4.3054e+01,  2.5535e+01],
         [ 3.1999e+01, -8.3981e+00, -5.6480e+00,  ...,  3.1037e+00,
           2.1093e+01,  3.0293e+00],
         [ 5.5799e+00,  1.0306e+01, -2.0165e+00,  ...,  3.8163e+00,
           4.0567e+01, -1.2256e+00],
         [-3.6323e+01, -1.4260e+01,  3.3353e-02,  ..., -9.4611e+00,
          -1.6435e-01, -3.5157e+01]]], grad_fn=<AddBackward0>)
 shape:torch.Size([2, 4, 512])

对比下面编码器的编码结果:

enc_res: tensor([[[-0.9458,  1.4723,  0.6997,  ...,  0.6569, -1.9873,  0.7674],
         [-0.9278,  0.0055, -0.0309,  ..., -1.2925, -1.2145, -1.6950],
         [ 0.1456, -1.1068, -0.8927,  ..., -0.2079, -2.2481,  1.8858],
         [-1.2406,  1.3828, -0.8069,  ...,  0.1041, -1.5828,  1.9792]],
        [[-0.1922, -1.1158,  0.7787,  ...,  0.2102, -1.7763,  1.1359],
         [ 1.4014, -0.3193, -0.3572,  ..., -0.0428,  0.7563,  0.1116],
         [ 0.3749,  0.4738, -0.0470,  ...,  0.1295,  1.8679,  0.0937],
         [-1.5545, -0.5667, -0.0432,  ..., -0.6391, -0.0121, -1.4567]]],
       grad_fn=<AddBackward0>)

原数据的掩码张量存在意义:
掩码原数据中,关联性弱的数据,不让注意力计算分散,提升计算效率。

解码器代码

N个解码器层构成一个解码器

class Decoder(nn.Module):
    def __init__(self, layer, N):
        """ 确定解码器层和层数

        :param layer: 解码器层
        :param N: 解码器层的个数
        """
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)          # 使用clones函数克隆N个类
        self.norm = LayerNorm(layer.size)        # 实例化规范化层

    def forward(self, x, memory, source_mask, target_mask):
        """ 循环构建解码器,经过规范化层后输出

        :param x:目标数据的嵌入表示
        :param memory:解码器层的输出QV
        :param source_mask:源数据掩码张量
        :param target_mask:目标数据掩码张量
        :return:经过规范化后的解码器
        """
        for layer in self.layers:
            x = layer(x, memory, source_mask, target_mask)

        return self.norm(x)

示例

size = d_model = 512
head = 8
d_ff =64
dropout = 0.2
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)     # 第一个attn作为自注意力机制,第二个attn作为注意力机制

N = 8
x = pe_res
memory = enc_res
mask = torch.zeros(2, 4, 4)
source_mask = target_mask = mask

de = Decoder(layer, N)      # 实例化解码器
de_res = de(x, memory, source_mask, target_mask)
print(f"de_res: {de_res}\n shape: {de_res.shape}")


de_res: tensor([[[-0.7714,  0.1066,  1.8197,  ..., -0.1137,  0.2005,  0.5856],
         [-0.9215, -0.9844, -0.4962,  ..., -0.1074,  0.4848,  0.3493],
         [-2.2495,  0.0859, -0.7644,  ..., -0.0679, -0.7270, -1.3438],
         [-0.4822,  0.2821,  1.0786,  ..., -1.9442,  0.8834, -1.1757]],
        [[-0.2491, -0.6117,  0.7908,  ..., -2.1624,  0.6212,  0.6190],
         [-0.3938, -0.5203,  0.6412,  ..., -0.8679,  0.8462,  0.3037],
         [-1.0217, -1.0685, -0.5138,  ...,  1.2010,  2.0795, -0.0143],
         [-0.2919, -0.5916,  1.5231,  ..., -0.1215,  0.7127, -0.0586]]],
       grad_fn=<AddBackward0>)
 shape: torch.Size([2, 4, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1204819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

牛客网上收藏题目总结及重写(C语言)(3)

每日一言 如果预计中的不幸没有发生的话&#xff0c;我们就会收获意外的喜悦。 --人生的智慧 题目BC84 错因&#xff1a;忘记要使用小数除法 代码 #include <stdio.h> int main() {int i 0;int n 0;scanf("%d",&n);double sum 0;for(i1;i<n;i){su…

分享一些有趣的MATLAB提示音(代码可直接复制)

先做一个声明&#xff1a;文章是由我的个人公众号中的推送直接复制粘贴而来&#xff0c;因此对智能优化算法感兴趣的朋友&#xff0c;可关注我的个人公众号&#xff1a;启发式算法讨论。我会不定期在公众号里分享不同的智能优化算法&#xff0c;经典的&#xff0c;或者是近几年…

联想笔记本Fn + A可以全选,Ctrl失效

问题&#xff1a;联想笔记本Fn A可以全选&#xff0c;ctrl失效。 原因&#xff1a;BIOS启用了Fn键和Ctrl键互换。 解决操作&#xff1a; 1.开机时一直按F2&#xff0c;进入BIOS 2.点击More Settings > 2.选取Configuration 3.将Fool Proof Fn Ctrl 设定变更为Disabled 4.按…

【算法与数据结构】491、LeetCode递增子序列

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;本题和【算法与数据结构】78、90、LeetCode子集I&#xff0c; II中90.子集II问题有些类似&#xff0c;…

(四)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB

一、七种算法&#xff08;DBO、LO、SWO、COA、LSO、KOA、GRO&#xff09;简介 1、蜣螂优化算法DBO 蜣螂优化算法&#xff08;Dung beetle optimizer&#xff0c;DBO&#xff09;由Jiankai Xue和Bo Shen于2022年提出&#xff0c;该算法主要受蜣螂的滚球、跳舞、觅食、偷窃和繁殖…

Outlook如何精准搜索邮件

说明&#xff1a; 使用Outlook默认的搜索时&#xff0c;会出来很多无关的信息&#xff0c;对搜索邮件带来很大的不便&#xff0c;下面介绍一个使用精准搜索的方法。 操作指引&#xff1a; 1、在outlook左上角&#xff0c;进行如下操作&#xff0c;打开“其他命令” 2、打开快…

UBoot

uboot是什么&#xff1f; 嵌入式linux系统启动过程 嵌入式系统上电后先执行uboot、然后uboot负责初始化DDR&#xff0c;初始化Flash&#xff0c;然后将OS从Flash中读取到DDR中&#xff0c;然后启动OS&#xff08;OS启动后uboot就无用了&#xff09;uboot是什么&#xff0c;ubo…

【Java】集合(二)Set

1.Set接口基本介绍 无序:存取顺序不一致不重复:可以去除重复无索引:没有带索引的方法&#xff0c;所以不能使用普通for循环遍历&#xff0c;也不能通过索引来获取元素 2.Set集合的实现类 HashSet:无序、不重复、无索引LinkedHashSet: 有序、不重复、无索引TreeSet: 可排序、不…

3、Linux库的生成和使用(核心代码是程序员不可公开的小秘密)

目录 Linux库的概念 Linux 静态库 Linux 静态库作用 Linux 静态库的创建 1. 将.c文件生成.o文件 ​编辑 2. 将所有的.o文件归档为一个静态库.a文件 Linux 静态库的使用 Linux 动态库&#xff1a; Linux 动态库作用 Linux 动态库的创建 生成.so动态库文件 ​编辑 …

推荐一份适合所有人做的副业,尤其是程序员。

我建议每个人都去尝试一下网上接单&#xff0c;这是一个门槛低、类型多样的方式&#xff0c;尤其适合程序员&#xff01; 在接单平台上&#xff0c;你可以看到各种类型的兼职。以freelancer为例&#xff0c;你可以在这里找到技术、设计、写作等类型的兼职&#xff0c;只要发挥…

广告算法资料汇总【建设中】

业内大佬 阿里妈妈技术 张俊林 王喆 萧瑟 朱小强 综合 付海军&#xff1a;基于互联网广告发展演变和思考&#xff08;附视频讲解PPT&#xff09; 广告算法工程师入门_广告与算法的博客-CSDN博客 广告算法学习笔记 20万、50万、100万的算法工程师&#xff0c;到底有什么区别…

EtherCAT转Modbus网关的 EtherCAT从站配置案例

兴达易控EtherCAT转Modbus网关&#xff08;XD-MDEC20 &#xff09;是一款具备ETHERCAT从站功能的通讯网关&#xff0c;其主要作用是将ETHERCAT网络和MODBUS-RTU网络连接起来。该网关可作为ETHERCAT总线中的从站使用&#xff0c;同时也能够连接到MODBUS-RTU总线中&#xff0c;作…

upload-labs关卡4(黑名单点空格绕过或htaccess绕过)通关思路

文章目录 前言一、回顾上一关知识点二、靶场第四关方法一通关思路1.看源码2、点空格绕过 三、靶场第四关方法二通关思路1、htaccess文件是什么2、通过上传htaccess文件进行绕过1、使用前提2、上传htaccess文件&#xff0c;然后再上传phpinfo的jpg文件 总结 前言 此文章只用于学…

轻量封装WebGPU渲染系统示例<29>- 深度模糊DepthBlur(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/DepthBlur.ts 当前示例运行效果: 此示例基于此渲染系统实现&#xff0c;当前示例TypeScript源码如下: const blurRTTTex0 { diffuse: { uuid: "rtt0", …

数据结构-堆排序及其复杂度计算

目录 1.堆排序 1.1 向上调整建堆 1.2 向下调整建堆 2. 两种建堆方式的时间复杂度比较 2.1 向下调整建堆的时间复杂度 2.2 向上调整建堆的时间复杂度 Topk问题 上节内容&#xff0c;我们讲了堆的实现&#xff0c;同时还包含了向上调整法和向下调整法&#xff0c;最后我们…

为什么要安装田间气象站?

随着农业科技的发展&#xff0c;越来越多的农民朋友开始关注如何利用科技手段来提高农业生产效益。其中&#xff0c;安装田间气象站成为了许多农民朋友的选择之一&#xff0c;为什么会有这种情况呢&#xff1f;安装田间气象站会带来哪些优势呢&#xff1f; 一、了解气候变化 气…

Vue3问题:如何实现页面引导提示?

前端功能问题系列文章&#xff0c;点击上方合集↑ 序言 大家好&#xff0c;我是大澈&#xff01; 本文约1700字&#xff0c;整篇阅读大约需要3分钟。 本文主要内容分三部分&#xff0c;第一部分是需求分析&#xff0c;第二部分是实现步骤&#xff0c;第三部分是问题详解。 …

No194.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…

C语言--假设共有鸡、兔30只,脚90只,求鸡、兔各有多少只​

一.题目描述 假设共有鸡、兔30只&#xff0c;脚90只&#xff0c;求鸡、兔各有多少只&#xff1f; 二.思路分析 本题是一个典型的穷举法例题&#xff0c;而穷举法&#xff0c;最重要的就是条件判断。⭐⭐ 本题中的条件很容易发现&#xff1a; 假设鸡有x只&#xff0c;兔有y只…

【C++类和对象下:解锁面向对象编程的奇妙世界】

【本节目标】 1. 再谈构造函数 2. Static成员 3. 友元 4. 内部类 5.匿名对象 6.拷贝对象时的一些编译器优化 7. 再次理解封装 1. 再谈构造函数 1.1 构造函数体赋值 在创建对象时&#xff0c;编译器通过调用构造函数&#xff0c;给对象中各个成员变量一个合适的初始值。…