Transformer--输入部分

news2024/11/18 1:32:18

🏷️上文我们简单介绍了Transformer模型的总体架构,本章我们主要介绍其输入部分


 📖前言

 📖文本嵌入层的作用

 📖位置编码器的作用 


📖前言

输入部分主要包括源文本嵌入层以及位置编码器,目标文本嵌入层以及位置编码器


 📖文本嵌入层的作用

🏷️无论是源文本嵌入还是目标文本嵌入,都是为了将文本中词汇的数字表示转变为向量表示, 希望在这样的高维空间捕捉词汇间的关系.

  • 文本嵌入层的代码分析:
# 导入必备的工具包
import torch

# 预定义的网络层torch.nn, 工具开发者已经帮助我们开发好的一些常用层, 
# 比如,卷积层, lstm层, embedding层等, 不需要我们再重新造轮子.
import torch.nn as nn

# 数学计算工具包
import math

# torch中变量封装函数Variable.
from torch.autograd import Variable

# 定义Embeddings类来实现文本嵌入层,这里s说明代表两个一模一样的嵌入层, 他们共享参数.
# 该类继承nn.Module, 这样就有标准层的一些功能, 这里我们也可以理解为一种模式, 我们自己实现的所有层都会这样去写.
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        """类的初始化函数, 有两个参数, d_model: 指词嵌入的维度, vocab: 指词表的大小."""
        # 接着就是使用super的方式指明继承nn.Module的初始化函数, 我们自己实现的所有层都会这样去写.
        super(Embeddings, self).__init__()
        # 之后就是调用nn中的预定义层Embedding, 获得一个词嵌入对象self.lut
        self.lut = nn.Embedding(vocab, d_model)
        # 最后就是将d_model传入类中
        self.d_model = d_model

    def forward(self, x):
        """可以将其理解为该层的前向传播逻辑,所有层中都会有此函数
           当传给该类的实例化对象参数时, 自动调用该类函数
           参数x: 因为Embedding层是首层, 所以代表输入给模型的文本通过词汇映射后的张量"""

        # 将x传给self.lut并与根号下self.d_model相乘作为结果返回

        # 让 embeddings vector 在增加 之后的 postion encoing 之前相对大一些的操作,
        # 主要是为了让position encoding 相对的小,这样会让原来的 embedding vector 中的信息在和 position encoding 的信息相加时不至于丢失掉
        # 让 embeddings vector 相对大一些
        return self.lut(x) * math.sqrt(self.d_model)

 📖位置编码器的作用 

🏷️因为在Transformer的编码器结构中, 并没有针对词汇位置信息的处理,因此需要在Embedding层后加入位置编码器,将词汇位置不同可能会产生不同语义的信息加入到词嵌入张量中, 以弥补位置信息的缺失.

  •  位置编码器的代码分析
