• 【Keras案例学习】 sklearn包装器使用示范(mnist_sklearn_wrapper)


    import numpy as np 
    from keras.datasets import mnist
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Activation, Flatten
    from keras.layers import Convolution2D, MaxPooling2D
    from keras.utils import np_utils
    # sklean接口的包装器KerasClassifier,作为sklearn的分类器接口
    from keras.wrappers.scikit_learn import KerasClassifier
    # 穷搜所有特定的参数值选出最好的模型参数
    from sklearn.grid_search import GridSearchCV
    
    Using TensorFlow backend.
    
    # 类别的数目
    nb_classes = 10
    # 输入图像的维度
    img_rows, img_cols = 28, 28
    # 读取数据
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    # 读取的数据不包含通道维,因此shape为(60000,28,28)
    # 为了保持和后端tensorflow的数据格式一致,将数据补上通道维
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    # 新的数据shape为 (60000,28,28,1), 1代表通道是1,也就是灰阶图片
    # 指明输入数据的大小,便于后面搭建网络的第一层传入该参数
    input_shape = (img_rows, img_cols, 1)
    # 数据类型改为float32,单精度浮点数
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    # 数据归一化(图像数据常用)
    X_train /= 255
    X_test /= 255
    # 将类别标签转换为one-hot编码
    y_train = np_utils.to_categorical(y_train, nb_classes)
    y_test = np_utils.to_categorical(y_test, nb_classes)
    
    # 定义配置卷积网络模型的函数
    def make_model(dense_layer_sizes, nb_filters, nb_conv, nb_pool):
        model = Sequential()
        model.add(Convolution2D(nb_filters, nb_conv, nb_conv,
                                border_mode='valid',
                                input_shape=input_shape))
        model.add(Activation('relu'))
        model.add(Convolution2D(nb_filters, nb_conv, nb_conv))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(nb_pool, nb_pool)))
        model.add(Dropout(0.25))
        
        model.add(Flatten())
        for layer_size in dense_layer_sizes:
            model.add(Dense(layer_size))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))
        model.add(Dense(nb_classes))
        model.add(Activation('softmax'))
        
        model.compile(loss='categorical_crossentropy',
                      optimizer='adadelta',
                      metrics=['accuracy'])
        return model 
    
    # 全连接层的备选参数列表
    dense_size_candidates = [[32], [64], [32, 32], [64, 64]]
    # 实现为Keras准备的sklearn分类器接口,创建一个分类器/评估器对象
    # 传入的参数为:
    # build_fn: callable function or class instance
    # **sk_params: model parameters & fitting parameters
    # 具体分析如下:
    # 传入的第一个参数(build_fn)为可回调的函数,该函数建立、配置并返回一个Keras model,
    # 该model将被用来训练/预测,这里我们传入了刚刚定义好的make_model函数
    # 传入的第二个参数(**sk_params)为关键字参数(关键字参数在函数内部自动组装为一个dict),
    # 既可以是模型的参数,也可以是训练的参数,合法的模型参数就是build_fn的参数,
    # 注意,像所有sklearn中其他的评估器(estimator)一样,build_fn应当为其参数提供默认值,
    # 以便我们在建立estimator的时候不用向sk_params传入任何值。
    # sk_params也可以接收用来调用fit/predict/predict_proba/score方法的参数,
    # 例如'nb_epoch','batch_size'
    # fit/predict/predict_proba/score方法的参数将会优先从传入fit/predict/predict_proba/score
    # 的字典参数中选择,其次才从传入sk_params的参数中选,最后才选择keras的Sequential模型的默认参数中选择
    # 这里我们传入了用于调用fit方法的batch_size参数
    my_classifier = KerasClassifier(make_model, batch_size=32)
    # 当调用sklearn的grid_search接口时,合法的可调参数就是传给sk_params的参数,包括训练参数
    # 换句话说,就是可以用grid_search来选择最佳的batch_size/nb_epoch,或者其他的一些模型参数
    
    # GridSearchCV类,穷搜(Exhaustive search)评估器中所有特定的参数,
    # 其重要的两类方法为fit和predict
    # 传入参数为评估器对象my_classifier,由每一个grid point实例化一个estimator
    # 参数网格param_grid,类型为dict,需要尝试的参数名称以及对应的数值
    # 评估方式scoring,这里采用对数损失来评估
    validator = GridSearchCV(my_classifier,
                             param_grid={'dense_layer_sizes': dense_size_candidates,
                                         'nb_epoch': [3,6],
                                         'nb_filters': [8],
                                         'nb_conv': [3],
                                         'nb_pool': [2]},
                             scoring='log_loss')
    # 根据各个参数值的不同组合在(X_train, y_train)上训练模型
    validator.fit(X_train, y_train)
    # 打印出训练过程中找到的最佳参数
    print('Yhe parameters of the best model are: ')
    print(validator.best_params_)
    
    Epoch 1/3
    40000/40000 [==============================] - 14s - loss: 0.8058 - acc: 0.7335    
    Epoch 2/3
    40000/40000 [==============================] - 10s - loss: 0.4620 - acc: 0.8545    
    Epoch 3/3
    40000/40000 [==============================] - 10s - loss: 0.3958 - acc: 0.8747    
    19776/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 11s - loss: 0.9589 - acc: 0.6804    
    Epoch 2/3
    40000/40000 [==============================] - 10s - loss: 0.5885 - acc: 0.8116    
    Epoch 3/3
    40000/40000 [==============================] - 10s - loss: 0.5021 - acc: 0.8429    
    19488/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 11s - loss: 0.9141 - acc: 0.6958    
    Epoch 2/3
    40000/40000 [==============================] - 10s - loss: 0.5716 - acc: 0.8136    
    Epoch 3/3
    40000/40000 [==============================] - 10s - loss: 0.4515 - acc: 0.8547    
    19584/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 11s - loss: 0.8968 - acc: 0.6983    
    Epoch 2/6
    40000/40000 [==============================] - 10s - loss: 0.5692 - acc: 0.8130    
    Epoch 3/6
    40000/40000 [==============================] - 10s - loss: 0.4600 - acc: 0.8494    
    Epoch 4/6
    40000/40000 [==============================] - 10s - loss: 0.4091 - acc: 0.8694    
    Epoch 5/6
    40000/40000 [==============================] - 10s - loss: 0.3717 - acc: 0.8790    
    Epoch 6/6
    40000/40000 [==============================] - 10s - loss: 0.3461 - acc: 0.8898    
    20000/20000 [==============================] - 1s     
    Epoch 1/6
    40000/40000 [==============================] - 11s - loss: 0.8089 - acc: 0.7310    
    Epoch 2/6
    40000/40000 [==============================] - 10s - loss: 0.4770 - acc: 0.8498    
    Epoch 3/6
    40000/40000 [==============================] - 10s - loss: 0.4086 - acc: 0.8704    
    Epoch 4/6
    40000/40000 [==============================] - 10s - loss: 0.3657 - acc: 0.8860    
    Epoch 5/6
    40000/40000 [==============================] - 10s - loss: 0.3383 - acc: 0.8938    
    Epoch 6/6
    40000/40000 [==============================] - 10s - loss: 0.3164 - acc: 0.9027    
    19520/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 11s - loss: 0.8393 - acc: 0.7214    
    Epoch 2/6
    40000/40000 [==============================] - 10s - loss: 0.5132 - acc: 0.8379    
    Epoch 3/6
    40000/40000 [==============================] - 10s - loss: 0.4331 - acc: 0.8635    
    Epoch 4/6
    40000/40000 [==============================] - 10s - loss: 0.3813 - acc: 0.8808    
    Epoch 5/6
    40000/40000 [==============================] - 10s - loss: 0.3530 - acc: 0.8902    
    Epoch 6/6
    40000/40000 [==============================] - 10s - loss: 0.3278 - acc: 0.8986    
    19936/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 11s - loss: 0.5975 - acc: 0.8099    
    Epoch 2/3
    40000/40000 [==============================] - 10s - loss: 0.3181 - acc: 0.9048    
    Epoch 3/3
    40000/40000 [==============================] - 10s - loss: 0.2673 - acc: 0.9199    
    19808/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 11s - loss: 0.6155 - acc: 0.8040    
    Epoch 2/3
    40000/40000 [==============================] - 11s - loss: 0.3500 - acc: 0.8951    
    Epoch 3/3
    40000/40000 [==============================] - 11s - loss: 0.2864 - acc: 0.9156    
    20000/20000 [==============================] - 1s     
    Epoch 1/3
    40000/40000 [==============================] - 11s - loss: 0.7519 - acc: 0.7560    
    Epoch 2/3
    40000/40000 [==============================] - 10s - loss: 0.4660 - acc: 0.8580    
    Epoch 3/3
    40000/40000 [==============================] - 10s - loss: 0.3553 - acc: 0.8936    
    19776/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 11s - loss: 0.5869 - acc: 0.8162    
    Epoch 2/6
    40000/40000 [==============================] - 11s - loss: 0.3279 - acc: 0.9014    
    Epoch 3/6
    40000/40000 [==============================] - 11s - loss: 0.2725 - acc: 0.9187    
    Epoch 4/6
    40000/40000 [==============================] - 11s - loss: 0.2366 - acc: 0.9291    
    Epoch 5/6
    40000/40000 [==============================] - 11s - loss: 0.2102 - acc: 0.9386    
    Epoch 6/6
    40000/40000 [==============================] - 16s - loss: 0.1954 - acc: 0.9423    
    19840/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 11s - loss: 0.5526 - acc: 0.8262    
    Epoch 2/6
    40000/40000 [==============================] - 11s - loss: 0.2903 - acc: 0.9142    
    Epoch 3/6
    40000/40000 [==============================] - 11s - loss: 0.2361 - acc: 0.9302    
    Epoch 4/6
    40000/40000 [==============================] - 11s - loss: 0.2064 - acc: 0.9396    
    Epoch 5/6
    40000/40000 [==============================] - 10s - loss: 0.1886 - acc: 0.9443    
    Epoch 6/6
    40000/40000 [==============================] - 10s - loss: 0.1755 - acc: 0.9496    
    19808/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 11s - loss: 0.7275 - acc: 0.7677    
    Epoch 2/6
    40000/40000 [==============================] - 10s - loss: 0.4141 - acc: 0.8772    
    Epoch 3/6
    40000/40000 [==============================] - 10s - loss: 0.3136 - acc: 0.9056    
    Epoch 4/6
    40000/40000 [==============================] - 10s - loss: 0.2651 - acc: 0.9210    
    Epoch 5/6
    40000/40000 [==============================] - 10s - loss: 0.2363 - acc: 0.9306    
    Epoch 6/6
    40000/40000 [==============================] - 10s - loss: 0.2092 - acc: 0.9380    
    19552/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 12s - loss: 0.7849 - acc: 0.7334    
    Epoch 2/3
    40000/40000 [==============================] - 11s - loss: 0.4506 - acc: 0.8587    
    Epoch 3/3
    40000/40000 [==============================] - 11s - loss: 0.3741 - acc: 0.8813    
    19872/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 12s - loss: 0.8744 - acc: 0.7068    
    Epoch 2/3
    40000/40000 [==============================] - 11s - loss: 0.5231 - acc: 0.8312    
    Epoch 3/3
    40000/40000 [==============================] - 11s - loss: 0.4305 - acc: 0.8635    
    19552/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 12s - loss: 0.7567 - acc: 0.7473    
    Epoch 2/3
    40000/40000 [==============================] - 11s - loss: 0.4200 - acc: 0.8685    
    Epoch 3/3
    40000/40000 [==============================] - 11s - loss: 0.3604 - acc: 0.8887    
    19712/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 12s - loss: 0.7111 - acc: 0.7676    
    Epoch 2/6
    40000/40000 [==============================] - 11s - loss: 0.4243 - acc: 0.8669    
    Epoch 3/6
    40000/40000 [==============================] - 11s - loss: 0.3638 - acc: 0.8873    
    Epoch 4/6
    40000/40000 [==============================] - 11s - loss: 0.3223 - acc: 0.8995    
    Epoch 5/6
    40000/40000 [==============================] - 11s - loss: 0.2994 - acc: 0.9073    
    Epoch 6/6
    40000/40000 [==============================] - 11s - loss: 0.2823 - acc: 0.9135    
    20000/20000 [==============================] - 2s     
    Epoch 1/6
    40000/40000 [==============================] - 12s - loss: 0.7588 - acc: 0.7513    
    Epoch 2/6
    40000/40000 [==============================] - 11s - loss: 0.4568 - acc: 0.8570    
    Epoch 3/6
    40000/40000 [==============================] - 12s - loss: 0.3757 - acc: 0.8819    
    Epoch 4/6
    40000/40000 [==============================] - 11s - loss: 0.3256 - acc: 0.8969    
    Epoch 5/6
    40000/40000 [==============================] - 11s - loss: 0.2996 - acc: 0.9060    
    Epoch 6/6
    40000/40000 [==============================] - 11s - loss: 0.2702 - acc: 0.9146    
    19904/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 12s - loss: 0.7798 - acc: 0.7464    
    Epoch 2/6
    40000/40000 [==============================] - 11s - loss: 0.4625 - acc: 0.8571    
    Epoch 3/6
    40000/40000 [==============================] - 11s - loss: 0.3869 - acc: 0.8814    
    Epoch 4/6
    40000/40000 [==============================] - 11s - loss: 0.3429 - acc: 0.8959    
    Epoch 5/6
    40000/40000 [==============================] - 11s - loss: 0.3143 - acc: 0.9035    
    Epoch 6/6
    40000/40000 [==============================] - 11s - loss: 0.2889 - acc: 0.9122    
    19840/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 13s - loss: 0.5828 - acc: 0.8161    
    Epoch 2/3
    40000/40000 [==============================] - 11s - loss: 0.3009 - acc: 0.9099    
    Epoch 3/3
    40000/40000 [==============================] - 11s - loss: 0.2393 - acc: 0.9291    
    19680/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 12s - loss: 0.5584 - acc: 0.8246    
    Epoch 2/3
    40000/40000 [==============================] - 12s - loss: 0.2862 - acc: 0.9152    
    Epoch 3/3
    40000/40000 [==============================] - 11s - loss: 0.2334 - acc: 0.9319    
    19488/20000 [============================>.] - ETA: 0sEpoch 1/3
    40000/40000 [==============================] - 13s - loss: 0.6253 - acc: 0.8020    
    Epoch 2/3
    40000/40000 [==============================] - 11s - loss: 0.3054 - acc: 0.9093    
    Epoch 3/3
    40000/40000 [==============================] - 12s - loss: 0.2463 - acc: 0.9278    
    19808/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 13s - loss: 0.5753 - acc: 0.8200    
    Epoch 2/6
    40000/40000 [==============================] - 12s - loss: 0.2827 - acc: 0.9170    
    Epoch 3/6
    40000/40000 [==============================] - 11s - loss: 0.2217 - acc: 0.9339    
    Epoch 4/6
    40000/40000 [==============================] - 11s - loss: 0.1863 - acc: 0.9455    
    Epoch 5/6
    40000/40000 [==============================] - 12s - loss: 0.1663 - acc: 0.9516    
    Epoch 6/6
    40000/40000 [==============================] - 12s - loss: 0.1535 - acc: 0.9550    
    19680/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 13s - loss: 0.5670 - acc: 0.8247    
    Epoch 2/6
    40000/40000 [==============================] - 12s - loss: 0.2728 - acc: 0.9204    
    Epoch 3/6
    40000/40000 [==============================] - 12s - loss: 0.2134 - acc: 0.9383    
    Epoch 4/6
    40000/40000 [==============================] - 12s - loss: 0.1890 - acc: 0.9459    
    Epoch 5/6
    40000/40000 [==============================] - 12s - loss: 0.1695 - acc: 0.9501    
    Epoch 6/6
    40000/40000 [==============================] - 12s - loss: 0.1570 - acc: 0.9535    
    19712/20000 [============================>.] - ETA: 0sEpoch 1/6
    40000/40000 [==============================] - 13s - loss: 0.6227 - acc: 0.7986    
    Epoch 2/6
    40000/40000 [==============================] - 12s - loss: 0.3322 - acc: 0.9007    
    Epoch 3/6
    40000/40000 [==============================] - 12s - loss: 0.2469 - acc: 0.9258    
    Epoch 4/6
    40000/40000 [==============================] - 12s - loss: 0.2029 - acc: 0.9409    
    Epoch 5/6
    40000/40000 [==============================] - 12s - loss: 0.1748 - acc: 0.9496    
    Epoch 6/6
    40000/40000 [==============================] - 12s - loss: 0.1558 - acc: 0.9542    
    19872/20000 [============================>.] - ETA: 0sEpoch 1/6
    60000/60000 [==============================] - 19s - loss: 0.4922 - acc: 0.8482    
    Epoch 2/6
    60000/60000 [==============================] - 24s - loss: 0.2342 - acc: 0.9318    
    Epoch 3/6
    60000/60000 [==============================] - 24s - loss: 0.1843 - acc: 0.9485    
    Epoch 4/6
    60000/60000 [==============================] - 25s - loss: 0.1556 - acc: 0.9549    
    Epoch 5/6
    60000/60000 [==============================] - 24s - loss: 0.1450 - acc: 0.9581    
    Epoch 6/6
    60000/60000 [==============================] - 25s - loss: 0.1312 - acc: 0.9624    
    Yhe parameters of the best model are: 
    {'nb_conv': 3, 'nb_epoch': 6, 'nb_pool': 2, 'dense_layer_sizes': [64, 64], 'nb_filters': 8}
    
    # validator.best_estimator_返回sklearn-warpped版本的最佳模型
    # validator.best_estimator_.model返回未包装的最佳模型
    best_model = validator.best_estimator_.model
    # 度量值的名称
    metric_names = best_model.metrics_names 
    # metric_names = ['loss', 'acc']
    # 度量值的数值
    metric_values = best_model.evaluate(X_test, y_test)
    # metric_values = [0.0550, 0.9826]
    print()
    for metric, value in zip(metric_names, metric_values):
        print(metric, ': ', value)
    
     9984/10000 [============================>.] - ETA: 0s
    loss :  0.0550105490824
    acc :  0.9826
    
    
    
    
    
    
  • 相关阅读:
    Markdown学习笔记
    带下划线点域名解析失败
    前端工程师学习之路
    Java 调用 WebService 客户端代码 含通过代理调用
    MySQL 日期函数 时间函数 总结 (MySQL 5_X)
    Apache、Tomcat整合环境搭建
    201671010142 <java程序设计>初次学习心得与感悟
    201671010142 Java基本程序设计结构学习的感悟
    201671010142.第五章的学习总结
    201671010142 继承定义与使用 感悟与总结
  • 原文地址:https://www.cnblogs.com/surfzjy/p/6445404.html
Copyright © 2020-2023  润新知