【大规模语言模型:从理论到实践】Transformer中PositionalEncoder详解

news2024/11/16 2:42:50

书籍链接:大规模语言模型:从理论到实践

第15页位置表示层代码详解
PositionalEncoder

1. 构造函数 __init__()

def __init__(self, d_model, max_seq_len=80):
    super().__init__()
    self.d_model = d_model  # 嵌入的维度(embedding dimension)
  • d_model: 表示输入词向量的维度。
  • max_seq_len: 表示句子的最大长度(最大序列长度)。
  • self.d_model: 保存词嵌入的维度。
创建 PE 矩阵
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
    for i in range(0, d_model, 2):
        pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
        pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

这里,我们为所有可能的位置 pos 和维度 i 生成了位置编码矩阵 pe。编码规则是使用正弦和余弦函数来生成位置编码:

  • 对于每个位置 pos,在每个嵌入维度 i 上:

    • 奇数维度使用正弦函数 sin(pos / 10000^(2i/d_model))
    • 偶数维度使用余弦函数 cos(pos / 10000^(2i/d_model))

    这样做的好处是,正弦和余弦函数生成了一个平滑的周期性变化,使得位置编码具有一定的连续性和距离信息。

pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
  • pe.unsqueeze(0):将 pe 的第一个维度扩展为 1,这是为了便于后续将其与输入批次结合在一起。
  • register_buffer:将 pe 作为一个不可训练的参数(Tensor),并注册为模型的一部分,以确保其在模型的 .cuda().to(device) 等操作时也能够转移到对应设备上。

2. 前向传播 forward()

def forward(self, x):
    x = x * math.sqrt(self.d_model)  # 对输入乘以嵌入维度的平方根,使得它们的值更大一些
  • 这里的 x 是输入的词嵌入(word embeddings),即一个形状为 [batch_size, seq_len, d_model] 的张量。
  • x = x * math.sqrt(self.d_model):这一行操作是为了放大嵌入值,使得单词嵌入值的范围更加合适。
seq_len = x.size(1)  # 获取序列长度(句子长度)
x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
  • seq_len = x.size(1):获取当前输入序列的长度。
  • self.pe[:, :seq_len]:根据当前序列长度,从 pe 中提取对应的位置信息(只取前 seq_len 个位置的编码)。
  • x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda():将位置信息 pe 添加到输入词嵌入中。requires_grad=False 表示不对位置编码进行梯度更新。

3. 详细分析x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()

这行代码在位置编码器中的作用是将预计算好的位置编码矩阵 pe 加到输入的词嵌入矩阵 x 上。这是为了在词嵌入的基础上加入位置信息,使模型能够同时使用词汇语义和位置信息。我们分解这句话的各个部分:

x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
1. self.pe[:, :seq_len]
  • self.pe 是我们在初始化时生成的位置编码矩阵,其形状为 [1, max_seq_len, d_model]

    • 这里的 1 是 batch 维度,用来保持与输入张量 x 形状的一致性。
    • max_seq_len 是句子可能的最大长度,表示可以编码的最大序列长度。
    • d_model 是词嵌入的维度。
  • self.pe[:, :seq_len] 表示从 pe 矩阵中取出前 seq_len 个位置的编码。这个操作的作用是根据输入句子的实际长度(seq_len)来选择对应长度的位置信息。例如,如果 seq_len 是 50,则取出 pe 中前 50 行的编码。

2. Variable(self.pe[:, :seq_len], requires_grad=False)
  • Variable 是用于包裹张量,使其在反向传播中能够区分哪些需要计算梯度,哪些不需要。
    • requires_grad=False 表示位置编码 pe 不参与梯度计算,位置编码是一个固定值,不会像模型权重那样进行训练或更新。

注意: 在较新的版本的 PyTorch 中,Variable 已经被整合到了 Tensor 中,不再需要显式使用 Variable。直接使用张量即可,它们本身已经具有 requires_grad 属性。

