条件设置
batch_size=1
src_len = 8 # 源句子的最大长度 根据这个进行padding的填充
tgt_len = 7 # 目标输入句子的最大长度 根据这个进行padding的填充
d_model=512 # embedding的维度
d_ff=2048 # 全连接层的维度
h_head=8 # Multi-Head Attention 的个数
d_k=d_q= 64 # dimension of K(=Q)(Q和K的维度需要相同,为了方便也可以让K=Q=V)
d_v=128 # dimension of V
Encoder
经过处理后数据作为Encoder的输入:[batch_size,src_len] #即 [1,8]
1.1输入层
经过word embedding:[batch_size, src_len, d_model]
经过positional encoding:[batch_size,src_len, d_model]
两者相加为:[batch_size,src_len, d_model]
数据经过dropout:[batch_size,src_len, d_model]
1.2 计算multi-head attention
W_Q = [d_model, d_q * n_heads]
W_K = [d_model, d_k * n_heads]
W_V = [d_model, d_v * n_heads]
X=[batch_size, src_len, d_model]
Q=X*W_Q=[batch_size, src_len,d_q * n_heads]
K=X*W_K=[batch_size, src_len,d_k * n_heads]
V=X*W_V=[batch_size, src_len,d_v * n_heads]
1.2.1 进行分割
Q=[batch_size,n_heads,src_len,d_q]
K=[batch_size,n_heads,src_len,d_k]
V=[batch_size,n_heads,src_len,d_v]
1.2.2 计算dot-product attention
K改变形状为 K=[batch_size,n_heads,d_k,src_len]
outputs=Q*K=[batch_size,n_heads,src_len,d_q]*[batch_size,n_heads,d_k,src_len]=[batch_size,n_heads,src_len,src_len]
outputs /= d_k**0.5
1.2.3进行softmax
先进行mask,将padding部分变为-inf
outputs=softmax(outputs)=[batch_size,n_heads,src_len,src_len]
1.2.4 最终结果相乘
outputs=outputs*V=[batch_size,n_heads,src_len,src_len]*[batch_size,n_heads,src_len,d_v]=[batch_size,n_heads,src_len,d_v]
1.2.5 全连接、残差链接和层归一化
先进行reshape,得到outputs=[batch_size,src_len,n_heads*d_v]
fc=[n_heads*d_v,d_model]
outputs=outputs*fc=[batch_size,src_len,n_heads*d_v]*[n_heads*d_v,d_model]=[batch_size,src_len,d_model]
1.3 全连接、残差链接和层归一化
outputs=[batch_size,src_len,d_model]
Decoder
2.1 Masked multi-head attention
这部分除了多了一步mask,其他步骤是相同的就不加赘述了
2.2 multi-head attention
Q=[batch_size,n_heads,tgt_len,d_q]
K=[batch_size,n_heads,src_len,d_k]
V=[batch_size,n_heads,src_len,d_v]
K改变形状为 K=[batch_size,n_heads,d_k,src_len]
outputs=Q*K=[batch_size,n_heads,tgt_len,d_q]*[batch_size,n_heads,d_k,src_len]=[batch_size,n_heads,tgt_len,src_len]
outputs /= d_k**0.5
outputs=softmax(outputs)=[batch_size,n_heads,tgt_len,src_len]
outputs=outputs*V=[batch_size,n_heads,tgt_len,src_len]*[batch_size,n_heads,src_len,d_v]=[batch_size,n_heads,src_len,d_v]
2.3
剩下的都是相同的这里就不赘述了