• 每日一题 为了工作 2020 0507 第六十五题


    package data.bysj.tree;
    
    import org.apache.spark.Accumulator;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.api.java.function.VoidFunction;
    import org.apache.spark.mllib.linalg.Vectors;
    import org.apache.spark.mllib.regression.LabeledPoint;
    import org.apache.spark.mllib.tree.RandomForest;
    import org.apache.spark.mllib.tree.model.RandomForestModel;
    import org.apache.spark.sql.SparkSession;
    import scala.Tuple2;
    
    import java.util.HashMap;
    import java.util.Map;
    
    /**
     *
     * @author 雪瞳
     * @Slogan 时钟尚且前行,人怎能就此止步!
     * @Function 
     *
     */
    public class RandomForestTrees {
        public static void main(String[] args) {
            String name = "forest";
            String master = "local[3]";
            SparkSession session = SparkSession.builder().master(master).appName(name).getOrCreate();
            JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext());
            jsc.setLogLevel("Error");
            JavaRDD<String> input = jsc.textFile("./save/rootData");
            JavaRDD<LabeledPoint> metaData = input.map(new Function<String, LabeledPoint>() {
                @Override
                public LabeledPoint call(String line) throws Exception {
                    //"2015-11-01 20:20:16"	1.85999330468501	1.22359452534749	2.51578969727773	-0.403918740333512	0.0149184125297424		0
                    String[] splits = line.split("	");
                    String label = splits[splits.length - 1];
    
                    double[] wd = new double[splits.length - 3];
                    for (int i = 0; i < wd.length; i++) {
                        wd[i] = Double.parseDouble(splits[i + 1]);
                    }
    
                    LabeledPoint labeledPoint = new LabeledPoint(Double.parseDouble(label), Vectors.dense(wd));
                    return labeledPoint;
                }
            });
            
            double[] doubles = new double[]{0.7,0.3};
            //RDD<LabeledPoint> rdd = metaData.rdd();
            JavaRDD<LabeledPoint>[] metaDataSource = metaData.randomSplit(doubles, 10L);
            
            JavaRDD<LabeledPoint> traingData = metaDataSource[0];
            JavaRDD<LabeledPoint> testData = metaDataSource[1];
            
           
            int numClass = 2;
            
            Map<Integer,Integer> categoricalFeaturesInfo = new HashMap<>();
           
            int numTrees = 3;
            
            String featureSubsetStrategy = "auto";
           
            String impurity = "entropy";
            
            int maxDepth = 4;
          
            int maxBins = 32;
            
            int seed = 1;
            RandomForestModel model = RandomForest.trainClassifier(
                    traingData,
                    numClass,
                    categoricalFeaturesInfo,
                    numTrees,
                    featureSubsetStrategy,
                    impurity,
                    maxDepth,
                    maxBins,
                    seed
                    );
            JavaRDD<Double> predictRdd = testData.map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    double predict = model.predict(labeledPoint.features());
                    return predict;
                }
            });
            JavaPairRDD<Double, Double> resultRDD = predictRdd.zip(testData.map(new Function<LabeledPoint, Double>() {
                @Override
                public Double call(LabeledPoint labeledPoint) throws Exception {
                    return labeledPoint.label();
                }
            }));
            long count = resultRDD.count();
            Accumulator<Integer> accumulator = jsc.accumulator(0);
            resultRDD.foreach(new VoidFunction<Tuple2<Double, Double>>() {
                @Override
                public void call(Tuple2<Double, Double> tp) throws Exception {
                    Double label = tp._2();
                    Double predict = tp._1();
                    if (Double.compare(label,predict)==0){
                        accumulator.add(1);
                    }
                }
            });
            Integer value = accumulator.value();
            System.err.println("数目是:"+count);
            System.err.println("数目是:"+value);
            double rate = value / (double) count;
            System.err.println("正确率是:"+rate*100+"%");
            String path ="./save/model";
            double  stand = 80.00;
            if (Double.compare(rate,stand)<0){
    //            model.save(sc,path);
                System.out.println(model.toDebugString());
            }
    
        }
    }
    

      

     
  • 相关阅读:
    51nod 1463 找朋友 (扫描线+线段树)
    51nod 1295 XOR key (可持久化Trie树)
    51nod 1494 选举拉票 (线段树+扫描线)
    51Nod 1199 Money out of Thin Air (树链剖分+线段树)
    51Nod 1287 加农炮 (线段树)
    51Nod 1175 区间中第K大的数 (可持久化线段树+离散)
    Codeforces Round #426 (Div. 1) B The Bakery (线段树+dp)
    前端基础了解
    git 教程
    HIVE 默认分隔符 以及linux系统中特殊字符的输入和查看方式
  • 原文地址:https://www.cnblogs.com/walxt/p/12843902.html
Copyright © 2020-2023  润新知