【Transformer从零开始代码实现 pytoch版】(五)总架构类的实现

news2025/1/23 14:52:24

Transformer总架构

在这里插入图片描述
在实现完输入部分、编码器、解码器和输出部分之后,就可以封装各个部件为一个完整的实体类了。

【Transformer从零开始代码实现 pytoch版】(一)输入部件:embedding+positionalEncoding

【Transformer从零开始代码实现 pytoch版】(二)Encoder编码器组件:mask + attention + feed forward + add&norm

【Transformer从零开始代码实现 pytoch版】(三)Decoder编码器组件:多头自注意力+多头注意力+全连接层+规范化层

【Transformer从零开始代码实现 pytoch版】(四)输出部件:Linear+softmax

编码器-解码器总结构代码实现

class EncoderDecoder(nn.Module):
    """ 编码器解码器架构实现、定义了初始化、forward、encode和decode部件
    """
    def __init__(self, encoder, decoder, source_embed, target_embed, generator):
        """ 传入五大部件参数

        :param encoder: 编码器
        :param decoder: 解码器
        :param source_embed: 源数据embedding函数
        :param target_embed: 目标数据embedding函数
        :param generator: 输出部分类被生成器对象
        """
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = source_embed
        self.tgt_embed = target_embed
        self.generator = generator					# 生成器后面会专门用到

    def forward(self, source, target, source_mask, target_mask):
        """ 构建数据流入流出

        :param source: 源数据
        :param target: 目标数据
        :param source_mask: 源数据掩码张量
        :param target_mask: 目标数据掩码张量
        :return:
        """
        # 注意这里先用的encode和decode函数,又才在其函数里面,再用了encoder和decoder
        return self.decode(self.encode(source, source_mask), source_mask, target, target_mask)

    def encode(self, source, source_mask):
        """ 编码函数,编码部件

        :param source: 源数据张量
        :param source_mask: 源数据的掩码张量
        :return: 经过解码器的输出
        """
        return self.encoder(self.src_embed(source), source_mask)

    def decode(self, memory, source_mask, target, target_mask):
        """ 解码函数,解码部件

        :param memory:编码器的输出QV
        :param source_mask:源数据的掩码张量
        :param target:目标数据
        :param target_mask:目标数据的掩码张量
        :return:
        """
        return self.decoder(self.tgt_embed(target), memory, source_mask, target_mask)

示例

# 输入参数
vocab_size = 1000
size = d_model = 512

# 编码器部分
dropout = 0.2
d_ff = 64				# 隐藏层参数
head = 8				# 注意力头数
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
encoder_layer = EncoderLayer(size, c(attn), c(ff), dropout)
encoder_N = 8
encoder = Encoder(encoder_layer, encoder_N)

# 解码器部分
dropout = 0.2
d_ff = 64
head = 8
c = copy.deepcopy
attn = MultiHeadedAttention(head, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
decoder_layer = DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout)
decoder_N = 8
decoder = Decoder(decoder_layer, decoder_N)

# 用了nn的embedding作为输入示意
source_embed = nn.Embedding(vocab_size, d_model)
target_embed = nn.Embedding(vocab_size, d_model)
generator = Generator(d_model, vocab_size)

# 输入张量和掩码张量
source = target = torch.LongTensor([[100, 2, 421, 508], [491, 998, 1, 221]])
source_mask = target_mask = torch.zeros(2, 4, 4)

# 实例化编码器-解码器,再带入参数实现
ed = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)
ed_res = ed(source, target, source_mask, target_mask)
print(f"ed_res: {ed_res}\n shape:{ed_res.shape}")


ed_res: tensor([[[-0.1861,  0.0849, -0.3015,  ...,  1.1753, -1.4933,  0.2484],
         [-0.3626,  1.3383,  0.1739,  ...,  1.1304,  2.0266, -0.5929],
         [ 0.0785,  1.4932,  0.3184,  ..., -0.2021, -0.2330,  0.1539],
         [-0.9703,  1.1944,  0.1763,  ...,  0.1586, -0.6066, -0.6147]],
        [[-0.9216, -0.0309, -0.6490,  ...,  1.0177,  0.5574,  0.4873],
         [-1.4097,  0.6678, -0.6708,  ...,  1.1176,  0.1959, -1.2494],
         [-0.3204,  1.2794, -0.4022,  ...,  0.6319, -0.4709,  1.0520],
         [-1.3238,  1.1470, -0.9943,  ...,  0.4026,  1.0911,  0.1327]]],
       grad_fn=<AddBackward0>)
 shape:torch.Size([2, 4, 512])

编码器-解码器模型构建函数

