手撕Transformer -- Day6 -- DecoderBlock

news2025/1/16 20:54:45

手撕Transformer – Day6 – DecoderBlock

目录

  • 手撕Transformer -- Day6 -- DecoderBlock
    • Transformer 网络结构图
    • DecoderBlock 代码
      • Part1 库函数
      • Part2 实现一个解码器Block,作为一个类
      • Part3 测试
    • 参考

Transformer 网络结构图

在这里插入图片描述

Transformer 网络结构

DecoderBlock 代码

Part1 库函数

# 这个是解码器的block,和编码器来说多了一个掩码注意力机制,但是其实就是把掩码换一下即可,同时还对于第二个多头注意力机制的k_v和q不同源了
# 主要构成要素,输入嵌入好的句子,经过1.掩码注意力机制+残差归一化 2. 交叉注意力+残差归一化 3. 前向+残差归一化。保证输入输出同纬度(batch_size,seq_len,emding)
'''
# Part1 引入库函数
'''
import torch
from torch import nn
from multihead_attn import MultiHeadAttention
# 应该是用于测试
from dataset import train_dataset,de_preprocess,de_vocab,en_preprocess,en_vocab,PAD_IDX
from emb import EmbeddingWithPosition
from encoder import Encoder

Part2 实现一个解码器Block,作为一个类

'''
# Part2 写个类,实现EncoderBlock
'''
class DecoderBlock(nn.Module):
    def __init__(self,head,emd_size,q_k_size,v_size,f_size):
        super().__init__()
        # 首先要进行掩码多头注意力机制
        self.mask_multi_atten=MultiHeadAttention(head=head,emd_size=emd_size,q_k_size=q_k_size,v_size=v_size)
        self.linear1=nn.Linear(head*v_size,emd_size)
        # 归一化(填写的是最后一个的那个维度大小)
        self.norm1=nn.LayerNorm(emd_size)

        # 交叉注意力机制
        self.cross_multi_atten=MultiHeadAttention(head=head,emd_size=emd_size,q_k_size=q_k_size,v_size=v_size)
        self.linear2 = nn.Linear(head * v_size, emd_size)
        # 归一化(填写的是最后一个的那个维度大小)
        self.norm2 = nn.LayerNorm(emd_size)

        # 前向
        self.feedforward=nn.Sequential(
            nn.Linear(emd_size,f_size),
            nn.ReLU(),
            nn.Linear(f_size, emd_size)
        )
        self.norm3 = nn.LayerNorm(emd_size)
    def forward(self, x, encoder_z, mask_1, mask_2): # x(batch_size,q_seq_len,emd_size)
        # 掩码注意力机制
        z1=self.mask_multi_atten(x_q=x, x_k_v=x, mask_pad=mask_1) # (batch_size,q_seq_len,head*v_size)
        z1=self.linear1(z1) # (batch_size,q_seq_len,emd_size)
        # 第一个残差归一化,得到第一层的输出output
        outpu1=self.norm1(z1+x) # (batch_size,q_seq_len,emd_size)

        # 交叉注意力机制,把output作为q,编码器作为k_v
        z2=self.cross_multi_atten(x_q=outpu1, x_k_v=encoder_z, mask_pad=mask_2) # (batch_size,q_seq_len,head*v_size)
        # 第二个残差归一化
        z2 = self.linear1(z2) # (batch_size,q_seq_len,emd_size)
        output2=self.norm2(z2+outpu1) # (batch_size,q_seq_len,emd_size)

        # 前向
        z3=self.feedforward(output2) # (batch_size,q_seq_len,emd_size)
        # 第三个残差归一化
        output3 = self.norm3(z3 + output2) # (batch_size,q_seq_len,emd_size)
        return output3

Part3 测试