3. .cuda()
  • .cuda() 将张量移动到 GPU 上进行计算,确保模型的所有张量在同一个设备上。如果你使用的是 CPU,这一部分会报错或需要改成 .to(device),以便适应不同设备。
4. x + self.pe[:, :seq_len]
  • x 是输入的词嵌入矩阵,形状为 [batch_size, seq_len, d_model]
  • self.pe[:, :seq_len] 是位置编码矩阵,形状为 [1, seq_len, d_model],即与 x 的第二、第三维度一致。
  • 加法操作x + self.pe[:, :seq_len] 表示将对应位置的词嵌入和位置编码逐元素相加。这个加法是一个广播操作,即 self.pe 的第一个维度为 1,自动扩展到与 xbatch_size 相同大小,然后再进行相加操作。
5. self.pe[:, :seq_len]self.pe[:, :seq_len, :]相互替换

两者在功能上是等价的,但后者更明确地表达了正在获取 pe 矩阵的所有维度。这种做法在某些情况下可以提高代码的可读性,特别是当你的张量具有多个维度时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2108771.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Echart 环形图 特殊字体 富文本

注&#xff1a;特殊字体需要UI人员提供一下 .ttf 文件 完整代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">&…

QAM星座图平均功率能量

文章目录 引言结论计算&推导推导公式数值结果验证参考文献 引言 本文主要参考了博客1&#xff0c;文章写的比较漂亮。但可惜推导过程是错误的 结论 先说结论&#xff0c;对于M-QAM调制而言&#xff0c;QAM符号的平均能量 E s E_{s} Es​ 可以由下式计算得到 E s ( M …

使用Qt this->pos,和event->pos 实现界面跟随鼠标移动

对比 this->pos() 表示当前窗口左上角相对于整个桌面屏幕的位置。 如下图所示 event->globalPos() 是当前鼠标点击的位置相对于桌面的位置。 想要做到鼠标界面跟随鼠标左键移动&#xff0c;就需要计算他们的相对位置。 最后让鼠标移动到新的位置的时候&#xff0c;使用…

【SRC挖掘】越权漏洞——burp插件被动检测越权漏洞,一个插件让挖洞效率翻倍!Autorize

越权与未授权漏洞 越权漏洞什么是越权漏洞&#xff1f;Autorize插件安装使用步骤拦截过滤器 越权漏洞 什么是越权漏洞&#xff1f; 越权漏洞是指应用程序未对当前用户操作的身份权限进行严格校验&#xff0c;导致用户可以操作超出自己管理权限范围的功能&#xff0c;从而操作…

硬件工程师笔试面试——继电器

目录 6、继电器 6.1 基础 继电器原理图 继电器实物图 6.1.1 概念 6.1.2 结构组成及工作 6.1.3 应用场景 6.1.4 优点与缺点 6.1.5 继电器工作原理 6.2 相关问题 6.2.1 如何选择合适的继电器满足特定的应用需求 6.2.2 继电器在汽车电子系统中通常承担那些角色 6.2.3…

Android调整第三方库PickerView宽高--回忆录

一、效果 // 时间选择implementation com.contrarywind:Android-PickerView:4.1.9 多年前&#xff0c;使用到事件选择器&#xff0c;但是PickerView默认宽度使满屏的&#xff0c;不太符合业务需求&#xff0c;当时为此花了许多时间&#xff0c;最终找到了解决方案&#xff0c;…

二维高斯函数的两种形式

第一种形式很常见 多元正态分布 多元正态分布&#xff08;Multivariate Normal Distribution&#xff09;&#xff0c;也称为多变量正态分布或多维正态分布&#xff0c;是统计学中一种重要的概率分布&#xff0c;用于描述多个随机变量的联合分布。 假设有 n n n 个随机变量…

自己设计的QT系统,留个档

注册登录 主界面展示 天气预报 音乐播放

卷积神经网络与小型全连接网络在MNIST数据集上的对比

