• 机器学习(十) 评价分类结果 (下)


    五、精准率和召回率的平衡

    Precision-Recall 的平衡

    六、精准率-召回率曲线

    七、ROC曲线

     Receiver Operation Characteristic Curve

    描述 TPR  和 FPR 之间的关系

     metrics.py

    import numpy as np
    from math import sqrt
    
    
    def accuracy_score(y_true, y_predict):
        """计算y_true和y_predict之间的准确率"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum(y_true == y_predict) / len(y_true)
    
    
    def mean_squared_error(y_true, y_predict):
        """计算y_true和y_predict之间的MSE"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum((y_true - y_predict)**2) / len(y_true)
    
    
    def root_mean_squared_error(y_true, y_predict):
        """计算y_true和y_predict之间的RMSE"""
    
        return sqrt(mean_squared_error(y_true, y_predict))
    
    
    def mean_absolute_error(y_true, y_predict):
        """计算y_true和y_predict之间的MAE"""
        assert len(y_true) == len(y_predict), 
            "the size of y_true must be equal to the size of y_predict"
    
        return np.sum(np.absolute(y_true - y_predict)) / len(y_true)
    
    
    def r2_score(y_true, y_predict):
        """计算y_true和y_predict之间的R Square"""
    
        return 1 - mean_squared_error(y_true, y_predict)/np.var(y_true)
    
    
    def TN(y_true, y_predict):
        assert len(y_true) == len(y_predict)
        return np.sum((y_true == 0) & (y_predict == 0))
    
    
    def FP(y_true, y_predict):
        assert len(y_true) == len(y_predict)
        return np.sum((y_true == 0) & (y_predict == 1))
    
    
    def FN(y_true, y_predict):
        assert len(y_true) == len(y_predict)
        return np.sum((y_true == 1) & (y_predict == 0))
    
    
    def TP(y_true, y_predict):
        assert len(y_true) == len(y_predict)
        return np.sum((y_true == 1) & (y_predict == 1))
    
    
    def confusion_matrix(y_true, y_predict):
        return np.array([
            [TN(y_true, y_predict), FP(y_true, y_predict)],
            [FN(y_true, y_predict), TP(y_true, y_predict)]
        ])
    
    
    def precision_score(y_true, y_predict):
        assert len(y_true) == len(y_predict)
        tp = TP(y_true, y_predict)
        fp = FP(y_true, y_predict)
        try:
            return tp / (tp + fp)
        except:
            return 0.0
    
    
    def recall_score(y_true, y_predict):
        assert len(y_true) == len(y_predict)
        tp = TP(y_true, y_predict)
        fn = FN(y_true, y_predict)
        try:
            return tp / (tp + fn)
        except:
            return 0.0
    
    
    def f1_score(y_true, y_predict):
        precision = precision_score(y_true, y_predict)
        recall = recall_score(y_true, y_predict)
    
        try:
            return 2. * precision * recall / (precision + recall)
        except:
            return 0.
    
    
    def TPR(y_true, y_predict):
        tp = TP(y_true, y_predict)
        fn = FN(y_true, y_predict)
        try:
            return tp / (tp + fn)
        except:
            return 0.
    
    
    def FPR(y_true, y_predict):
        fp = FP(y_true, y_predict)
        tn = TN(y_true, y_predict)
        try:
            return fp / (fp + tn)
        except:
            return 0.
        

    八、多分类问题中的混淆矩阵

     

     我写的文章只是我自己对bobo老师讲课内容的理解和整理,也只是我自己的弊见。bobo老师的课 是慕课网出品的。欢迎大家一起学习。

  • 相关阅读:
    VMware下安装Ubuntu虚拟机
    py3+urllib+bs4+反爬,20+行代码教你爬取豆瓣妹子图
    老铁,这年头得玩玩这个:Git基本操作【github】
    本地Git与GitHub服务器建立连接(SSH方式通信)
    python开启httpserver服务在自动化测试中的一个小运用
    python测试webservice接口
    Xcache3.2.0不支持php7.0.11
    Nginx设置alias别名目录访问phpmyadmin
    CentOS 7.2.1511编译安装Nginx1.10.1+MySQL5.7.15+PHP7.0.11
    CentOS平滑更新nginx版本
  • 原文地址:https://www.cnblogs.com/zhangtaotqy/p/9571289.html
Copyright © 2020-2023  润新知