if __name__ == '__main__':
    # 取2个de句子转词ID序列,输入给encoder
    de_tokens1, de_ids1 = de_preprocess(train_dataset[0][0])
    de_tokens2, de_ids2 = de_preprocess(train_dataset[1][0])
    # 对应2个en句子转词ID序列,再做embedding,输入给decoder
    en_tokens1, en_ids1 = en_preprocess(train_dataset[0][1])
    en_tokens2, en_ids2 = en_preprocess(train_dataset[1][1])

    # de句子组成batch并padding对齐
    if len(de_ids1) < len(de_ids2):
        de_ids1.extend([PAD_IDX] * (len(de_ids2) - len(de_ids1)))
    elif len(de_ids1) > len(de_ids2):
        de_ids2.extend([PAD_IDX] * (len(de_ids1) - len(de_ids2)))

    enc_x_batch = torch.tensor([de_ids1, de_ids2], dtype=torch.long)
    print('enc_x_batch batch:', enc_x_batch.size())

    # en句子组成batch并padding对齐
    if len(en_ids1) < len(en_ids2):
        en_ids1.extend([PAD_IDX] * (len(en_ids2) - len(en_ids1)))
    elif len(en_ids1) > len(en_ids2):
        en_ids2.extend([PAD_IDX] * (len(en_ids1) - len(en_ids2)))

    dec_x_batch = torch.tensor([en_ids1, en_ids2], dtype=torch.long)
    print('dec_x_batch batch:', dec_x_batch.size())

    # Encoder编码,输出每个词的编码向量
    enc = Encoder(vocab_size=len(de_vocab), emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8, nums_encoderblock=3)
    enc_outputs = enc(enc_x_batch)
    print('encoder outputs:', enc_outputs.size())

    # 生成decoder所需的掩码
    first_attn_mask = (dec_x_batch == PAD_IDX).unsqueeze(1).expand(dec_x_batch.size()[0], dec_x_batch.size()[1],
                                                                   dec_x_batch.size()[1])  # 目标序列的pad掩码
    first_attn_mask = first_attn_mask | torch.triu(torch.ones(dec_x_batch.size()[1], dec_x_batch.size()[1]),
                                                   diagonal=1).bool().unsqueeze(0).expand(dec_x_batch.size()[0], -1,
                                                                                          -1) # &目标序列的向后看掩码
    print('first_attn_mask:', first_attn_mask.size())
    # 根据来源序列的pad掩码,遮盖decoder每个Q对encoder输出K的注意力
    second_attn_mask = (enc_x_batch == PAD_IDX).unsqueeze(1).expand(enc_x_batch.size()[0], dec_x_batch.size()[1],
                                                                    enc_x_batch.size()[
                                                                        1])  # (batch_size,target_len,src_len)
    print('second_attn_mask:', second_attn_mask.size())

    first_attn_mask = first_attn_mask
    second_attn_mask = second_attn_mask

    # Decoder输入做emb先
    emb = EmbeddingWithPosition(len(en_vocab), 128)
    dec_x_emb_batch = emb(dec_x_batch)
    print('dec_x_emb_batch:', dec_x_emb_batch.size())

    # 5个Decoder block堆叠
    decoder_blocks = []
    for i in range(5):
        decoder_blocks.append(DecoderBlock(emd_size=128, q_k_size=256, v_size=512, f_size=512, head=8))

    for i in range(5):
        dec_x_emb_batch = decoder_blocks[i](dec_x_emb_batch, enc_outputs, first_attn_mask, second_attn_mask)
    print('decoder_outputs:', dec_x_emb_batch.size())

参考

视频讲解:transformer-带位置信息的词嵌入向量_哔哩哔哩_bilibili

github代码库:github.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2277699.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【功能测试总结】

功能测试 1. 功能测试用例1.1 设计用例容易出现的问题 2. 如何写用例2.1 什么是好的用例2.2 测试用例设计常见方法 3. 用例分级 1. 功能测试用例 1.1 设计用例容易出现的问题 基础功能点用例覆盖不全/描述不清 描述不清 什么是正常内容&#xff0c;仅看用例能否知道该输入什么…

Mac玩Steam游戏秘籍!

