• Keras用动态数据生成器(DataGenerator)和fitgenerator动态训练模型


     

     最近做Kaggle的图像分类比赛:RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection/overview)以及阅读Yolov3

    源码的时候接触到深度学习训练时一个有趣的技巧,那就是构造生成器generator 并且用Keras 的fit_generator来批量生成数据,释放内存,该方法适合于大规模数据集的训练。一个DataGenerator是keras的Sequence类的继承类,一般要包含__len__,__getitem__, on_epoch_end等方法,例如下面的批量图片数据生成器:

    class DataGenerator(keras.utils.Sequence):
          
          
          def __init__(self, list_IDs, labels, batch_size=1, img_size=(512, 512), 
                       img_dir, *args, **kwargs):
    
             """
                self.list_IDs:存放所有需要训练的图片文件名的列表。
                self.labels:记录图片标注的分类信息的pandas.DataFrame数据类型,已经预先给定。
                self.batch_size:每次批量生成,训练的样本大小。
                self.img_size:训练的图片尺寸。
                self.img_dir:图片在电脑中存放的路径。
          
          
             """
    
              
              self.list_IDs = list_IDs
              self.labels = labels
              self.batch_size = batch_size
              self.img_size = img_size
              self.img_dir = img_dir
              self.on_epoch_end()
    
          def __len__(self):
              
              """
                 返回生成器的长度,也就是总共分批生成数据的次数。
                 
              """
              return int(ceil(len(self.list_IDs) / self.batch_size))
    
         def __getitem__(self, index):
             
             """
                该函数返回每次我们需要的经过处理的数据。
             """
             
             indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
             list_IDs_temp = [self.list_IDs[k] for k in indices]
             X, Y = self.__data_generation(list_IDs_temp)
             return X, Y
    
         def on_epoch_end(self):
             
             """
                该函数将在训练时每一个epoch结束的时候自动执行,在这里是随机打乱索引次序以方便下一batch运行。
    
             """
             self.indices = np.arange(len(self.list_IDs))
             np.random.shuffle(self.indices)
    
         def __data_generation(self, list_IDs_temp):
    
            """
               给定文件名,生成数据。
            """
            X = np.empty((self.batch_size, *self.img_size, 1))
            Y = np.empty((self.batch_size, 6), dtype=np.float32)
    
           for i, ID in enumerate(list_IDs_temp):
           X[i,] = mpimg.imread(self.img_dir+ID+".png")
           Y[i,] = self.labels.loc[ID].values
    
           return X, Y

    有了这个生成器,我们就可以用fit_generator 方法进行训练,格式套路如下:

    model.fit_generator(generator,

    steps_per_epoch=...,

    epochs=...,

    verbose=...,

    callbacks=...,

    validation_data=...,

    validation_steps=...,

    validation_freq=...,

    class_weight=None=...,

    max_queue_size=...

    workers=...,

    use_multiprocessing=...,

    )

    除此以外我们还可以搞批量预测:

    model.predict_generator()

  • 相关阅读:
    基于lua语言实现面向对象编程
    一.Linux常用命令
    获取线程名称、设置线程名称、获取当前所有线程
    关系型数据库和非关系数据库区别
    Java基础类型之间的转换
    初始化 List 的几种方法
    谷歌浏览器打不开网页,但Opera可以打开网页
    遍历List和Map的几种方法
    java对数组进行排序
    MySQL实现事务隔离的原理:MVCC
  • 原文地址:https://www.cnblogs.com/szqfreiburger/p/11621261.html
Copyright © 2020-2023  润新知