• Keras函数式API介绍


    参考文献:Géron, Aurélien. Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow: Concepts, Tools, and Techniques to Build Intelligent Systems. O'Reilly Media, 2019.

    Keras的Sequential顺序模型可以快速搭建简易的神经网络,同时Keras也提供函数式API(Functional API)用于定制各种不同类型的网格结构。

    Concatenate

    在搭建Wide & Deep神经网格的时候,需要进行层融合(concatenation)。图中的融合层将Input层和hidden layer最后一层相加在一起。

    input_ = keras.layers.Input(shape=X_train.shape[1:])
    hidden1 = keras.layers.Dense(30, activation="relu")(input_)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_, hidden2])
    output = keras.layers.Dense(1)(concat)
    model = keras.models.Model(inputs=[input_], outputs=[output])
    

    Multi-inputs

    可以将输入特征先分成多组(可以有重叠部分),让它们分别通过神经网络中的不同路径。

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="output")(concat)
    model = keras.models.Model(inputs=[input_A, input_B], outputs=[output])
    

    Multi-outputs

    input_A = keras.layers.Input(shape=[5], name="wide_input")
    input_B = keras.layers.Input(shape=[6], name="deep_input")
    hidden1 = keras.layers.Dense(30, activation="relu")(input_B)
    hidden2 = keras.layers.Dense(30, activation="relu")(hidden1)
    concat = keras.layers.concatenate([input_A, hidden2])
    output = keras.layers.Dense(1, name="main_output")(concat)
    aux_output = keras.layers.Dense(1, name="aux_output")(hidden2)
    model = keras.models.Model(inputs=[input_A, input_B],
                               outputs=[output, aux_output])
    

    每个output可以单独设置损失函数

    model.compile(loss=[“mse”,”mse”], loss_weights=[0.9, 0.1], optimizer=“sgd”)
    

    如果不设置的话,Keras默认使用相同的损失函数。训练中,Keras会单独计算两个损失函数,相加一起得到作为最后的损失值。

  • 相关阅读:
    微信小程序踩坑(二)——微信小程序recorderManager和innerAudioContext相关
    log4j:WARN Please initialize the log4j system properly解决办法
    如何跳过登录验证码
    Fiddler Mock长度变化的response不成功
    解决Chrome浏览器访问https提示“您的连接不是私密连接”的问题
    Fiddler抓不到https的解决办法
    互联网项目流程
    测试人员参与线下问题处理须知
    Mac下安装证书fiddlerRoot.cer
    MySQL存储过程中实现执行动态SQL语句
  • 原文地址:https://www.cnblogs.com/yaos/p/14014172.html
Copyright © 2020-2023  润新知