• tensorflow中数据批次划分示例教程


    1.简介

    将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。

    1.1 tf.train

    tf.train.slice_input_producer(tensor_list,shuffle=True,seed=None,capacity=32)

    tf.train.batch(tensors,batch_size,num_threads=1,capacity=32,allow_smaller_final_batch=False)

    参数说明:

    shuffle:为True时进行数据清洗

    allow_smaller_final_batch:为True时将小于batch_size的批次值输出

    -------------------------------------------------------------------------------------------------------------------------

    -------------------------------------------------------------------------------------------------------------------------

    1.2 tf.data.Dataset

    tf.data.Dataset是一个类,可以使用以下方法:

    from_tensor_slices(tensors)

    batch(batch_size,drop_remainder=False)

    shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)

    repeat(count=None)

    make_one_shot_iterator() / get_next()

    注:make_one_shot_iterator() / get_next()用于Dataset数据的迭代器

    参数说明:

    tensors:可以是列表、字典、元组等类型

    drop_remainder:为False时表示不保留小于batch_size的批次,否则删除

    buffer_size:数据清洗时使用的buffer大小

    count:对应为epoch个数,为None时表示数据序列无限延续

    2.示例

    2.1 使用tf.train.slice_input_producer和tf.train.batch

     1 import tensorflow as tf
     2 import numpy as np
     3 import math
     4 
     5 # 生成样例数据集
     6 def generate_data():
     7     num = 15
     8     labels = np.asarray(range(num))
     9     images = np.random.random([num, 5, 5, 3])
    10     return images, labels
    11 
    12 # 打印样例信息
    13 images, labels = generate_data()
    14 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape))
    15 
    16 # 定义周期、批次、数据总量和遍历一次所有数据所需的迭代次数
    17 n_epochs = 3
    18 batch_size = 6
    19 train_nums = 15
    20 iterations = math.ceil(train_nums/batch_size)
    21 
    22 # 使用tf.train.slice_input_producer将所有数据放入队列,使用tf.train.batch划分队列中的数据
    23 input_queue = tf.train.slice_input_producer([images, labels], shuffle=False)
    24 image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32)
    25 print('image_batch.shape={0}, label_batch.shape={1}'.format(image_batch.shape, label_batch.shape))
    26 
    27 
    28 with tf.Session() as sess:
    29     tf.global_variables_initializer().run()
    30     # 启动队列线程
    31     coord = tf.train.Coordinator()
    32     threads = tf.train.start_queue_runners(sess, coord)
    33     # 打印信息
    34     for epoch in range(n_epochs):       
    35         for iteration in range(iterations):
    36             cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch])
    37             print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
    38     # 接收线程
    39     coord.request_stop()
    40     coord.join(threads)    
    41 
    42 
    43 # 打印结果如下
    44 images.shape=(15, 5, 5, 3), labels.shape=(15,)
    45 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
    46 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
    47 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
    48 The 1 epoch, the 3 iteration, current batch is [12 13 14  0  1  2]
    49 The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8]
    50 The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14]
    51 The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5]
    52 The 3 epoch, the 1 iteration, current batch is [ 6  7  8  9 10 11]
    53 The 3 epoch, the 2 iteration, current batch is [12 13 14  0  1  2]
    54 The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]

    如果tf.train.slice_input_producer(shuffle=True),输出为乱序,结果如下:

     1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
     2 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,)
     3 The 1 epoch, the 1 iteration, current batch is [ 2  5  8 11  3 10]
     4 The 1 epoch, the 2 iteration, current batch is [ 9 12  7  1 14 13]
     5 The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6]
     6 The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13  5]
     7 The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7]
     8 The 2 epoch, the 3 iteration, current batch is [10 13  1  4 12  3]
     9 The 3 epoch, the 1 iteration, current batch is [ 2  8  5  9 14  7]
    10 The 3 epoch, the 2 iteration, current batch is [ 0 11  6  1 14  9]
    11 The 3 epoch, the 3 iteration, current batch is [11  6 12  7  0 13]

    如果tf.train.batch(allow_smaller_final_batch=True),则会返回不足批次数目的数据,结果如下:

     1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
     2 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
     3 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
     4 The 1 epoch, the 3 iteration, current batch is [12 13 14]
     5 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
     6 The 2 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
     7 The 2 epoch, the 3 iteration, current batch is [12 13 14]
     8 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
     9 The 3 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
    10 The 3 epoch, the 3 iteration, current batch is [12 13 14]

    2.2 使用tf.data.Dataset类

     1 import tensorflow as tf
     2 import numpy as np
     3 import math
     4 
     5 # 生成样例数据集
     6 def generate_data():
     7     num = 15
     8     labels = np.asarray(range(num))
     9     images = np.random.random([num, 5, 5, 3])
    10     return images, labels
    11 # 打印样例信息
    12 images, labels = generate_data()
    13 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape))
    14 
    15 # 定义周期、批次、数据总数、遍历一次所有数据需的迭代次数
    16 n_epochs = 3
    17 batch_size = 6
    18 train_nums = 15
    19 iterations = math.ceil(train_nums/batch_size)
    20 
    21 # 使用from_tensor_slices将数据放入队列,使用batch和repeat划分数据批次,且让数据序列无限延续
    22 dataset = tf.data.Dataset.from_tensor_slices((images, labels))
    23 dataset = dataset.batch(batch_size).repeat()
    24 
    25 # 使用生成器make_one_shot_iterator和get_next取数据
    26 iterator = dataset.make_one_shot_iterator()
    27 next_iterator = iterator.get_next()
    28 
    29 with tf.Session() as sess:
    30     for epoch in range(n_epochs):
    31         for iteration in range(iterations):
    32             cu_image_batch, cu_label_batch = sess.run(next_iterator)
    33             print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch))
    34 
    35 
    36 # 结果如下:
    37 images.shape=(15, 5, 5, 3), labels.shape=(15,)
    38 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
    39 The 1 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
    40 The 1 epoch, the 3 iteration, current batch is [12 13 14]
    41 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
    42 The 2 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
    43 The 2 epoch, the 3 iteration, current batch is [12 13 14]
    44 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5]
    45 The 3 epoch, the 2 iteration, current batch is [ 6  7  8  9 10 11]
    46 The 3 epoch, the 3 iteration, current batch is [12 13 14]

    使用shuffle(),第23行修改为dataset = dataset.shuffle(100).batch(batch_size).repeat(),结果如下:

     1 images.shape=(15, 5, 5, 3), labels.shape=(15,)
     2 The 1 epoch, the 1 iteration, current batch is [ 7  4 10  8  3 11]
     3 The 1 epoch, the 2 iteration, current batch is [ 0  2 12 13 14  5]
     4 The 1 epoch, the 3 iteration, current batch is [6 9 1]
     5 The 2 epoch, the 1 iteration, current batch is [ 6 14  7  9  3  8]
     6 The 2 epoch, the 2 iteration, current batch is [13  5 12  1 11  2]
     7 The 2 epoch, the 3 iteration, current batch is [ 0  4 10]
     8 The 3 epoch, the 1 iteration, current batch is [10  8 13 12  3 14]
     9 The 3 epoch, the 2 iteration, current batch is [ 6  9  2  5  1 11]
    10 The 3 epoch, the 3 iteration, current batch is [0 4 7]

    !!!

  • 相关阅读:
    CentOS/Linux安装VNCserver
    vncserver的安装和使用
    linux下常用FTP命令 1. 连接ftp服务器
    linux下安装dovecot
    教你如何架设linux邮件服务器postfix
    vim打开文件时显示行号
    VirtualBox 配置虚拟网卡(桥接),实现主机-虚拟机网络互通
    Linux文件权限详解
    虚拟机下CentOS 6.5配置IP地址的三种方法
    Linux基础知识之man手册的使用
  • 原文地址:https://www.cnblogs.com/jfl-xx/p/9945967.html
Copyright © 2020-2023  润新知