• spark LinearRegression 预测缺失字段的值


    最近在做金融科技建模的时候,字段里面很多缺少值得时候,模型对于新用户的预测会出现很大的不稳定,即PSI较大的情况。

    虽然我们依据字段IV值得大小不断的在调整字段且开发新变量,但是很多IV值很大的字段直接用平均值、或者0代替显然不够合理。

    所以,我们在尝试把字段缺失值当作需要预测的值,把该字段不缺失的当作y,用其他字段当作X,去预测该字段缺失值得值。不同于机器学习的回归和分类预测。

    这里的预测结果是一个具体的值,它的范围从负无穷到正无穷都有可能。

      数据直接读存于Hive,代码如下:

    import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode}
    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.mllib.linalg.{Vector, Vectors}
    import org.apache.spark.ml.regression.LinearRegression
    import org.apache.spark.mllib.regression.LabeledPoint
    import  org.apache.spark.ml.regression.LinearRegressionModel
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
    import scala.collection.mutable.ArrayBuffer
    //  select   corr(cast(p.cnt_addbook_one as double),cast(l.cnt_addbook_one as double))as corrs  from   lkl_card_score.predictcnt_addbook_one20180201  p join lkl_card_score.fieldValuePredictModel3
    //l on p.order_id=l.order_src  where l.cnt_addbook_one<>0
    //
    object predictcnt_addbook_one20180201 {
      def main(args: Array[String]): Unit = {
        val cf = new SparkConf().setAppName("ass").setMaster("local")
        val sc = new SparkContext(cf)
        val sqlContext = new SQLContext(sc)
        val hc = new HiveContext(sc)
        import sqlContext.implicits._
    
         val data = hc.sql(s"select * from lkl_card_score.fieldValuePredictModel3 where cnt_addbook_one<>0   and cnt_addbook_one%2=1").map {
              row =>
                val arr = new ArrayBuffer[Double]()
                //剔除label、phone字段
                for (i <- 4 until row.size) {
                  if (row.isNullAt(i)) {
                    arr += 0.0
                  }
                  else if (row.get(i).isInstanceOf[Int])
                    arr += row.getInt(i).toDouble
                  else if (row.get(i).isInstanceOf[Double])
                    arr += row.getDouble(i)
                  else if (row.get(i).isInstanceOf[Long])
                    arr += row.getLong(i).toDouble
                  else if (row.get(i).isInstanceOf[String])
                    arr += 0.0
                }
                LabeledPoint(row.getLong(0).toDouble,Vectors.dense(arr.toArray))
            }.toDF("Murder","features")
    
        // 建立模型,预测谋杀率Murder
        // 设置线性回归参数
    
          val lr1 = new LinearRegression()
         val lr2 = lr1.setFeaturesCol("features").setLabelCol("Murder").setFitIntercept(true)
        // RegParam:正则化
        val lr3 = lr2.setMaxIter(50).setRegParam(0.3).setElasticNetParam(0.8)
        // 将训练集合代入模型进行训练
    
          val lr = lr3
        val lrModel = lr.fit(data)
        // 输出模型全部参数
        lrModel.extractParamMap()
        println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
        lrModel.write.overwrite().save(s"hdfs://ns1/user/songchunlin/model/predictcnt_addbook_one20180202")
        // 模型进行评价
        val trainingSummary = lrModel.summary
        println(s"numIterations: ${trainingSummary.totalIterations}")
        println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
        trainingSummary.residuals.show()
        println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
        println(s"r2: ${trainingSummary.r2}")
        val predict = hc.sql(s"select * from lkl_card_score.fieldValuePredictModel3 where cnt_addbook_one<>0   and cnt_addbook_one%2=0").map {
          row =>
            val arr = new ArrayBuffer[Double]()
            //剔除label、phone字段
            for (i <- 4 until row.size) {
              if (row.isNullAt(i)) {
                arr += 0.0
              }
              else if (row.get(i).isInstanceOf[Int])
                arr += row.getInt(i).toDouble
              else if (row.get(i).isInstanceOf[Double])
                arr += row.getDouble(i)
              else if (row.get(i).isInstanceOf[Long])
                arr += row.getLong(i).toDouble
              else if (row.get(i).isInstanceOf[String])
                arr += 0.0
            }
            (row.getString(2),Vectors.dense(arr.toArray))
        }.toDF("order_src","features")
    
    
       val models=LinearRegressionModel.load("hdfs://ns1/user/songchunlin/model/predictcnt_addbook_one20180202")
        val prediction =models.transform(predict)
    
    
        //    val predictions = lrModel.transform(vecDF)
        println("输出预测结果")
        val predict_result: DataFrame =prediction.selectExpr("order_src","prediction")
        val pre2=prediction.map(row=>Row(row.get(0).toString,row.get(2).toString))
        val schema = StructType(
          List(
            StructField("order_id", StringType, true),
            StructField("cnt_addbook_one", StringType, true)
          )
        )
        val scoreDataFrame = hc.createDataFrame(pre2,schema)
        scoreDataFrame.count()
        scoreDataFrame.write.mode(SaveMode.Overwrite).saveAsTable("lkl_card_score.predictcnt_addbook_one20180202")
    
    //    predict_result.write.mode(SaveMode.Overwrite).saveAsTable("lkl_card_score.fieldValuePredictModel3_prediction20180131")
    //    predict_result.foreach(println(_))
    //    sc.stop()
    
    
    
      }
    }

       用模型预测未参加训练的数据,计算预测的数据和真实数据相关性为0.99553818714507836,有很大的价值。

    select  corr(cast(l.cnt_addbook_one as double),cast(p.cnt_addbook_one as double)) from    lkl_card_score.predictcnt_addbook_one20180202  l
    join lkl_card_score.fieldValuePredictModel3 p  on l.order_id=p.order_src
    ;
    

     

  • 相关阅读:
    贝壳找房一站式报警平台建设实践
    golang读取email frange 博客园 https://www.cnblogs.com/Frange/p/14113326.html
    select * order group by
    MYSQL中的COLLATE是什么? 星朝 博客园 https://www.cnblogs.com/jpfss/p/11548826.html
    1
    a
    一种基于元模型的访问控制策略描述语言
    vivo 服务端监控架构设计与实践
    FFMPEG笔记
    Vue生命周期
  • 原文地址:https://www.cnblogs.com/canyangfeixue/p/8397607.html
Copyright © 2020-2023  润新知