【课程总结】day19(下):Transformer源码深入理解

news2024/9/23 19:17:18

前言

在上一章【课程总结】day19(下):Transformer架构及注意力机制了解总结中,我们对Transformer架构以及注意力机制有了初步了解,本章将结合《The Annotated Transformer》中的源码,对Transformer的架构进行深入理解。

背景

《The Annotated Transformer》是由 Harvard NLP Group 提供的一个详细教程,旨在帮助读者理解 Transformer 模型的工作原理和实现细节。

原文博客地址:https://nlp.seas.harvard.edu/2018/04/03/attention.html
Github仓库地址:https://github.com/harvardnlp/annotated-transformer

整体框架

Transformer架构

class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """

    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

代码解析:

  • Transformer架构总体由一个 EncoderDecoder 类构成。
  • EncoderDecoder 类的成员变量包含五部分:
    • encoder:编码器,负责将输入序列编码为固定长度的向量。
    • decoder:解码器,负责将编码后的向量解码为输出序列。
    • src_embed:输入序列的嵌入层,将输入序列转换为固定维度的向量。
    • tgt_embed:输出序列的嵌入层,将输出序列转换为固定维度的向量。
    • generator:生成器,将解码后的向量转换为输出序列。

Encoder

首先查看 Encoder 类的结构:
源码如下:

class Encoder(nn.Module):
    "Core encoder is a stack of N layers"

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

代码解析:

  • Encoder 类的初始化函数中,分别创建了self.layersself.norm
  • self.layers 的创建方式是使用clones函数进行deepcopy的批量化创建
    • clones 函数可以创建指定数量的相同对象的列表。(Transformer的多层网络能力即由该函数体现)
    • self.layers 创建并实例化的layer对象,其类型为EncoderLayer
    • Encoder的前向传播 forward 函数中,会依次给每个EncoderLayer对象传入mask以便进行pad掩码操作。
  • self.norm 对应是LayerNorm类,该类用于对输入序列进行归一化处理。
EncoderLayer

因为 Encoder 类是由多个 EncoderLayer构成,所以接着了解EncoderLayer类。
源码如下:

class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

代码解析:

  • EncoderLayer 类中包含4个成员变量:
    • self_attn:自注意力机制,用于对输入序列进行自注意力,该成员变量创建时会通过公共函数attension函数创建,并作为参数传入给self_attn成员变量。
    • feed_forward:前馈网络,对应PositionwiseFeedForward类的对象。
    • sublayer:包含两个SublayerConnection类的对象(对应图示中的Add&Norm),其作用是对输入序列进行归一化处理。
    • size:输入序列的维度大小。
  • forward函数中:
    • self.sublayer[0]代表两个 SublayerConnection 实例的列表第一个子层,即自注意力机制的连接。
    • lambda x: self.self_attn(x, x, x, mask) 是一个匿名函数,它接收输入 x 并执行 self.self_attn(x, x, x, mask) 自注意力计算.
    • self.self_attn(x, x, x, mask) 表示使用输入 x 作为查询(Q)、键(K)和值(V),同时传入 mask
    • 然后,SublayerConnection 将处理这个输出,通常包括残差连接和层归一化。

Decoder

其次,查看 Decoder 的实现。
源码如下:

class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1976492.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LaneATT推理详解及部署实现(上)

目录 前言1. 概述2. 环境配置3. Demo测试4. ONNX导出初探5. ONNX导出优化6. ONNX导出总结结语下载链接参考 前言 最近想关注下车道线检测任务,在 GitHub 上找了一个模型 LaneATT,想通过调试分析 LaneATT 代码把 LaneATT 模型导出来,并在 tens…

Java游戏源码:象棋网络对战版

学习java朋友们,福利来了,今天小编给大家带来了一款象棋网络对战版源码。 源码搭建和讲解 源码分为客户端和服务器,采用java原生 java.net.Socket 实现,服务器主循环代码: import java.net.ServerSocket; import jav…

二维码生成原理及解码原理

☝☝☝二维码配图 二维码 二维码(Quick Response Code,简称QR码)是一种广泛使用的二维条形码技术,由日本公司Denso Wave在1994年开发。二维码能有效地存储和传递信息,广泛应用于商品追溯、支付、广告等多个领域。二维…

Star-CCM+负体积网格检查与出现原因

要使网格可用于有限体积计算,每个网格单元必须具有正体积,否则初始化过程将失败,且模拟计算无法运行。 负体积网格单元可能会以多种不同的方式出现,但必须修复或从网格中移除,才能继续执行任何后续操作。 要检查体网…

<数据集>人员摔倒识别数据集<目标检测>

数据集格式:VOCYOLO格式 图片数量:8605张 标注数量(xml文件个数):8605 标注数量(txt文件个数):8605 标注类别数:1 标注类别名称:[fall] 序号类别名称图片数框数1fall860512275 使用标注工具&#xf…

当前生物信息学研究面临的四大机遇和挑战(特别是最后一个,一定要足够重视)...

生物信息学是应用计算方法分析生物数据,如 DNA,RNA,蛋白质和代谢物。生物信息学已成为促进我们对生命科学的理解以及开发新的诊断,治疗和生物技术产品的重要工具。本文我们将探讨生物信息学研究的一些当前趋势和发展,以…

如何快速入门 PyTorch ?

PyTorch是一个机器学习框架,主要依靠深度神经网络,目前已迅速成为机器学习领域中最可靠的框架之一。 PyTorch 的大部分基础代码源于 Ronan Collobert 等人 在 2007 年发起的 Torch7 项目,该项目源于 Yann LeCun 和 Leon Bottou 首创的编程语…

【C++题解】1249. 搬砖问题

欢迎关注本专栏《C从零基础到信奥赛入门级(CSP-J)》 问题:1249. 搬砖问题 类型:嵌套穷举 题目描述: 36 块砖, 36 人搬。男搬 4 ,女搬 3 ,两个小儿抬一砖。 要求一次全搬完。问需…

GitHub最全中文排行榜开源项目,助你轻松发现优质资源!

文章目录 GitHub-Chinese-Top-Charts:中文开发者的开源项目精选项目介绍项目特点核心功能1. 热门项目榜单2. 详细项目信息 如何使用覆盖范围软件类资料类 GitHub-Chinese-Top-Charts:中文开发者的开源项目精选 在全球范围内,GitHub已经成为了…

谷歌外链:提升网站权重的秘密武器!

谷歌外链之被称为提升网站权重的秘密武器,主要是因为它们对网站的搜索引擎排名有着直接且显著的影响 谷歌和其他搜索引擎使用外链作为衡量网站信任度和权威性的重要指标。当一个网站获得来自其他信誉良好的源的链接时,这被视为信任的投票。多个高质量链…

opencv-图像仿射变换

仿射变换就是将矩形变为平行四边形,而透视变换可以变成任意不规则四边形。实际上,仿射变换是透视变换的子集,仿射变换是线性变换,而透视变换不仅仅是线性变换。 仿射变换设计图像位置角度的变化,是深度学习预处理中常…

力扣SQL50 患某种疾病的患者 正则表达式

Problem: 1527. 患某种疾病的患者 在SQL查询中,REGEXP 是用于执行正则表达式匹配的操作符。正则表达式允许使用特殊字符和模式来匹配字符串中的特定文本。具体到你的查询,^DIAB1|\\sDIAB1 是一个正则表达式,它使用了一些特殊的通配符和符号。…

Vue:vue-router使用指南

一、简介 点击查看vue-router官网 Vue Router 是 Vue.js 的官方路由。它与 Vue.js 核心深度集成,让用 Vue.js 构建单页应用变得轻而易举。功能包括: 嵌套路由映射动态路由选择模块化、基于组件的路由配置-路由参数、查询、通配符-展示由 Vue.js 的过渡系…

DNS常见面试题

DNS是什么? 域名使用字符串来代替 IP 地址,方便用户记忆,本质上一个名字空间系统;DNS 是一个树状的分布式查询系统,但为了提高查询效率,外围有多级的缓存;DNS 就像是我们现实世界里的电话本、查…

电路板热仿真覆铜率,功率,结温,热阻率信息计算获取方法总结

🏡《电子元器件学习目录》 目录 1,概述2,覆铜率3,功率4,器件尺寸5,结温6,热阻1,概述 电路板热仿真操作是一个复杂且细致的过程,旨在评估和优化电路板内部的热分布及温度变化,以确保电子元件的可靠性和性能。本文简述在进行电路板的热仿真时,元器件热信息的计算方法…

59.DevecoStudio项目引入不同目录的文件进行函数调用

59.DevecoStudio ArkUI项目引入不同目录的文件进行函数调用 arkUi,ets,cj文件,ts文件的引用 import common from ohos.app.ability.common; import stringutils from ./uint8array2string; //index.ts的当前目录 import StringUtils2 from ../http2/uint8array2st…

python全栈开发《23.字符串的find与index函数》

1.补充说明上文 python全栈开发《22.字符串的startswith和endswith函数》 endswith和startswith也可以对完整(整体)的字符串进行判断。 info.endswith(this is a string example!!)或info.startswith(this is a string example!!)相当于bool(info this …

鸿蒙媒体开发【拼图】拍照和图片

拼图 介绍 该示例通过ohos.multimedia.image和ohos.file.photoAccessHelper接口实现获取图片,以及图片裁剪分割的功能。 效果预览 使用说明: 使用预置相机拍照后启动应用,应用首页会读取设备内的图片文件并展示获取到的第一个图片&#x…

Animate软件基础:关于补间动画中的图层

Animate 文档中的每一个场景都可以包含任意数量的时间轴图层。使用图层和图层文件夹可组织动画序列的内容和分隔动画对象。在图层和文件夹中组织它们可防止它们在重叠时相互擦除、连接或分段。若要创建一次包含多个元件或文本字段的补间移动的动画,请将每个对象放置…

go 中 string 并发写导致的 panic

类型的一点变化 在Go语言的演化过程中,引入了unsafe.String来取代之前的StringHeader结构体,这是为了提供更安全和简洁的字符串操作方式。 旧设计 (StringHeader 结构体) StringHeader注释发生了一点变动,被标注了 Deprecated,…