输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility),每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出。
import torch
import torch.nn as nn
class ScaledDotProductAttention(nn.Module):
""" Scaled Dot-Product Attention """
def __init__(self, scale):
super().__init__()
self.scale = scale
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmul
u = u / self.scale # 2.Scale
if mask is not None:
u = u.masked_fill(mask, -np.inf) # 3.Mask
attn = self.softmax(u) # 4.Softmax
output = torch.bmm(attn, v) # 5.Output
return attn, output
if __name__ == "__main__":
n_q, n_k, n_v = 2, 4, 4
d_q, d_k, d_v = 128, 128, 64
batch = 2
q = torch.randn(batch, n_q, d_q)
k = torch.randn(batch, n_k, d_k)
v = torch.randn(batch, n_v, d_v)
mask = torch.zeros(batch, n_q, n_k).bool()
attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
attn, output = attention(q, k, v, mask=mask)
print(attn)
print(output)
运行结果:
tensor([[[0.4165, 0.3548, 0.1667, 0.0620],
[0.0381, 0.3595, 0.4584, 0.1439]],
[[0.3611, 0.1587, 0.2078, 0.2723],
[0.1603, 0.0530, 0.0670, 0.7198]]])
tensor([[[ 2.2813e-01, -6.3289e-01, 1.3624e+00, 8.4069e-01, 8.1762e-02,
-6.3727e-01, -6.3929e-01, -1.0091e+00, 3.7668e-01, -2.9384e-01,
-6.2543e-02, -4.4706e-01, 3.8331e-01, 2.2979e-02, -1.1968e+00,
-3.7061e-01, -1.9007e-01, -1.7616e-01, 3.6516e-01, 1.1321e-01,
-9.5077e-01, -1.3449e+00, -1.2594e+00, 4.2644e-01, -6.3195e-01,
-5.2016e-01, -2.5782e-01, -2.4116e-01, 1.7582e-01, -1.5177e+00,
-9.3120e-01, -4.9671e-01, -4.5024e-01, -1.0746e+00, 5.4357e-01,
-6.2079e-01, 5.1379e-01, 5.6308e-02, -6.3830e-01, -3.6174e-01,
-3.0044e-01, -3.0946e-01, -5.0303e-01, -1.8382e-01, 1.1064e+00,
-7.5142e-01, -1.5372e-01, -3.3204e-01, -7.9568e-01, 1.3108e-01,
-8.6041e-01, 2.5165e-01, 8.8248e-02, 3.7294e-01, -5.2247e-02,
4.8462e-01, -7.4389e-01, -5.4351e-01, -9.7697e-01, -9.3327e-01,
-4.4550e-02, 6.1108e-01, -5.4613e-01, 2.3962e-01],
[ 6.9032e-02, 9.0591e-01, 8.3206e-01, 1.3668e+00, 1.8095e-02,
-7.3172e-02, -3.0873e-01, -9.2571e-01, 4.3452e-01, -4.7707e-02,
-3.0431e-01, -1.7578e-01, 4.0575e-01, -4.4958e-01, -4.9809e-01,
-1.7263e-02, -3.8684e-01, 2.8536e-01, 4.1150e-02, -3.7069e-01,
-7.2903e-01, -2.5185e-01, -1.0011e-01, 9.0434e-01, -7.8387e-02,
6.9680e-01, 5.3684e-01, 2.8456e-01, 2.2887e-01, -1.7423e+00,
-4.4135e-01, -2.9209e-01, 1.7053e-01, -6.4208e-01, 1.7977e-01,
1.3822e-01, -1.7873e-01, -4.7619e-01, -6.7788e-01, -5.3340e-01,
3.1518e-01, -5.6127e-02, 2.2175e-01, -3.9524e-01, 5.4478e-01,
-5.7730e-01, 5.8043e-01, -3.0143e-01, -5.7146e-01, 1.5063e-05,
-6.8221e-01, -1.3456e-02, -6.5192e-01, 7.4233e-02, 3.1776e-01,
3.1504e-01, -9.5457e-01, -8.9894e-01, -7.8422e-01, -4.1440e-01,
-9.4272e-02, 2.7226e-01, -7.0286e-01, 8.9388e-01]],
[[-7.6068e-02, 1.6911e-01, 5.1532e-02, -5.3612e-02, 2.4258e-02,
1.6490e-01, 7.4469e-01, -1.1471e+00, -4.5234e-01, 1.0684e-01,
1.0929e+00, -5.8079e-01, 1.7665e-01, -2.0187e-02, -3.3850e-01,
4.4517e-01, -4.5871e-01, 6.7840e-01, -4.3617e-01, 7.6141e-01,
3.8135e-02, -2.3898e-01, 3.2086e-01, 4.1481e-01, -1.8267e-01,
8.4337e-01, 7.8504e-02, -1.0101e+00, 5.0766e-02, 2.3338e-01,
-3.5572e-01, 1.3751e-01, -4.9570e-02, 4.8627e-01, -3.3225e-01,
6.5361e-01, 2.8979e-01, 9.9991e-02, 8.6995e-01, -7.2569e-02,
2.5490e-01, -2.6418e-01, 6.1185e-01, -7.7243e-01, -4.6956e-01,
-3.1459e-01, -2.1278e-01, 9.1588e-01, -2.1349e-02, -5.0036e-01,
3.6214e-01, 1.3723e-02, 1.2322e-01, -5.3018e-01, 2.4809e-01,
-3.2042e-01, 2.4807e-01, -1.5764e-01, -2.6655e-01, 1.8610e-01,
-1.6585e-01, 2.3454e-01, 3.1852e-01, 6.1627e-01],
[-1.7126e-01, 8.6634e-01, 4.7069e-01, -8.1842e-01, -6.2145e-01,
-3.8596e-02, 1.2991e+00, -8.4528e-01, -1.5742e+00, 1.2813e+00,
1.1197e+00, -1.2562e+00, 7.3848e-01, 2.2198e-02, -4.1664e-01,
1.1044e+00, -1.2744e+00, -1.6599e-01, -6.4863e-01, 1.1497e+00,
-1.4236e-01, -1.2829e-01, -2.7600e-01, 4.7095e-01, -5.1933e-02,
8.7453e-01, -6.4251e-01, -4.2953e-01, 3.5337e-01, -2.2782e-01,
2.5079e-01, 1.7728e-01, 6.4826e-01, 2.4980e-01, 8.3032e-02,
2.1247e+00, -3.0265e-01, -1.9821e-01, 9.7439e-01, -3.6237e-01,
-2.6392e-01, -5.1498e-01, 1.3055e+00, -9.1860e-01, -6.9769e-01,
6.5717e-01, 5.8009e-01, 3.6944e-01, 2.0414e-01, -9.0271e-01,
4.5972e-01, 9.4667e-01, 1.3700e-02, -2.7962e-01, 3.7535e-01,
-4.1842e-01, -6.2615e-01, 6.8238e-03, -3.4866e-01, 5.7681e-01,
-5.5240e-01, 1.8245e-01, 6.2508e-01, 6.0020e-01]]])