• 【nlp】Attention机制学习


    Attention 机制学习

    Attention 机制中一般需要用到的三个参数

    query(Q), key(K), value(V)

    preview

    attention 包括硬编码和软编码

    其中 [公式] 是编码器每个step的输出, [公式] 是解码器每个step的输出,计算步骤是这样的:

    1. 先对输入进行编码,得到 [公式]
    2. 开始解码了,先用固定的start token也就是 [公式] 最为Q,去和每个 [公式] (同时作为K和V)去计算attention,得到加权的 [公式]
    3. [公式] 作为解码的RNN输入(同时还有上一步的 [公式] ),得到 [公式] 并预测出第一个词是machine
    4. 再继续预测的话,就是用 [公式] 作为Q去求attention:

    增加了attention的学习机制后,可以编码更长的序列信息,同时,也可以优化输出序列和输入序列中,单词排序不同情况下的表现,这在机器对语句进行理解、摘要或者翻译中,具有重要影响。

    当然,这种attention可能会减少对序列顺序的敏感性,同时,由于使用rnn,不能并行化计算。


    在实现Seq2Seq模型中,Decoder解码部分,对于前一个预测词,有两种来源可以采用:

    • 模型预测的单词
    • 给定结果的单词

    模型01、02,我没有采用Attention机制,同时只用给定结果的单词参与运算。导致了模型训练极度拟合了train训练集,因此,在valid集上的loss越来越大。在自己抽样调查中,可以明显感知,valid上,完全是用train中的原句去预测。可以在抽样中看到。

    在模型03中,我采用了Attention机制,虽然沿用了只“给定结果的单词”的方式,但是效果还不错,在20个epoch训练后,valid集上交叉熵损失随着train集的损失稳步下降。抽样调查也令人欣慰。

    参考:https://zhuanlan.zhihu.com/p/44121378
    图片来源也是

    skr
  • 相关阅读:
    c++中vector的用法详解[转]
    C++ String
    va_list用法
    如何高效的分析AWR报告
    Oracle存储过程跟踪错误的方法
    Oracle找出锁,并KILL掉
    OracleAWR报告概念和生成
    Linux系统的内存管理
    AIX系统下配置FTP服务
    通过修改注册表配置IE选项
  • 原文地址:https://www.cnblogs.com/ckxkexing/p/14423780.html
Copyright © 2020-2023  润新知