tensorflow 双向 rnn
如何在tensorflow中实现双向rnn
单层双向rnn
tensorflow
中已经提供了双向rnn
的接口,它就是tf.nn.bidirectional_dynamic_rnn()
. 我们先来看一下这个接口怎么用.
1 bidirectional_dynamic_rnn( 2 cell_fw, #前向 rnn cell 3 cell_bw, #反向 rnn cell 4 inputs, #输入序列. 5 sequence_length=None,# 序列长度 6 initial_state_fw=None,#前向rnn_cell的初始状态 7 initial_state_bw=None,#反向rnn_cell的初始状态 8 dtype=None,#数据类型 9 parallel_iterations=None, 10 swap_memory=False, 11 time_major=False, 12 scope=None 13 )
返回值:一个tuple(outputs, outputs_states), 其中,outputs
是一个tuple(outputs_fw, outputs_bw). 关于outputs_fw
和outputs_bw
,如果time_major=True
则它俩也是time_major
的,vice versa. 如果想要concatenate
的话,直接使用tf.concat(outputs, 2)
即可.
如何使用:
bidirectional_dynamic_rnn 在使用上和 dynamic_rn
n是非常相似的. 定义前向和反向rnn_cell 定义前向和反向rnn_cell的初始状态 准备好序列 调用bidirectional_dynamic_rnn import tensorflow as tf from tensorflow.contrib import rnn cell_fw = rnn.LSTMCell(10) cell_bw = rnn.LSTMCell(10) initial_state_fw = cell_fw.zero_state(batch_size) initial_state_bw = cell_bw.zero_state(batch_size) seq = ... seq_length = ... (outputs, states)=tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, seq, seq_length, initial_state_fw,initial_state_bw) out = tf.concat(outputs, 2)
# ....
多层双向rnn
单层双向rnn可以通过上述方法简单的实现,但是多层的双向rnn就不能使将MultiRNNCell
传给bidirectional_dynamic_rnn
了.
想要知道为什么,我们需要看一下bidirectional_dynamic_rnn
的源码片段.
1 with vs.variable_scope(scope or "bidirectional_rnn"): 2 # Forward direction 3 with vs.variable_scope("fw") as fw_scope: 4 output_fw, output_state_fw = dynamic_rnn( 5 cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 6 initial_state=initial_state_fw, dtype=dtype, 7 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 8 time_major=time_major, scope=fw_scope)
这只是一小部分代码,但足以看出,bi-rnn
实际上是依靠dynamic-rnn
实现的,如果我们使用MuitiRNNCell
的话,那么每层之间不同方向之间交互就被忽略了.所以我们可以自己实现一个工具函数,通过多次调用bidirectional_dynamic_rnn
来实现多层的双向RNN 这是我对多层双向RNN的一个精简版的实现,如有错误,欢迎指出
bidirectional_dynamic_rnn源码一探
上面我们已经看到了正向过程的代码实现,下面来看一下剩下的反向部分的实现.
其实反向的过程就是做了两次reverse
1. 第一次reverse
:将输入序列进行reverse
,然后送入dynamic_rnn
做一次运算.
2. 第二次reverse
:将上面dynamic_rnn
返回的outputs
进行reverse
,保证正向和反向输出的time
是对上的.
1 def _reverse(input_, seq_lengths, seq_dim, batch_dim): 2 if seq_lengths is not None: 3 return array_ops.reverse_sequence( 4 input=input_, seq_lengths=seq_lengths, 5 seq_dim=seq_dim, batch_dim=batch_dim) 6 else: 7 return array_ops.reverse(input_, axis=[seq_dim]) 8 9 with vs.variable_scope("bw") as bw_scope: 10 inputs_reverse = _reverse( 11 inputs, seq_lengths=sequence_length, 12 seq_dim=time_dim, batch_dim=batch_dim) 13 tmp, output_state_bw = dynamic_rnn( 14 cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, 15 initial_state=initial_state_bw, dtype=dtype, 16 parallel_iterations=parallel_iterations, swap_memory=swap_memory, 17 time_major=time_major, scope=bw_scope) 18 19 output_bw = _reverse( 20 tmp, seq_lengths=sequence_length, 21 seq_dim=time_dim, batch_dim=batch_dim) 22 23 outputs = (output_fw, output_bw) 24 output_states = (output_state_fw, output_state_bw) 25 26 return (outputs, output_states)
tf.reverse_sequence
对序列中某一部分进行反转
1 reverse_sequence( 2 input,#输入序列,将被reverse的序列 3 seq_lengths,#1Dtensor,表示输入序列长度 4 seq_axis=None,# 哪维代表序列 5 batch_axis=None, #哪维代表 batch 6 name=None, 7 seq_dim=None, 8 batch_dim=None 9 )
官网上的例子给的非常好,这里就直接粘贴过来:
1 # Given this: 2 batch_dim = 0 3 seq_dim = 1 4 input.dims = (4, 8, ...) 5 seq_lengths = [7, 2, 3, 5] 6 7 # then slices of input are reversed on seq_dim, but only up to seq_lengths: 8 output[0, 0:7, :, ...] = input[0, 7:0:-1, :, ...] 9 output[1, 0:2, :, ...] = input[1, 2:0:-1, :, ...] 10 output[2, 0:3, :, ...] = input[2, 3:0:-1, :, ...] 11 output[3, 0:5, :, ...] = input[3, 5:0:-1, :, ...] 12 13 # while entries past seq_lens are copied through: 14 output[0, 7:, :, ...] = input[0, 7:, :, ...] 15 output[1, 2:, :, ...] = input[1, 2:, :, ...] 16 output[2, 3:, :, ...] = input[2, 3:, :, ...] 17 output[3, 2:, :, ...] = input[3, 2:, :, ...]
例二:
1 # Given this: 2 batch_dim = 2 3 seq_dim = 0 4 input.dims = (8, ?, 4, ...) 5 seq_lengths = [7, 2, 3, 5] 6 7 # then slices of input are reversed on seq_dim, but only up to seq_lengths: 8 output[0:7, :, 0, :, ...] = input[7:0:-1, :, 0, :, ...] 9 output[0:2, :, 1, :, ...] = input[2:0:-1, :, 1, :, ...] 10 output[0:3, :, 2, :, ...] = input[3:0:-1, :, 2, :, ...] 11 output[0:5, :, 3, :, ...] = input[5:0:-1, :, 3, :, ...] 12 13 # while entries past seq_lens are copied through: 14 output[7:, :, 0, :, ...] = input[7:, :, 0, :, ...] 15 output[2:, :, 1, :, ...] = input[2:, :, 1, :, ...] 16 output[3:, :, 2, :, ...] = input[3:, :, 2, :, ...] 17 output[2:, :, 3, :, ...] = input[2:, :, 3, :, ...]