• Spark线性回归实现优化


     1 import org.apache.log4j.{Level, Logger}
     2 import org.apache.spark.ml.feature.VectorAssembler
     3 import org.apache.spark.ml.regression.LinearRegression
     4 import org.apache.spark.sql.SparkSession
     5 
     6 /**
     7   * 线性回归
     8   * Created by zhen on 2018/11/12.
     9   */
    10 object LinearRegression {
    11   Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别
    12   def main(args: Array[String]) {
    13     val spark = SparkSession
    14       .builder()
    15       .appName("LinearRegression")
    16       .master("local[2]")
    17       .getOrCreate()
    18     val train_data = spark.sparkContext.textFile("E:/BDS/newsparkml/src/train.txt") // 加载数据
    19     val train_map_data = train_data.map{ row =>
    20         val split = row.split(",")
    21         (split(0).toDouble,split(1).toDouble,split(2).toDouble,split(3).toDouble,
    22           split(4).toDouble,split(5).toDouble,split(6).toDouble,split(7).toDouble)
    23       }
    24     val df = spark.sqlContext.createDataFrame(train_map_data)
    25     val colArray = Array("Population","Income","Illiteracy","LifeExp","HSGrad","Frost","Area")
    26     val train_df = df.toDF(colArray(0),colArray(1),colArray(2),colArray(3),"Murder",colArray(4),colArray(5),colArray(6))
    27     val assembler = new VectorAssembler()
    28       .setInputCols(colArray)
    29       .setOutputCol("features")
    30     val vectDF = assembler.transform(train_df)
    31     val weights = Array(0.8,0.2) //设置训练集和测试集的比例
    32     val split_data = vectDF.randomSplit(weights) // 拆分训练集和测试集
    33     // 创建模型对象
    34     val linearRegression = new LinearRegression()
    35       .setFeaturesCol("features")
    36       .setLabelCol("Murder")
    37       .setFitIntercept(true)
    38       .setMaxIter(10)
    39       .setRegParam(0.3)// 正则化
    40       .setElasticNetParam(0.8)
    41     // 训练模型
    42     val lrModel = linearRegression.fit(split_data(0))
    43     // 查看模型参数
    44     //lrModel.extractParamMap()
    45     println(s"Cofficients:${lrModel.coefficients} Intercept:${lrModel.intercept}")
    46     //模型评估
    47     val trainingSummary = lrModel.summary
    48     println(s"objectiveHistoryList:${trainingSummary.objectiveHistory.toList}")
    49     println(s"r2:${trainingSummary.r2}")
    50     // 预测
    51     val predictions = lrModel.transform(split_data(1))
    52     val predict_result = predictions.selectExpr("features","Murder","round(prediction,1) as prediction") // 保存一位小数
    53     println("训练集数据------------------------------真实值--预测值")
    54     predict_result.foreach(println(_))
    55   }
    56 }

    结果:

  • 相关阅读:
    mysql主从与mycat与MHA
    mycat+mysql集群:实现读写分离,分库分表
    centos7下扩展根分区(图文详解)
    MySQL高可用之MHA的搭建
    zabbix_agentd客户端安装与配置(windows操作系统)
    centos7手把手教你搭建zabbix监控
    有关添加System.Web的问题
    锚标签<a>
    MVC3出现“提供程序未返回 ProviderManifestToken 字符串”的解决办法
    解决.net后台调用js弹窗后,前台界面样式乱掉问题
  • 原文地址:https://www.cnblogs.com/yszd/p/9952268.html
Copyright © 2020-2023  润新知