Transformer 中的绝对位置编码
(以下由gpt 生成)
Transformer 的绝对位置编码(Absolute Position Encoding, APE)是用于对序列数据中的位置信息进行建模的一种方法。在 Transformer 的架构中,输入数据(如句子中的单词或字符)通过词嵌入层转化为连续的向量表示。这些向量本身并不包含任何位置信息,而 Transformer 是一种无序模型(对输入的顺序敏感性取决于位置信息的建模方式)。因此,需要加入位置信息使模型能够感知输入序列中每个元素的位置。
绝对位置编码通过为每个序列位置生成固定的编码向量,将这些向量添加到输入的词嵌入中,从而将位置信息显式地融入模型。
- 绝对位置编码的公式
解释:
正弦和余弦函数的使用使得不同位置的编码之间具有周期性。
1/10000^(2i/d)控制了不同维度的频率,使得高维特征的变化比低维特征更快。
2. 绝对位置编码的性质
固定性:绝对位置编码是固定的,不需要通过训练学习。这减少了模型的参数量。
平滑性:相邻位置的编码向量之间具有平滑的变化,便于模型捕捉到序列中的局部连续性。
维度分布:不同维度的频率不同,能够捕捉到不同尺度的位置信息
3. 绝对位置编码的使用
在 Transformer 中,绝对位置编码通常与词嵌入相加后输入到模型中:
# 假设 word_embeddings 是嵌入向量 (seq_len, d_model)
word_embeddings = torch.randn(seq_len, d_model)
input_with_pos = word_embeddings + pos_encoding # 加入位置编码
这种加法操作将词嵌入与位置编码结合在一起,使得模型既能够感知词语的语义信息,也能感知其在序列中的位置信息。
4. 实现代码
import torch
import math
def generate_position_encoding(seq_len, d_model):
"""
生成绝对位置编码的函数。
:param seq_len: 序列的长度
:param d_model: 嵌入向量的维度
:return: 位置编码矩阵 (seq_len, d_model)
"""
# 初始化位置编码矩阵
position_encoding = torch.zeros(seq_len, d_model)
# 生成位置索引和维度索引
position = torch.arange(0, seq_len).unsqueeze(1) # (seq_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) # (d_model // 2,)
# 应用正弦和余弦函数
position_encoding[:, 0::2] = torch.sin(position * div_term) # 偶数维度
position_encoding[:, 1::2] = torch.cos(position * div_term) # 奇数维度
return position_encoding
# 示例
seq_len = 32 #10 # 序列长度
d_model = 128 #16 # 嵌入维度
pos_encoding = generate_position_encoding(seq_len, d_model)
print(pos_encoding)
# 可视化位置编码
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(10, 6))
ax = sns.heatmap(pos_encoding.numpy(), cmap='coolwarm', annot=False, cbar=True)
# 将横轴放置在顶部
ax.xaxis.set_ticks_position('top') # 将x轴移至顶部
plt.title('Absolute Position Encoding')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position in Sequence')
# 调整布局以避免标签重叠
plt.subplots_adjust(top=0.85)
plt.show()
print()
5. 绝对位置编码的优缺点
优点:
无参数化:绝对位置编码是固定的,不会增加模型的参数。
周期性和可扩展性:正弦和余弦函数的周期性使得编码具有平滑的性质,且理论上可以扩展到更长的序列。
简单易用:只需将固定的编码添加到词嵌入中即可。
缺点:
不灵活:固定的位置编码对任务或数据不具备适应性,可能限制模型的表现。
长序列表示问题:对于非常长的序列,编码的分辨率可能不足(由于正弦和余弦函数的周期性)。
相对位置信息不足:绝对位置编码只关注位置本身,无法直接捕捉相对位置关系