• 序贯模型


     

    模型搭建

    举一个最简单的MLP例子,这下面我们添加的都是全连接层

    from keras.models import Sequential

    from keras.layers import Dense, Activation

    model = Sequential()  #序贯模型

    model.add(Dense(units=64, input_dim=100))

    model.add(Activation("relu"))     

    model.add(Dense(units=10))

    model.add(Activation("softmax"))

    #或者使用一次性搭建的方式

    model = Sequential([Dense(32, units=784),Activation('relu'),Dense(10),Activation('softmax')])

    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])   #通过compile来编译模型

    from keras.optimizers import SGD  #定制损失函数。Keras里也封装好了很多优化器和损失函数

    model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01, momentum=0.9, nesterov=True))

    输入数据并训练

    注意:batch_size太大可能不能收敛到最低点,batch_size太小测试的准确率会剧烈震荡

    1)model.fit(x_train, y_train, epochs=5, batch_size=32)

    2)model.train_on_batch(x_batch, y_batch)  #自己定义batch训练

    3)如果你的数据量很大,你可能要用到fit_generator

    def generate_arrays_from_file(path):

        while 1:

            f = open(path)

            for line in f:

                x, y = process_line(line)

                img = load_images(x)

                yield (img, y)

            f.close()

    model.fit_generator(generate_arrays_from_file('/my_file.txt'), samples_per_epoch=10000, nb_epoch=10)

    测试集评估与预测

    在测试集上评估效果

    loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128)

    实际预测

    classes = model.predict(x_test, batch_size=128)

    优化器optimizer、损失函数loss、评估指标metrics

    # 多分类问题

    model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

    # 二分类问题

    model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

    # 回归问题

    model.compile(optimizer='rmsprop', loss='mse')

    # 自定义metrics

    import keras.backend as K

    def mean_pred(y_true, y_pred):

        return K.mean(y_pred)

    model.compile(optimizer='rmsprop',

                  loss='binary_crossentropy',

                  metrics=['accuracy', mean_pred])

     

  • 相关阅读:
    修改python的pip下载源
    MySQL_Sql_打怪升级_进阶篇_进阶12: DDL常见数据类型
    【Xshell】SFTP子系统申请已拒绝,请确保SSH连接的SFTP子系统设置有效
    MySQL_Sql_打怪升级_进阶篇_进阶11: DDL数据定义语言
    MySQL_Sql_打怪升级_进阶篇_进阶10: DML数据操纵语言
    MySQL_Sql_打怪升级_进阶篇_进阶9:联合查询
    MySQL_Sql_打怪升级_进阶篇_ 进阶8:分页查询
    MySQL_Sql_打怪升级_进阶篇_进阶7:子查询
    MySQL_Sql_打怪升级_进阶篇_进阶6:连接查询
    Ubuntu出现E: Failed to fetch
  • 原文地址:https://www.cnblogs.com/yongfuxue/p/10095895.html
Copyright © 2020-2023  润新知