1 import tensorflow as tf 2 import numpy as np 3 4 B=3 5 D=4 6 T=5 7 8 tf.reset_default_graph() 9 xs=tf.placeholder(shape=[T,B,D],dtype=tf.float32) 10 11 with tf.variable_scope('rnn'): 12 GRUcell = tf.nn.rnn_cell.GRUCell(num_units=D) 13 cell = tf.nn.rnn_cell.MultiRNNCell([GRUcell]) 14 15 output_ta = tf.TensorArray(size=T,dtype=tf.float32) 16 input_ta = tf.TensorArray(size=T,dtype=tf.float32) 17 input_ta = input_ta.unstack(xs) 18 19 def body(time,output_ta_t,state): 20 xt = input_ta.read(time) 21 new_output,new_state = cell(xt,state) 22 output_ta_t = output_ta_t.write(time, new_output) 23 return (time+1,output_ta_t,new_state) 24 25 def condition(time,output,state): 26 return time<T 27 28 time=0 29 state=cell.zero_state(B,tf.float32) 30 time_final,output_ta_final,state_final=tf.while_loop(cond=condition,body=body,loop_vars=(time,output_ta,state)) 31 output_final = output_ta_final.stack() 32 33 x=np.random.randn(T,B,D) 34 with tf.Session() as sess: 35 sess.run(tf.global_variables_initializer()) 36 output_final_,state_final_=sess.run([output_final,state_final],feed_dict={xs:x})
1 import tensorflow as tf 2 tf.enable_eager_execution() 3 4 def condition(time,max_time, output_ta_l): 5 return tf.less(time, max_time) 6 7 def body(time,max_time, output_ta_l): 8 output_ta_l = output_ta_l.write(time, [2.4, 3.5]) 9 return time + 1, max_time,output_ta_l 10 11 max_time=tf.constant(3) 12 time = tf.constant(0) 13 output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True) 14 result = tf.while_loop(condition, body, loop_vars=[time,max_time,output_ta]) 15 last_time,max_time, last_out = result 16 final_out = last_out.stack() 17 18 19 print(last_time) 20 print(final_out) 21 22 23 ''' 24 ta.stack(name=None) 将TensorArray中元素叠起来当做一个Tensor输出 25 ta.unstack(value, name=None) 可以看做是stack的反操作,输入Tensor,输出一个新的TensorArray对象 26 ta.write(index, value, name=None) 指定index位置写入Tensor 27 ta.read(index, name=None) 读取指定index位置的Tensor 28 作者:加勒比海鲜 29 原文:https://blog.csdn.net/guolindonggld/article/details/79256018 30 '''
TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用
tips:[n.name for n in tf.get_default_graph().as_graph_def().node]获取图中所有节点