• (2)tf.data.Dataset喂数据给模型


    一般而言把数据喂给模型的方式有三种:

    1.建立placeholder,然后使用feed_dict将数据feed进placeholder进行使用。使用这种方法十分灵活,可以一下子将所有数据读入内存,然后分batch进行feed;也可以建立一个Python的generator,一个batch一个batch的将数据读入,并将其feed进placeholder。这种方法很直观,用起来也比较方便灵活,但是这种方法的效率较低,难以满足高速计算的需求。

     

    2.使用TensorFlow的QueueRunner,通过一系列的Tensor操作,将磁盘上的数据分批次读入并送入模型进行使用。这种方法效率很高,但因为其牵涉到Tensor操作,不够直观,也不方便调试,所有有时候会显得比较困难。使用这种方法时,常用的一些操作包括tf.TextLineReader,tf.FixedLengthRecordReader以及tf.decode_raw等等。如果需要循环,条件操作,还需要使用TensorFlow的tf.while_loop,tf.case等操作,更是难上加难。

     1 import tensorflow as tf
     2 filename_queue=tf.train.string_input_producer(["./data/all_c_dev.en"])
     3 
     4 reader=tf.TextLineReader()
     5 key,value=reader.read(filename_queue)
     6 
     7 
     8 with tf.Session() as sess:
     9     tf.train.start_queue_runners()
    10     for i in range(10):
    11         print(sess.run([key,value]))

    3.自1.x版本开始,逐步开发引入了tf.data.Dataset模块,使其数据读入的操作变得更为方便,而支持多线程(进程)的操作,也在效率上获得了一定程度的提高。

      a.一次将所有数据读入内存

     1 images = ...                                                 #图像数据images读入内存;
     2 labels = ...                                                 #对应的标签数据labels读入内存;
     3 data = tf.data.Dataset.from_tensor_slices((images, labels))  #使用读入内存的数据images、labels构建Dataset;
     4 data = data.batch(batch_size)                                #设置batchsize大小
     5 iterator=tf.data.Iterator.from_structure(data.output_types,
     6                         data.output_shapes)  #基于此前构建的Dataset的数据类型和结构,构建一个可重新初始化iterator
     7 init_op = iterator.make_initializer(data)                    #基于此前构建的Dataset构建一个iterator初始化op。
     8 with tf.Session()  as sess:                                  #展开会话
     9     sess.run(init_op)                                        #初始化iterator
    10     try:
    11         images, labels = iterator.get_next()                 #获取一个batchsize的数据
    12     except tf.errors.OutOfRangeError:                        #iterator中的元素取完之后,会抛出OutOfRangeError异常,TensorFlow没有对这个异常进行处理,我们需要对其进行捕捉和处理。
    13         sess.run(init_op)

      b.包装一个generator

    def gen():                                                            #定义一个生成器函数
        with  open('train.csv')  as f:
            lines = [line.strip().split(',')  for line in f.readlines()]
            index = 0
            while  True:
                image = cv2.imread(lines[index][0])
                image = cv2.resize(image, (224, 224))
                label = lines[index][1]
                yield  (image, label)
                index += 1
                if index == len(lines):
               index = 0
    
    batch_size = 2
    data = tf.data.Dataset.from_generator( gen,                           #指定通过gen构建Dataset
                           (tf.float32, tf.int32), #指定数据类型                        (tf.TensorShape([
    224, 224, 3]),tf.TensorShape([]))) #指定shape
                         
    data = data.batch(batch_size)                          #设置batchsize
    iter = data.make_one_shot_iterator()                                 #创建单次迭代器,是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。 
    with tf.Session()  as sess:
        images, labels = iter.get_next()

      c.使用tensor读取数据

    def _parse_function(filename, label):                  #定义解析数据的函数
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_jpeg(image_string, channels=3)
        image = tf.cast(image_decoded, tf.float32)
        image = tf.image.resize_images(image, [224, 224])
        return image, filename, label
    
    images = tf.constant(image_names)                    #转化为tensor
    labels = tf.constant(labels)
    images = tf.random_shuffle(images, seed=0)
    labels = tf.random_shuffle(labels, seed=0)
    data = tf.data.Dataset.from_tensor_slices((images, labels))    #利用tensor构建dataset
    data = data.map(_parse_function, num_parallel_calls=4)       #利用map函数处理tensor得到新的dataset,num_parallel_calls表示并行处理
    data
    = data.prefetch(buffer_size=batch_size * 10)          #prefetch可以充分利用时间,预准备 data = data.batch(batch_size)                      #设置batchsize iterator = tf.data.Iterator.from_structure(data.output_types,data.output_shapes) #构建iterator init_op = iterator.make_initializer(data)               #初始化 with tf.Session() as sess: sess.run(init_op) try: images, filenames, labels = iterator.get_next() except tf.errors.OutOfRangeError: sess.run(init_op)

    使用tf,data是一种管道pipeline机制,他有很多的特色,比如prefetch和map,能够充分利用cpu的时间,这篇博客介绍的很好。

     tf.data.Dataset.from_generator
  • 相关阅读:
    第三节:模板模式——在Spring框架应用
    第二节:模板模式——模板模式应用
    idea ---- intelij IDEA安装
    计算机基础 ---- 编码(er)
    preg_match一些问题
    php 两个值进行比较的问题
    php中in_array一些问题
    配置完php.ini中的扩展库后,重启apache出现错误1067
    基于Intel 174;E810 的OVS-DPDK VXLAN TUNNEL性能优化
    tc filter 工作模式:传统模式和 direct-action 模式
  • 原文地址:https://www.cnblogs.com/super-zheng/p/13215425.html
Copyright © 2020-2023  润新知