• 基于sklearn的常用分类任务指标Python实现


    基于sklearn的常用分类任务指标Python实现

    一、摘要

    分类任务常用指标包含混淆矩阵、每类分类精度、平均分类精度、总体分类精度、f1-score等。 Python的sklearn.metrics 模块覆盖了分类任务中大部分常用的验证指标, 本文选择其中几种评价指标展示代码片段,供读者使用。 基于tensorflow-1.0与mnist数据集做demo展示并列举实验结果。 文末附有sklearn.metrics模块的相关资料链接,方便高端玩家深入探索。

    二、本文包含的评价指标

    混淆矩阵(Confusion Matrix,CM) 
    每类别分类精度 
    每类别召回率 
    平均分类精度(Average Accuracy,AA) 
    总体分类精度(Overall Accuracy,OA)

    三、功能代码片段展示

    代码在tensorflow-1.0、Python3.5环境下通过测试,tf1.0版本API改动较大,1.0以下版本tensorflow可能不能通过测试,精力有限,其他环境尚未做测试。

     1 from sklearn import metrics
     2 import numpy as np
     3 #####
     4 # Do classification task, 
     5 # then get the ground truth and the predict label named y_true and y_pred
     6 classify_report = metrics.classification_report(y_true, y_pred)
     7 confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
     8 overall_accuracy = metrics.accuracy_score(y_true, y_pred)
     9 acc_for_each_class = metrics.precision_score(y_true, y_pred, average=None)
    10 average_accuracy = np.mean(acc_for_each_class)
    11 score = metrics.accuracy_score(y_true, y_pred)
    12 print('classify_report : 
    ', classify_report)
    13 print('confusion_matrix : 
    ', confusion_matrix)
    14 print('acc_for_each_class : 
    ', acc_for_each_class)
    15 print('average_accuracy: {0:f}'.format(average_accuracy))
    16 print('overall_accuracy: {0:f}'.format(overall_accuracy))
    17 print('score: {0:f}'.format(score))

    四、实验结果展示

    本文基于tensorflow-1.0框架与mnist数据集,使用线性分类器与卷积神经网络分类并使用上文提到的代码片段展示分类性能。

    分类性能结果直观,排列清晰,便于二次使用。

    1. 线性分类器分类报告:

    2. 线性分类器混淆矩阵与其他分类指标展示:

    3. 卷积神经网络每层参数显示:

    4. 卷积神经网络分类报告:

    5. 卷积神经网络混淆矩阵与其他分类指标展示:

     

    五、代码示例

    使用类似如下的代码片段可以直观查看tensor相关内容

     1 print(some_tensor.op.name, ' ', some_tensor.get_shape().as_list()) 

    代码太长,这里就不粘贴了。代码来源:https://github.com/JiJingYu/tensorflow-exercise/blob/master/mnist_test.py

     六、总结

    不得不说sklearn是个全面的Python模块,常用的机器学习方法以及评价准则都能从中找到函数与例程。同时,tensorflow作为google亲儿子,发展劲头势不可挡。

    sklearn 官网:http://scikit-learn.org/stable/index.html

    一份高质量 sklearn tutorial:https://github.com/jakevdp/sklearn_tutorial

  • 相关阅读:
    五种线程池的分类与作用
    什么是死锁?
    事务隔离级别区分,未提交读,提交读,可重复读
    共享锁(读锁)和排他锁(写锁)
    java中的成员变量和全局变量的区别
    Algorithm
    6
    5
    4
    3
  • 原文地址:https://www.cnblogs.com/nwpuxuezha/p/6539054.html
Copyright © 2020-2023  润新知