• 使用Keras训练大规模数据集


             官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求。但是在有些场合下,需要自己读取大规模数据以及对应标签,下面提供一种方法。

    步骤0:导入相关

    import random
    import numpy as np
    from keras.preprocessing.image import load_img,img_to_array
    from keras.preprocessing.image import ImageDataGenerator
    from keras.models import Model

    步骤1:准备数据

    #训练集样本路径
    train_X = ["train/cat_1.jpg",
               "train/cat_2.jpg",
               "train/cat_3.jpg",
               "train/dog_1.jpg",
               "train/dog_2.jpg",
               "train/dog_3.jpg"]
    #验证集样本路径
    val_X =   ["val/cat_1.jpg",
               "val/cat_2.jpg",
               "val/cat_3.jpg",
               "val/dog_1.jpg",
               "val/dog_2.jpg",
               "val/dog_3.jpg"]
    
    # 根据图片路径获取图片标签
    def get_img_label(img_paths):
        img_labels = []
        
        for img_path in img_paths:     
            animal = img_path.split("/")[-1].split('_')[0]
            if animal=='cat':
                img_labels.append(0)
            else:
                img_labels.append(1)
            
        return img_labels 
    
    # 读取图片
    def load_batch_image(img_path, train_set = True, target_size=(224, 224)):
        im = load_img(img_path, target_size=target_size)
        if train_set:
            return img_to_array(im) #converts image to numpy array
        else:
            return img_to_array(im)/255.0
    # 建立一个数据迭代器
    def GET_DATASET_SHUFFLE(X_samples, batch_size, train_set = True):
        random.shuffle(X_samples)
            
        batch_num = int(len(X_samples) / batch_size)
        max_len = batch_num * batch_size
        X_samples = np.array(X_samples[:max_len])
        y_samples = get_img_label(X_samples)
        print(X_samples.shape)
         
        X_batches = np.split(X_samples, batch_num)
        y_batches = np.split(y_samples, batch_num)
    
        for i in range(len(X_batches)):
            if train_set:
                x = np.array(list(map(load_batch_image, X_batches[i], [True for _ in range(batch_size)])))
            else:
                x = np.array(list(map(load_batch_image, X_batches[i], [False for _ in range(batch_size)])))
            #print(x.shape)
            y = np.array(y_batches[i])
            yield x,y

    步骤2:对训练数据进行数据增强处理

    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
        rotation_range=10,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)        
            

    步骤3:定义模型

    model = Model(...)

    步骤4:模型训练

    n_epoch = 12
    batch_size = 16
    for e in range(n_epoch):
        print("epoch", e)
        batch_num = 0
        loss_sum=np.array([0.0,0.0])
        for X_train, y_train in GET_DATASET_SHUFFLE(train_X, batch_size, True): # chunks of 100 images 
            for X_batch, y_batch in train_datagen.flow(X_train, y_train, batch_size=batch_size): # chunks of 32 samples
                loss = model.train_on_batch(X_batch, y_batch)
                loss_sum += loss 
                batch_num += 1
                break #手动break
            if batch_num%200==0:
                print("epoch %s, batch %s: train_loss = %.4f, train_acc = %.4f"%(e, batch_num, loss_sum[0]/200, loss_sum[1]/200))
                loss_sum=np.array([0.0,0.0])
        res = model.evaluate_generator(GET_DATASET_SHUFFLE(val_X, batch_size, False),int(len(val_X)/batch_size))
        print("val_loss = %.4f, val_acc = %.4f: "%( res[0], res[1]))
    
        model.save("weight.h5")

    另外,如果在训练的时候不需要做数据增强处理,那么训练就更加简单了,如下:

    model.fit_generator(
      GET_DATASET_SHUFFLE(train_X, batch_size, True),
      epochs=10,
      steps_per_epoch=int(len(train_X)/batch_size))

    参考文献:

    Training on Large Scale Image Datasets with Keras

  • 相关阅读:
    jQuery选择器大全(48个代码片段+21幅图演示)
    抽象和模型
    什么叫做精通,我来给大家解释一下
    设置浏览器固定大小的固定位置的方法
    selenium对浏览器属性操作的方法
    selenium 最大化浏览器是解决浏览器和驱动不匹配的方法如下
    java selenium手动最大化chrome浏览器的方法
    java selenium启动火狐浏览器报错:Cannot find firefox binary in PATH. Make sure firefox is installed. OS appears to be: VISTA Build info: version: '3.8.1', revision: '6e95a6684b', time: '2017-12-01T19:05:14.666Z
    mygenerator().next() AttributeError: 'generator' object has no attribute 'next'
    马的遍历 搜索
  • 原文地址:https://www.cnblogs.com/hejunlin1992/p/9371078.html
Copyright © 2020-2023  润新知