TensorFlow------TFRecords的分析与存储实例:
import os
import tensorflow as tf
# 定义cifar的数据等命令行参数
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('cifar_dir','./data/cifar10/cifar-10-batches-bin','文件的目录')
tf.app.flags.DEFINE_string('cifar_tfrecords','./tmp/cifar.tfrecords','存储tfrecords的文件')
class CifarRead(object):
'''
完成读取二进制文件,写进tfrecords,读取tfrecords
:param object:
:return:
'''
def __init__(self,filelist):
# 文件列表
self.file_list = filelist
# 定义读取的图片的一些属性
self.height = 32
self.width = 32
self.channel = 3
# 二进制文件每张图片的字节
self.label_bytes = 1
self.image_bytes = self.height * self.width * self.channel
self.bytes = self.label_bytes + self.image_bytes
def read_and_decode(self):
# 1. 构建文件队列
file_queue = tf.train.string_input_producer(self.file_list)
# 2. 构建二进制文件读取器,读取内容,每个样本的字节数
reader = tf.FixedLengthRecordReader(self.bytes)
key,value = reader.read(file_queue)
# 3. 解码内容,二进制文件内容的解码 label_image包含目标值和特征值
label_image = tf.decode_raw(value,tf.uint8)
print(label_image)
# 4.分割出图片和标签数据,特征值和目标值
label = tf.slice(label_image,[0],[self.label_bytes])
image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
print('---->')
print(image)
# 5. 可以对图片的特征数据进行形状的改变 [3072]-->[32,32,3]
image_reshape = tf.reshape(image,[self.height,self.width,self.channel])
print('======>')
print(label)
print('======>')
# 6. 批处理数据
image_batch,label_batch = tf.train.batch([image_reshape,label],batch_size=10,num_threads=1,capacity=10)
print(image_batch,label_batch)
return image_batch,label_batch
def write_ro_tfrecords(self,image_batch,label_batch):
'''
将图片的特征值和目标值存进tfrecords
:param image_batch: 10张图片的特征值
:param label_batch: 10张图片的目标值
:return: None
'''
# 1.建立TFRecord存储器
writer = tf.python_io.TFRecordWriter(FLAGS.cifar_tfrecords)
# 2. 循环将所有样本写入文件,每张图片样本都要构造example协议
for i in range(10):
# 取出第i个图片数据的特征值和目标值
image = image_batch[i].eval().tostring()
label = int(label_batch[i].eval()[0])
# 构造一个样本的example
example = tf.train.Example(features=tf.train.Features(feature={
'image':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
}))
# 写入单独的样本
writer.write(example.SerializeToString())
# 关闭
writer.close()
return None
if __name__ == '__main__':
# 找到文件,构建列表 路径+名字 ->列表当中
file_name = os.listdir(FLAGS.cifar_dir)
# 拼接路径 重新组成列表
filelist = [os.path.join(FLAGS.cifar_dir,file) for file in file_name if file[-3:] == 'bin']
# 调用函数传参
cf = CifarRead(filelist)
image_batch,label_batch = cf.read_and_decode()
# 开启会话
with tf.Session() as sess:
# 定义一个线程协调器
coord = tf.train.Coordinator()
# 开启读文件的线程
threads = tf.train.start_queue_runners(sess,coord=coord)
# 存进tfrecords文件
print('开始存储')
cf.write_ro_tfrecords(image_batch,label_batch)
print('结束存储')
# 打印读取的内容
# print(sess.run([image_batch,label_batch]))
# 回收子线程
coord.request_stop()
coord.join(threads)