• tf.keras自定义损失函数


    自定义损失函数

    In statistics, the Huber loss is a loss function used in robust regression, that is less sensitive to outliers in data than the squared error loss. A variant for classification is also sometimes used.

    def huber_fn(y_true, y_pred):
        error = y_true - y_pred
        is_small_error = tf.abs(error) < 1
        squared_loss = tf.square(error) / 2
        linear_loss  = tf.abs(error) - 0.5
        return tf.where(is_small_error, squared_loss, linear_loss)
    

    注意,自定义损失函数的返回值是一个向量而不是损失平均值,每个元素对应一个实例。这样的好处是Keras可以通过class_weightsample_weight调整权重。

    huber_fn(y_valid, y_pred)
    <tf.Tensor: id=4894, shape=(3870, 1), dtype=float64, numpy=
    array([[0.10571115],
           [0.03953311],
           [0.02417886],
           ...,
           [0.00039475],
           [0.00245003],
           [0.12238744]])>
    

    导入损失函数

    model = keras.models.load_model("my_model_with_a_custom_loss.h5",
                                    custom_objects={"huber_fn": huber_fn})
    

    带参数的自定义损失函数

    def create_huber(threshold=1.0):
        def huber_fn(y_true, y_pred):
            error = y_true - y_pred
            is_small_error = tf.abs(error) < threshold
            squared_loss = tf.square(error) / 2
            linear_loss  = threshold * tf.abs(error) - threshold**2 / 2
            return tf.where(is_small_error, squared_loss, linear_loss)
        return huber_fn
    
    model.compile(loss=create_huber(2.0), optimizer="nadam", metrics=["mae"])
    

    导入模型的时候注意

    model = keras.models.load_model("my_model_with_a_custom_loss_threshold_2.h5",
                                    custom_objects={"huber_fn": create_huber(2.0)})
    

    导入的是带有参数的create_huber(2.0),而不是create_huber。如果想要保留参数设置,必须自定义

  • 相关阅读:
    HDU 1828 Picture (线段树:扫描线周长)
    [USACO18OPEN] Multiplayer Moo (并查集+维护并查集技巧)
    NOIP2016 天天爱跑步 (树上差分+dfs)
    NOIP2013 华容道 (棋盘建图+spfa最短路)
    NOIP2015 运输计划 (树上差分+二分答案)
    NOIP2018 前流水账
    luogu P2331 [SCOI2005]最大子矩阵
    luogu P2627 修剪草坪
    CF101D Castle
    luogu P2473 [SCOI2008]奖励关
  • 原文地址:https://www.cnblogs.com/yaos/p/14014163.html
Copyright © 2020-2023  润新知