Mac玩Steam游戏秘籍&#xff01; 大家好&#xff01;最近有不少朋友在用MacBook玩Steam游戏时遇到不支持mac的问题。别担心&#xff0c;我来教你如何用第三方工具Crossover来畅玩这些不支持的游戏&#xff0c;简单又实用&#xff01; 第一步&#xff1a;下载Crossover 首先&…

基于Springboot + vue实现的旅游网站

&#x1f942;(❁◡❁)您的点赞&#x1f44d;➕评论&#x1f4dd;➕收藏⭐是作者创作的最大动力&#x1f91e; &#x1f496;&#x1f4d5;&#x1f389;&#x1f525; 支持我&#xff1a;点赞&#x1f44d;收藏⭐️留言&#x1f4dd;欢迎留言讨论 &#x1f525;&#x1f525;&…

题解 CodeForces 430B Balls Game 栈 C/C++

题目传送门&#xff1a; Problem - B - Codeforceshttps://mirror.codeforces.com/contest/430/problem/B翻译&#xff1a; Iahub正在为国际信息学奥林匹克竞赛&#xff08;IOI&#xff09;做准备。有什么比玩一个类似祖玛的游戏更好的训练方法呢&#xff1f; 一排中有n个球…

Vue3播放视频报ReferenceError: SharedArrayBuffer is not defined

解决办法 前端本地测试vue.config.js server: {headers: {"Cross-Origin-Opener-Policy": "same-origin","Cross-Origin-Embedder-Policy": "require-corp",}, }, 后端vue.js生产环境 跨域隔离 是一种现代Web安全策略&#xff0c;…

Android BottomNavigationView不加icon使text垂直居中,完美解决。

这个问题网上千篇一律的设置iconsize为0&#xff0c;labale固定什么的&#xff0c;都没有效果。我的这个基本上所有人用都会有效果。 问题解决之前的效果&#xff1a;垂直方向&#xff0c;文本不居中&#xff0c;看着很难受 问题解决之后&#xff1a;舒服多了 其实很简单&…

微调神经机器翻译模型全流程

MBART: Multilingual Denoising Pre-training for Neural Machine Translation 模型下载 mBART 是一个基于序列到序列的去噪自编码器&#xff0c;使用 BART 目标在多种语言的大规模单语语料库上进行预训练。mBART 是首批通过去噪完整文本在多种语言上预训练序列到序列模型的方…

基于32QAM的载波同步和定时同步性能仿真,包括Costas环的gardner环

目录 1.算法仿真效果 2.算法涉及理论知识概要 3.MATLAB核心程序 4.完整算法代码文件获得 1.算法仿真效果 matlab2022a仿真结果如下&#xff08;完整代码运行后无水印&#xff09;&#xff1a; 仿真操作步骤可参考程序配套的操作视频。 2.算法涉及理论知识概要 载波同步是…

设计模式-工厂模式/抽象工厂模式

工厂模式 定义 定义一个创建对象的接口&#xff0c;让子类决定实列化哪一个类&#xff0c;工厂模式使一个类的实例化延迟到其子类&#xff1b; 工厂方法模式是简单工厂模式的延伸。在工厂方法模式中&#xff0c;核心工厂类不在负责产品的创建&#xff0c;而是将具体的创建工作…

【机器学习】零售行业的智慧升级:机器学习驱动的精准营销与库存管理

我的个人主页 我的领域&#xff1a;人工智能篇&#xff0c;希望能帮助到大家&#xff01;&#xff01;&#xff01;&#x1f44d;点赞 收藏❤ 在当今数字化浪潮汹涌澎湃的时代&#xff0c;零售行业正站在转型升级的十字路口。市场竞争的白热化使得企业必须另辟蹊径&#xff0…

day_2_排序算法和树

文章目录 排序算法和树排序算法算法稳定性排序算法☆ 冒泡排序冒泡思路冒泡步骤代码实现效率优化 ☆ 选择排序排序思路排序步骤代码实现 ... 树01-树的基本概念02-树的相关术语03-二叉树的种类04-二叉树的存储05-树的应用场景_数据库索引06-二叉树的概念和性质07-广度优先遍历0…

