• Tensorflow学习记录 --TensorFlow高效读取数据tfrecord


    Tensorflow学习过程中tfrecord的简单理解

    1 TFRecord的介绍:

    一般使用直接将数据加载到内存的方式来存储数据量较小的数据,然后再分batch输入网络进行训练。如果数据量太大,这种方法是十分消耗内存的,这时可以使用tensorflow提供的队列queue从文件中提取数据(比如csv文件等)。还有一种较为常用的,高效的读取方法,即使用tensorflow内定标准格式——TFRecords.作者也是刚接触tensorflow,对日常学习遇到的问题做简单记录,有不对地方需要指正。

    1.1 什么是TFRecord?

    TFRecord是谷歌推荐的一种常用的存储二进制序列数据的文件格式,理论上它可以保存任何格式的信息。下面是Tensorflow的官网给出的文档结构,整个文件由文件长度信息,长度校验码,数据,数据校验码组成。

    uint64 length
    uint32 masked_crc32_of_length
    byte   data[length]
    uint32 masked_crc32_of_data
    

    2 代码及相关简介

    2.1 构建写入数据的writer

    import numpy as np 
    import tensorflow as tf 
    writer = tf.python_io.TFRecordWriter('test.tfrecord')
    

    2.2 TFRecord

    TensorFlow经常使用 tf.Example 来写入,读取TFRecord数据。

    通常tf.example有下面几种数据结构:

    • tf.train.FloatList: 可以使用的类型包括 float和double
    • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
    • f.train.BytesList: 可以使用的类型包括 string和byte

    TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature,如下所示:

    #feature一般是多维数组,要先转为list
    tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
     
    #tostring函数后feature的形状信息会丢失,把shape也写入
    tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) 
     
    tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
    

    下面以一个具体的简单例子来介绍tf.example

    for k in range(0, 3):
        x = 0.1712 + k
        y = [1+k, 2+k]
        z = np.array([[1,2,3],[4,5,6]]) + k
        z = z.astype(np.uint8)
        z_raw = z.tostring()
        example = tf.train.Example(
            features = tf.train.Features(
                feature = {'x':tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
                           'y':tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
                           'z':tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}))
        serialized = example.SerializeToString()
        writer.write(serialized)
    writer.close()
    

    x,y,z分别是以float,int64和string的形式存储的,注意观察下面语句:

    feature = {'x':tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
               'y':tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
               'z':tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}
    

    value的值是一个list形式,x定义的为一个数,value的值应为[x],同样y定义的格式就是一个list所以value的值直接为y即可,z_raw是由z转换过来的string形式,对应的value值与x的形式应该是一样的。

    2.3 创建文件读取队列并读取其中内容(字典格式)

    #output file name string to a queue
    filename_queue = tf.train.string_input_producer(['test.tfrecord'], num_epochs = None)
    #Create a reader from file queue
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    #Get feature from serialized example
    features = tf.parse_single_example(serialized_example,
                    features = {'x': tf.FixedLenFeature([],tf.float32),
                                'y': tf.FixedLenFeature([2],tf.int64),
                                'z': tf.FixedLenFeature([],tf.string)})
    

    2.4 读取数据

    x_out  = features['x']
    y_out  = features['y']
    z_raw_out = features['z']
    z_out = tf.decode_raw(z_raw_out,tf.uint8)
    z_out = tf.reshape(z_out, [2,3])
    print(x_out)
    print(y_out)
    print(z_out)
    

    显示结果为:

    Tensor("ParseSingleExample_2/ParseSingleExample:0", shape=(), dtype=float32)
    Tensor("ParseSingleExample_2/ParseSingleExample:1", shape=(2,), dtype=int64)
    Tensor("Reshape_1:0", shape=(2, 3), dtype=uint8)
    

    3 以存储图片为例理解TFRecord的应用

    使用Tensorflow训练网络时,为提高数据的读取效率,一般都采用TFRecords格式。初学CNN我们使用了手写数字数据集学习,这些都是做好的数据集,我们可以直接使用,比如MNIST,CIFAR_10等。现在我们还不是很清楚怎样输入训练的图片,此时就要用到TFRecord来制作自己的数据集。

    3.1 将图片转换成tfrecords格式

    假设我们的输入的图片需要三种信息,图片的名字,图片维度以及图片的内容:name shape content
    输入图片以及输出tfrecord文件:

    input_photo = r'D:Furhjupyter codeTensorflow Tipsdatadog.jpg'
    output_file = r'D:Furhjupyter codeTensorflow Tipsdog.tfrecord'
    
    # 使用 TFRecordWriter 将信息写入到 TFRecord 文件
    writer = tf.python_io.TFRecordWriter(output_file)
    #读取图片进行解码
    image = tf.read_file(input_photo)
    image = tf.image.decode_jpeg(image)
    
    with tf.Session() as sess:
        image_new = sess.run(image)
        shape = image_new.shape
        #将图片转换成string 
        image_data = image_new.tostring()
        print(type(image_new))
        print(len(image_data))
        name = bytes('dog',encoding = 'utf-8')
        print(type(name))
        # 创建Example对象,将所有的Features填充进去
        example = tf.train.Example(
                        features = tf.train.Features(
                            feature = {
                                'name': tf.train.Feature(bytes_list = tf.train.BytesList(value = [name])),
                                'shape': tf.train.Feature(int64_list = tf.train.Int64List(value = [shape[0],shape[1],shape[2]])),
                                'data': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data]))
                            }
                        ))
        # 将example序列化成string类型写入
        writer.write(example.SerializeToString())
    writer.close()
    

    Note:

    • Feature 中value应该是列表形式,当数据不是列表时,加上[]
    • 解码后的图片要转化成string数据,再填充
    • example需要使用SerializeToString()进行序列化

    3.2 TFRecord 文件读取成图片

    #解析数据 
    def parse_record(example_photo):
        features = {
            'name': tf.FixedLenFeature((),tf.string),
            'shape': tf.FixedLenFeature([3],tf.int64), #这里制定维度3
            'data' : tf.FixedLenFeature((),tf.string)
        }
        #在解析example时,用现成的API: tf.parse_single_example
        parsed_features = tf.parse_single_example(example_photo,features = features)
        return parsed_features
    
    def read_test(input_file):
        #使用dataset读取TFRecord文件
        dataset = tf.data.TFRecordDataset(input_file)
        dataset = dataset.map(parse_record)
        iterator = dataset.make_one_shot_iterator()
        
        with tf.Session() as sess:
            features = sess.run(iterator.get_next())
            name = features['name']
            name = name.decode
            img_data = features['data']
            shape = features['shape']
            
            #从bytes数组中加载图片原始数据,并重新reshape,结果是ndarray数组
            img_data = np.fromstring(img_data, dtype=np.uint8) #获取解析后的string数据,并把数据还原成unit8
            image_data = np.reshape(img_data,shape)
            
            plt.figure()
            plt.imshow(image_data)
            plt.show()
            
            #将数据重新编码成jpg图片保存
            img = tf.image.encode_jpeg(image_data)
            #把图片保存到本地    
            tf.gfile.GFile('dog_encode,jpg', 'wb').write(img.eval())
    
    read_test('dog.tfrecord')
    

    Note:
    在使用dataset进行样本解析之前,我们需要按照先定义一个解析字典,告诉dataset如何去解析每个样本,这个字典就是用来指定对于每条输入样本的每一列应该用什么的feature去解析,dataset默认提供了FixedLenFeature,VarLenFeature,FixedLenSequenceFeature等。

    FixedLenFeature() 函数有三个参数:

    • shape:输入数据的shape。
    • dtype:输入的数据类型。
    • default_value:如果示例缺少此功能,则使用该值。它必须与dtype和指定shape兼容。

    代码注释:

    主要参考:
    TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
    tensorflow学习笔记——高效读取数据的方法(TFRecord

  • 相关阅读:
    asp的多国语言构思
    制作IE和FF都支持的无限级关联菜单
    破解网络尖兵(真正对付限制ADSL路由共享的方法)
    Asp透过系统DSN链接mysql数据库
    找到了一首曾经很喜欢的老歌
    生意人应具备的性格
    简单的操作让你的迅雷变的清爽
    线路分流自动跳转代码
    通过regsvr32注册DLL可以解决的一些疑难杂症
    页面无刷新超时自动退出
  • 原文地址:https://www.cnblogs.com/ysfurh/p/14127941.html
Copyright © 2020-2023  润新知