# 定义位置编码器类, 我们同样把它看做一个层, 因此会继承nn.Module    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        """位置编码器类的初始化函数, 共有三个参数, 分别是d_model: 词嵌入维度, 
           dropout: 置0比率, max_len: 每个句子的最大长度"""
        super(PositionalEncoding, self).__init__()

        # 实例化nn中预定义的Dropout层, 并将dropout传入其中, 获得对象self.dropout
        self.dropout = nn.Dropout(p=dropout)

        # 初始化一个位置编码矩阵, 它是一个0阵,矩阵的大小是max_len x d_model.
        pe = torch.zeros(max_len, d_model)

        # 初始化一个绝对位置矩阵, 在我们这里,词汇的绝对位置就是用它的索引去表示. 
        # 所以我们首先使用arange方法获得一个连续自然数向量,然后再使用unsqueeze方法拓展向量维度使其成为矩阵, 
        # 又因为参数传的是1,代表矩阵拓展的位置,会使向量变成一个max_len x 1 的矩阵, 
        position = torch.arange(0, max_len).unsqueeze(1)

        # 绝对位置矩阵初始化之后,接下来就是考虑如何将这些位置信息加入到位置编码矩阵中,
        # 最简单思路就是先将max_len x 1的绝对位置矩阵, 变换成max_len x d_model形状,然后覆盖原来的初始位置编码矩阵即可, 
        # 要做这种矩阵变换,就需要一个1xd_model形状的变换矩阵div_term,我们对这个变换矩阵的要求除了形状外,
        # 还希望它能够将自然数的绝对位置编码缩放成足够小的数字,有助于在之后的梯度下降过程中更快的收敛.  这样我们就可以开始初始化这个变换矩阵了.
        # 首先使用arange获得一个自然数矩阵, 但是细心的同学们会发现, 我们这里并没有按照预计的一样初始化一个1xd_model的矩阵, 
        # 而是有了一个跳跃,只初始化了一半即1xd_model/2 的矩阵。 为什么是一半呢,其实这里并不是真正意义上的初始化了一半的矩阵,
        # 我们可以把它看作是初始化了两次,而每次初始化的变换矩阵会做不同的处理,第一次初始化的变换矩阵分布在正弦波上, 第二次初始化的变换矩阵分布在余弦波上, 
        # 并把这两个矩阵分别填充在位置编码矩阵的偶数和奇数位置上,组成最终的位置编码矩阵.
        div_term = torch.exp(torch.arange(0, d_model, 2) *
                             -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 这样我们就得到了位置编码矩阵pe, pe现在还只是一个二维矩阵,要想和embedding的输出(一个三维张量)相加,
        # 就必须拓展一个维度,所以这里使用unsqueeze拓展维度.
        pe = pe.unsqueeze(0)

        # 最后把pe位置编码矩阵注册成模型的buffer,什么是buffer呢,
        # 我们把它认为是对模型效果有帮助的,但是却不是模型结构中超参数或者参数,不需要随着优化步骤进行更新的增益对象. 
        # 注册之后我们就可以在模型保存后重加载时和模型结构与参数一同被加载.
        self.register_buffer('pe', pe)

    def forward(self, x):
        """forward函数的参数是x, 表示文本序列的词嵌入表示"""
        # 在相加之前我们对pe做一些适配工作, 将这个三维张量的第二维也就是句子最大长度的那一维将切片到与输入的x的第二维相同即x.size(1),
        # 因为我们默认max_len为5000一般来讲实在太大了,很难有一条句子包含5000个词汇,所以要进行与输入张量的适配. 
        # 最后使用Variable进行封装,使其与x的样式相同,但是它是不需要进行梯度求解的,因此把requires_grad设置成false.
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        # 最后使用self.dropout对象进行'丢弃'操作, 并返回结果.
        return self.dropout(x)

🏷️还有一部分知识设计绘制词汇向量中特征的分布曲线 ,其思想有些抽象,我们只需要知道我们通过上面的操作把嵌入的数值很好的匹配到正弦和余弦图像上,值域的范围都在[-1,1],我们可以更快的计算梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1955773.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Vulnhub系列】Vulnhub_SecureCode1靶场渗透(原创)

【Vulnhub系列靶场】Vulnhub_SecureCode1靶场渗透 原文转载已经过授权 原文链接:Lusen的小窝 - 学无止尽,不进则退 (lusensec.github.io) 一、环境配置 1、从百度网盘下载对应靶机的.ova镜像 2、在VM中选择【打开】该.ova 3、选择存储路径&#xff0…

高效管理基础设施:掌握 Terraform 的 templatefile 函数技巧

由于Terraform的许可证变更,我曾经担心未来的动向,但IBM宣布收购HashiCorp后,我感到有所安心。我将继续关注相关动向。 本文将介绍Terraform的内置函数templatefile。 什么是templatefile函数? templatefile函数用于读取指定路…

Ip2region - 基于xdb离线库的Java IP查询工具提供给脚本调用

文章目录 Pre效果实现git clone编译测试程序将ip2region.xdb放到指定目录使用改进最终效果 Pre OpenSource - Ip2region 离线IP地址定位库和IP定位数据管理框架 Ip2region - xdb java 查询客户端实现 效果 最终效果 实现 git clone git clone https://github.com/lionsou…

用SQL将数值转换为进度条

hi,大家好呀! 最近天气是真的热,上周我们在某音做了一次直播,主要是讲解一下表,那我们最近的会在视频号,也就是微信上给大家直播讲解一下查询,直播预告晚点会分享给大家,请大家关注…

队列queue介绍

队列是一种常见的数据结构,它遵循FIFO(先进先出)的原则,即最先进入队列的元素将最先被移除。队列在Java中有多种实现方式,其中包括: 1.ArrayDeque:这是一个基于数组的双端队列,可以在…

模拟实现短信登录功能 (session 和 Redis 两种代码实例) 带前端演示

目录 整体流程 发送验证码 短信验证码登录、注册 校验登录状态 基于 session 实现登录 实现发送短信验证码功能 1. 前端发送请求 2. 后端处理请求 3. 演示 实现登录功能 1. 前端发送请求 2. 后端处理请求 校验登录状态 1. 登录拦截器 2. 注册拦截器 3. 登录完整…

Boost_Searcher测试用例编写

功能描述: 用户在客户端页面,在搜索框输入关键词,页面将显示Boost库中所有包含该关键词的内容。 界面功能兼容性易用性安全性性能弱网安装/卸载 编写测试用例: 功能: 在浏览器搜索框中输入ip地址与端口号&#xff0…

MySQL的库操作和表操作

文章目录 MYSQLSQL语句分类服务器,数据库和表的关系 库操作表操作 MYSQL SQL语句分类 DDL【data definition language】 数据定义语言,用来维护存储数据的结构代表指令: create, drop, alterDML【data manipulation language】 数据操纵语言&#xff0…

Playwright 的使用

Playwright 的特点 支持当前所有主流浏览器,包括 Chrome 和 Edge (基于 Chromiuns), Firefox , Safari 支持移动端页面测试,使用设备模拟技术,可以让我们在移动Web 浏览器中测试响应式的 Web 应用程序 支持所有浏览…

做一个能和你互动玩耍的智能机器人之四--固件

在openbot的firmware目录下我们能够找到arduino的固件源码和相关的文档。 openbot的controller目录下,是控制器的代码目录,用来控制机器人做一些动作。未来的目标是加入大模型,使其能够理解人的语言和动作来控制。 固件代码,支持…

数据结构 -- 算法的时间复杂度和空间复杂度

数据结构 -- 算法的时间复杂度和空间复杂度 1.算法效率1.1 如何衡量一个算法的好坏1.2 算法的复杂度 2.时间复杂度2.1 时间复杂度的概念2.2 大O的渐进表示法2.3常见时间复杂度计算举例 3.空间复杂度4. 常见复杂度对比 1.算法效率 1.1 如何衡量一个算法的好坏 如何衡量一个算法…

数据库实验:SQL Server基本表单表查询

一、实验目的: 1、掌握使用SQL语法实现单表查询 二、实验内容: 1. 查询订购日期为2001年5月22日的订单情况。(Orders)(时间日期的表达方式为 dOrderDate ‘2001-5-22’,类似字符串,使用单引号…

Linux---git工具

目录 初步了解 基本原理 基本用法 安装git 拉取远端仓库 提交三板斧 1、添加到缓存区 2、提交到本地仓库 3、提交到远端 其他指令补充 多人协作管理 windows用户提交文件 Linux用户提交文件 初步了解 在Linux中,git是一个指令,可以帮助我们做…

Python爬虫-中国汽车市场月销量数据

前言 本文是该专栏的第34篇,后面会持续分享python爬虫干货知识,记得关注。 在本文中,笔者将通过某汽车平台,来采集“中国汽车市场”的月销量数据。 具体实现思路和详细逻辑,笔者将在正文结合完整代码进行详细介绍。废话不多说,下面跟着笔者直接往下看正文详细内容。(附…

【原创】使用keepalived虚拟IP(VIP)实现MySQL的高可用故障转移

1. 背景 A、B服务器均部署有MySQL数据库,且互为主主。此处为A、B服务器部署MySQL数据库实现高可用的部署,当其中一台MySQL宕机后,VIP可自动切换至另一台MySQL提供服务,实现故障的自动迁移,实现高可用的目的。具体流程…

微服务-MybatisPlus下

微服务-MybatisPlus下 文章目录 微服务-MybatisPlus下1 MybatisPlus扩展功能1.1 代码生成1.2 静态工具1.3 逻辑删除1.4 枚举处理器1.5 JSON处理器**1.5.1.定义实体****1.5.2.使用类型处理器** **1.6 配置加密(选学)**1.6.1.生成秘钥**1.6.2.修改配置****…

哪里可以查找短视频素材?6个素材查找下载渠道分享!

在短视频的风靡浪潮中,不少创作者纷纷投身于这一领域,无论是分享生活点滴还是进行商业宣传,高质量的短视频内容总能吸引众多观众的目光。然而,精良的短视频制作离不开优质的素材支持。本文将为大家介绍6个优秀的高质量短视频素材下…

ProxmoxPVE虚拟化平台--U盘挂载、硬盘直通

界面说明 ### 网络设置 ISO镜像文件 虚拟机中使用到的磁盘 挂载USB设备 这个操作比较简单,不涉及命令 选中需要到的虚拟机,然后选择: 添加->USB设置选择使用USB端口:选择对应的U盘即可 硬盘直通 通常情况下我们需要将原有…

前端Long类型精度丢失:后端处理策略

文章目录 精度丢失的具体原因解决方法1. 使用 JsonSerialize 和 ToStringSerializer2. 使用 JsonFormat 注解3. 全局配置解决方案 结论 开发商城管理系统的品牌管理界面时,发现一个问题,接口返回品牌Id和页面展示的品牌Id不一致,如接口返回的…