1 import org.apache.log4j.{Level, Logger} 2 import org.apache.spark.ml.classification.LogisticRegression 3 import org.apache.spark.ml.linalg.Vectors 4 import org.apache.spark.sql.SparkSession 5 6 /** 7 * 逻辑回归 8 * Created by zhen on 2018/11/20. 9 */ 10 object LogisticRegression { 11 Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别 12 def main(args: Array[String]) { 13 val spark = SparkSession.builder() 14 .appName("LogisticRegression") 15 .master("local[2]") 16 .getOrCreate() 17 val sqlContext = spark.sqlContext 18 // 加载训练数据和测试数据 19 val data = sqlContext.createDataFrame(Seq( 20 (1.0, Vectors.dense(0.0, 1.1, 0.1)), 21 (0.0, Vectors.dense(2.0, 1.0, -1.1)), 22 (1.0, Vectors.dense(1.0, 2.1, 0.1)), 23 (0.0, Vectors.dense(2.0, -1.3, 1.1)), 24 (0.0, Vectors.dense(2.0, 1.0, -1.1)), 25 (1.0, Vectors.dense(1.0, 2.1, 0.1)), 26 (1.0, Vectors.dense(2.0, 1.3, 1.1)), 27 (0.0, Vectors.dense(-2.0, 1.0, -1.1)), 28 (1.0, Vectors.dense(1.0, 2.1, 0.1)), 29 (0.0, Vectors.dense(2.0, -1.3, 1.1)), 30 (1.0, Vectors.dense(2.0, 1.0, -1.1)), 31 (1.0, Vectors.dense(1.0, 2.1, 0.1)), 32 (0.0, Vectors.dense(-2.0, 1.3, 1.1)), 33 (1.0, Vectors.dense(0.0, 1.2, -0.4)) 34 )) 35 .toDF("label", "features") 36 val weights = Array(0.8,0.2) //设置训练集和测试集的比例 37 val split_data = data.randomSplit(weights) // 拆分训练集和测试集 38 // 创建逻辑回归对象 39 val lr = new LogisticRegression() 40 // 设置参数 41 lr.setMaxIter(10).setRegParam(0.01) 42 // 训练模型 43 val model = lr.fit(split_data(0)) 44 model.transform(split_data(1)) 45 .select("label", "features", "probability", "prediction") 46 .collect() 47 .foreach(println(_)) 48 //关闭spark 49 spark.stop() 50 } 51 }
结果: