• tensorflow中的lstm的state


       

    考虑 state_is_tuple

       

    Output, new_state = cell(input, state)

       

    state其实是两个 一个 c state,一个m(对应下图的hidden 或者h) 其中m(hidden)其实也就是输出

       

       

       

       

    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple

    else array_ops.concat(1, [c, m]))

    return m, new_state

       

       

    def basic_rnn_seq2seq(

    encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None):

    with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):

    _, enc_state = rnn.rnn(cell, encoder_inputs, dtype=dtype)

    return rnn_decoder(decoder_inputs, enc_state, cell)

       

       

    def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None,

    scope=None):

    with variable_scope.variable_scope(scope or "rnn_decoder"):

    state = initial_state

    outputs = []

    prev = None

    for i, inp in enumerate(decoder_inputs):

    if loop_function is not None and prev is not None:

    with variable_scope.variable_scope("loop_function", reuse=True):

    inp = loop_function(prev, i)

    if i > 0:

    variable_scope.get_variable_scope().reuse_variables()

    output, state = cell(inp, state)

    outputs.append(output)

    if loop_function is not None:

    prev = output

    return outputs, state

       

       

    这里decoder用了encoder的最后一个state 作为输入

       

    然后输出结果是decoder过程最后的state 加上所有ouput的集合(也就是hidden的集合)

    注意ouputs[-1]其实数值和state里面的m是一致的

    当然有可能后面outputs dynamic rnn 会补0

       

    encode_feature, state = melt.rnn.encode(

    cell,

    inputs,

    seq_length,

    encode_method=0,

    output_method=3)

       

    encode_feature.eval()

    array([[[ 4.27834410e-03, 1.45841937e-03, 1.25767402e-02,
    5.00775501e-03],
    [ 6.24437723e-03, 2.60074623e-03, 2.32168660e-02,
    9.47457738e-03],
    [ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
    -5.71310846e-03],

    [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00]]], dtype=float32)

       

       

    state[1].eval()

    array([[ 7.59789022e-03, -5.34060055e-05, 1.64511874e-02,
    -5.71310846e-03]], dtype=float32
    )

       

       

       

  • 相关阅读:
    【PAT】 B1006 换个格式输出整数
    【PAT】B1014 福尔摩斯的约会
    【PAT】B1005 继续(3n+1)猜想
    【PAT】B1004 成绩排名
    【PAT】B1003 我要通过!
    【PAT】B1002 写出这个数
    【PAT】B1001 害死人不偿命的(3n+1)猜想
    【PAT】A1001A+B Format
    【PAT】B1027 打印沙漏(20 分)
    【PAT】B1032 挖掘机技术哪家强(20 分)
  • 原文地址:https://www.cnblogs.com/rocketfan/p/6257137.html
Copyright © 2020-2023  润新知