• 机器学习-逻辑回归算法


    代码:

    package com.test
    
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.linalg.DenseVector
    import org.apache.spark.sql.SparkSession
    
    object Test03 {
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf()
        conf.setMaster("local")
        val spark = SparkSession.builder().config(conf).appName("Logistic linear regression").getOrCreate()
    
        import spark.implicits._
    
        val dataRdd = spark.sparkContext.textFile("data/breast_cancer.csv")
        val data = dataRdd.map(x => {
          val arr = x.split(",")
          val features = new Array[String](arr.length - 1)
          arr.copyToArray(features, 0, arr.length - 1)
          val label = arr(arr.length - 1)
          (new DenseVector(features.map(_.toDouble)), label.toDouble)
        }).toDF("features", "label")
    
        val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
    
        val (trainingData, testData) = (splits(0), splits(1))
    
        val lr = new LogisticRegression().setMaxIter(100)
    
        val lrModel = lr.fit(trainingData)
    
        println(s"w1~wn: ${lrModel.coefficients} w0: ${lrModel.intercept}")
    
        //测试集验证正确率
        val testRest = lrModel.transform(testData)
        testRest.show(false)
    
        // 计算正确率
        val mean = testRest.rdd.map(row => {
          //这个样本真实的分类号
          val label = row.getAs[Double]("label")
          //将测试数据的x特征带入到model后预测出来的分类号
          val prediction = row.getAs[Double]("prediction")
          //0:预测正确   1:预测错了  abs绝对值
          math.abs(label - prediction)
        }).sum()
        println("正确率:" + (1 - (mean / testData.count())))
        // 相当于上面的整个mean计算
        println("正确率:" + lrModel.evaluate(testData).accuracy)
    
    
        val count = testRest.rdd.map(row => {
          val probability = row.getAs[DenseVector]("probability")
          val label = row.getAs[Double]("label")
          val score = probability(1)
          val prediction = if (score > 0.3) 1 else 0
          math.abs(label - prediction)
        }).sum()
        println("自定义分类阈值 正确率:" + (1 - (count / testData.count())))
        spark.close()
      }
    
    }

    结果:

    +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+----------------------------------------+-----------+----------+
    |features                                                                                                                                                                                                                   |label|rawPrediction                           |probability|prediction|
    +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+----------------------------------------+-----------+----------+
    |[7.729,25.49,47.98,178.8,0.08098,0.04878,0.0,0.0,0.187,0.07285,0.3777,1.462,2.492,19.14,0.01266,0.009692,0.0,0.0,0.02882,0.006872,9.077,30.92,57.17,248.0,0.1256,0.0834,0.0,0.0,0.3058,0.09938]                            |1.0  |[-6902.544911715737,6902.544911715737]  |[0.0,1.0]  |1.0       |
    |[7.76,24.54,47.92,181.0,0.05263,0.04362,0.0,0.0,0.1587,0.05884,0.3857,1.428,2.548,19.15,0.007189,0.00466,0.0,0.0,0.02676,0.002783,9.456,30.37,59.16,268.6,0.08996,0.06444,0.0,0.0,0.2871,0.07039]                          |1.0  |[-4229.951065882061,4229.951065882061]  |[0.0,1.0]  |1.0       |
    |[8.618,11.79,54.34,224.5,0.09752,0.05272,0.02061,0.007799,0.1683,0.07187,0.1559,0.5796,1.046,8.322,0.01011,0.01055,0.01981,0.005742,0.0209,0.002788,9.507,15.4,59.9,274.9,0.1733,0.1239,0.1168,0.04419,0.322,0.09026]      |1.0  |[-10129.616436167826,10129.616436167826]|[0.0,1.0]  |1.0       |
    |[8.671,14.45,54.42,227.2,0.09138,0.04276,0.0,0.0,0.1722,0.06724,0.2204,0.7873,1.435,11.36,0.009172,0.008007,0.0,0.0,0.02711,0.003399,9.262,17.04,58.36,259.2,0.1162,0.07057,0.0,0.0,0.2592,0.07848]                        |1.0  |[-13336.888611143484,13336.888611143484]|[0.0,1.0]  |1.0       |
    |[8.726,15.83,55.84,230.9,0.115,0.08201,0.04132,0.01924,0.1649,0.07633,0.1665,0.5864,1.354,8.966,0.008261,0.02213,0.03259,0.0104,0.01708,0.003806,9.628,19.62,64.48,284.4,0.1724,0.2364,0.2456,0.105,0.2926,0.1017]         |1.0  |[-9268.251080440577,9268.251080440577]  |[0.0,1.0]  |1.0       |
    |[8.734,16.84,55.27,234.3,0.1039,0.07428,0.0,0.0,0.1985,0.07098,0.5169,2.079,3.167,28.85,0.01582,0.01966,0.0,0.0,0.01865,0.006736,10.17,22.8,64.01,317.0,0.146,0.131,0.0,0.0,0.2445,0.08865]                                |1.0  |[-10929.784584048108,10929.784584048108]|[0.0,1.0]  |1.0       |
    |[9.295,13.9,59.96,257.8,0.1371,0.1225,0.03332,0.02421,0.2197,0.07696,0.3538,1.13,2.388,19.63,0.01546,0.0254,0.02197,0.0158,0.03997,0.003901,10.57,17.84,67.84,326.6,0.185,0.2097,0.09996,0.07262,0.3681,0.08982]           |1.0  |[-9972.504809621163,9972.504809621163]  |[0.0,1.0]  |1.0       |
    |[9.333,21.94,59.01,264.0,0.0924,0.05605,0.03996,0.01282,0.1692,0.06576,0.3013,1.879,2.121,17.86,0.01094,0.01834,0.03996,0.01282,0.03759,0.004623,9.845,25.05,62.86,295.8,0.1103,0.08298,0.07993,0.02564,0.2435,0.07393]    |1.0  |[-10874.187824751454,10874.187824751454]|[0.0,1.0]  |1.0       |
    |[9.397,21.68,59.75,268.8,0.07969,0.06053,0.03735,0.005128,0.1274,0.06724,0.1186,1.182,1.174,6.802,0.005515,0.02674,0.03735,0.005128,0.01951,0.004583,9.965,27.99,66.61,301.0,0.1086,0.1887,0.1868,0.02564,0.2376,0.09206]  |1.0  |[-9687.163679996896,9687.163679996896]  |[0.0,1.0]  |1.0       |
    |[9.676,13.14,64.12,272.5,0.1255,0.2204,0.1188,0.07038,0.2057,0.09575,0.2744,1.39,1.787,17.67,0.02177,0.04888,0.05189,0.0145,0.02632,0.01148,10.6,18.04,69.47,328.1,0.2006,0.3663,0.2913,0.1075,0.2848,0.1364]              |1.0  |[-10425.819568377874,10425.819568377874]|[0.0,1.0]  |1.0       |
    |[9.742,19.12,61.93,289.7,0.1075,0.08333,0.008934,0.01967,0.2538,0.07029,0.6965,1.747,4.607,43.52,0.01307,0.01885,0.006021,0.01052,0.031,0.004225,11.21,23.17,71.79,380.9,0.1398,0.1352,0.02085,0.04589,0.3196,0.08009]     |1.0  |[-7470.623970453418,7470.623970453418]  |[0.0,1.0]  |1.0       |
    |[9.777,16.99,62.5,290.2,0.1037,0.08404,0.04334,0.01778,0.1584,0.07065,0.403,1.424,2.747,22.87,0.01385,0.02932,0.02722,0.01023,0.03281,0.004638,11.05,21.47,71.68,367.0,0.1467,0.1765,0.13,0.05334,0.2533,0.08468]          |1.0  |[-9631.107650152822,9631.107650152822]  |[0.0,1.0]  |1.0       |
    |[9.787,19.94,62.11,294.5,0.1024,0.05301,0.006829,0.007937,0.135,0.0689,0.335,2.043,2.132,20.05,0.01113,0.01463,0.005308,0.00525,0.01801,0.005667,10.92,26.29,68.81,366.1,0.1316,0.09473,0.02049,0.02381,0.1934,0.08988]    |1.0  |[-10574.948203901578,10574.948203901578]|[0.0,1.0]  |1.0       |
    |[10.03,21.28,63.19,307.3,0.08117,0.03912,0.00247,0.005159,0.163,0.06439,0.1851,1.341,1.184,11.6,0.005724,0.005697,0.002074,0.003527,0.01445,0.002411,11.11,28.94,69.92,376.3,0.1126,0.07094,0.01235,0.02579,0.2349,0.08061]|1.0  |[-8893.557477503311,8893.557477503311]  |[0.0,1.0]  |1.0       |
    |[10.08,15.11,63.76,317.5,0.09267,0.04695,0.001597,0.002404,0.1703,0.06048,0.4245,1.268,2.68,26.43,0.01439,0.012,0.001597,0.002404,0.02538,0.00347,11.87,21.18,75.39,437.0,0.1521,0.1019,0.00692,0.01042,0.2933,0.07697]    |1.0  |[-5668.804179132235,5668.804179132235]  |[0.0,1.0]  |1.0       |
    |[10.17,14.88,64.55,311.9,0.1134,0.08061,0.01084,0.0129,0.2743,0.0696,0.5158,1.441,3.312,34.62,0.007514,0.01099,0.007665,0.008193,0.04183,0.005953,11.02,17.45,69.86,368.6,0.1275,0.09866,0.02168,0.02579,0.3557,0.0802]    |1.0  |[-11852.030184143723,11852.030184143723]|[0.0,1.0]  |1.0       |
    |[10.32,16.35,65.31,324.9,0.09434,0.04994,0.01012,0.005495,0.1885,0.06201,0.2104,0.967,1.356,12.97,0.007086,0.007247,0.01012,0.005495,0.0156,0.002606,11.25,21.77,71.12,384.9,0.1285,0.08842,0.04384,0.02381,0.2681,0.07399]|1.0  |[-9073.703438558807,9073.703438558807]  |[0.0,1.0]  |1.0       |
    |[10.48,19.86,66.72,337.7,0.107,0.05971,0.04831,0.0307,0.1737,0.0644,0.3719,2.612,2.517,23.22,0.01604,0.01386,0.01865,0.01133,0.03476,0.00356,11.48,29.46,73.68,402.8,0.1515,0.1026,0.1181,0.06736,0.2883,0.07748]          |1.0  |[-4389.849643347756,4389.849643347756]  |[0.0,1.0]  |1.0       |
    |[10.51,23.09,66.85,334.2,0.1015,0.06797,0.02495,0.01875,0.1695,0.06556,0.2868,1.143,2.289,20.56,0.01017,0.01443,0.01861,0.0125,0.03464,0.001971,10.93,24.22,70.1,362.7,0.1143,0.08614,0.04158,0.03125,0.2227,0.06777]      |1.0  |[-10834.90580709216,10834.90580709216]  |[0.0,1.0]  |1.0       |
    |[10.57,20.22,70.15,338.3,0.09073,0.166,0.228,0.05941,0.2188,0.0845,0.1115,1.231,2.363,7.228,0.008499,0.07643,0.1535,0.02919,0.01617,0.0122,10.85,22.82,76.51,351.9,0.1143,0.3619,0.603,0.1465,0.2597,0.12]                 |1.0  |[-9212.082801976405,9212.082801976405]  |[0.0,1.0]  |1.0       |
    +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+----------------------------------------+-----------+----------+
    only showing top 20 rows
    正确率:0.9766081871345029
    正确率:0.9766081871345029
    
    自定义分类阈值 正确率:0.9766081871345029
  • 相关阅读:
    事务管理思考
    sleep、yield、wait的区别
    线程异常
    线程
    JAVA线程中断
    volatile synchronized在线程安全上的区别
    jms amqp activemq rabbitmq的区别
    servlet不是线程安全的
    雪花算法
    个人税收申报时候对于“全年一次性奖金“的处理
  • 原文地址:https://www.cnblogs.com/bigdata-familyMeals/p/14643759.html
Copyright © 2020-2023  润新知