• tensorflow TensorArray 代码例子


     1 import tensorflow as tf 
     2 import numpy as np 
     3 
     4 B=3
     5 D=4
     6 T=5
     7 
     8 tf.reset_default_graph()
     9 xs=tf.placeholder(shape=[T,B,D],dtype=tf.float32)
    10 
    11 with tf.variable_scope('rnn'):
    12     GRUcell = tf.nn.rnn_cell.GRUCell(num_units=D)
    13     cell = tf.nn.rnn_cell.MultiRNNCell([GRUcell])
    14 
    15     output_ta = tf.TensorArray(size=T,dtype=tf.float32)
    16     input_ta = tf.TensorArray(size=T,dtype=tf.float32)
    17     input_ta = input_ta.unstack(xs)
    18 
    19     def body(time,output_ta_t,state):
    20         xt = input_ta.read(time)
    21         new_output,new_state = cell(xt,state)
    22         output_ta_t = output_ta_t.write(time, new_output)
    23         return (time+1,output_ta_t,new_state)
    24 
    25     def condition(time,output,state):
    26         return time<T
    27     
    28     time=0
    29     state=cell.zero_state(B,tf.float32)
    30     time_final,output_ta_final,state_final=tf.while_loop(cond=condition,body=body,loop_vars=(time,output_ta,state))
    31     output_final = output_ta_final.stack()
    32 
    33 x=np.random.randn(T,B,D)
    34 with tf.Session() as sess:
    35     sess.run(tf.global_variables_initializer())
    36     output_final_,state_final_=sess.run([output_final,state_final],feed_dict={xs:x})
     1 import tensorflow as tf
     2 tf.enable_eager_execution()
     3 
     4 def condition(time,max_time, output_ta_l):
     5     return tf.less(time, max_time)
     6 
     7 def body(time,max_time, output_ta_l):
     8     output_ta_l = output_ta_l.write(time, [2.4, 3.5])
     9     return time + 1, max_time,output_ta_l
    10 
    11 max_time=tf.constant(3)
    12 time = tf.constant(0)
    13 output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
    14 result = tf.while_loop(condition, body, loop_vars=[time,max_time,output_ta])
    15 last_time,max_time, last_out = result
    16 final_out = last_out.stack()
    17 
    18 
    19 print(last_time)
    20 print(final_out)
    21 
    22 
    23 '''
    24 ta.stack(name=None) 将TensorArray中元素叠起来当做一个Tensor输出
    25 ta.unstack(value, name=None) 可以看做是stack的反操作,输入Tensor,输出一个新的TensorArray对象
    26 ta.write(index, value, name=None) 指定index位置写入Tensor
    27 ta.read(index, name=None) 读取指定index位置的Tensor
    28 作者:加勒比海鲜 
    29 原文:https://blog.csdn.net/guolindonggld/article/details/79256018 
    30 '''

     TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用

    tips:[n.name for n in tf.get_default_graph().as_graph_def().node]获取图中所有节点

  • 相关阅读:
    Python统计excel表格中文本的词频,生成词云图片
    springboot application.properties 常用完整版配置信息
    JAVA高级-面试题总结
    删除csdn上面自己上传的资源
    本博客背景特效源码
    我的自定义框架 || 基于Spring Boot || 第一步
    PYTHON 实现的微信跳一跳【辅助工具】仅作学习
    PM2守护babel-node
    记一个HOST引起的前端项目打不开的问题
    迭代器与iterable
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/9914102.html
Copyright © 2020-2023  润新知