NLP——Transfromer 详解

news2024/9/24 15:20:38

Transformer总体架构图

在这里插入图片描述

  1. 输入部分:源文本嵌入层及其位置编码器、目标文本嵌入层及其位置编码器

  2. 编码器部分
    由N个编码器层堆叠而成
    每个编码器层由两个子层连接结构组成
    第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
    第二个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接

  3. 解码器部分
    由N个解码器层堆叠而成
    每个解码器层由三个子层连接结构组成
    第一个子层连接结构包括一个多头自注意力子层和规范化层以及一个残差连接
    第二个子层连接结构包括一个多头注意力子层和规范化层以及一个残差连接
    第三个子层连接结构包括一个前馈全连接子层和规范化层以及一个残差连接

  4. 输出部分包:线性层、softmax层

输入部分实现

1. 文本嵌入层
import torch
# 预定义的网络层torch.nn, 工具开发者已经帮助我们开发好的一些常用层, 
# 比如,卷积层, lstm层, embedding层等, 不需要我们再重新造轮子.
import torch.nn as nn
import math

# torch中变量封装函数Variable.
from torch.autograd import Variable

# 定义Embeddings类来实现文本嵌入层,这里s说明代表两个一模一样的嵌入层, 他们共享参数.
# 该类继承nn.Module, 这样就有标准层的一些功能, 这里我们也可以理解为一种模式, 我们自己实现的所有层都会这样去写.
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        """类的初始化函数, 有两个参数, d_model: 指词嵌入的维度, vocab: 指词表的大小."""
        # 接着就是使用super的方式指明继承nn.Module的初始化函数, 我们自己实现的所有层都会这样去写.
        super(Embeddings, self).__init__()
        # 之后就是调用nn中的预定义层Embedding, 获得一个词嵌入对象self.lut
        self.lut = nn.Embedding(vocab, d_model)
        # 最后就是将d_model传入类中
        self.d_model = d_model

    def forward(self, x):
        """可以将其理解为该层的前向传播逻辑,所有层中都会有此函数
           当传给该类的实例化对象参数时, 自动调用该类函数
           参数x: 因为Embedding层是首层, 所以代表输入给模型的文本通过词汇映射后的张量"""

        # 将x传给self.lut并与根号下self.d_model相乘作为结果返回

        # 让 embeddings vector 在增加 之后的 postion encoing 之前相对大一些的操作,
        # 主要是为了让position encoding 相对的小,这样会让原来的 embedding vector 中的信息在和 position encoding 的信息相加时不至于丢失掉
        # 让 embeddings vector 相对大一些
        return self.lut(x) * math.sqrt(self.d_model)
# 词嵌入维度是512维
d_model = 512

# 词表大小是1000
vocab = 1000

# 输入x是一个使用Variable封装的长整型张量, 形状是2 x 4
x = Variable(torch.LongTensor([[100,2,421,508],[491,998,1,221]]))

emb = Embeddings(d_model, vocab)
embr = emb(x)
print("embr:", embr)

2. 位置编码器

因为在Transformer的编码器结构中, 并没有针对词汇位置信息的处理,因此需要在Embedding层后加入位置编码器,将词汇位置不同可能会产生不同语义的信息加入到词嵌入张量中, 以弥补位置信息的缺失.

编码器部分实现

