• 数据存储方式tfrecord


    为什么使用tfrecord?

    正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

    来自 <https://www.jianshu.com/p/b480e5fcb638>

    这个就是队列,也就不用placeholder了,FIFO等队列了。

    TFrecord write

    #各个数据集生成tfrecord文件
    def tfrecord_write(args):
        #获取数据集分割后的txt文件名字
        txt_names = [txt_name for txt_name in os.listdir(args.out_path) if txt_name.split('.')[1]=='txt']
        #txt路径
        txt_paths = [os.path.join(args.out_path,txt_path) for txt_path in txt_names]
    
        #tfrecord文件名字
        tfrecord_names = [name.split('.')[0] for name in txt_names]
        #tfrecord路径
        tfrecord_paths = [os.path.join(args.out_path,tfrecord_path+'.tfrecord') for tfrecord_path in tfrecord_names]
    
        #产生txt文件数目个txrecord
        for txt_path, tfrecord_path in zip(txt_paths, tfrecord_paths):
            print(tfrecord_path)
            writer = tf.python_io.TFRecordWriter(tfrecord_path)
            with open(txt_path, 'r') as f:
                for line in f.readlines():
                    name,num = line.strip().split('	')
                    #print(name)
                    if name == 'George_W_Bush':
                        print("Because George_W_Bush has 530 pictures so we will not use it to save time ")
                        #尽可能一次writer 尽可能多的写进去数据,否则会很慢
                        continue
                    pics, len_pics = _get_pics(os.path.join(args.path,name))
                    for i, pic in enumerate(pics):
                        _store2tfrecord(pic, i, writer)
                    assert int(num) == len_pics
            writer.close()
    
    #读取某个人名文件夹下的所有人脸,以及图片个数
    def _get_pics(path):
        pic_list = [os.path.join(path, pic) for pic in os.listdir(path)]
        pics = []
    
        for pic_path in pic_list:
            pics.append(cv2.imread(pic_path))
    
        return np.asarray(pics), len(pic_list)
    
    #将某个文件夹下的图片和个数tfrecord保存
    def _store2tfrecord(pic, index, writer):
        pic_shape = list(pic.shape)
        print(pic_shape)
        pic_string = pic.tostring()
    
        example = tf.train.Example(features=tf.train.Features(
            feature={
                'index': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
                'shape':tf.train.Feature(int64_list=tf.train.Int64List(value=pic_shape)),
                'pic': tf.train.Feature(bytes_list=tf.train.BytesList(value=[pic_string]))
            }
        ))
        serialized = example.SerializeToString()
        writer.write(serialized)
        #这里不能有writer.close 否则就会关闭。
        #一个writer will make a tfrecord file ,if exists it will remake
    
    tfrecord_write(parsed)
    
    1、设置存放.tfrecord文件的位置
    2、在该位置生成tfrecord文件
        不要每个example都生成一个writer否则只会存储最后一个数据。
        • writer = tf.python_io.TFRecordWriter("位置信息")#方法一
        writer.close() 
        • with tf.python_io.TFRecordWriter(位置信息)  as writer#方法二
    3、组成example
    example = tf.train.Example(features=tf.train.Features(
                    feature={
                    'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                    'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                    }))
    值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,
    tfrecord支持整型、浮点数和二进制三种格式,分别是value必须是列表    
    tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))    
    tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))    
    tf.train.Feature(float_list = tf.train.FloatList(value=[float_scalar]))    
        
    如果单个保存的是list长度,value=[len_list]    
    如果单个保存的是数组的形状,value=array_array,这个在读取解析的时候[]里面就需要指定个数    
    如果是array的数据(比如图片)可以通过array_data.tostring()转化为string。再用byte_list
    这样可以节省空间。矩阵会失去维度。所以还要保存维度信息    
    如果读取的时候是批量读取的,每个‘label’的形状必须一样否则只能一个一个的读取    
    如果数据可以存储成(25,160,160,3)的形式,就不要存储成25个(160,160,3)的格式,后面一种存储空间大,并且时间长    
    10G的tfrecord格式文件速度比较快,在1080Ti上    
    
    4、序列化,减少内存
    example = example.SerializeToString()
    5、写进tfrecords
    writer.write(example)
    6、关闭tfrecord文件
    writer.close()

    TFrecord read

    1、创建文件队列
    files_queue = tf.train.string_input_producer(tfrecord_paths)
    2、创建reader
     reader = tf.TFRecordReader()  
    3、读取序列化后的文件名和example
      _, serialized_example = reader.read(路径) 
    4、反序列化
     features = tf.parse_single_example(  
            serialized_example,  
            features={  
                'a': tf.FixedLenFeature([], tf.float32),  
                'b': tf.FixedLenFeature([2], tf.int64),  
                'c': tf.FixedLenFeature([], tf.string)  
            }  
        ) 
        • 如果序列化时value = ['somgthing'],这里的[]内就不用写数字了
        • 如果序列化时value=something,这个something就要在[]指定这一个是由几个元素组成的了
    5、获取数值
         a = features['a']  
     b = features['b']  
     c_raw = features['c']
        如果是to_string过的,还必须经过三步
        • pic = tf.decode_raw(pic,tf.uint8)转换成指定格式
        • pic = tf.reshape(pic,pic_shape)
        • pic.set_shape([182,182,3])如果是tf.train.batch或者是shuffle_batch都必须用第三个,如果是一个一个的读取就不用第三个了
        
    
    sess使用
    一个一个读取
        sess = tf.Session()
        glo = tf.global_variables_initializer()
        loc = tf.local_variables_initializer()
        sess.run(glo)
        sess.run(loc)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess=sess, coord=coord)
        print('put the mouse in the window,press "q" for continue.')
        for i in range(20):
            real_pic, real_index, real_shape = sess.run([pic, index,shape_pic])
    一个批量读取
     with tf.Graph().as_default():
                pic, index, shape = _read_from_record(['./brief_test_name.tfrecord'])
                #这一句千万不要放进里面,否则tensorflow就会挂起来不动,不执行也不报错。
                pic_batch, indexs, shapes = tf.train.batch([pic, index, shape],
                                                           batch_size=16,
                                                           num_threads=2,
                                                           capacity=16 * 2
                                                           )
                init = tf.initialize_all_variables()
                with tf.Session() as sess:
                    sess.run(init)
                    coord = tf.train.Coordinator()
                    tf.train.start_queue_runners(sess, coord)
                    
                   
                        for i in range(32):
                            
                            print(sess.run([pic_batch,indexs,shapes]))
                            sess.close()
    
    ----------------------------代码---------------------------
    #一个批次一个批次的读取
    def tfrecord_read_batch(pic,index, batch_size, num_threads, capacity):
    
    
        pic_batch,indexs = tf.train.batch([pic, index],
                                             batch_size=batch_size,
                                             num_threads= num_threads,
                                             capacity= capacity)
    
        return pic_batch, indexs
    
    
    #一张一张的读取tfrecord图片,主要是用于测试
    def tfrecord_read_one():
        pic, index , shape_pic= _read_from_record(['./brief_test_name.tfrecord'])
    
        sess = tf.Session()
        glo = tf.global_variables_initializer()
        loc = tf.local_variables_initializer()
        sess.run(glo)
        sess.run(loc)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess=sess, coord=coord)
        print('put the mouse in the window,press "q" for continue.')
        for i in range(20):
            real_pic, real_index, real_shape = sess.run([pic, index,shape_pic])
            cv2.imshow('%s %s'%(real_index,list(real_shape)), real_pic)
            if cv2.waitKey(0) & 0xff == ord('q'):
                continue
    
        cv2.destroyAllWindows()
    
    #读取tfrecord数据
    def _read_from_record(tfrecord_paths):
        files_queue = tf.train.string_input_producer(tfrecord_paths)
        reader = tf.TFRecordReader()
        _,serialized_example = reader.read(files_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'index':tf.FixedLenFeature([],tf.int64),
                'shape':tf.FixedLenFeature([3],tf.int64),
                'pic'  :tf.FixedLenFeature([],tf.string)
            }
        )
        index = features['index']
        pic_shape = features['shape']
        pic = features['pic']
        pic = tf.decode_raw(pic,tf.uint8)
        #pics = tf.image.resize_images(pics, pics_shape)
        #下面这个就不行否则会 出错All shapes must be fully defined:
        pic = tf.reshape(pic,pic_shape)
    
        print(pic.get_shape())
        pic.set_shape([182,182,3])
        print(pic.get_shape())
        return pic, index, pic_shape
    
    if __name__ == "__main__":
        parsed = parse(sys.argv[1:])
        #默认应该执行1,2,3
        flg = 0
        
            
            with tf.Graph().as_default():
                pic, index, shape = _read_from_record(['./brief_test_name.tfrecord'])
                pic_batch, indexs, shapes = tf.train.batch([pic, index, shape],
                                                           batch_size=16,
                                                           num_threads=2,
                                                           capacity=16 * 2
                                                           )
                init = tf.initialize_all_variables()
                with tf.Session() as sess:
                    sess.run(init)
                    coord = tf.train.Coordinator()
                    threads = tf.train.start_queue_runners(sess, coord)
                    #tf.train.start_queue_runners(sess = sess)
                    try:
                        for i in range(32):
                            
                            print(sess.run([pic_batch,indexs,shapes]))
                            sess.close()
  • 相关阅读:
    尝试加载 Oracle 客户端库时引发 BadImageFormatException。如果在安装 32 位 Oracle 客户端组件的情况下以 64 位模式运行 已解决!
    iis 无法在Web服务器上启动调试。打开的URL的IIS辅助进程当前没有运行
    aspx页面,Page_Load 无人进入,解决
    Ajax后台传数组参数,接收不到报错!
    FusionCharts和highcharts 饼图区别!
    redis
    Hibernate不同数据库的连接及SQL方言
    Kafka
    Zookeeper
    BaseDao+万能方法 , HibernateDaoSupport
  • 原文地址:https://www.cnblogs.com/yunshangyue71/p/13611247.html
Copyright © 2020-2023  润新知