• tensorflow学习笔记(三十九):双向rnn


    tensorflow 双向 rnn

    如何在tensorflow中实现双向rnn

    单层双向rnn

    这里写图片描述
    单层双向rnn (cs224d)


    tensorflow中已经提供了双向rnn的接口,它就是tf.nn.bidirectional_dynamic_rnn(). 我们先来看一下这个接口怎么用.

     1 bidirectional_dynamic_rnn(
     2     cell_fw, #前向 rnn cell
     3     cell_bw, #反向 rnn cell
     4     inputs, #输入序列.
     5     sequence_length=None,# 序列长度
     6     initial_state_fw=None,#前向rnn_cell的初始状态
     7     initial_state_bw=None,#反向rnn_cell的初始状态
     8     dtype=None,#数据类型
     9     parallel_iterations=None,
    10     swap_memory=False,
    11     time_major=False,
    12     scope=None
    13 )

    返回值:一个tuple(outputs, outputs_states), 其中,outputs是一个tuple(outputs_fw, outputs_bw). 关于outputs_fwoutputs_bw,如果time_major=True则它俩也是time_major的,vice versa. 如果想要concatenate的话,直接使用tf.concat(outputs, 2)即可.

    如何使用: 
    bidirectional_dynamic_rnn 在使用上和 dynamic_rn

    n是非常相似的.
    
    定义前向和反向rnn_cell
    定义前向和反向rnn_cell的初始状态
    准备好序列
    调用bidirectional_dynamic_rnn
    import tensorflow as tf
    from tensorflow.contrib import rnn
    cell_fw = rnn.LSTMCell(10)
    cell_bw = rnn.LSTMCell(10)
    initial_state_fw = cell_fw.zero_state(batch_size)
    initial_state_bw = cell_bw.zero_state(batch_size)
    seq = ...
    seq_length = ...
    (outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq,
     seq_length, initial_state_fw,initial_state_bw)
    out = tf.concat(outputs, 2)
    View Code
    # ....

    多层双向rnn

    这里写图片描述
    多层双向rnn(cs224d)

    单层双向rnn可以通过上述方法简单的实现,但是多层的双向rnn就不能使将MultiRNNCell传给bidirectional_dynamic_rnn了. 
    想要知道为什么,我们需要看一下bidirectional_dynamic_rnn的源码片段.

    1 with vs.variable_scope(scope or "bidirectional_rnn"):
    2   # Forward direction
    3   with vs.variable_scope("fw") as fw_scope:
    4     output_fw, output_state_fw = dynamic_rnn(
    5         cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
    6         initial_state=initial_state_fw, dtype=dtype,
    7         parallel_iterations=parallel_iterations, swap_memory=swap_memory,
    8         time_major=time_major, scope=fw_scope)

    这只是一小部分代码,但足以看出,bi-rnn实际上是依靠dynamic-rnn实现的,如果我们使用MuitiRNNCell的话,那么每层之间不同方向之间交互就被忽略了.所以我们可以自己实现一个工具函数,通过多次调用bidirectional_dynamic_rnn来实现多层的双向RNN 这是我对多层双向RNN的一个精简版的实现,如有错误,欢迎指出

    bidirectional_dynamic_rnn源码一探

    上面我们已经看到了正向过程的代码实现,下面来看一下剩下的反向部分的实现. 
    其实反向的过程就是做了两次reverse 
    1. 第一次reverse:将输入序列进行reverse,然后送入dynamic_rnn做一次运算. 
    2. 第二次reverse:将上面dynamic_rnn返回的outputs进行reverse,保证正向和反向输出的time是对上的.

     1 def _reverse(input_, seq_lengths, seq_dim, batch_dim):
     2   if seq_lengths is not None:
     3     return array_ops.reverse_sequence(
     4         input=input_, seq_lengths=seq_lengths,
     5         seq_dim=seq_dim, batch_dim=batch_dim)
     6   else:
     7     return array_ops.reverse(input_, axis=[seq_dim])
     8 
     9 with vs.variable_scope("bw") as bw_scope:
    10   inputs_reverse = _reverse(
    11       inputs, seq_lengths=sequence_length,
    12       seq_dim=time_dim, batch_dim=batch_dim)
    13   tmp, output_state_bw = dynamic_rnn(
    14       cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
    15       initial_state=initial_state_bw, dtype=dtype,
    16       parallel_iterations=parallel_iterations, swap_memory=swap_memory,
    17       time_major=time_major, scope=bw_scope)
    18 
    19 output_bw = _reverse(
    20   tmp, seq_lengths=sequence_length,
    21   seq_dim=time_dim, batch_dim=batch_dim)
    22 
    23 outputs = (output_fw, output_bw)
    24 output_states = (output_state_fw, output_state_bw)
    25 
    26 return (outputs, output_states)

    tf.reverse_sequence

    对序列中某一部分进行反转

    1 reverse_sequence(
    2     input,#输入序列,将被reverse的序列
    3     seq_lengths,#1Dtensor,表示输入序列长度
    4     seq_axis=None,# 哪维代表序列
    5     batch_axis=None, #哪维代表 batch
    6     name=None,
    7     seq_dim=None,
    8     batch_dim=None
    9 )

    官网上的例子给的非常好,这里就直接粘贴过来:

     1 # Given this:
     2 batch_dim = 0
     3 seq_dim = 1
     4 input.dims = (4, 8, ...)
     5 seq_lengths = [7, 2, 3, 5]
     6 
     7 # then slices of input are reversed on seq_dim, but only up to seq_lengths:
     8 output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...]
     9 output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...]
    10 output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...]
    11 output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...]
    12 
    13 # while entries past seq_lens are copied through:
    14 output[0, 7:, :, ...] = input[0, 7:, :, ...]
    15 output[1, 2:, :, ...] = input[1, 2:, :, ...]
    16 output[2, 3:, :, ...] = input[2, 3:, :, ...]
    17 output[3, 2:, :, ...] = input[3, 2:, :, ...]

    例二:

     1 # Given this:
     2 batch_dim = 2
     3 seq_dim = 0
     4 input.dims = (8, ?, 4, ...)
     5 seq_lengths = [7, 2, 3, 5]
     6 
     7 # then slices of input are reversed on seq_dim, but only up to seq_lengths:
     8 output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...]
     9 output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...]
    10 output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...]
    11 output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...]
    12 
    13 # while entries past seq_lens are copied through:
    14 output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...]
    15 output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...]
    16 output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...]
    17 output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]
  • 相关阅读:
    tomcat7的catalina.sh配置说明
    nginx防攻击的简单配置
    linux系统自签发免费ssl证书,为nginx生成自签名ssl证书
    mysql ERROR 1045 (28000): Access denied for user 'root'@'localhost'
    /var/log/secure 文件清空
    Linux日志文件
    记一次网站被挂马处理
    Uedit32对文本进行回车换行
    安装mysql血泪史。
    mysql-8.0.19安装教程(Windows)
  • 原文地址:https://www.cnblogs.com/silence-tommy/p/8058333.html
Copyright © 2020-2023  润新知