• 新版seqseq接口说明


    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units=FLAGS.rnn_hidden_size, memory = encoder_outputs, memory_sequence_length = encoder_sequence_length)

    这一步创造一个attention_mechanism。通过__call__(self, query, previous_alignments)来调用,输入query也就是decode hidden,输入previous_alignments是encode hidden,输出是一个attention概率矩阵

    helper = tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper(inputs, tf.to_int32(sequence_length), emb, tf.constant(FLAGS.scheduled_sampling_probability))

    创建一个helper,用来处理每个时刻的输入和输出

    my_decoder = tf.contrib.seq2seq.BasicDecoder(cell = cell, helper = helper, initial_state = state)

    调用的核心部分。通过def step(self, time, inputs, state, name=None)来控制每一个进行decode

    首先把inputs和attention进行concat作为输入。(为什么这样做,参考LSTM的实现 W1U+W2V,其实是把U,V concat在乘以一个W),那么这里inputs就是U,attention就是V(其实tf.concat(query,attention矩阵 * memory)在做个outpreject)。

    outputs, state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(my_decoder, scope='seq_decode')

    最后通过dynamic_decode来控制整个flow

    写到前面:

    先看:

    class BasicRNNCell(RNNCell):

    def call(self, inputs, state):
    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
    if self._linear is None:
    self._linear = _Linear([inputs, state], self._num_units, True)

    这个是核心,也就是W * input + U * state + B的实现,tf是用_Linear来实现的(_Linear的实现就是把input和state进行concat,然后乘以一个W)。由于rnn只有hidden,所以这里的state就是hidden

    再看

    class BasicLSTMCell(RNNCell):

    if self._state_is_tuple:
    c, h = state
    else:
    c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

    if self._linear is None:
    self._linear = _Linear([inputs, h], 4 * self._num_units, True)
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(
    value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

    new_c = (
    c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)

    if self._state_is_tuple:
    new_state = LSTMStateTuple(new_c, new_h)
    else:
    new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state

    就非常明显了,由于lstm的state是由两部分构成的,一个是hidden,一个是state,第一步先split。之后用inputs和h进行linear,由于我们要输出4个结果,记得输出维度一定要是4*_num_units。然后根据公式再进行后面的操作,最后返回新的hidden和state,也很直观。

    之后再看,加入attention之后怎么弄:

    我们这里的attention为encode hidden,那么根据公式是attention和decode hidden进行concat作为一个大的hidden,之后和inputs一起进入网络。

    但是,tf实现的时候是这样子的,首先把attention和inputs进行concat,之后把连接的结果作为inputs和decode hidden一起送入网络。为什么能这么做呢,是因为在网络内部其实也是concat之后再linear,参考上面的BasicLSTMCell实现,所有关键就是把(inputs,attention,decode hidden)concat一起就行了,不管顺序是啥。说道这里你终于明白了AttentionWrapper到底是干啥的了。那么attention怎么计算呢,有个_compute_attention函数。我感觉就是非常直接了,attention_mechanism是你需要的attention映射矩阵的方式,

    def _compute_attention(attention_mechanism, cell_output, previous_alignments,
    attention_layer):
    """Computes the attention and alignments for a given attention_mechanism."""
    alignments = attention_mechanism(
    cell_output, previous_alignments=previous_alignments)

    # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
    expanded_alignments = array_ops.expand_dims(alignments, 1)
    # Context is the inner product of alignments and values along the
    # memory time dimension.
    # alignments shape is
    # [batch_size, 1, memory_time]
    # attention_mechanism.values shape is
    # [batch_size, memory_time, memory_size]
    # the batched matmul is over memory_time, so the output shape is
    # [batch_size, 1, memory_size].
    # we then squeeze out the singleton dim.
    context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
    context = array_ops.squeeze(context, [1])

    if attention_layer is not None:
    attention = attention_layer(array_ops.concat([cell_output, context], 1))
    else:
    attention = context

    return attention, alignments

  • 相关阅读:
    [转]在WEB下的客户端控件的开发应用
    [转]C#将文件保存到数据库中或者从数据库中读取文件
    [转]如何组织一个高效的开发团队
    [转]使用C#调用金诚信71x系列读卡器的DLL
    [转]统一建模语言UML轻松入门之综合实例
    [转]Web application 的压力测试 MS Web Application Stress Tool (was)
    [转]使用NUnit在.Net编程中进行单元测试
    国外著名人脸识别介绍文章
    js与as通信
    php访问mysql 封装
  • 原文地址:https://www.cnblogs.com/dmesg/p/8195621.html
Copyright © 2020-2023  润新知