• spark mllib 之线性回归


    public static void main(String[] args) {
        SparkConf sparkConf = new SparkConf()
              .setAppName("Regression")
              .setMaster("local[2]");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        JavaRDD<String> data = sc.textFile("/home/yurnom/lpsa.txt");
        JavaRDD<LabeledPoint> parsedData = data.map(line -> {
            String[] parts = line.split(",");
            double[] ds = Arrays.stream(parts[1].split(" "))
                  .mapToDouble(Double::parseDouble)
                  .toArray();
            return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(ds));
        }).cache();
     
        int numIterations = 100; //迭代次数
        LinearRegressionModel model = LinearRegressionWithSGD.train(parsedData.rdd(), numIterations);
        RidgeRegressionModel model1 = RidgeRegressionWithSGD.train(parsedData.rdd(), numIterations);
        LassoModel model2 = LassoWithSGD.train(parsedData.rdd(), numIterations);
     
        print(parsedData, model);
        print(parsedData, model1);
        print(parsedData, model2);
     
        //预测一条新数据方法
        double[] d = new double[]{1.0, 1.0, 2.0, 1.0, 3.0, -1.0, 1.0, -2.0};
        Vector v = Vectors.dense(d);
        System.out.println(model.predict(v));
        System.out.println(model1.predict(v));
        System.out.println(model2.predict(v));
    }
     
    public static void print(JavaRDD<LabeledPoint> parsedData, GeneralizedLinearModel model) {
        JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point -> {
            double prediction = model.predict(point.features()); //用模型预测训练数据
            return new Tuple2<>(point.label(), prediction);
        });
     
        Double MSE = valuesAndPreds.mapToDouble((Tuple2<Double, Double> t) -> Math.pow(t._1() - t._2(), 2)).mean(); //计算预测值与实际值差值的平方值的均值
        System.out.println(model.getClass().getName() + " training Mean Squared Error = " + MSE);
    }
    
     
     
    
    
    运行结果
    
     
    
    LinearRegressionModel training Mean Squared Error = 6.206807793307759
    RidgeRegressionModel training Mean Squared Error = 6.416002077543526
    LassoModel training Mean Squared Error = 6.972349839013683
    Prediction of linear: 0.805390219777772
    Prediction of ridge: 1.0907608111865237
    Prediction of lasso: 0.18652645118913225

    测试数据:

    -0.4307829,-1.63735562648104 -2.00621178480549 -1.86242597251066 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
    -0.1625189,-1.98898046126935 -0.722008756122123 -0.787896192088153 -1.02470580167082 -0.522940888712441 -0.863171185425945 -1.04215728919298 -0.864466507337306
    -0.1625189,-1.57881887548545 -2.1887840293994 1.36116336875686 -1.02470580167082 -0.522940888712441 -0.863171185425945 0.342627053981254 -0.155348103855541

    参考:
    http://blog.selfup.cn/747.html

  • 相关阅读:
    php中向mysql插入数据
    W3Cschool菜鸟教程离线版下载链接
    Call to undefined function mysqli_connect()
    Windows下MySQL 5.6安装及配置详细图解
    请不要再责怪你的程序员“太慢”
    工欲善其事必先利其器
    PHP正则表达式
    matlab画柱状图
    matlab 把数组中的NaN去除掉
    建模2017A题 角度lingo代码
  • 原文地址:https://www.cnblogs.com/rigid/p/5564455.html
Copyright © 2020-2023  润新知