• 每日一题 为了工作 2020 0429 第五十八题


    //Java版本的线性回归的预测代码

    package com.swust.machine;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.Function2;
    import org.apache.spark.mllib.linalg.Vector;
    import org.apache.spark.mllib.linalg.Vectors;
    import org.apache.spark.mllib.regression.LabeledPoint;
    import org.apache.spark.mllib.regression.LinearRegressionModel;
    import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
    import org.apache.spark.rdd.RDD;
    import scala.Tuple2;
    
    import java.util.List;
    
    
    /**
     *
     * @author 雪瞳
     * @Slogan 时钟尚且前行,人怎能再此止步!
     * @Function 线性回归算法实现
     *
     */
    public class LinearRegression {
        public static void main(String[] args) {
            SparkConf conf = new SparkConf();
            conf.setMaster("local").setAppName("line");
            JavaSparkContext jsc = new JavaSparkContext(conf);
            jsc.setLogLevel("Error");
    
            // 读取样本数据
            JavaRDD<String> data = jsc.textFile("./data/lpsa.data");
    
            JavaRDD<LabeledPoint> examples = data.map(new Function<String, LabeledPoint>() {
                @Override
                public LabeledPoint call(String line) throws Exception {
                    String[] splits = line.split(",");
                    String y = splits[0];
                    String xs = splits[1];
                    String[] words = xs.split(" ");
                    double[] wd = new double[words.length];
                    for (int i = 0; i < words.length; i++) {
                        wd[i] = Double.parseDouble(words[i]);
                    }
                    return new LabeledPoint(Double.parseDouble(y),
                            Vectors.dense(wd));
                }
            });
            //将数据集按比例切分为训练集和测试集
            double[] doubles = new double[]{0.8,0.2};
            RDD<LabeledPoint> rdd = examples.rdd();
            RDD<LabeledPoint>[] TestData = rdd.randomSplit(doubles, 1L);
    
            //设置迭代次数
            int numIterations = 100;
            //设置迭代过程中 梯度下降算法的下降步长大小
            // 0.1 0.2 0.3 0.4
            int stepSize = 1;
            int miniBatchFraction = 1;
            LinearRegressionWithSGD lrs = new LinearRegressionWithSGD();
            //设置训练模型是否存在截距
            lrs.setIntercept(true);
            //设置步长
            lrs.optimizer().setStepSize(stepSize);
            //设置迭代次数
            lrs.optimizer().setNumIterations(numIterations);
            //计算所有样本的误差值,1代表所有样本,默认1.0
            lrs.optimizer().setMiniBatchFraction(miniBatchFraction);
            //GeneralizedLinearAlgorithm
            LinearRegressionModel model = lrs.run(TestData[0]);
            System.err.println(model.weights());
            System.err.println(model.intercept());
    
            //对样本的测试
            JavaRDD<Double> prediction = model.predict(TestData[1].toJavaRDD().map(new Function<LabeledPoint, Vector>() {
                @Override
                public Vector call(LabeledPoint labeledPoint) throws Exception {
                    return labeledPoint.features();
                }
            }));
            //压缩样本
            JavaPairRDD<Double, Double> predictionAndLabel = prediction.zip(TestData[1].toJavaRDD().map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    return labeledPoint.label();
                }
            }));
            //数据分析 取其中20条
            List<Tuple2<Double, Double>> take = predictionAndLabel.take(20);
            //预测 标签
            System.err.println("prediction"+"	"+"label");
            for (Tuple2<Double, Double> elem:take){
                System.out.println(elem._1()+"	"+elem._2());
            }
            //计算数据的平均误差
            JavaRDD<Double> dataLoss = predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
                @Override
                public Double call(Tuple2<Double, Double> one) throws Exception {
                    double err = one._1() - one._2();
                    return Math.abs(err);
                }
            });
            Double lossResult = dataLoss.reduce(new Function2<Double, Double, Double>() {
                @Override
                public Double call(Double aDouble, Double aDouble2) throws Exception {
                    return aDouble + aDouble2;
                }
            });
            double err = lossResult / TestData[1].count();
            System.err.println("Test RMSE"+err);
            jsc.stop();
    
    
        }
    }
    

      

    //由于数据量本身只有100条 所以预测效果相对较差 但是重要的是思路不是嘛

    // 有道无术术可求 有术无道止于术 学会一个思想更为重要

  • 相关阅读:
    洛谷P2886 [USACO07NOV]Cow Relays G
    CF1344F Piet's Palette
    CF1344E Train Tracks
    CF1342F Make It Ascending
    LOJ6049 拍苍蝇
    test20200430 最长路径
    LG1742 最小圆覆盖 和 LOJ6360 复燃「恋之埋火」
    LOJ6358 前夕
    LOJ6485 LJJ学二项式定理
    LOJ2882 两个人的星座
  • 原文地址:https://www.cnblogs.com/walxt/p/12803045.html
Copyright © 2020-2023  润新知