介绍
在Bahadanu注意力机制中,本质上是序列到序列学习的注意力机制实现,在编码器-解码器结构中,解码器的每一步解码过程都依赖着整个上下文变量,通过Bahdanau注意力,使得解码器在每一步解码时,对于整个上下文变量的不同部分产生不同程度的对齐,如在文本翻译时,将“I am studying”的“studying”与“我正在学习”的“学习”进行对齐,即注意力在解码时将绝大多数注意力放在“studying”处。
原理和结构
原理
Bahdanau注意力机制本质上是将上下文变量进行转换即可,其中转换后的上下文变量计算方式如下式所示:
在传统注意力机制中,一般使用的公式形如,在Bahdanau中,键与值是同一个变量,都是t时刻的编码器隐状态,s表示该时刻的查询,即上一时刻的解码器隐状态。
架构
下图为Bahdanau注意力机制的编码器-解码器架构示意图:
为便于理解,对上述示意结构进行说明:首先将X依次输入GRU,之后在循环过程中依次产生len个隐状态,最后一个隐状态直接作为解码器的初始隐状态。在每个解码步骤 (t),注意力机制计算当前解码器隐藏状态 (s_t) 和编码器所有隐藏状态 (h_i) 的相似度,即应用注意力机制编写新的上下文变量,之后在解码器的循环解码过程中,都计算带注意力的上下文变量,通过此变量和上一解码隐状态计算当前时刻t的输出。
代码实现
引入库
注:本blog使用mxnet进行训练学习。
from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l
npx.set_np()
定义注意力解码器
这里实现一个接口,只需重新定义解码器即可。 为了更方便地显示学习的注意力权重, 以下AttentionDecoder
类定义了带有注意力机制解码器的基本接口。
class AttentionDecoder(d2l.Decoder):
def __init__(self, **kwargs):
super(AttentionDecoder, self).__init__(**kwargs)
def attention_weights(self):
raise NotImplementedError
接下来,让我们在接下来的Seq2SeqAttentionDecoder
类中实现带有Bahdanau注意力的循环神经网络解码器。 首先,初始化解码器的状态,需要下面的输入:
-
编码器在所有时间步的最终层隐状态,将作为注意力的键和值;
-
上一时间步的编码器全层隐状态,将作为初始化解码器的隐状态;
-
编码器有效长度(排除在注意力池中填充词元)。
在每个解码时间步骤中,解码器上一个时间步的最终层隐状态将用作查询。 因此,注意力输出和输入嵌入都连结为循环神经网络解码器的输入。
对接下来的代码实现略作补充说明:在编码器中,对一个batch每个输入(采用one-hot编码,长度为Vocab_size)依次进行嵌入层运算,得到固定embed_size个结果之后进行RNN运算,RNN使用层数为num_layers的深层循环神经网络,进行forward运算得到状态,部分过程进行闭包和解包,对于对维度大小出现疑惑的点,大多是闭包和解包造成的。
class Seq2SeqAttentionDecoder(AttentionDecoder):
def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):
super(Seq2SeqAttentionDecoder, self).__init__(**kwargs)
self.attention = d2l.AdditiveAttention(num_hiddens, dropout)
self.embedding = nn.Embedding(vocab_size, embed_size)
self.rnn = rnn.GRU(num_hiddens, num_layers, dropout=dropout)
self.dense = nn.Dense(vocab_size, flatten=False)
def init_state(self, enc_outputs, enc_valid_lens, *args):
outputs, hidden_state = enc_outputs
return (outputs.swapaxes(0, 1), hidden_state, enc_valid_lens)
def forward(self, X, state):
enc_outputs, hidden_state, enc_valid_lens = state
X = self.embedding(X).swapaxes(0, 1)
outputs, self._attention_weights = [], []
for x in X:
query = np.expand_dims(hidden_state[0][-1], axis=1)
context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens)
x = np.concatenate((context, np.expand_dims(x, axis=1)), axis=-1)
out, hidden_state = self.rnn(x.swapaxes(0, 1), hidden_state)
outputs.append(out)
self._attention_weights.append(self.attention.attention_weights)
outputs = self.dense(np.concatenate(outputs, axis=0))
return outputs.swapaxes(0, 1), [enc_outputs, hidden_state,enc_valid_lens]
def attention_weights(self):
return self._attention_weights
训练
我们在这里指定超参数,实例化一个带有Bahdanau注意力的编码器和解码器, 并对这个模型进行机器翻译训练。
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 250, d2l.try_gpu()
train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers, dropout)
decoder = Seq2SeqAttentionDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers, dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
结果
采用BLEU计算困惑度,代码具有较好的表现。
go . => entre ., bleu 0.000 i lost . => j'ai gagné ., bleu 0.000 he's calm . => j'ai gagné ., bleu 0.000 i'm home . => je suis chez moi <unk> !, bleu 0.719