• ROC曲线绘制


    1. 引入相关包

    使用matplotlib包作为绘图库,故要引入相关的包

    为了使画出的图更为符合期刊要求,这里引入SciencePlots。

    它是一个基于Matplotlib的补充包,里面主要包含了一些以.mplstyle为后缀的图表样式的配置文件。这样,你画图的时候只需要通过调用这些配置文件,就能画出比较好看的数据可视化图表,也避免了你每次画图时都要从头开始手动配置图表的格式。

    pip install SciencePlots
    

    还要引入numpy对数据进行处理

    要计算AUC,还应该引入sklearn中计算相关值的包。

    import matplotlib.pyplot as plt
    plt.style.use(['science'])
    import numpy as np
    from sklearn.metrics import roc_curve, auc
    

    然后导入相关数据

    # 真实值
    y = np.load('npy\\y_test.npy')
    # 各种预测值 并非0或1 而是概率
    yp_ann = np.load('npy\\ann.npy')
    yp_lstm = np.load('npy\\lstm.npy')
    yp_lr = np.load('npy\\lr.npy')
    yp_rf = np.load('npy\\rf.npy')
    yp_xgb = np.load('npy\\xgb.npy')
    yp_lgbm = np.load('npy\\lgbm.npy')
    yp_catb = np.load('npy\\catb.npy')
    

    2. 计算AUC值

    AUC,即AUROC,指的是由TPRFPR围成的ROC曲线下的面积

    将分类任务的实际值和预测值作为参数输入给roc_curve()方法可以得到FPR、TPR和对应的阈值。

    auc()方法可以计算曲线下的面积,将FPR和TPR作为参数输入,即可获得AUC值。

    fpr_1, tpr_1, threshold_1 = roc_curve(y, yp_ann)  # 计算FPR和TPR
    auc_1 = auc(fpr_1, tpr_1)  # 计算AUC值
    
    fpr_2, tpr_2, threshold_2 = roc_curve(y, yp_lstm)
    auc_2 = auc(fpr_2, tpr_2)
    
    fpr_3, tpr_3, threshold_3 = roc_curve(y, yp_lr)
    auc_3 = auc(fpr_3, tpr_3)
    
    fpr_4, tpr_4, threshold_4 = roc_curve(y, yp_rf)
    auc_4 = auc(fpr_4, tpr_4)
    
    fpr_5, tpr_5, threshold_5 = roc_curve(y, yp_xgb)
    auc_5 = auc(fpr_5, tpr_5)
    
    fpr_6, tpr_6, threshold_6 = roc_curve(y, yp_lgbm)
    auc_6 = auc(fpr_6, tpr_6)
    
    fpr_7, tpr_7, threshold_7 = roc_curve(y, yp_catb)
    auc_7 = auc(fpr_7, tpr_7)
    

    3. 绘制曲线

    首先定义曲线的宽度和图的大小,如下所示。

    line_width = 1  # 曲线的宽度
    plt.figure(figsize=(16, 10))  # 图的大小
    

    使用plt的plot()方法可以绘制曲线,通常可以传入的参数有以下几种:

    • x轴的数据
    • y轴的数据
    • lw:线条的宽度
    • label:曲线的标签(曲线标签甚至支持LaTex公式,例如$K_{d,1}$
    • color:曲线的颜色(如果不指定,plt会自动选择)
    • linestyle:线型,包括“-”代表实线,“--”代表虚线,“-.”代表中间有点的虚线,“:”点型虚线
    plt.plot(fpr_1, tpr_1, lw=line_width, label='Ann (AUC = %0.4f)' % auc_1,)
    plt.plot(fpr_2, tpr_2, lw=line_width, label='Lstm (AUC = %0.4f)' % auc_2,)
    plt.plot(fpr_3, tpr_3, lw=line_width, label='LogisticRegression (AUC = %0.4f)' % auc_3,)
    plt.plot(fpr_4, tpr_4, lw=line_width, label='RandomForest (AUC = %0.4f)' % auc_4,)
    plt.plot(fpr_5, tpr_5, lw=line_width, label='XGboost (AUC = %0.4f)' % auc_5,)
    plt.plot(fpr_6, tpr_6, lw=line_width, label='LightGBM (AUC = %0.4f)' % auc_6,)
    plt.plot(fpr_7, tpr_7, lw=line_width, label='Catboost (AUC = %0.4f)' % auc_7,)
    

    4. 坐标轴范围和标题

    限定x轴和y轴的范围,如下所示。

    plt.xlim([0.0, 1.0])  # 限定x轴的范围
    plt.ylim([0.0, 1.0])  # 限定y轴的范围
    

    也可以通过xticks()和yticks()直接调整坐标轴的刻度,如下所示。

    # plt.xticks(range(0, 10, 1)) # 修改x轴的刻度
    # plt.yticks(range(0, 10, 1)) # 修改y轴的刻度
    

    指定坐标轴的标题,如下所示。

    plt.xlabel('False Positive Rate')  # x坐标轴标题
    plt.ylabel('True Positive Rate')  # y坐标轴标题
    

    使用grid()方法在图中添加网格,如下所示。

    plt.grid()  # 在图中添加网格
    

    显示图例并指定图例位置,常见位置包括{upper,center,lower} {left,center,right},如下所示。

    plt.legend(loc="lower right")  # 显示图例并指定图例位置
    

    5. 中文处理问题

    如果在坐标轴、标题等地方出现了中文,plt会显示乱码,添加以下两条语句可以解决中文处理问题。

    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    

    6. 展示图片和保存

    TIFF格式(Tag Image File Format,TIFF)是常见的论文图片投稿格式,TIFF格式能够制作质量非常高的图像,多数出版社(如Springer、Elsevier)都接受并推荐使用dpi=300的TIFF格式的插图。

    plt.savefig('AUC.tif', dpi=300)
    

    使用plt的show方法展示曲线,如下所示。

    plt.show()
    

    7. 示例代码

    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.metrics import roc_curve, auc
    from sklearn.metrics import precision_recall_curve
    
    y = np.load('npy\\y_test.npy')
    yp_ann = np.load('npy\\ann.npy')
    yp_lstm = np.load('npy\\lstm.npy')
    yp_lr = np.load('npy\\lr.npy')
    yp_rf = np.load('npy\\rf.npy')
    yp_xgb = np.load('npy\\xgb.npy')
    yp_lgbm = np.load('npy\\lgbm.npy')
    yp_catb = np.load('npy\\catb.npy')
    
    fpr_1, tpr_1, threshold_1 = roc_curve(y, yp_ann)  # 计算FPR和TPR
    auc_1 = auc(fpr_1, tpr_1)  # 计算AUC值
    
    fpr_2, tpr_2, threshold_2 = roc_curve(y, yp_lstm)
    auc_2 = auc(fpr_2, tpr_2)
    
    fpr_3, tpr_3, threshold_3 = roc_curve(y, yp_lr)
    auc_3 = auc(fpr_3, tpr_3)
    
    fpr_4, tpr_4, threshold_4 = roc_curve(y, yp_rf)
    auc_4 = auc(fpr_4, tpr_4)
    
    fpr_5, tpr_5, threshold_5 = roc_curve(y, yp_xgb)
    auc_5 = auc(fpr_5, tpr_5)
    
    fpr_6, tpr_6, threshold_6 = roc_curve(y, yp_lgbm)
    auc_6 = auc(fpr_6, tpr_6)
    
    fpr_7, tpr_7, threshold_7 = roc_curve(y, yp_catb)
    auc_7 = auc(fpr_7, tpr_7)
    
    plt.style.use(['science'])
    line_width = 2  # 曲线的宽度
    plt.figure(figsize=(8, 5))  # 图的大小
    
    plt.plot(fpr_1, tpr_1, lw=line_width, label='Ann (AUC = %0.4f)' % auc_1,)
    plt.plot(fpr_2, tpr_2, lw=line_width, label='Lstm (AUC = %0.4f)' % auc_2,)
    plt.plot(fpr_3, tpr_3, lw=line_width, label='LogisticRegression (AUC = %0.4f)' % auc_3,)
    plt.plot(fpr_4, tpr_4, lw=line_width, label='RandomForest (AUC = %0.4f)' % auc_4,)
    plt.plot(fpr_5, tpr_5, lw=line_width, label='XGboost (AUC = %0.4f)' % auc_5,)
    plt.plot(fpr_6, tpr_6, lw=line_width, label='LightGBM (AUC = %0.4f)' % auc_6,)
    plt.plot(fpr_7, tpr_7, lw=line_width, label='Catboost (AUC = %0.4f)' % auc_7,)
    
    
    plt.xlim([0.0, 1.0])  # 限定x轴的范围
    plt.ylim([0.0, 1.0])  # 限定y轴的范围
    plt.xlabel('False Positive Rate', fontsize=16)  # x坐标轴标题
    plt.ylabel('True Positive Rate', fontsize=16)  # y坐标轴标题
    plt.title('ROC', fontsize=16)  # 标题
    plt.grid()  # 在图中添加网格
    plt.legend(loc="lower right", fontsize=16)  # 显示图例并指定图例位置
    
    plt.savefig('ROC.tif', dpi=300)
    plt.show()
    

    image-20221026202136784

  • 相关阅读:
    html-css___table属性(设置细线边框)
    简单的jquery表单验证+添加+删除+全选/反选
    CKEditor5 使用第二天 获取回传数据,图片上传
    ckeditor5 使用第一天 下载并加载居中,居左,居右功能
    Android studio 3.4 新建项目报错Error:unable to resolve dependency for app@。。。解决办法
    IDEA 运行后乱码问题解决
    tomcat9启动后控制台输出乱码问题
    springboot架构下运用shiro后在configuration,通过@Value获取不到值,总是为null
    IDEA org.apache.ibatis.binding.BindingException: Invalid bound statement (not found):
    查找 oracle 数据库中包含某一字段的所有表的表名
  • 原文地址:https://www.cnblogs.com/wkfvawl/p/16829917.html
Copyright © 2020-2023  润新知