• tensorflow基础【7】-loss function


     cross entropy

    交叉熵,tensorflow  对 cross entropy 进行了集成:

    1. 二分类和多分类公式集成,共用一个 API;

    p(x) 真实标签,q(x) 预测概率;

    2. 把 sigmoid 、softmax 等集成到 cross entropy 中;

    正常情况下,神经网络最后的输出需要通过 softmax 转换成概率,然后再套用公式计算交叉熵,tf 的集成 API 直接输入神经网络的输出即可 

    tf.nn.softmax_cross_entropy_with_logits

    集成了 softmax 和 cross entropy 的 API

    def softmax_cross_entropy_with_logits(
        _sentinel=None,  # pylint: disable=invalid-name
        labels=None,
        logits=None,
        dim=-1,
        name=None,
        axis=None)

    示例

    #our NN's output
    logits = tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
    #step1:do softmax
    y = tf.nn.softmax(logits)
    
    #true label
    y_= tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
    
    #step2:do cross_entropy
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    
    #do cross_entropy just one step
    cross_entropy2 = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))   # dont forget tf.reduce_sum()!!
    
    with tf.Session() as sess:
        softmax_value=sess.run(y)
        c_e = sess.run(cross_entropy)
        c_e2 = sess.run(cross_entropy2)
        print(softmax_value)
        print(c_e)      # 1.222818
        print(c_e2)     # 1.2228179

    可以看到手动计算 和 API 计算的结果是一样的

    tf.nn.sparse_softmax_cross_entropy_with_logits

    API 参数同上;

    sparse,稀疏编码,把类别进行稀疏编码,如共 3 个类别,样本属于第 2 个,则需要编码为 [0,1,0];    【对实际 label 的 sparse】

    集成了 稀疏编码、softmax 和交叉熵;

    # our NN's output
    logits = tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
    
    # true label
    # 注意这里标签必须是浮点数,不然在后面计算tf.multiply时就会因为类型不匹配tf_log的float32数据类型而出错
    y_= tf.constant([[0,0,1.0],[0,0,1.0],[0,0,1.0]])     # 这个是稀疏的标签
    
    # 手算交叉熵
    y = tf.nn.softmax(logits)
    tf_log = tf.log(y)
    pixel_wise_mult = tf.multiply(y_,tf_log)
    cross_entropy = -tf.reduce_sum(pixel_wise_mult)
    
    #将标签稠密化
    dense_y = tf.argmax(y_,1)       # [2 2 2]
    cross_entropy2_step1 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=dense_y,logits=logits)
    cross_entropy2_step2 = tf.reduce_sum(cross_entropy2_step1)
    
    with tf.Session() as sess:
        cross_entropy_value=sess.run(cross_entropy)
        sparse_cross_entropy2_step2_value=sess.run([cross_entropy2_step2])
        print(sess.run(dense_y))                        # [2 2 2]
        print("step4:cross_entropy result=
    %s
    "%(cross_entropy_value))                        # 1.222818
        print("Function(tf.reduce_sum) result=
    %s
    "%(sparse_cross_entropy2_step2_value))      # 1.2228179

    tf.nn.sigmoid_cross_entropy_with_logits

    API 参数同上;

    这个 API 适用于 一个样本有多个 label 的情况,如在目标检测中,一张图像上可能有 猫,可能有狗,输出的 label 可能为 [0,1,1,0];

    它的本质不是多分类,而是多个二分类;  

    def sigmoid(x):
        return 1.0/(1+np.exp(-x))
    
    labels = np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
    logits = np.array([[11.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
    y_pred = sigmoid(logits)
    prob_error1 = -labels * np.log(y_pred) - (1 - labels) * np.log(1 - y_pred)
    
    labels1 = np.array([[0.,1.,0.],[1.,1.,0.],[0.,0.,1.]])  # 不一定只属于一个类别
    logits1 = np.array([[1.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
    y_pred1 = sigmoid(logits1)
    prob_error11 = -labels1 * np.log(y_pred1) - (1 - labels1) * np.log(1 - y_pred1)
    
    with tf.Session() as sess:
        print(prob_error1)
        # [[1.67015613e-05 8.00033541e+00 7.00091147e+00]
        #  [1.00000454e+01 8.31528373e-07 3.04858735e+00]
        #  [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
        print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
        # [[1.67015613e-05 8.00033541e+00 7.00091147e+00]
        #  [1.00000454e+01 8.31528373e-07 3.04858735e+00]
        #  [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
        print(prob_error11)
        # [[1.31326169e+00 3.35406373e-04 7.00091147e+00]
        #  [4.53988992e-05 8.31528373e-07 3.04858735e+00]
        #  [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
        print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels1,logits=logits1)))
        ### 同上

    tf.nn.weighted_cross_entropy_with_logits

    它是 sigmoid_cross_entropy_with_logits 的扩展

    def weighted_cross_entropy_with_logits(labels=None,
                                           logits=None,
                                           pos_weight=None,
                                           name=None,
                                           targets=None):
      """Computes a weighted cross entropy.
    
          labels * -log(sigmoid(logits)) * pos_weight +
              (1 - labels) * -log(1 - sigmoid(logits))
      
        pos_weight: A coefficient to use on the positive examples
      """

    tf.losses.softmax_cross_entropy

    增加了一个权重,当权重为 1 时,等价于 tf.nn.softmax_cross_entropy_with_logits

    #our NN's output
    logits = tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
    #step1:do softmax
    y = tf.nn.softmax(logits)
    
    #true label
    y_= tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
    
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_))
    # tf.losses.softmax_cross_entropy(y_, logits, weights=1.)
    tf.losses.softmax_cross_entropy(y_, logits, weights=0.5)
    
    with tf.Session() as sess:
        print(sess.run(loss))                           # 0.40760598
        print(sess.run(tf.losses.get_total_loss()))     # 0.40760598        weights=1 时想等, weights=0.5 时为 0.20380299

    均方差

    tensorflow 其实没有提供这个 API,自己实现也很方便  

    y = tf.constant([0.9, 2.1, 2.8])
    y_pred = tf.constant([1, 2, 3], dtype=tf.float32)
    err1 = tf.reduce_sum(tf.square(y - y_pred)) / 3
    err2 = tf.reduce_mean(tf.square(y - y_pred))
    
    sess = tf.Session()
    print(sess.run(err1))       # 0.020000001
    print(sess.run(err2))       # 0.020000001

    参考资料:

    https://blog.csdn.net/marsjhao/article/details/72630147

    https://blog.csdn.net/weixin_42561002/article/details/87802096  tf.losses.softmax_cross_entropy()及相邻函数中weights参数的设置

  • 相关阅读:
    io工具类
    并发高级知识
    HashMap相关源码阅读
    ArrayList和LinkedList部分源码分析性能差异
    我自己的JdbcTemplate
    mysql5.7.20靠谱安装步骤
    NG 转发配置
    SQLite总结
    算是不常用的东西,java中的ResultSet转List
    不常用的技能-【手动编译java类】
  • 原文地址:https://www.cnblogs.com/yanshw/p/10518253.html
Copyright © 2020-2023  润新知