def make_model(source_vocab, target_vocab, N=6, d_model=512, d_ff=2048, head=8, dropout=0.1):
    """ 用于构建模型

    :param source_vocab: 源数据词汇总数
    :param target_vocab: 目标词汇总数
    :param N: 解码器/解码器堆叠层数
    :param d_model: 词嵌入维度
    :param d_ff: 前馈全连接层隐藏层维度
    :param dropout: 置0比率
    :return: 返回构建编码器-解码器模型
    """
    # 拷贝函数,来保证拷贝的函数彼此之间相互独立,不受干扰
    c = copy.deepcopy

    # 实例化多头注意力
    attn = MultiHeadedAttention(head, d_model)

    # 实例化全连接层
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)

    # 实例化位置编码类,得到对象position
    position = PositionalEncoding(d_model, dropout)

    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embedding(d_model, source_vocab), c(position)),
        nn.Sequential(Embedding(d_model, source_vocab), c(position)),
        Generator(d_model, target_vocab)
    )

    # 模型结构构建完成后,初始化模型中的参数
    for p in model.parameters():
        # 这里判定当参数维度大于1的时候,则会将其初始化成一个服从均匀分布的矩阵
        if p.dim() > 1:
            nn.init.xavier_normal(p)        # 生成服从正态分布的数,默认为U(-1, 1),更改第二个参数可以改值

    return model

示例

source_vocab = target_vocab = 11
N = 6
res = make_model(source_vocab, target_vocab, N)
print(res)


