1.简介
将数据划分成若干批次的数据,可以使用tf.train或者tf.data.Dataset中的方法。
1.1 tf.train
tf.train.slice_input_producer(tensor_list,
shuffle=True,
seed=None,
capacity=
32
)
tf.train.batch(tensors,
batch_size,
num_threads=1,
capacity=32
,
allow_smaller_final_batch=False
)
参数说明:
shuffle:为True时进行数据清洗
allow_smaller_final_batch:为True时将小于batch_size的批次值输出
-------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------
1.2 tf.data.Dataset
tf.data.Dataset是一个类,可以使用以下方法:
from_tensor_slices(tensors
)
batch(batch_size,
drop_remainder=False
)
shuffle(buffer_size,
seed=None,
reshuffle_each_iteration=None
)
repeat(count=None
)
make_one_shot_iterator() / get_next()
注:make_one_shot_iterator() / get_next()用于Dataset数据的迭代器
参数说明:
tensors:可以是列表、字典、元组等类型
drop_remainder:为False时表示不保留小于batch_size的批次,否则删除
buffer_size:数据清洗时使用的buffer大小
count:对应为epoch个数,为None时表示数据序列无限延续
2.示例
2.1 使用tf.train.slice_input_producer和tf.train.batch
1 import tensorflow as tf 2 import numpy as np 3 import math 4 5 # 生成样例数据集 6 def generate_data(): 7 num = 15 8 labels = np.asarray(range(num)) 9 images = np.random.random([num, 5, 5, 3]) 10 return images, labels 11 12 # 打印样例信息 13 images, labels = generate_data() 14 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape)) 15 16 # 定义周期、批次、数据总量和遍历一次所有数据所需的迭代次数 17 n_epochs = 3 18 batch_size = 6 19 train_nums = 15 20 iterations = math.ceil(train_nums/batch_size) 21 22 # 使用tf.train.slice_input_producer将所有数据放入队列,使用tf.train.batch划分队列中的数据 23 input_queue = tf.train.slice_input_producer([images, labels], shuffle=False) 24 image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=1, capacity=32) 25 print('image_batch.shape={0}, label_batch.shape={1}'.format(image_batch.shape, label_batch.shape)) 26 27 28 with tf.Session() as sess: 29 tf.global_variables_initializer().run() 30 # 启动队列线程 31 coord = tf.train.Coordinator() 32 threads = tf.train.start_queue_runners(sess, coord) 33 # 打印信息 34 for epoch in range(n_epochs): 35 for iteration in range(iterations): 36 cu_image_batch, cu_label_batch = sess.run([image_batch, label_batch]) 37 print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch)) 38 # 接收线程 39 coord.request_stop() 40 coord.join(threads) 41 42 43 # 打印结果如下 44 images.shape=(15, 5, 5, 3), labels.shape=(15,) 45 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,) 46 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 47 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 48 The 1 epoch, the 3 iteration, current batch is [12 13 14 0 1 2] 49 The 2 epoch, the 1 iteration, current batch is [3 4 5 6 7 8] 50 The 2 epoch, the 2 iteration, current batch is [ 9 10 11 12 13 14] 51 The 2 epoch, the 3 iteration, current batch is [0 1 2 3 4 5] 52 The 3 epoch, the 1 iteration, current batch is [ 6 7 8 9 10 11] 53 The 3 epoch, the 2 iteration, current batch is [12 13 14 0 1 2] 54 The 3 epoch, the 3 iteration, current batch is [3 4 5 6 7 8]
如果tf.train.slice_input_producer(shuffle=True),输出为乱序,结果如下:
1 images.shape=(15, 5, 5, 3), labels.shape=(15,) 2 image_batch.shape=(6, 5, 5, 3), label_batch.shape=(6,) 3 The 1 epoch, the 1 iteration, current batch is [ 2 5 8 11 3 10] 4 The 1 epoch, the 2 iteration, current batch is [ 9 12 7 1 14 13] 5 The 1 epoch, the 3 iteration, current batch is [0 6 4 2 3 6] 6 The 2 epoch, the 1 iteration, current batch is [11 10 12 14 13 5] 7 The 2 epoch, the 2 iteration, current batch is [8 1 0 9 4 7] 8 The 2 epoch, the 3 iteration, current batch is [10 13 1 4 12 3] 9 The 3 epoch, the 1 iteration, current batch is [ 2 8 5 9 14 7] 10 The 3 epoch, the 2 iteration, current batch is [ 0 11 6 1 14 9] 11 The 3 epoch, the 3 iteration, current batch is [11 6 12 7 0 13]
如果tf.train.batch(allow_smaller_final_batch=True),则会返回不足批次数目的数据,结果如下:
1 images.shape=(15, 5, 5, 3), labels.shape=(15,) 2 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 3 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 4 The 1 epoch, the 3 iteration, current batch is [12 13 14] 5 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 6 The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 7 The 2 epoch, the 3 iteration, current batch is [12 13 14] 8 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 9 The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 10 The 3 epoch, the 3 iteration, current batch is [12 13 14]
2.2 使用tf.data.Dataset类
1 import tensorflow as tf 2 import numpy as np 3 import math 4 5 # 生成样例数据集 6 def generate_data(): 7 num = 15 8 labels = np.asarray(range(num)) 9 images = np.random.random([num, 5, 5, 3]) 10 return images, labels 11 # 打印样例信息 12 images, labels = generate_data() 13 print('images.shape={0}, labels.shape={1}'.format(images.shape, labels.shape)) 14 15 # 定义周期、批次、数据总数、遍历一次所有数据需的迭代次数 16 n_epochs = 3 17 batch_size = 6 18 train_nums = 15 19 iterations = math.ceil(train_nums/batch_size) 20 21 # 使用from_tensor_slices将数据放入队列,使用batch和repeat划分数据批次,且让数据序列无限延续 22 dataset = tf.data.Dataset.from_tensor_slices((images, labels)) 23 dataset = dataset.batch(batch_size).repeat() 24 25 # 使用生成器make_one_shot_iterator和get_next取数据 26 iterator = dataset.make_one_shot_iterator() 27 next_iterator = iterator.get_next() 28 29 with tf.Session() as sess: 30 for epoch in range(n_epochs): 31 for iteration in range(iterations): 32 cu_image_batch, cu_label_batch = sess.run(next_iterator) 33 print('The {0} epoch, the {1} iteration, current batch is {2}'.format(epoch+1,iteration+1,cu_label_batch)) 34 35 36 # 结果如下: 37 images.shape=(15, 5, 5, 3), labels.shape=(15,) 38 The 1 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 39 The 1 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 40 The 1 epoch, the 3 iteration, current batch is [12 13 14] 41 The 2 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 42 The 2 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 43 The 2 epoch, the 3 iteration, current batch is [12 13 14] 44 The 3 epoch, the 1 iteration, current batch is [0 1 2 3 4 5] 45 The 3 epoch, the 2 iteration, current batch is [ 6 7 8 9 10 11] 46 The 3 epoch, the 3 iteration, current batch is [12 13 14]
使用shuffle(),第23行修改为dataset = dataset.shuffle(100).batch(batch_size).repeat(),结果如下:
1 images.shape=(15, 5, 5, 3), labels.shape=(15,) 2 The 1 epoch, the 1 iteration, current batch is [ 7 4 10 8 3 11] 3 The 1 epoch, the 2 iteration, current batch is [ 0 2 12 13 14 5] 4 The 1 epoch, the 3 iteration, current batch is [6 9 1] 5 The 2 epoch, the 1 iteration, current batch is [ 6 14 7 9 3 8] 6 The 2 epoch, the 2 iteration, current batch is [13 5 12 1 11 2] 7 The 2 epoch, the 3 iteration, current batch is [ 0 4 10] 8 The 3 epoch, the 1 iteration, current batch is [10 8 13 12 3 14] 9 The 3 epoch, the 2 iteration, current batch is [ 6 9 2 5 1 11] 10 The 3 epoch, the 3 iteration, current batch is [0 4 7]
!!!