一、基本队列:
队列有两个基本操作,对应在tf中就是enqueue&dequeue
tf.FIFOQueue(2,'int32')
import tensorflow as tf '''FIFO队列操作''' # 创建队列 # 队列有两个int32的元素 q = tf.FIFOQueue(2,'int32') # 初始化队列 init= q.enqueue_many(([0,10],)) # 出队 x = q.dequeue() y = x + 1 # 入队 q_inc = q.enqueue([y]) with tf.Session() as sess: init.run() for _ in range(5): v,_ = sess.run([x,q_inc]) print(v)
tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes='float')
'''随机队列操作''' # 最大长度10,最小长度2,类型float的随机队列 q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes='float') sess = tf.Session() for i in range(0,10): sess.run(q.enqueue(i)) for i in range(0,8): # 在输出8次后会被阻塞 print(sess.run(q.dequeue())) #run_option = tf.RunOptions(timeout_in_ms = 10000) # 等待时间10s #for i in range(0,7): # 在输出8次后会被阻塞 # # 超时报错继续,不会退出 # try: # print(sess.run(q.dequeue(),options=run_option)) # except tf.errors.DeadlineExceededError: # print('out of range') print('-----'*5)
二、队列管理:
tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)
'''队列管理器''' # 队列管理器使用线程管理队列 q = tf.FIFOQueue(1000,'float') counter = tf.Variable(0.0) # 计数器 increment_op = tf.assign_add(counter, tf.constant(1.0)) # 计数器加一 enqueue_op = q.enqueue(counter) # 入队 # 线程面向队列q,启动2个线程,每个线程中是[in,en]两个操作 qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2) sess.run(tf.global_variables_initializer()) enqueue_threads = qr.create_threads(sess,start=True) # 启动入队线程 for i in range(10): print(sess.run(q.dequeue())) # 由于主线程和入队线程异步,所以输出不是自然数序列
出队操作还有Queu.dequeue_many(batch_size),如果入队时采用enqueue([image, label]),则可以实现队列数据参与训练。
tf.train.Coordinator()
'''协调器''' q = tf.FIFOQueue(1000,'float') counter = tf.Variable(0.0) # 计数器 increment_op = tf.assign_add(counter, tf.constant(1.0)) # 计数器加一 enqueue_op = q.enqueue(counter) # 入队 # 线程面向队列q,启动2个线程,每个线程中是[in,en]两个操作 qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2) sess = tf.Session() sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() # 线程管理器启动线程,接收协调器管理 enqueue_thread = qr.create_threads(sess,coord=coord,start=True) for i in range(0,10): print(sess.run(q.dequeue())) coord.request_stop() # 向各个线程发终止信号 coord.join(enqueue_thread) # 等待各个线程成功结束