Attention 机制学习
Attention 机制中一般需要用到的三个参数
query(Q), key(K), value(V)
attention 包括硬编码和软编码
其中 是编码器每个step的输出, 是解码器每个step的输出,计算步骤是这样的:
- 先对输入进行编码,得到
- 开始解码了,先用固定的start token也就是 最为Q,去和每个 (同时作为K和V)去计算attention,得到加权的
- 用 作为解码的RNN输入(同时还有上一步的 ),得到 并预测出第一个词是machine
- 再继续预测的话,就是用 作为Q去求attention:
增加了attention的学习机制后,可以编码更长的序列信息,同时,也可以优化输出序列和输入序列中,单词排序不同情况下的表现,这在机器对语句进行理解、摘要或者翻译中,具有重要影响。
当然,这种attention可能会减少对序列顺序的敏感性,同时,由于使用rnn,不能并行化计算。
在实现Seq2Seq模型中,Decoder解码部分,对于前一个预测词,有两种来源可以采用:
- 模型预测的单词
- 给定结果的单词
模型01、02,我没有采用Attention机制,同时只用给定结果的单词参与运算。导致了模型训练极度拟合了train训练集,因此,在valid集上的loss越来越大。在自己抽样调查中,可以明显感知,valid上,完全是用train中的原句去预测。可以在抽样中看到。
在模型03中,我采用了Attention机制,虽然沿用了只“给定结果的单词”的方式,但是效果还不错,在20个epoch训练后,valid集上交叉熵损失随着train集的损失稳步下降。抽样调查也令人欣慰。
参考:https://zhuanlan.zhihu.com/p/44121378
图片来源也是