基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask
flyfish
Word-level Language Modeling using RNN and Transformer
word_language_model
PyTorch 提供的 word_language_model 示例展示了如何使用循环神经网络RNN(GRU或LSTM)和 Transformer 模型进行词级语言建模 。默认情况下,训练使用Wikitext-2数据集,generate.py可以使用训练好的模型来生成新文本。
源码地址
https://github.com/pytorch/examples/tree/main/word_language_model
文件:model.py
import torch
import matplotlib.pyplot as plt
import numpy as np
def _generate_square_subsequent_mask(sz):
return torch.log(torch.tril(torch.ones(sz, sz)))
# 设置矩阵大小
sz = 5
mask = _generate_square_subsequent_mask(sz)
# 将 mask 转换为 numpy 数组,方便可视化
mask_np = mask.numpy()
# 可视化
plt.imshow(mask_np, cmap='viridis')
plt.colorbar()
plt.title("Square Subsequent Mask")
plt.show()
可视化图示
在可视化结果中,你会看到一个下三角矩阵,其值为 0 的部分为下三角部分,值为负无穷的部分为上三角部分。图像中通常负无穷会被显示为一种不同的颜色。
这样,你可以直观地理解生成的掩码矩阵的结构和作用。这个掩码矩阵主要用于 Transformer 模型中,以确保模型在预测时只能看到当前时刻及之前的时刻信息,而不能看到未来的信息。
结果
运行这段代码,你会看到一个 5x5 的矩阵,其中下三角部分是 0(因为 log(1) = 0),上三角部分是负无穷(由于 log(0) 是负无穷)。
def _generate_square_subsequent_mask(sz):
return torch.log(torch.tril(torch.ones(sz, sz)))
# 设置矩阵大小
sz = 5
mask = _generate_square_subsequent_mask(sz)
# 打印矩阵
print(mask)
输出
tensor([[0., -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf],
[0., 0., 0., -inf, -inf],
[0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0.]])
在数学上,定义对数函数时,log(0) 是未定义的,但在计算中,我们处理这种情况的方式是认为 log(0) 的极限值是负无穷。因此,计算机通常会返回负无穷来表示这种情况。
在 PyTorch 中,torch.log(0) 的结果是 -inf(负无穷)。这是因为对数函数是单调递增的,并且在接近0时值会急剧下降到负无穷。