为什么使用tfrecord?
正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。
来自 <https://www.jianshu.com/p/b480e5fcb638>
这个就是队列,也就不用placeholder了,FIFO等队列了。
TFrecord write
#各个数据集生成tfrecord文件 def tfrecord_write(args): #获取数据集分割后的txt文件名字 txt_names = [txt_name for txt_name in os.listdir(args.out_path) if txt_name.split('.')[1]=='txt'] #txt路径 txt_paths = [os.path.join(args.out_path,txt_path) for txt_path in txt_names] #tfrecord文件名字 tfrecord_names = [name.split('.')[0] for name in txt_names] #tfrecord路径 tfrecord_paths = [os.path.join(args.out_path,tfrecord_path+'.tfrecord') for tfrecord_path in tfrecord_names] #产生txt文件数目个txrecord for txt_path, tfrecord_path in zip(txt_paths, tfrecord_paths): print(tfrecord_path) writer = tf.python_io.TFRecordWriter(tfrecord_path) with open(txt_path, 'r') as f: for line in f.readlines(): name,num = line.strip().split(' ') #print(name) if name == 'George_W_Bush': print("Because George_W_Bush has 530 pictures so we will not use it to save time ") #尽可能一次writer 尽可能多的写进去数据,否则会很慢 continue pics, len_pics = _get_pics(os.path.join(args.path,name)) for i, pic in enumerate(pics): _store2tfrecord(pic, i, writer) assert int(num) == len_pics writer.close() #读取某个人名文件夹下的所有人脸,以及图片个数 def _get_pics(path): pic_list = [os.path.join(path, pic) for pic in os.listdir(path)] pics = [] for pic_path in pic_list: pics.append(cv2.imread(pic_path)) return np.asarray(pics), len(pic_list) #将某个文件夹下的图片和个数tfrecord保存 def _store2tfrecord(pic, index, writer): pic_shape = list(pic.shape) print(pic_shape) pic_string = pic.tostring() example = tf.train.Example(features=tf.train.Features( feature={ 'index': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'shape':tf.train.Feature(int64_list=tf.train.Int64List(value=pic_shape)), 'pic': tf.train.Feature(bytes_list=tf.train.BytesList(value=[pic_string])) } )) serialized = example.SerializeToString() writer.write(serialized) #这里不能有writer.close 否则就会关闭。 #一个writer will make a tfrecord file ,if exists it will remake tfrecord_write(parsed) 1、设置存放.tfrecord文件的位置 2、在该位置生成tfrecord文件 不要每个example都生成一个writer否则只会存储最后一个数据。 • writer = tf.python_io.TFRecordWriter("位置信息")#方法一 writer.close() • with tf.python_io.TFRecordWriter(位置信息) as writer#方法二 3、组成example example = tf.train.Example(features=tf.train.Features( feature={ 'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])), 'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw])) })) 值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知, tfrecord支持整型、浮点数和二进制三种格式,分别是value必须是列表 tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar])) tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte])) tf.train.Feature(float_list = tf.train.FloatList(value=[float_scalar])) 如果单个保存的是list长度,value=[len_list] 如果单个保存的是数组的形状,value=array_array,这个在读取解析的时候[]里面就需要指定个数 如果是array的数据(比如图片)可以通过array_data.tostring()转化为string。再用byte_list 这样可以节省空间。矩阵会失去维度。所以还要保存维度信息 如果读取的时候是批量读取的,每个‘label’的形状必须一样否则只能一个一个的读取 如果数据可以存储成(25,160,160,3)的形式,就不要存储成25个(160,160,3)的格式,后面一种存储空间大,并且时间长 10G的tfrecord格式文件速度比较快,在1080Ti上 4、序列化,减少内存 example = example.SerializeToString() 5、写进tfrecords writer.write(example) 6、关闭tfrecord文件 writer.close()
TFrecord read
1、创建文件队列 files_queue = tf.train.string_input_producer(tfrecord_paths) 2、创建reader reader = tf.TFRecordReader() 3、读取序列化后的文件名和example _, serialized_example = reader.read(路径) 4、反序列化 features = tf.parse_single_example( serialized_example, features={ 'a': tf.FixedLenFeature([], tf.float32), 'b': tf.FixedLenFeature([2], tf.int64), 'c': tf.FixedLenFeature([], tf.string) } ) • 如果序列化时value = ['somgthing'],这里的[]内就不用写数字了 • 如果序列化时value=something,这个something就要在[]指定这一个是由几个元素组成的了 5、获取数值 a = features['a'] b = features['b'] c_raw = features['c'] 如果是to_string过的,还必须经过三步 • pic = tf.decode_raw(pic,tf.uint8)转换成指定格式 • pic = tf.reshape(pic,pic_shape) • pic.set_shape([182,182,3])如果是tf.train.batch或者是shuffle_batch都必须用第三个,如果是一个一个的读取就不用第三个了 sess使用 一个一个读取 sess = tf.Session() glo = tf.global_variables_initializer() loc = tf.local_variables_initializer() sess.run(glo) sess.run(loc) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) print('put the mouse in the window,press "q" for continue.') for i in range(20): real_pic, real_index, real_shape = sess.run([pic, index,shape_pic]) 一个批量读取 with tf.Graph().as_default(): pic, index, shape = _read_from_record(['./brief_test_name.tfrecord']) #这一句千万不要放进里面,否则tensorflow就会挂起来不动,不执行也不报错。 pic_batch, indexs, shapes = tf.train.batch([pic, index, shape], batch_size=16, num_threads=2, capacity=16 * 2 ) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess, coord) for i in range(32): print(sess.run([pic_batch,indexs,shapes])) sess.close() ----------------------------代码--------------------------- #一个批次一个批次的读取 def tfrecord_read_batch(pic,index, batch_size, num_threads, capacity): pic_batch,indexs = tf.train.batch([pic, index], batch_size=batch_size, num_threads= num_threads, capacity= capacity) return pic_batch, indexs #一张一张的读取tfrecord图片,主要是用于测试 def tfrecord_read_one(): pic, index , shape_pic= _read_from_record(['./brief_test_name.tfrecord']) sess = tf.Session() glo = tf.global_variables_initializer() loc = tf.local_variables_initializer() sess.run(glo) sess.run(loc) coord = tf.train.Coordinator() tf.train.start_queue_runners(sess=sess, coord=coord) print('put the mouse in the window,press "q" for continue.') for i in range(20): real_pic, real_index, real_shape = sess.run([pic, index,shape_pic]) cv2.imshow('%s %s'%(real_index,list(real_shape)), real_pic) if cv2.waitKey(0) & 0xff == ord('q'): continue cv2.destroyAllWindows() #读取tfrecord数据 def _read_from_record(tfrecord_paths): files_queue = tf.train.string_input_producer(tfrecord_paths) reader = tf.TFRecordReader() _,serialized_example = reader.read(files_queue) features = tf.parse_single_example( serialized_example, features={ 'index':tf.FixedLenFeature([],tf.int64), 'shape':tf.FixedLenFeature([3],tf.int64), 'pic' :tf.FixedLenFeature([],tf.string) } ) index = features['index'] pic_shape = features['shape'] pic = features['pic'] pic = tf.decode_raw(pic,tf.uint8) #pics = tf.image.resize_images(pics, pics_shape) #下面这个就不行否则会 出错All shapes must be fully defined: pic = tf.reshape(pic,pic_shape) print(pic.get_shape()) pic.set_shape([182,182,3]) print(pic.get_shape()) return pic, index, pic_shape if __name__ == "__main__": parsed = parse(sys.argv[1:]) #默认应该执行1,2,3 flg = 0 with tf.Graph().as_default(): pic, index, shape = _read_from_record(['./brief_test_name.tfrecord']) pic_batch, indexs, shapes = tf.train.batch([pic, index, shape], batch_size=16, num_threads=2, capacity=16 * 2 ) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) #tf.train.start_queue_runners(sess = sess) try: for i in range(32): print(sess.run([pic_batch,indexs,shapes])) sess.close()