参考资料:
PyTorch的Transformer
register_buffer的作用是:
登记成员变量,它会自动成为模型中的参数,随着模型移动(gpu/cpu)而移动,但是不会随着梯度进行更新。
参考资料:【Torch API】pytorch 中register_buffer()函数详解
在不同的上下文中,一个token要表达出不一样的含义。我们引入这三个可学习的Wq、Wk、Wv,就可以更好更灵活的把embedding映射到更合适的空间,可以做到更好的提取特征表达含义。
torch.full(size, fill_value)
:用fill_value
填充得到大小为size
的tensor,注意数据类型是浮点类型。
torch.triu()
:返回上三角矩阵。可以设置对角线偏移。