Tensorflow学习过程中tfrecord的简单理解
1 TFRecord的介绍:
一般使用直接将数据加载到内存的方式来存储数据量较小的数据,然后再分batch输入网络进行训练。如果数据量太大,这种方法是十分消耗内存的,这时可以使用tensorflow提供的队列queue从文件中提取数据(比如csv文件等)。还有一种较为常用的,高效的读取方法,即使用tensorflow内定标准格式——TFRecords.作者也是刚接触tensorflow,对日常学习遇到的问题做简单记录,有不对地方需要指正。
1.1 什么是TFRecord?
TFRecord是谷歌推荐的一种常用的存储二进制序列数据的文件格式,理论上它可以保存任何格式的信息。下面是Tensorflow的官网给出的文档结构,整个文件由文件长度信息,长度校验码,数据,数据校验码组成。
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
2 代码及相关简介
2.1 构建写入数据的writer
import numpy as np
import tensorflow as tf
writer = tf.python_io.TFRecordWriter('test.tfrecord')
2.2 TFRecord
TensorFlow经常使用 tf.Example 来写入,读取TFRecord数据。
通常tf.example有下面几种数据结构:
- tf.train.FloatList: 可以使用的类型包括 float和double
- tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
- f.train.BytesList: 可以使用的类型包括 string和byte
TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature,如下所示:
#feature一般是多维数组,要先转为list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
#tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
下面以一个具体的简单例子来介绍tf.example
for k in range(0, 3):
x = 0.1712 + k
y = [1+k, 2+k]
z = np.array([[1,2,3],[4,5,6]]) + k
z = z.astype(np.uint8)
z_raw = z.tostring()
example = tf.train.Example(
features = tf.train.Features(
feature = {'x':tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
'y':tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
'z':tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
x,y,z分别是以float,int64和string的形式存储的,注意观察下面语句:
feature = {'x':tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
'y':tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
'z':tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}
value的值是一个list形式,x定义的为一个数,value的值应为[x],同样y定义的格式就是一个list所以value的值直接为y即可,z_raw是由z转换过来的string形式,对应的value值与x的形式应该是一样的。
2.3 创建文件读取队列并读取其中内容(字典格式)
#output file name string to a queue
filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs = None)
#Create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
#Get feature from serialized example
features = tf.parse_single_example(serialized_example,
features = {'x': tf.FixedLenFeature([],tf.float32),
'y': tf.FixedLenFeature([2],tf.int64),
'z': tf.FixedLenFeature([],tf.string)})
2.4 读取数据
x_out = features['x']
y_out = features['y']
z_raw_out = features['z']
z_out = tf.decode_raw(z_raw_out,tf.uint8)
z_out = tf.reshape(z_out, [2,3])
print(x_out)
print(y_out)
print(z_out)
显示结果为:
Tensor("ParseSingleExample_2/ParseSingleExample:0", shape=(), dtype=float32)
Tensor("ParseSingleExample_2/ParseSingleExample:1", shape=(2,), dtype=int64)
Tensor("Reshape_1:0", shape=(2, 3), dtype=uint8)
3 以存储图片为例理解TFRecord的应用
使用Tensorflow训练网络时,为提高数据的读取效率,一般都采用TFRecords格式。初学CNN我们使用了手写数字数据集学习,这些都是做好的数据集,我们可以直接使用,比如MNIST,CIFAR_10等。现在我们还不是很清楚怎样输入训练的图片,此时就要用到TFRecord来制作自己的数据集。
3.1 将图片转换成tfrecords格式
假设我们的输入的图片需要三种信息,图片的名字,图片维度以及图片的内容:name shape content
输入图片以及输出tfrecord文件:
input_photo = r'D:Furhjupyter codeTensorflow Tipsdatadog.jpg'
output_file = r'D:Furhjupyter codeTensorflow Tipsdog.tfrecord'
# 使用 TFRecordWriter 将信息写入到 TFRecord 文件
writer = tf.python_io.TFRecordWriter(output_file)
#读取图片进行解码
image = tf.read_file(input_photo)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
image_new = sess.run(image)
shape = image_new.shape
#将图片转换成string
image_data = image_new.tostring()
print(type(image_new))
print(len(image_data))
name = bytes('dog',encoding = 'utf-8')
print(type(name))
# 创建Example对象,将所有的Features填充进去
example = tf.train.Example(
features = tf.train.Features(
feature = {
'name': tf.train.Feature(bytes_list = tf.train.BytesList(value = [name])),
'shape': tf.train.Feature(int64_list = tf.train.Int64List(value = [shape[0],shape[1],shape[2]])),
'data': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data]))
}
))
# 将example序列化成string类型写入
writer.write(example.SerializeToString())
writer.close()
Note:
- Feature 中value应该是列表形式,当数据不是列表时,加上[]
- 解码后的图片要转化成string数据,再填充
- example需要使用SerializeToString()进行序列化
3.2 TFRecord 文件读取成图片
#解析数据
def parse_record(example_photo):
features = {
'name': tf.FixedLenFeature((),tf.string),
'shape': tf.FixedLenFeature([3],tf.int64), #这里制定维度3
'data' : tf.FixedLenFeature((),tf.string)
}
#在解析example时,用现成的API: tf.parse_single_example
parsed_features = tf.parse_single_example(example_photo,features = features)
return parsed_features
def read_test(input_file):
#使用dataset读取TFRecord文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
features = sess.run(iterator.get_next())
name = features['name']
name = name.decode
img_data = features['data']
shape = features['shape']
#从bytes数组中加载图片原始数据,并重新reshape,结果是ndarray数组
img_data = np.fromstring(img_data, dtype=np.uint8) #获取解析后的string数据,并把数据还原成unit8
image_data = np.reshape(img_data,shape)
plt.figure()
plt.imshow(image_data)
plt.show()
#将数据重新编码成jpg图片保存
img = tf.image.encode_jpeg(image_data)
#把图片保存到本地
tf.gfile.GFile('dog_encode,jpg', 'wb').write(img.eval())
read_test('dog.tfrecord')
Note:
在使用dataset进行样本解析之前,我们需要按照先定义一个解析字典,告诉dataset如何去解析每个样本,这个字典就是用来指定对于每条输入样本的每一列应该用什么的feature去解析,dataset默认提供了FixedLenFeature,VarLenFeature,FixedLenSequenceFeature等。
FixedLenFeature() 函数有三个参数:
- shape:输入数据的shape。
- dtype:输入的数据类型。
- default_value:如果示例缺少此功能,则使用该值。它必须与dtype和指定shape兼容。
代码注释:
主要参考:
TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
tensorflow学习笔记——高效读取数据的方法(TFRecord