• 采用tfrecord形式读写训练数据


    tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。尤其在面对海量数据时,使用常用的内存读取方式变得不切实际,tfrecored方式为我们带来了更大的便捷,同时还可以配合shuffe大大提高model的train效率。

    示例代def convert_tfrecord(data, label):

    """保存为tfrecord形式
        :param data:
        :param label:
        :return:
        """
        record_path = './resources/train.tfrecord'
        # 调用example和features函数将数据格式化保存起来
        cnt = 0
        writer = tf.python_io.TFRecordWriter(record_path)
        for d, s, l in zip(data[0], data[1], label):
            if cnt % 100 == 0:
                print('write example {}'.format(cnt))
            cnt += 1
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=d)),
                        'score': tf.train.Feature(float_list=tf.train.FloatList(value=s)),
                        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[l]))
                    }
                )
            )
    
            writer.write(example.SerializeToString())
        writer.close()
        print('写入ok')
    
        # 读取,batch 取
        filename_queue = tf.train.string_input_producer([record_path],)
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
    features
    = tf.io.parse_single_example(serialized_example, features={ 'sample': tf.io.FixedLenFeature([9], tf.int64), 'score': tf.io.FixedLenFeature([9], tf.float32), 'label': tf.io.FixedLenFeature([1], tf.int64), }) is_batch = True if is_batch: batch_size = 3 min_after_dequeue = 10 capacity = min_after_dequeue + 3 * batch_size samples, scores, labels = tf.train.shuffle_batch([features['sample'], features['score'], features['label']], batch_size=batch_size, num_threads=3, capacity=capacity, min_after_dequeue=min_after_dequeue) with tf.compat.v1.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(1000): # 从会话中取出数据 sample, score, label = sess.run([samples, scores, labels]) print(sample) print(score) print('###########') coord.request_stop() coord.join(threads) print('ok')
  • 相关阅读:
    HTML 标题
    HTML 属性
    点云配准的端到端深度神经网络:ICCV2019论文解读
    人脸真伪验证与识别:ICCV2019论文解析
    人体姿态和形状估计的视频推理:CVPR2020论文解析
    FPGA最全科普总结
    深度人脸识别:CVPR2020论文要点
    视频教学动作修饰语:CVPR2020论文解析
    分层条件关系网络在视频问答VideoQA中的应用:CVPR2020论文解析
    慢镜头变焦:视频超分辨率:CVPR2020论文解析
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13789061.html
Copyright © 2020-2023  润新知