蓝桥杯刷题第二天——背包问题

题目描述 有N件物品和一个容量是V的背包。每件物品只能使用一次。第i件物品的体积是Vi价值是Wi。 求解将哪些物品装入背包&#xff0c;可使这些物品的总体积不超过背包容量&#xff0c;且总价值最大。 输出最大价值。 输入格式 第一行两个整数&#xff0c;N&#xff0c;V&am…

Linux x86_64 程序动态链接之GOT 和 PLT

文章目录 前言一、动态链接二、位置无关代码三、GOT 和 PLT3.1 GOT3.2 PLT3.3 延时绑定3.4 示例 四、demo演示五、延迟绑定技术和代码修补参考资料 前言 这篇文章描述了&#xff1a;Linux x86_64 程序静态链接之重定位&#xff0c;接来本文描述Linux x86_64 程序动态链接之GOT…

学习记录-责任链模式验证参数

学习记录-责任链模式验证参数 1.什么是责任链模式 责任链模式&#xff08;Chain of Responsibility Pattern&#xff09;是一种行为设计模式&#xff0c;它允许将请求沿着一个处理链传递&#xff0c;直到链中的某个对象处理它。这样&#xff0c;发送者无需知道哪个对象将处理…

练习:MySQL单表查询与多表查询

一.单表查询 创建worke数据库&#xff0c;在数据库底下创建worker表 mysql> create database worke; Query OK, 1 row affected (0.00 sec)mysql> show databases; -------------------- | Database | -------------------- | information_schema | | mysql …

HarmonyOS NEXT应用开发边学边玩系列:从零实现一影视APP (四、最近上映电影滚动展示及加载更多的实现)

在HarmonyOS NEXT开发环境中&#xff0c;可以使用多种组件和库来构建丰富且交互友好的应用。本文将展示如何使用HarmonyOS NEXT框架和nutpi/axios库&#xff0c;从零开始实现一个简单的影视APP的首页&#xff0c;主要关注最近上映电影的滚动展示及加载更多功能的实现。 开源项目…

卷积神经05-GAN对抗神经网络

卷积神经05-GAN对抗神经网络 使用Python3.9CUDA11.8Pytorch实现一个CNN优化版的对抗神经网络 简单的GAN图片生成 CNN优化后的图片生成 优化模型代码对比 0-核心逻辑脉络 1&#xff09;Anacanda使用CUDAPytorch2&#xff09;使用本地MNIST进行手写图片训练3&#xff09;…

客户案例:某家居制造企业跨境电商,解决业务端(亚马逊平台)、易仓ERP与财务端(金蝶ERP)系统间的业务财务数据对账互通

一、系统定义 1、系统定位&#xff1a; 数据中台系统是一种战略选择和组织形式&#xff0c;通过有型的产品支撑和实施方法论&#xff0c;解决企业面临的数据孤岛、数据维护混乱、数据价值利用低的问题&#xff0c;依据企业特有的业务和架构&#xff0c;构建一套从数据汇聚、开…

服务器一次性部署One API + ChatGPT-Next-Web

服务器一次性部署One API ChatGPT-Next-Web One API ChatGPT-Next-Web 介绍One APIChatGPT-Next-Web docker-compose 部署One API ChatGPT-Next-WebOpen API docker-compose 配置ChatGPT-Next-Web docker-compose 配置docker-compose 启动容器 后续配置 同步发布在个人笔记服…

辅助云运维

为客户提供运维支持&#xff0c;保障业务连续性。 文章目录 一、服务范围二、服务内容三、服务流程四、 服务交付件五、责任分工六、 完成标志 一、服务范围 覆盖范围 云产品使用咨询、问题处理、配置指导等&#xff1b; 云产品相关操作的技术指导&#xff1b; 云相关资源日常…