• 每日一题 为了工作 2020 0504 第六十二题


    package data.bjsj.fjjb;
    
    
    import org.apache.spark.Accumulator;
    import org.apache.spark.SparkContext;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.VoidFunction;
    import org.apache.spark.mllib.classification.LogisticRegressionModel;
    import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
    import org.apache.spark.mllib.linalg.Vectors;
    import org.apache.spark.mllib.regression.LabeledPoint;
    
    import org.apache.spark.rdd.RDD;
    import org.apache.spark.sql.SparkSession;
    import scala.Tuple2;
    
    
    /**
     *
     * @author 雪瞳
     * @Slogan 时钟尚且前行,人怎能就此止步!
     * @Function 
     *
     */
    public class LogisticModel {
        public static void main(String[] args) {
            
            SparkSession session = SparkSession.builder().appName("logistic").master("local").getOrCreate();
            JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
            SparkContext sc = JavaSparkContext.toSparkContext(jsc);
    
            jsc.setLogLevel("Error");
            JavaRDD<String> fileRDD = jsc.textFile("./save/rootData");
            JavaRDD<LabeledPoint> labeledPointJavaRDD = fileRDD.map(new Function<String, LabeledPoint>() {
                //"2015-11-01 20:20:16"	1.85999330468501	1.22359452534749	2.51578969727773	-0.403918740333512	0.0149184125297424		0
                @Override
                public LabeledPoint call(String line) throws Exception {
                    String[] splits = line.split("	");
                    String label = splits[splits.length - 1];
    
                    double[] wd = new double[splits.length - 3];
                    for (int i = 0; i < wd.length; i++) {
                        wd[i] = Double.parseDouble(splits[i+1]);
                    }
                    LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
                    return labeledPoint;
                }
            });
    
            
            double[] doubles = new double[]{0.7,0.3};
            RDD<LabeledPoint> rdd = labeledPointJavaRDD.rdd();
            RDD<LabeledPoint>[] metaDataSource = rdd.randomSplit(doubles, 100L);
            
            RDD<LabeledPoint> traingData = metaDataSource[0];
            RDD<LabeledPoint> testData = metaDataSource[1];
            
            LogisticRegressionWithLBFGS lr = new LogisticRegressionWithLBFGS();
            lr.setNumClasses(2);
            lr.setIntercept(true);
            LogisticRegressionModel model = lr.run(traingData);
            JavaRDD<Double> predictRdd = testData.toJavaRDD().map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    double predict = model.predict(labeledPoint.features());
                    return predict;
                }
            });
            JavaPairRDD<Double, Double> zipRdd = predictRdd.zip(testData.toJavaRDD().map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    return labeledPoint.label();
                }
            }));
    
            Accumulator<Integer> accumulator = jsc.accumulator(0);
            zipRdd.foreach(new VoidFunction<Tuple2<Double, Double>>() {
                @Override
                public void call(Tuple2<Double, Double> tp) throws Exception {
                    Double label = tp._2();
                    Double predict = tp._1();
                    if (Double.compare(label,predict)==0){
                        accumulator.add(1);
                    }
                }
            });
            long count = zipRdd.count();
            Integer value = accumulator.value();
            System.err.println("总数目是:"+count);
            System.err.println("正确数目是:"+value);
            double rate = value / (double) count;
            System.err.println("正确率是:"+rate*100+"%");
            String path ="./save/model";
            double  stand = 80.00;
            if (Double.compare(rate,stand)<0){
                model.save(sc,path);
            }
        }
    }
    

      

  • 相关阅读:
    可序列化serializable的作用是什么
    HelloWorld编译正常运行报noclassdeffounderror
    dtd对xml没有起到约束作用
    Ajax发送XML请求案例
    Ajax发送GET和POST请求案例
    Ajax发送简单请求案例
    初识Ajax
    数据库设计
    数据库和实例的区别
    Flask
  • 原文地址:https://www.cnblogs.com/walxt/p/12825682.html
Copyright © 2020-2023  润新知