EncoderDecoder(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w1): Linear(in_features=512, out_features=2048, bias=True)
          (w2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0-1): 2 x SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderLayer(
        (self_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (src_attn): MultiHeadedAttention(
          (linears): ModuleList(
            (0-3): 4 x Linear(in_features=512, out_features=512, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w1): Linear(in_features=512, out_features=2048, bias=True)
          (w2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (sublayer): ModuleList(
          (0-2): 3 x SublayerConnection(
            (norm): LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm()
  )
  (src_embed): Sequential(
    (0): Embedding(
      (lut): Embedding(512, 11)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (tgt_embed): Sequential(
    (0): Embedding(
      (lut): Embedding(512, 11)
    )
    (1): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (generator): Generator(
    (project): Linear(in_features=512, out_features=11, bias=True)
  )
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1203157.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Power Automate-变量和excel表数据的应用

前提表格 Power Automate连接excel请参考&#xff1a;SharePoint-连接Excel-CSDN博客 需求1&#xff1a;计算表格中某列的和 添加操作&#xff0c;搜索变量&#xff0c;选择初始化变量 添加变量的名称、类型和初始值 再新增操作&#xff0c;搜索Excel&#xff0c;点击查看更多…

Kubernetes介绍和环境部署

文章目录 Kubernetes一、Kubernetes介绍1.Kubernetes简介2.Kubernetes概念3.Kubernetes功能4.Kubernetes工作原理5.kubernetes组件6.Kubernetes优缺点 二、Kubernetes环境部署环境基本配置1.所有节点安装docker2.所有节点安装kubeadm、kubelet、kubectl添加yum源containerd配置…

查询数据表格中的数据

1.创建这个表至少20个 1&#xff09;创建数据库&#xff1a;create database 四川信息职业技术; 2&#xff09;创建数据表 3&#xff09;插入数据&#xff08;第一条代码修改了一下手机号码的字段类型&#xff09; 2.统计表中的人数 如果你想根据某个特定的列来统计人数&…

Jenkins在Linux环境下的安装与配置

Jenkins是一个开源软件项目&#xff0c;是基于Java开发的一种持续集成&#xff08;CI&#xff09;工具&#xff0c;用于解决持续重复的部署、监控工作&#xff1b;它一个开放易用的软件平台&#xff0c;大大简化软件的持续集成。 安装Jenkins 1.使用docker安装 2.本地下载je…

Python数据结构:元组(Tuple)详解

1.介绍和基础操作 Python中的元组&#xff08;Tuple&#xff09;是不可变有序序列&#xff0c;可以容纳任意数据类型&#xff08;包括数字、字符串、布尔型、列表、字典等&#xff09;的元素&#xff0c;通常用圆括号() 包裹。与列表&#xff08;List&#xff09;类似&#xff…

[IJKPLAYER]基于DEMO分析IJKPLAYER(整理版本)

背景 博主主要是从事C语言开发&#xff0c;因此本文着重强调FFMPEG部分&#xff0c;关于JAVA应用和框架层只是一笔带过。IJKPLAYER的实质是对FFMPEG项目中的ffplayer程序进行的二次封装&#xff0c;通过JNI方式完成对外提供JAVA接口。 1.目录结构 activities:包含了demo的所有…

react函数式组件props形式父向子传参

父组件中定义 子组件中触发回调传值 import { useState } from "react"; function Son(params) {const [count, setCount] useState(0);function handleClick() {console.log(params, paramsparamsparamsparamsparamsparams);params.onClick(111)setCount(count 1…

多个微信快速同步发圈

做营销最重要的任务是什么&#xff1f; 毋庸置疑&#xff0c;就是发布朋友圈。 为什么要发圈呢&#xff1f; 现在社交媒体中&#xff0c;微信不管在生活上、工作上都是不可或缺的工具&#xff0c;而朋友圈是微信中社交场景之一&#xff0c;也是很多企业作为推广产品和服务的重…

腾讯云服务器多少钱一年?2023年腾讯云优惠云服务器推荐

作为一名程序员&#xff0c;技术的突飞猛进是从拥有第一台云服务器开始的。那时&#xff0c;我开始尝试使用Linux系统&#xff0c;并成功上线了自己的第一个小程序。自此之后&#xff0c;我和我的同事们都开始拥有自己的云服务器&#xff0c;用来搭建各种小项目或者好玩的东西。…

OpenAtom OpenHarmony三方库创建发布及安全隐私检测

OpenAtom OpenHarmony三方库&#xff08;以下简称“三方库”或“包”&#xff09;&#xff0c;是经过验证可在OpenHarmony系统上可重复使用的软件组件&#xff0c;可帮助开发者快速开发OpenHarmony应用。三方库根据其开发语言分为2种&#xff0c;一种是使用JavaScript和TypeScr…

wpf devexpress设置行和编辑器

如下教程示范如何计算行布局&#xff0c;特定的表格单元编辑器&#xff0c;和格式化显示值。这个教程基于前一个文章 选择行显示 GridControl为所有字段生成行和绑定数据源&#xff0c;如果AutoGenerateColumns 属性选择AddNew。添加行到GridControl精确显示为特别的几行设置。…

Containerd接入Harbor仓库

在使用容器时&#xff0c;避免不了会使用到私有仓库&#xff0c;一般都是采用 harbor 作为私有仓库&#xff0c;docker 对接 harbor 仓库非常简单&#xff0c;哪 containerd 如何对接 harbor 呢&#xff1f; 在内网使用 harbor 根据个人习惯&#xff0c;一般都是非 http 并且是…

【SpringBoot3+Vue3】一【基础篇】

目录 一、Spring Boot概述 1、Spring Boot 特性 1.1 起步依赖 1.2 自动配置 1.3 其他特性 1.3.1 内嵌的Tomcat、Jetty (无需部署WAR文件) 1.3.2 外部化配置 1.3.3 不需要XML配置(properties/yml) 二、Spring Boot入门 1、一个入门程序需求 2、步骤 2.1 创建Maven工…

智能配方颗粒管理系统解决方案,专业实现中医药产业数字化-亿发

“中药配方颗粒”&#xff0c;又被称为免煎中药&#xff0c;源自传统中药饮片&#xff0c;经过提取、分离、浓缩、干燥、制粒、包装等工艺加工而成。这种新型配方药物完整保留了原中药饮片的所有特性。既能满足医师的辨证论治和随症加减需求&#xff0c;同时具备强劲好人高效的…

Python实现猎人猎物优化算法(HPO)优化XGBoost回归模型(XGBRegressor算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 猎人猎物优化搜索算法(Hunter–prey optimizer, HPO)是由Naruei& Keynia于2022年提出的一种最新的…

(论文阅读29/100 人体姿态估计)

29.文献阅读笔记 简介 题目 DeepCut: Joint Subset Partition and Labeling for Multi Person Pose Estimation 作者 Leonid Pishchulin, Eldar Insafutdinov, Siyu Tang, Bjoern Andres, Mykhaylo Andriluka, Peter Gehler, and Bernt Schiele, CVPR, 2016. 原文链接 h…

分享一套适合二开的JAVA开源版本MES系统

一、系统概述&#xff1a; 万界星空科技免费MES、开源MES、商业开源MES、市面上最好的开源MES、MES源代码 万界星空开源MES制造执行系统的Java开源版本。 开源mes系统包括系统管理&#xff0c;车间基础数据管理&#xff0c;计划管理&#xff0c;物料控制&#xff0c;生产执行…

ubuntu开机系统出错且无法恢复。请联系系统管理员。

背景&#xff1a; ubuntu22.04.2命令行&#xff0c;执行自动安装系统推荐显卡驱动命令&#xff0c;字体变大&#xff0c;重启后出现如下图错误&#xff0c;无法进入系统&#xff0c;无法通过CTRLALTF1-F3进入TTY模式。 解决办法&#xff1a; 1.首先要想办法进入系统&#xff…

江门車馬炮汽车金融中心 11月11日开张

江门车马炮汽车金融中心于11月11日正式开张&#xff0c;这是江门市汽车金融服务平台&#xff0c;旨在为广大车主提供更加便捷、高效的汽车金融服务。 江门市作为广东省的一个经济发达城市&#xff0c;汽车保有量持续增长&#xff0c;但车主在购车、用车、养车等方面仍存在诸多不…

华东“启明”青少年音乐艺术实践中心揭幕暨中国“启明”巴洛克合奏团首演音乐会

2023年11月11日&#xff0c;华东“启明”青少年音乐艺术实践中心在上海揭幕&#xff0c;中国“启明”巴洛克合奏团开启了首场音乐会。 华东“启明”青少年音乐艺术实践中心由中共宁波市江北区委宣传部与上音管风琴艺术中心联合指导&#xff0c;宁波音乐港、宁波市江北区洛奇音乐…