• 关于spark的mllib学习总结(Java版)


    本篇博客主要讲述如何利用spark的mliib构建机器学习模型并预测新的数据,具体的流程如下图所示:

     

    加载数据 对于数据的加载或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采用spark中提供的数据sample_libsvm_data.txt,其有一百个数据样本,658个特征。具体的数据形式如图所示: 

    加载libsvm 

    JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();

    LabeledPoint数据类型是对应与libsvmfile格式文件, 具体格式为: Lable(double类型),vector(Vector类型) 转化dataFrame数据类型 

    JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());
    StructType schema = new StructType(new StructField[]{
                        new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                        new StructField("features", new VectorUDT(), false, Metadata.empty()),
            });
    SQLContext jsql = new SQLContext(sc);
    DataFrame df = jsql.createDataFrame(jrow, schema);

    DataFrame:DataFrame是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者1个Python(或者R)中的data frame一样,但是比他们更优化。DataFrame可以根据结构化的数据文件、hive表、外部数据库或者已经存在的RDD构造。 SQLContext:spark sql所有功能的入口是SQLContext类,或者SQLContext的子类。为了创建一个基本的SQLContext,需要一个SparkContext。 特征提取 特征归一化处理 

    StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);
    DataFrame scalerDF = scaler.fit(df).transform(df);
    scaler.save(this.scalerModelPath);

    利用卡方统计做特征提取 

    ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");
    ChiSqSelectorModel chiModel = selector.fit(scalerDF);
    DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
    chiModel.save(this.featureSelectedModelPath);

    训练机器学习模型(以SVM为例)

    //转化为LabeledPoint数据类型, 训练模型
    JavaRDD<Row> selectedrows = selectedDF.javaRDD();
    JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());
    
    //训练SVM模型, 并保存
    int numIteration = 200;
    SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);
    model.clearThreshold();
    model.save(sc, this.mlModelPath);
    
    // LabeledPoint数据类型转化为Row
    static class LabeledPointToRow implements Function<LabeledPoint, Row> {
    
            public Row call(LabeledPoint p) throws Exception {
                double label = p.label();
                Vector vector = p.features();
                return RowFactory.create(label, vector);
            }
        }
    
    //Rows数据类型转化为LabeledPoint
    static class RowToLabel implements Function<Row, LabeledPoint> {
    
            public LabeledPoint call(Row r) throws Exception {
                Vector features = r.getAs(1);
                double label = r.getDouble(0);
                return new LabeledPoint(label, features);
            }
        }

    测试新的样本 测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:

    //初始化spark
    SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");
    conf.set("spark.testing.memory", "2147480000");
    SparkContext sc = new SparkContext(conf);
    
    //加载测试数据
    JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();
    
    //转化DataFrame数据类型
    JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());
            StructType schema = new StructType(new StructField[]{
                        new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                        new StructField("features", new VectorUDT(), false, Metadata.empty()),
            });
    SQLContext jsql = new SQLContext(sc);
    DataFrame df = jsql.createDataFrame(jrow, schema);
    
            //数据规范化
    StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
    DataFrame scalerDF = scaler.fit(df).transform(df);
    
            //特征选取
    ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);
    DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");

    测试数据集

    SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
    JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;
    predictResult.collect();
    
    static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {
            SVMModel model;
            public Prediction(SVMModel model){
                this.model = model;
            }
            public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {
                Double score = model.predict(p.features());
                return new Tuple2<Double , Double>(score, p.label());
            }
        }

    计算准确率

    double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();
    System.out.println(accuracy);
    
    static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {
            public Boolean call(Tuple2<Double, Double> t) throws Exception {
                double score = t._1();
                double label = t._2();
                System.out.print("score:" + score + ", label:"+ label);
                if(score >= 0.0 && label >= 0.0) return true;
                else if(score < 0.0 && label < 0.0) return true;
                else return false;
            }
        }
  • 相关阅读:
    node-webkit 不支持html5_video播放mp4的解决方法
    node-webkit(Windows系统) 打包成exe文件后,被360杀毒软件误报木马的解决方法
    剑指 Offer 36. 二叉搜索树与双向链表
    剑指 Offer 33. 二叉搜索树的后序遍历序列
    Leetcode96. 不同的二叉搜索树
    Leetcode95. 不同的二叉搜索树 II
    leetcode1493. 删掉一个元素以后全为 1 的最长子数组
    Leetcode1052. 爱生气的书店老板
    Leetcode92. 反转链表 II
    Leetcode495. 提莫攻击
  • 原文地址:https://www.cnblogs.com/itboys/p/9692594.html
Copyright © 2020-2023  润新知