位置编码公式
偶数位置用sin,奇数位置用cos. d_model 表示token的维度;pos表示token在序列中的位置;i表示每个token编码的第i个位置,属于[0,d_model)。
torch实现
import math
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
class PositionalEncoder(nn.Module):
def __init__(self, max_seq_len=50, d_model=128):
super().__init__()
self.d_model = d_model # d_model 表示token的维度
pe = torch.zeros(max_seq_len, d_model) # max_seq_len * d_model 的二维张量 例如: 50*128
for pos in range(max_seq_len): # 重新初始化
for i in range(0, d_model, 2):
pe[pos, i] = math.sin(pos / (10000 ** (i / d_model)))
pe[pos, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
pe = pe.unsqueeze(0) # 1*50*128
self.register_buffer('pe', pe)
def forward(self, x):
x = x * math.sqrt(self.d_model)
seq_len = x.size(1)
x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
return x
if __name__ == '__main__':
positional_encoder = PositionalEncoder(50, 128)
plt.pcolormesh(positional_encoder.pe.numpy()[0], cmap='RdBu')
plt.xlabel('Depth') # 50
plt.xlim((0, 128))
plt.ylabel('Position') # 128
plt.colorbar()
plt.show()