# 定义位置编码器类, 我们同样把它看做一个层, 因此会继承nn.Module    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        """位置编码器类的初始化函数, 共有三个参数, 分别是d_model: 词嵌入维度, 
           dropout: 置0比率, max_len: 每个句子的最大长度"""
        super(PositionalEncoding, self).__init__()

        # 实例化nn中预定义的Dropout层, 并将dropout传入其中, 获得对象self.dropout
        self.dropout = nn.Dropout(p=dropout)

        # 初始化一个位置编码矩阵, 它是一个0阵,矩阵的大小是max_len x d_model.
        pe = torch.zeros(max_len, d_model)

        # 初始化一个绝对位置矩阵, 在我们这里,词汇的绝对位置就是用它的索引去表示. 
        # 所以我们首先使用arange方法获得一个连续自然数向量,然后再使用unsqueeze方法拓展向量维度使其成为矩阵, 
        # 又因为参数传的是1,代表矩阵拓展的位置,会使向量变成一个max_len x 1 的矩阵, 
        position = torch.arange(0, max_len).unsqueeze(1)

        # 绝对位置矩阵初始化之后,接下来就是考虑如何将这些位置信息加入到位置编码矩阵中,
        # 最简单思路就是先将max_len x 1的绝对位置矩阵, 变换成max_len x d_model形状,然后覆盖原来的初始位置编码矩阵即可, 
        # 要做这种矩阵变换,就需要一个1xd_model形状的变换矩阵div_term,我们对这个变换矩阵的要求除了形状外,
        # 还希望它能够将自然数的绝对位置编码缩放成足够小的数字,有助于在之后的梯度下降过程中更快的收敛.  这样我们就可以开始初始化这个变换矩阵了.
        # 首先使用arange获得一个自然数矩阵, 但是细心的同学们会发现, 我们这里并没有按照预计的一样初始化一个1xd_model的矩阵, 
        # 而是有了一个跳跃,只初始化了一半即1xd_model/2 的矩阵。 为什么是一半呢,其实这里并不是真正意义上的初始化了一半的矩阵,
        # 我们可以把它看作是初始化了两次,而每次初始化的变换矩阵会做不同的处理,第一次初始化的变换矩阵分布在正弦波上, 第二次初始化的变换矩阵分布在余弦波上, 
        # 并把这两个矩阵分别填充在位置编码矩阵的偶数和奇数位置上,组成最终的位置编码矩阵.
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 这样我们就得到了位置编码矩阵pe, pe现在还只是一个二维矩阵,要想和embedding的输出(一个三维张量)相加,
        # 就必须拓展一个维度,所以这里使用unsqueeze拓展维度.
        pe = pe.unsqueeze(0)

        # 最后把pe位置编码矩阵注册成模型的buffer,什么是buffer呢,
        # 我们把它认为是对模型效果有帮助的,但是却不是模型结构中超参数或者参数,不需要随着优化步骤进行更新的增益对象. 
        # 注册之后我们就可以在模型保存后重加载时和模型结构与参数一同被加载.
        self.register_buffer('pe', pe)

    def forward(self, x):
        """forward函数的参数是x, 表示文本序列的词嵌入表示"""
        # 在相加之前我们对pe做一些适配工作, 将这个三维张量的第二维也就是句子最大长度的那一维将切片到与输入的x的第二维相同即x.size(1),
        # 因为我们默认max_len为5000一般来讲实在太大了,很难有一条句子包含5000个词汇,所以要进行与输入张量的适配. 
        # 最后使用Variable进行封装,使其与x的样式相同,但是它是不需要进行梯度求解的,因此把requires_grad设置成false.
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        # 最后使用self.dropout对象进行'丢弃'操作, 并返回结果.
        return self.dropout(x)
# 词嵌入维度是512维
d_model = 512

# 置0比率为0.1
dropout = 0.1

# 句子最大长度
max_len=60
# 输入x是Embedding层的输出的张量, 形状是2 x 4 x 512
x = embr


解码器部分实现



输出部分实现



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1981779.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Liunx---批量安装服务器

目录 一、环境准备 一、环境准备 1.准备一台rhel7的主机并且打开主机图形。 2.配置好可用ip 3.做kickstart自动安装脚本后面需要用到DHCP,关闭VMware DHCP功能 二、安装图形化kickstart自动安装脚本的工具 yum install system-config-kickstart ----安装图形化生…

Guitar Pro简谱怎么输入 ?如何把简谱设置到六线谱的下面?

一、Guitar Pro简谱怎么输入 简谱在音乐学习、演奏、创作和传播中都起着非常重要的作用,是音乐领域不可或缺的工具。吉他乐谱的制作可以使简谱,也可以使五线谱、六线谱等多种形式,这几种乐谱都可以使用Guitar Pro来完成。下面来看看Guitar Pr…

springboot大学生社会实践管理信息系统-计算机毕业设计源码61970

目 录 摘要 Abstract 1 绪论 1.1 研究背景与意义 1.2 国内外研究现状 1.3 论文结构与章节安排 2 系统分析 2.1 可行性分析 2.1.1技术可行性 2.1.2 经济可行性 2.1.3 社会可行性 2.2 系统流程分析 2.2.1 数据新增流程 2.2.2 数据删除流程 2.3 系统功能分析 2.3.…

谷歌账号被停用后,申诉没有反馈或者被拒绝后怎么办?附:谷歌账号申诉信要点和模板

有一些朋友在登录谷歌账号的时候,或者在是用谷歌账号的过程中突然被强制退出来,然后再次登录的时候就遇到了下面的提醒:您的账号已停用,而且原因通常是两大类:1)谷歌账号与其他多个账号一起创建或使用的&am…

Rust 所有权

所有权 Rust的核心特性就是所有权所有程序在运行时都必须管理他们使用计算机内存的方式 有些语言有垃圾收集机制,在程序运行时,他们会不断地寻找不再使用的内存在其他语言中,程序员必须显式的分配和释放内存 Rust采用了第三种方式&#xff1…

FFmpeg内存对齐简述

目录 引文 行字节数的计算 ffmpeg中的align ffmpeg中的linesize 内容参考 引文 在ffmpeg的使用过程中有时会发现align这个参数,那么这个参数代表什么意思,不同的值会产生什么影响呢,详见下文。 行字节数的计算 理解内存对齐之前首先要…

无人机之导航系统篇

一、导航系统组成 包括惯性导航系统、卫星导航系统、视觉导航系统等。 二、导航原理 利用传感器感知无人机的位置、速度和姿态信息,结合地图数据和导航算法,计算出无人机当前的位置和航向,从而引导无人机按照预设的航线飞行。 三、导航精…

Linux文件或图片名称中文乱码解决【适用于centos、ubuntu等系统】

👨‍🎓博主简介 🏅CSDN博客专家   🏅云计算领域优质创作者   🏅华为云开发者社区专家博主   🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入&#xff01…

【unittest】TestSuite搭建测试用例示例二

1.1 打开串口示例 常用的模组则包含AT指令测试,或串口数据测试,则可添加串口配置,将指令通过串口发送出去,如下所示: import serial def open_serial_port(port, baudrate115200, timeout2): try: # 创建并配置串…

Vue 3+Vite+Eectron从入门到实战系列之一环境安装篇

Electron 都应该不会陌生了,是一个使用 JavaScript、HTML 和 CSS 构建桌面应用的框架。通过将 Chromium 和 Node.js 嵌入到其二进制文件中,Electron 允许你维护一个 JavaScript 代码库并创建可在 Windows、macOS 和 Linux 上运行的跨平台应用 - 无需原生开发经验。 实现效果…

YOLOv6训练自己的数据集

文章目录 前言一、YOLOv6简介二、环境搭建三、构建数据集四、修改配置文件①数据集文件配置②权重下载③模型文件配置 五、模型训练和测试模型训练模型测试 总结 前言 提示:本文是YOLOv6训练自己数据集的记录教程,需要大家在本地已配置好CUDA,cuDNN等环…

思源笔记结合群晖WebDav与cpolar内网穿透实现跨网络笔记云同步

文章目录 前言1. 开启群晖WebDav 服务2. 本地局域网IP同步测试3. 群晖安装Cpolar4. 配置远程同步地址5. 笔记远程同步测试6. 固定公网地址7. 配置固定远程同步地址 前言 本教程主要分享如何将思源笔记、cpolar内网穿透和群晖WebDav三者相结合,实现思源笔记的云同步…

如何使用代理IP进行电子邮件保护?

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 前言 随着企业信息化的深入发展,电子邮件在私人生活和商业运营中起到越来越重要的作用,随之而来电子邮件…

掌握eBay刊登:十大工具助力卖家脱颖而出

在经济全球化的浪潮中,eBay作为全球最大的跨境电商平台之一,为卖家提供了一个展示商品、拓展市场的广阔舞台。然而,平台越大,意味着商家之间的竞争越激烈。如何在eBay上有效刊登商品,是卖家吸引用户的关键步骤。本文将…

500元蓝牙耳机排行榜有哪些?四款百元蓝牙耳机品牌排行推荐

在如今这个充满科技魅力的时代,蓝牙耳机已成为我们日常生活中不可或缺的一部分,无论是沉浸在音乐的世界中,还是在繁忙的通勤路上享受片刻宁静,一副优秀的蓝牙耳机都能为我们带来无与伦比的听觉享受,面对市场上琳琅满目…

合作文章(IF=5.9)|16s和非靶代谢组分析揭示亚麻籽木脂素对PAM过量诱导的肝毒性的保护作用

研究背景 扑热息痛(PAM)是世界上最常用的镇痛解热的药物之一。在肝酶细胞色素P450 Cyp2E1和Cyp1A2PAM酶的作用下,PAM转化为一种高活性的代谢物乙酰对位苯醌亚胺(NAPQI),通过与谷胱甘肽(GSH)偶联可解毒为无毒的谷胱甘肽-NAPQI。然…

视频汇聚平台EasyCVR接入移动执法记录仪,视频无法播放且报错500是什么原因?

GB28181国标视频汇聚平台EasyCVR视频管理系统以其强大的拓展性、灵活的部署方式、高性能的视频能力和智能化的分析能力,为各行各业的视频监控需求提供了优秀的解决方案。视频智能分析平台EasyCVR支持多协议接入,兼容多类型的设备,包括IPC、NV…

自动化测试中元素定位失败的解决策略

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 一、引言 自动化测试是软件开发流程中的重要组成部分,它能够帮助测试人员快速地验证应用程序的功能是否符合预期。然而,在自动化测试的过程…

互联网解决方案-文件存储方案:seafile真实案例

目录 seafile可靠性保证 事件驱动 seafile.log events.log 事件驱动好处 本地联思文件同步云联思真实案例 本地联思文件同步云联思架构 云联思客户端检查文件API 本地联思访问客户端封装 本地联思队列消费检查 实践过程中的弯路 文件目录处理 move = copy & de…

ANTD PRO VUE使用

目录 1.访问Antd Pro Vue官网 2.安装 3.目录结构 4.安装运行 5.npm run serve可能会报以下错误 6.解决办法 ​7.缩放会报以下错误 ​8.解决办法 1.访问Antd Pro Vue官网 https://pro.antdv.com 点击开始使用 2.安装 从 GitHub 仓库中直接安装最新的脚手架代码。 git…