• tensorflow2.0 评估函数


    一,常用的内置评估指标

    • MeanSquaredError(平方差误差,用于回归,可以简写为MSE,函数形式为mse)

    • MeanAbsoluteError (绝对值误差,用于回归,可以简写为MAE,函数形式为mae)

    • MeanAbsolutePercentageError (平均百分比误差,用于回归,可以简写为MAPE,函数形式为mape)

    • RootMeanSquaredError (均方根误差,用于回归)

    • Accuracy (准确率,用于分类,可以用字符串"Accuracy"表示,Accuracy=(TP+TN)/(TP+TN+FP+FN),要求y_true和y_pred都为类别序号编码)

    • Precision (精确率,用于二分类,Precision = TP/(TP+FP))

    • Recall (召回率,用于二分类,Recall = TP/(TP+FN))

    • TruePositives (真正例,用于二分类)

    • TrueNegatives (真负例,用于二分类)

    • FalsePositives (假正例,用于二分类)

    • FalseNegatives (假负例,用于二分类)

    • AUC(ROC曲线(TPR vs FPR)下的面积,用于二分类,直观解释为随机抽取一个正样本和一个负样本,正样本的预测值大于负样本的概率)

    • CategoricalAccuracy(分类准确率,与Accuracy含义相同,要求y_true(label)为onehot编码形式)

    • SparseCategoricalAccuracy (稀疏分类准确率,与Accuracy含义相同,要求y_true(label)为序号编码形式)

    • MeanIoU (Intersection-Over-Union,常用于图像分割)

    • TopKCategoricalAccuracy (多分类TopK准确率,要求y_true(label)为onehot编码形式)

    • SparseTopKCategoricalAccuracy (稀疏多分类TopK准确率,要求y_true(label)为序号编码形式)

    • Mean (平均值)

    • Sum (求和)

    • https://tensorflow.google.cn/api_docs/python/tf/keras/metrics

    二,自定义品函数及使用

    import numpy as np
    import pandas as pd
    import tensorflow as tf
    from tensorflow.keras import layers,models,losses,metrics
     
    # 函数形式的自定义评估指标
    @tf.function
    def ks(y_true,y_pred):
        y_true = tf.reshape(y_true,(-1,))
        y_pred = tf.reshape(y_pred,(-1,))
        length = tf.shape(y_true)[0]
        t = tf.math.top_k(y_pred,k = length,sorted = False)
        y_pred_sorted = tf.gather(y_pred,t.indices)
        y_true_sorted = tf.gather(y_true,t.indices)
        cum_positive_ratio = tf.truediv(
            tf.cumsum(y_true_sorted),tf.reduce_sum(y_true_sorted))
        cum_negative_ratio = tf.truediv(
            tf.cumsum(1 - y_true_sorted),tf.reduce_sum(1 - y_true_sorted))
        ks_value = tf.reduce_max(tf.abs(cum_positive_ratio - cum_negative_ratio)) 
        return ks_value
    y_true = tf.constant([[1],[1],[1],[0],[1],[1],[1],[0],[0],[0],[1],[0],[1],[0]])
    y_pred = tf.constant([[0.6],[0.1],[0.4],[0.5],[0.7],[0.7],[0.7],
                          [0.4],[0.4],[0.5],[0.8],[0.3],[0.5],[0.3]])
    tf.print(ks(y_true,y_pred))
    model.compile(
        loss="categorical_crossentropy",
        optimizer=keras.optimizers.Adam(lr=0.001),
        metrics=[keras.metrics.MeanIoU(num_classes=2),ks]
        
    )
    

      

  • 相关阅读:
    excel导出
    分页工具类
    orcale生成订单号---订单号的要求为:yyyyMMddHHmmss+000N
    spring data和spring嵌套使用 applictionContext.xml配置文件
    spring data jpa和spring嵌套使用 依赖引用
    spring data jpa 代码实现 增删改查
    一些常用的正则表达式
    mysql的两道实验题 涵盖sql语句基本操作方向
    Spring MVC 上传文件---配置文件
    Spring MVC 上传文件---代码实现
  • 原文地址:https://www.cnblogs.com/Dean0731/p/12918654.html
Copyright © 2020-2023  润新知