• 『TensorFlow』第十弹_队列&多线程_道路多坎坷


    一、基本队列:

    队列有两个基本操作,对应在tf中就是enqueue&dequeue

     tf.FIFOQueue(2,'int32')

    import tensorflow as tf
    
    '''FIFO队列操作'''
    
    # 创建队列
    # 队列有两个int32的元素
    q = tf.FIFOQueue(2,'int32')
    # 初始化队列
    init= q.enqueue_many(([0,10],))
    # 出队
    x = q.dequeue()
    y = x + 1
    # 入队
    q_inc = q.enqueue([y])
    
    with tf.Session() as sess:
        init.run()
        for _ in range(5):
            v,_ = sess.run([x,q_inc])
            print(v)
    

     tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes='float')

    '''随机队列操作'''
    
    # 最大长度10,最小长度2,类型float的随机队列
    q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes='float')
    
    sess = tf.Session()
    for i in range(0,10):
        sess.run(q.enqueue(i))
    for i in range(0,8): # 在输出8次后会被阻塞
        print(sess.run(q.dequeue()))
    
    #run_option = tf.RunOptions(timeout_in_ms = 10000) # 等待时间10s
    #for i in range(0,7): # 在输出8次后会被阻塞
    #    # 超时报错继续,不会退出
    #    try:
    #        print(sess.run(q.dequeue(),options=run_option))
    #    except tf.errors.DeadlineExceededError:
    #        print('out of range')
    
    print('-----'*5)
    

    二、队列管理:

    tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)

    '''队列管理器'''
    
    # 队列管理器使用线程管理队列
    
    q = tf.FIFOQueue(1000,'float')
    counter = tf.Variable(0.0)  # 计数器
    increment_op = tf.assign_add(counter, tf.constant(1.0))   # 计数器加一
    enqueue_op = q.enqueue(counter)                           # 入队
    
    # 线程面向队列q,启动2个线程,每个线程中是[in,en]两个操作
    qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)
    
    
    sess.run(tf.global_variables_initializer())
    enqueue_threads = qr.create_threads(sess,start=True)      # 启动入队线程
    for i in range(10):
        print(sess.run(q.dequeue()))
        # 由于主线程和入队线程异步,所以输出不是自然数序列
    

    出队操作还有Queu.dequeue_many(batch_size),如果入队时采用enqueue([image, label]),则可以实现队列数据参与训练。

    tf.train.Coordinator()

    '''协调器'''
    
    q = tf.FIFOQueue(1000,'float')
    counter = tf.Variable(0.0)  # 计数器
    increment_op = tf.assign_add(counter, tf.constant(1.0))   # 计数器加一
    enqueue_op = q.enqueue(counter)                           # 入队
    
    # 线程面向队列q,启动2个线程,每个线程中是[in,en]两个操作
    qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*2)
    
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    coord = tf.train.Coordinator()
    # 线程管理器启动线程,接收协调器管理
    enqueue_thread = qr.create_threads(sess,coord=coord,start=True)
    
    for i in range(0,10):
        print(sess.run(q.dequeue()))
    
    coord.request_stop()            # 向各个线程发终止信号
    coord.join(enqueue_thread)      # 等待各个线程成功结束
    
  • 相关阅读:
    OC3(字符串,值类)
    OC2(初始化方法)
    OC1(类和对象)
    postgresql 时间戳格式为5分钟、15分钟
    centos添加ftp用户并禁止外切目录
    postgresql 判断日期是否合法
    tigerVNC的简单使用教程(CentOS的远程桌面连接)
    linux awk解析csv文件
    由windows向linux上传下载文件方法
    codeblocks linux编译告警乱码解决办法
  • 原文地址:https://www.cnblogs.com/hellcat/p/6941367.html
Copyright © 2020-2023  润新知