• tensorflow2.0 使用fit实现复杂自定义loss函数


    import tensorflow as tf
    from tensorflow.python.keras import backend as K
    from tensorflow.python.keras import layers as KL
    from tensorflow.python.keras import models as KM
    import numpy as np
    
    class ComplicatedLoss(KL.Layer):
        def __init__(self, **kwargs):
            super(ComplicatedLoss, self).__init__(**kwargs)
        def call(self, inputs, **kwargs):
                # 父类KL.Layer的call方法明确要求inputs为一个tensor,或者包含多个tensor的列表/元组        这里为多个tensor组成的列表        """        # 解包入参
             y_true, y_weight, y_pred = inputs        # 复杂的损失函数
             bce_loss = K.binary_crossentropy(y_true, y_pred)
             wbce_loss = K.mean(bce_loss * y_weight)        # 把自定义的loss添加进层使其生效
             self.add_loss(wbce_loss, inputs=True)        # 将每个loss加入metric方便在KERAS的进度条上实时追踪
             self.add_metric(wbce_loss, aggregation="mean", name="wbce_loss")
             self.add_metric(bce_loss, aggregation="mean", name="bce_loss")
             return wbce_loss
    
    def my_model():
    # input layers
        input_img = KL.Input([32, 32, 3], name="img1")
        input_lbl = KL.Input([32, 32, 1], name="lbl")
        input_weight = KL.Input([32, 32, 1], name="weight")
        predict = KL.Conv2D(2, [1, 1], padding="same")(input_img)
        my_loss = ComplicatedLoss()([input_lbl, input_weight, predict])
        model = KM.Model(inputs=[input_img, input_lbl, input_weight], outputs=[predict, my_loss])
        model.compile(optimizer="adam")
        return model
    
    def get_fake_dataset():
        def map_fn(img, lbl, weight):
            inputs = {"img1": img, "lbl": lbl, "weight": weight}
            # inputs = [img, lbl, weight]
            targets = {}
            return inputs, targets
        fake_imgs = np.ones([100, 32, 32, 3])
        fake_lbls = np.ones([100, 32, 32, 1])
        fake_weights = np.zeros([100, 32, 32, 1])
        fake_dataset = tf.data.Dataset.from_tensor_slices((fake_imgs, fake_lbls, fake_weights)).map(map_fn).batch(10)
        return fake_dataset
    if __name__ == '__main__':
        model = my_model()
        my_dataset = get_fake_dataset()
        model.fit(my_dataset,epochs=2)
  • 相关阅读:
    生成淘宝在线旺旺页面入口
    IE6下的fixed实现
    HTML和XHTML的区别
    各大浏览器内核介绍(Rendering Engine)
    [导入]从架构设计到系统实施——基于.NET 3.0的全新企业应用系列课程(5):设计基于WPF的客户端.zip(6.98 MB)
    Java核心类库——java中的包装类
    Java语言基础——运算符
    Java核心类库——集合的迭代(遍历) Iterator接口
    Java语言基础——循环控制语句while for
    Java语言基础——方法
  • 原文地址:https://www.cnblogs.com/cxhzy/p/16311367.html
Copyright © 2020-2023  润新知