• TensorFlow TFRecord封装不定长的序列数据(文本)


    TensorFlow TFRecord封装不定长的序列数据(文本)

    在实验室环境中,通常数据都是一次性导入内存的,然后使用手工写的数据mini-batch函数来切分数据,但是这样的做法在海量数据下显得不太合适:1)内存太小不足以将全部数据一次性导入;2)数据切分和模型训练之间无法异步,训练过程易受到数据mini-batch切分耗时阻塞。3)无法部署到分布式环境中去

    下面的代码片段采取了TFrecord的数据文件格式,并且支持不定长序列,支持动态填充,基本可以满足处理NLP等具有序列要求的任务需求。

    import tensorflow as tf
    
    
    def generate_tfrecords(tfrecod_filename):
        sequences = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
                     [1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]
        labels = [1, 2, 3, 4, 5, 1, 2, 3, 4]
    
        with tf.python_io.TFRecordWriter(tfrecod_filename) as f:
            for feature, label in zip(sequences, labels):
                frame_feature = list(map(lambda id: tf.train.Feature(int64_list=tf.train.Int64List(value=[id])), feature))
    
                example = tf.train.SequenceExample(
                    context=tf.train.Features(feature={
                        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}),
                    feature_lists=tf.train.FeatureLists(feature_list={
                        'sequence': tf.train.FeatureList(feature=frame_feature)
                    })
                )
                f.write(example.SerializeToString())
    
    
    
    def single_example_parser(serialized_example):
        context_features = {
            "label": tf.FixedLenFeature([], dtype=tf.int64)
        }
        sequence_features = {
            "sequence": tf.FixedLenSequenceFeature([], dtype=tf.int64)
        }
    
        context_parsed, sequence_parsed = tf.parse_single_sequence_example(
            serialized=serialized_example,
            context_features=context_features,
            sequence_features=sequence_features
        )
    
        labels = context_parsed['label']
        sequences = sequence_parsed['sequence']
        return sequences, labels
    
    def batched_data(tfrecord_filename, single_example_parser, batch_size, padded_shapes, num_epochs=1, buffer_size=1000):
        dataset = tf.data.TFRecordDataset(tfrecord_filename)
            .map(single_example_parser)
            .padded_batch(batch_size, padded_shapes=padded_shapes)
            .shuffle(buffer_size)
            .repeat(num_epochs)
        return dataset.make_one_shot_iterator().get_next()
    
    
    if __name__ == "__main__":
        def model(features, labels):
            return features, labels
    
    
        tfrecord_filename = 'test.tfrecord'
        generate_tfrecords(tfrecord_filename)
        out = model(*batched_data(tfrecord_filename, single_example_parser, 2, ([None], [])))
    
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            try:
                while not coord.should_stop():
                    print(sess.run(out))
    
            except tf.errors.OutOfRangeError:
                print("done training")
            finally:
                coord.request_stop()
            coord.join(threads)
    
    
  • 相关阅读:
    Windows Server 2008关闭internet explorer增强的安全配置
    【转载并整理】mysql分页方法
    Mysql:MyIsam和InnoDB的区别
    【转载】web网站css,js更新后客户浏览器缓存问题,需要刷新才能正常展示的解决办法
    【转载】java前后端 动静分离,JavaWeb项目为什么我们要放弃jsp?
    Redis命令汇总
    Redis介绍及安装
    【转载】Spring Cache介绍
    简单示例:Spring4 整合 单个Redis服务
    【转载整理】Hibernater的锁机制
  • 原文地址:https://www.cnblogs.com/crackpotisback/p/9013712.html
Copyright © 2020-2023  润新知