卷积神经网络&#xff08;CNN&#xff09; 深度卷积神经网络中&#xff0c;有如下特性 很多层&#xff08;Compositionality&#xff0c;组合性&#xff09;: 深度卷积神经网络通常由多层卷积和非线性激活函数组成。这种多层结构使得网络能够逐步提取和组合低层次的特征&…

shell 学习笔记:数组

目录 1. 定义数组 2. 读取数组元素值 3. 关联数组 4. 在数组前加一个感叹号 ! 可以获取数组的所有键 5. 在数组前加一个井号 # 获取数组的长度 6. 数组初始化的时候&#xff0c;也可以用变量 7. 循环输出数组的方法 7.1 for循环输出 7.2 while循环输出 7.2.1 …

大数据-120 - Flink Window 窗口机制-滑动时间窗口、会话窗口-基于时间驱动基于事件驱动

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…

Redis实战宝典:开发规范与最佳实践

目录标题 Key命名设计&#xff1a;可读性、可管理性、简介性Value设计&#xff1a;拒绝大key控制Key的生命周期&#xff1a;设定过期时间时间复杂度为O(n)的命令需要注意N的数量禁用命令&#xff1a;KEYS、FLUSHDB、FLUSHALL等不推荐使用事务删除大key设置合理的内存淘汰策略使…

Java | Leetcode Java题解之第387题字符串中的第一个唯一字符

题目&#xff1a; 题解&#xff1a; class Solution {public int firstUniqChar(String s) {Map<Character, Integer> position new HashMap<Character, Integer>();Queue<Pair> queue new LinkedList<Pair>();int n s.length();for (int i 0; i …

【python因果推断库8】工具变量回归与使用 pymc 验证工具变量1

目录 工具变量回归与使用 pymc 验证工具变量 回归机制与局部平均处理效应 旁白&#xff1a;从多元正态分布中采样 import arviz as az import daft import matplotlib.pyplot as plt import numpy as np import pandas as pd import pymc as pm import scipy from matplotli…

如何阅读PyTorch文档及常见PyTorch错误

如何阅读PyTorch文档及常见PyTorch错误 文章目录 如何阅读PyTorch文档及常见PyTorch错误阅读PyTorch文档示例常见Pytorch错误Tensor在不同设备上维度不匹配cuda内存不足张量类型不匹配 参考 PyTorch文档查看https://pytorch.org/docs/stable/ torch.nn -> 定义神经网络 torc…

红队攻防 | 利用GitLab nday实现帐户接管

在一次红队任务中&#xff0c;目标是一家提供VoIP服务的公司。该目标拥有一些重要的客户&#xff0c;如政府组织&#xff0c;银行和电信提供商。该公司要求外部参与&#xff0c;资产测试范围几乎是公司拥有的每一项互联网资产。 第一天是对目标进行信息收集。这一次&#xff0…

结构开发笔记(七):solidworks软件(六):装配摄像头、摄像头座以及螺丝,完成摄像头结构示意图

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/141931518 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…

成功之路:如何获得机器学习和数据科学实习机会

一年内获得两份实习机会的数据科学家的建议和技巧 欢迎来到雲闪世界。在当今竞争激烈的就业市场中&#xff0c;获得数据科学实习机会可以成为您在科技领域取得成功的门票。 但申请者如此之多&#xff0c;你该如何脱颖而出呢&#xff1f; 无论您是学生、应届毕业生还是想要转行…

IDEA2024.2最新工具下载

​软件使用 1、解压缩包 2、打开如图第三个 3、运行过十来秒等待提示以下信息即可

Ubuntu 无法全局安装 node 包

Anchor: $: cat /etc/lsb* DISTRIB_IDUbuntu DISTRIB_RELEASE22.04 DISTRIB_CODENAMEjammy DISTRIB_DESCRIPTION"Ubuntu 22.04.4 LTS" $: node -v v20.17.0 $: npm -v 10.8.2Question: $: npm install -g docsify-cli结果&#xff1a;超时或者如下图 Answer: 有…