• Tensorflow 中(批量)读取数据的案列分析及TFRecord文件的打包与读取


    内容概要:

    单一数据读取方式:

      第一种:slice_input_producer()

    # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...]
    [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)

      第二种:string_input_producer()

    # 需要定义文件读取器,然后通过读取器中的 read()方法来获取数据(返回值类型 key,value),再通过 Session.run(value)查看
    file_queue = tf.train.string_input_producer(filename, num_epochs=None, shuffle=True)

    reader = tf.WholeFileReader() # 定义文件读取器
    key, value = reader.read(file_queue)    # key:文件名;value:文件中的内容

      !!!num_epochs=None,不指定迭代次数,这样文件队列中元素个数也不限定(None*数据集大小)。

      !!!如果不是None,则此函数创建本地计数器 epochs,需要使用local_variables_initializer()初始化局部变量

      !!!以上两种方法都可以生成文件名队列。

    (随机)批量数据读取方式:

    batchsize=2  # 每次读取的样本数量
    tf.train.batch(tensors, batch_size=batchsize)
    tf.train.shuffle_batch(tensors, batch_size=batchsize, capacity=batchsize*10, min_after_dequeue=batchsize*5) # capacity > min_after_dequeue

      !!!以上所有读取数据的方法,在Session.run()之前必须开启文件队列线程 tf.train.start_queue_runners()

     TFRecord文件的打包与读取

     一、单一数据读取方式

    第一种:slice_input_producer()

    def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None)

    案例1:

    import tensorflow as tf
    
    images = ['image1.jpg', 'image2.jpg', 'image3.jpg', 'image4.jpg']
    labels = [1, 2, 3, 4]
    
    # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
    
    # 当num_epochs=2时,此时文件队列中只有 2*4=8个样本,所有在取第9个样本时会出错
    # [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=2, shuffle=True)
    
    data = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True)
    print(type(data))   # <class 'list'>
    
    with tf.Session() as sess:
        # sess.run(tf.local_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()  # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)  # 开始在图表中收集队列运行器
    
        for i in range(10):
            print(sess.run(data))
    
        coord.request_stop()
        coord.join(threads)
    
    """
    运行结果:
    [b'image2.jpg', 2]
    [b'image1.jpg', 1]
    [b'image3.jpg', 3]
    [b'image4.jpg', 4]
    [b'image2.jpg', 2]
    [b'image1.jpg', 1]
    [b'image3.jpg', 3]
    [b'image4.jpg', 4]
    [b'image2.jpg', 2]
    [b'image3.jpg', 3]
    """

      !!!slice_input_producer() 中的第一个参数需要放在一个列表中,列表中的每个元素可以是 List 或 Tensor,如 [images,labels],

      !!!num_epochs设置

     第二种:string_input_producer()

    def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)

    文件读取器

      不同类型的文件对应不同的文件读取器,我们称为 reader对象

      该对象的 read 方法自动读取文件,并创建数据队列,输出key/文件名,value/文件内容;

    reader = tf.TextLineReader()      ### 一行一行读取,适用于所有文本文件

    reader = tf.TFRecordReader() ### A Reader that outputs the records from a TFRecords file
    reader = tf.WholeFileReader() ### 一次读取整个文件,适用图片

     案例2:读取csv文件

    iimport tensorflow as tf
    
    filename = ['data/A.csv', 'data/B.csv', 'data/C.csv']
    
    file_queue = tf.train.string_input_producer(filename, shuffle=True, num_epochs=2)   # 生成文件名队列
    reader = tf.WholeFileReader()           # 定义文件读取器(一次读取整个文件)
    # reader = tf.TextLineReader()            # 定义文件读取器(一行一行的读)
    key, value = reader.read(file_queue)    # key:文件名;value:文件中的内容
    print(type(file_queue))
    
    init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
    with tf.Session() as sess:
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        try:
            while not coord.should_stop():
                for i in range(6):
                    print(sess.run([key, value]))
                break
        except tf.errors.OutOfRangeError:
            print('read done')
        finally:
            coord.request_stop()
        coord.join(threads)
    
    """
    reader = tf.WholeFileReader()           # 定义文件读取器(一次读取整个文件)
    运行结果:
    [b'data/C.csv', b'7.jpg,7
    8.jpg,8
    9.jpg,9
    ']
    [b'data/B.csv', b'4.jpg,4
    5.jpg,5
    6.jpg,6
    ']
    [b'data/A.csv', b'1.jpg,1
    2.jpg,2
    3.jpg,3
    ']
    [b'data/A.csv', b'1.jpg,1
    2.jpg,2
    3.jpg,3
    ']
    [b'data/B.csv', b'4.jpg,4
    5.jpg,5
    6.jpg,6
    ']
    [b'data/C.csv', b'7.jpg,7
    8.jpg,8
    9.jpg,9
    ']
    """
    """
    reader = tf.TextLineReader()           # 定义文件读取器(一行一行的读)
    运行结果:
    [b'data/B.csv:1', b'4.jpg,4']
    [b'data/B.csv:2', b'5.jpg,5']
    [b'data/B.csv:3', b'6.jpg,6']
    [b'data/C.csv:1', b'7.jpg,7']
    [b'data/C.csv:2', b'8.jpg,8']
    [b'data/C.csv:3', b'9.jpg,9']
    """

    案例3:读取图片(每次读取全部图片内容,不是一行一行)

    import tensorflow as tf
    
    filename = ['1.jpg', '2.jpg']
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=1)
    reader = tf.WholeFileReader()              # 文件读取器
    key, value = reader.read(filename_queue)   # 读取文件 key:文件名;value:图片数据,bytes
    
    with tf.Session() as sess:
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()      # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)
    
        for i in range(filename.__len__()):
            image_data = sess.run(value)
            with open('img_%d.jpg' % i, 'wb') as f:
                f.write(image_data)
        coord.request_stop()
        coord.join(threads)

     二、(随机)批量数据读取方式:

      功能:shuffle_batch() 和 batch() 这两个API都是从文件队列中批量获取数据,使用方式类似;

    案例4:slice_input_producer() 与 batch()

    import tensorflow as tf
    import numpy as np
    
    images = np.arange(20).reshape([10, 2])
    label = np.asarray(range(0, 10))
    images = tf.cast(images, tf.float32)  # 可以注释掉,不影响运行结果
    label = tf.cast(label, tf.int32)     # 可以注释掉,不影响运行结果
    
    batchsize = 6   # 每次获取元素的数量
    input_queue = tf.train.slice_input_producer([images, label], num_epochs=None, shuffle=False)
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchsize)
    
    # 随机获取 batchsize个元素,其中,capacity:队列容量,这个参数一定要比 min_after_dequeue 大
    # image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=batchsize, capacity=64, min_after_dequeue=10)
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()      # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
        for cnt in range(2):
            print("第{}次获取数据,每次batch={}...".format(cnt+1, batchsize))
            image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
            print(image_batch_v, label_batch_v, label_batch_v.__len__())
    
        coord.request_stop()
        coord.join(threads)
    
    """
    运行结果:
    第1次获取数据,每次batch=6...
    [[ 0.  1.]
     [ 2.  3.]
     [ 4.  5.]
     [ 6.  7.]
     [ 8.  9.]
     [10. 11.]] [0 1 2 3 4 5] 6
    第2次获取数据,每次batch=6...
    [[12. 13.]
     [14. 15.]
     [16. 17.]
     [18. 19.]
     [ 0.  1.]
     [ 2.  3.]] [6 7 8 9 0 1] 6
    """

     案例5:从本地批量的读取图片 --- string_input_producer() 与 batch()

     1 import tensorflow as tf
     2 import glob
     3 import cv2 as cv
     4 
     5 def read_imgs(filename, picture_format, input_image_shape, batch_size=1):
     6     """
     7     从本地批量的读取图片
     8     :param filename: 图片路径(包括图片的文件名),[]
     9     :param picture_format: 图片的格式,如 bmp,jpg,png等; string
    10     :param input_image_shape: 输入图像的大小; (h,w,c)或[]
    11     :param batch_size: 每次从文件队列中加载图片的数量; int
    12     :return: batch_size张图片数据, Tensor
    13     """
    14     global new_img
    15     # 创建文件队列
    16     file_queue = tf.train.string_input_producer(filename, num_epochs=1, shuffle=True)
    17     # 创建文件读取器
    18     reader = tf.WholeFileReader()
    19     # 读取文件队列中的文件
    20     _, img_bytes = reader.read(file_queue)
    21     # print(img_bytes)    # Tensor("ReaderReadV2_19:1", shape=(), dtype=string)
    22     # 对图片进行解码
    23     if picture_format == ".bmp":
    24         new_img = tf.image.decode_bmp(img_bytes, channels=1)
    25     elif picture_format == ".jpg":
    26         new_img = tf.image.decode_jpeg(img_bytes, channels=3)
    27     else:
    28         pass
    29     # 重新设置图片的大小
    30     # new_img = tf.image.resize_images(new_img, input_image_shape)
    31     new_img = tf.reshape(new_img, input_image_shape)
    32     # 设置图片的数据类型
    33     new_img = tf.image.convert_image_dtype(new_img, tf.uint8)
    34 
    35     # return new_img
    36     return tf.train.batch([new_img], batch_size)
    37 
    38 
    39 def main():
    40     image_path = glob.glob(r'F:demoFaceRecognition人脸库ORL*.bmp')
    41     image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
    42     print(type(image_batch))
    43     # image_path = glob.glob(r'.*.jpg')
    44     # image_batch = read_imgs(image_path, ".jpg", (313, 500, 3), 1)
    45 
    46     sess = tf.Session()
    47     sess.run(tf.local_variables_initializer())
    48     tf.train.start_queue_runners(sess=sess)
    49 
    50     image_batch = sess.run(image_batch)
    51     print(type(image_batch))    # <class 'numpy.ndarray'>
    52 
    53     for i in range(image_batch.__len__()):
    54         cv.imshow("win_"+str(i), image_batch[i])
    55     cv.waitKey()
    56     cv.destroyAllWindows()
    57 
    58 def start():
    59     image_path = glob.glob(r'F:demoFaceRecognition人脸库ORL*.bmp')
    60     image_batch = read_imgs(image_path, ".bmp", (112, 92, 1), 5)
    61     print(type(image_batch))    # <class 'tensorflow.python.framework.ops.Tensor'>
    62 
    63 
    64     with tf.Session() as sess:
    65         sess.run(tf.local_variables_initializer())
    66         coord = tf.train.Coordinator()      # 线程的协调器
    67         threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
    68         image_batch = sess.run(image_batch)
    69         print(type(image_batch))    # <class 'numpy.ndarray'>
    70 
    71         for i in range(image_batch.__len__()):
    72             cv.imshow("win_"+str(i), image_batch[i])
    73         cv.waitKey()
    74         cv.destroyAllWindows()
    75 
    76         # 若使用 with 方式打开 Session,且没加如下2行语句,则会出错
    77         # ERROR:tensorflow:Exception in QueueRunner: Enqueue operation was cancelled;
    78         # 原因:文件队列线程还处于工作状态(队列中还有图片数据),而加载完batch_size张图片会话就会自动关闭,同时关闭文件队列线程
    79         coord.request_stop()
    80         coord.join(threads)
    81 
    82 
    83 if __name__ == "__main__":
    84     # main()
    85     start()
    从本地批量的读取图片案例

    案列6:TFRecord文件打包与读取

     1 def write_TFRecord(filename, data, labels, is_shuffler=True):
     2     """
     3     将数据打包成TFRecord格式
     4     :param filename: 打包后路径名,默认在工程目录下创建该文件;String
     5     :param data: 需要打包的文件路径名;list
     6     :param labels: 对应文件的标签;list
     7     :param is_shuffler:是否随机初始化打包后的数据,默认:True;Bool
     8     :return: None
     9     """
    10     im_data = list(data)
    11     im_labels = list(labels)
    12 
    13     index = [i for i in range(im_data.__len__())]
    14     if is_shuffler:
    15         np.random.shuffle(index)
    16 
    17     # 创建写入器,然后使用该对象写入样本example
    18     writer = tf.python_io.TFRecordWriter(filename)
    19     for i in range(im_data.__len__()):
    20         im_d = im_data[index[i]]    # im_d:存放着第index[i]张图片的路径信息
    21         im_l = im_labels[index[i]]  # im_l:存放着对应图片的标签信息
    22 
    23         # # 获取当前的图片数据  方式一:
    24         # data = cv2.imread(im_d)
    25         # # 创建样本
    26         # ex = tf.train.Example(
    27         #     features=tf.train.Features(
    28         #         feature={
    29         #             "image": tf.train.Feature(
    30         #                 bytes_list=tf.train.BytesList(
    31         #                     value=[data.tobytes()])), # 需要打包成bytes类型
    32         #             "label": tf.train.Feature(
    33         #                 int64_list=tf.train.Int64List(
    34         #                     value=[im_l])),
    35         #         }
    36         #     )
    37         # )
    38         # 获取当前的图片数据  方式二:相对于方式一,打包文件占用空间小了一半多
    39         data = tf.gfile.FastGFile(im_d, "rb").read()
    40         ex = tf.train.Example(
    41             features=tf.train.Features(
    42                 feature={
    43                     "image": tf.train.Feature(
    44                         bytes_list=tf.train.BytesList(
    45                             value=[data])), # 此时的data已经是bytes类型
    46                     "label": tf.train.Feature(
    47                         int64_list=tf.train.Int64List(
    48                             value=[im_l])),
    49                 }
    50             )
    51         )
    52 
    53         # 写入将序列化之后的样本
    54         writer.write(ex.SerializeToString())
    55     # 关闭写入器
    56     writer.close()
    TFRecord文件打包案列
     1 import tensorflow as tf
     2 import cv2
     3 
     4 def read_TFRecord(file_list, batch_size=10):
     5     """
     6     读取TFRecord文件
     7     :param file_list: 存放TFRecord的文件名,List
     8     :param batch_size: 每次读取图片的数量
     9     :return: 解析后图片及对应的标签
    10     """
    11     file_queue = tf.train.string_input_producer(file_list, num_epochs=None, shuffle=True)
    12     reader = tf.TFRecordReader()
    13     _, ex = reader.read(file_queue)
    14     batch = tf.train.shuffle_batch([ex], batch_size, capacity=batch_size * 10, min_after_dequeue=batch_size * 5)
    15 
    16     feature = {
    17         'image': tf.FixedLenFeature([], tf.string),
    18         'label': tf.FixedLenFeature([], tf.int64)
    19     }
    20     example = tf.parse_example(batch, features=feature)
    21 
    22     images = tf.decode_raw(example['image'], tf.uint8)
    23     images = tf.reshape(images, [-1, 32, 32, 3])
    24 
    25     return images, example['label']
    26 
    27 
    28 
    29 def main():
    30     # filelist = ['data/train.tfrecord']
    31     filelist = ['data/test.tfrecord']
    32     images, labels = read_TFRecord(filelist, 2)
    33     with tf.Session() as sess:
    34         sess.run(tf.local_variables_initializer())
    35         coord = tf.train.Coordinator()
    36         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    37 
    38         try:
    39             while not coord.should_stop():
    40                 for i in range(1):
    41                     image_bth, _ = sess.run([images, labels])
    42                     print(_)
    43 
    44                     cv2.imshow("image_0", image_bth[0])
    45                     cv2.imshow("image_1", image_bth[1])
    46                 break
    47         except tf.errors.OutOfRangeError:
    48             print('read done')
    49         finally:
    50             coord.request_stop()
    51         coord.join(threads)
    52         cv2.waitKey(0)
    53         cv2.destroyAllWindows()
    54 
    55 if __name__ == "__main__":
    56     main()
    TFReord文件的读取案列
  • 相关阅读:
    js 对象数组 排序
    sql 时间条件查询
    idea和Pycharm 等系列产品激活激活方法和激活码 100 年
    开源协议简介
    面试题
    VIM|基础命令
    git|基础命令
    VIM|复制
    lua|基础教程
    Printf格式输出详解
  • 原文地址:https://www.cnblogs.com/nbk-zyc/p/13159986.html
Copyright © 2020-2023  润新知