• Tensorflow datasets.shuffle repeat batch方法


    机器学习中数据读取是很重要的一个环节,TensorFlow也提供了很多实用的方法,为了避免以后时间久了又忘记,所以写下笔记以备日后查看。

    最普通的正常情况

    首先我们看看最普通的情况:

    # 创建0-10的数据集,每个batch取个数。
    dataset = tf.data.Dataset.range(10).batch(6)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
        for i in range(2):
            value = sess.run(next_element)
            print(value)
    

    输出结果

    [0 1 2 3 4 5]
    [6 7 8 9]
    

    由结果我们可以知道TensorFlow能很好地帮我们自动处理最后一个batch的数据。

    datasets.batch(batch_size)与迭代次数的关系

    但是如果上面for循环次数超过2会怎么样呢?也就是说如果 循环次数*批数量 > 数据集数量 会怎么样?我们试试看:

    dataset = tf.data.Dataset.range(10).batch(6)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
        >>==for i in range(3):==<<
            value = sess.run(next_element)
            print(value)
    

    输出结果

    [0 1 2 3 4 5]
    [6 7 8 9]
    ---------------------------------------------------------------------------
    OutOfRangeError                           Traceback (most recent call last)
    D:Continuumanaconda3libsite-packages	ensorflowpythonclientsession.py in _do_call(self, fn, *args)
       1277     try:
       
      ...
      ...省略若干信息...
      ...
      
    OutOfRangeError (see above for traceback): End of sequence
    	 [[Node: IteratorGetNext_64 = IteratorGetNext[output_shapes=[[?]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_28)]]
    

    可以知道超过范围了,所以报错了。

    datasets.repeat()

    为了解决上述问题,repeat方法登场。还是直接看例子吧:

    dataset = tf.data.Dataset.range(10).batch(6)
    dataset = dataset.repeat(2)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
        for i in range(4):
            value = sess.run(next_element)
            print(value)
    

    输出结果

    [0 1 2 3 4 5]
    [6 7 8 9]
    [0 1 2 3 4 5]
    [6 7 8 9]
    

    可以知道repeat其实就是将数据集重复了指定次数,上面代码将数据集重复了2次,所以这次即使for循环次数是4也依旧能正常读取数据,并且都能完整把数据读取出来。同理,如果把for循环次数设置为大于4,那么也还是会报错,这么一来,我每次还得算repeat的次数,岂不是很心累?所以更简便的办法就是对repeat方法不设置重复次数,效果见如下:

    dataset = tf.data.Dataset.range(10).batch(6)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
        for i in range(6):
            value = sess.run(next_element)
            print(value)
    

    输出结果:

    [0 1 2 3 4 5]
    [6 7 8 9]
    [0 1 2 3 4 5]
    [6 7 8 9]
    [0 1 2 3 4 5]
    [6 7 8 9]
    

    此时无论for循环多少次都不怕啦~~

    datasets.shuffle(buffer_size)

    仔细看可以知道上面所有输出结果都是有序的,这在机器学习中用来训练模型是浪费资源且没有意义的,所以我们需要将数据打乱,这样每批次训练的时候所用到的数据集是不一样的,这样啊可以提高模型训练效果。

    另外shuffle前需要设置buffer_size:

    • 不设置会报错,
    • buffer_size=1:不打乱顺序,既保持原序
    • buffer_size越大,打乱程度越大,演示效果见如下代码:
    dataset = tf.data.Dataset.range(10).shuffle(2).batch(6)
    dataset = dataset.repeat(2)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
        for i in range(4):
            value = sess.run(next_element)
            print(value)
    

    输出结果:

    [1 0 2 4 3 5]
    [7 8 9 6]
    [1 2 3 4 0 6]
    [7 8 9 5]
    

    注意:shuffle的顺序很重要,一般建议是最开始执行shuffle操作,因为如果是先执行batch操作的话,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。不信你看:

    dataset = tf.data.Dataset.range(10).batch(6).shuffle(10)
    dataset = dataset.repeat(2)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()
    
    with tf.Session() as sess:
        for i in range(4):
            value = sess.run(next_element)
            print(value)
    

    输出结果:

    [0 1 2 3 4 5]
    [6 7 8 9]
    [0 1 2 3 4 5]
    [6 7 8 9]
    



    MARSGGBO原创





    2018-8-5



  • 相关阅读:
    laravel 使用 php artisan make:model到指定目录(controller同理)
    Mysql常见的优化策略
    laravel路由别名
    laravel whereNotIn where子查詢
    phpstorm界面不停的indexing,不停的闪烁
    Linux下Redis开机自启(Centos6)
    数据结构常用算法
    困惑的前置操作与后置操作
    SSH框架整合中Hibernate实现Dao层常用结构
    过滤器与拦截器区别
  • 原文地址:https://www.cnblogs.com/marsggbo/p/9603789.html
Copyright © 2020-2023  润新知