简介
- TFRecord是TensorFlow官方推荐使用的数据格式化存储工具。
- 它规范了数据的读写方式。
- 只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。
将图片转换成TFRecord
本例,将fashion-MNIST数据转换成TFRecord,需要先下载fashion数据集到当前目录下,参考:https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion
import numpy as np import tensorflow as tf import gzip import os fashion_mnist_directory = './data/fashion/' def load_mnist(path, kind='train'): labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) with gzip.open(labels_path, 'rb') as lbpath: labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(images_path, 'rb') as imgpath: images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(-1, 784) print(labels_path, "shape =", labels.shape) print(images_path, "shape =", images.shape) return images, labels def make_example(image, label): return tf.train.Example(features=tf.train.Features(feature={ 'image_raw' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])), 'label' : tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label) ])) })) def write_tfrecord(images, labels, filename): writer = tf.python_io.TFRecordWriter(filename) for image, label, k in zip(images, labels, range(labels.shape[0])): exam = make_example(image, label) writer.write(exam.SerializeToString()) if (k%100 == 0): print(" writing", filename, "%6.2f%% complited." %(100.0*(k+1)/labels.shape[0]), end='') print(" writing", filename, "%6.2f%% complited." %(100.0)) writer.close() def main(): train_images, train_labels = load_mnist(fashion_mnist_directory, 'train') test_images, test_labels = load_mnist(fashion_mnist_directory, 't10k') write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecords') write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecords') if __name__ == '__main__': main()
读取TFRecord数据来训练
以下代码读取TFRecord数据用于训练,改代码改编自官方例程:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data
原始代码运行时报错,已修复。
注意:在这个例子中,_, loss_value = sess.run([train_op, loss]),只执行一次Batch Input,无论[]中是什么,有多少个操作。
import argparse import os.path import sys import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import mnist FLAGS = None TRAIN_FILE = 'fashion_mnist_train.tfrecords' VALIDATION_FILE = 'fashion_mnist_test.tfrecords' def decode(serialized_example): features = tf.parse_single_example(serialized_example, features={'image_raw': tf.FixedLenFeature([], tf.string), 'label': tf.FixedLenFeature([], tf.int64)}) image = tf.decode_raw(features['image_raw'], tf.uint8) image.set_shape((mnist.IMAGE_PIXELS)) label = tf.cast(features['label'], tf.int32) return image, label def augment(image, label): """Placeholder for data augmentation.""" # OPTIONAL: Could reshape into a 28x28 image and apply distortions here. return image, label def normalize(image, label): """Convert `image` from [0, 255] -> [-0.5, 0.5] floats.""" image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 return image, label def inputs(train, batch_size, num_epochs): """Reads input data""" if not num_epochs: num_epochs = None filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE) with tf.name_scope('input'): dataset = tf.data.TFRecordDataset(filename) dataset = dataset.map(decode) dataset = dataset.map(augment) dataset = dataset.map(normalize) dataset = dataset.shuffle(1000 + 3 * batch_size) dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) iterator = dataset.make_one_shot_iterator() return iterator.get_next() def run_training(): with tf.Graph().as_default(): image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs) logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2) loss = mnist.loss(logits, label_batch) train_op = mnist.training(loss, FLAGS.learning_rate) init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init_op) try: step = 0 while True: # Train until OutOfRangeError start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time if step % 100 == 0: print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) step += 1 except tf.errors.OutOfRangeError: print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) def main(_): run_training() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate.') parser.add_argument('--num_epochs', type=int, default=2, help='Number of epochs to run trainer.') parser.add_argument('--hidden1', type=int, default=128, help='Number of units in hidden layer 1.') parser.add_argument('--hidden2', type=int, default=32, help='Number of units in hidden layer 2.') parser.add_argument('--batch_size', type=int, default=100, help='Batch size.') parser.add_argument('--train_dir', type=str, default='./', help='Directory with the training data.') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
参考了:
- https://blog.csdn.net/gg_18826075157/article/details/78449104
- https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py