• Spark学习笔记——手写数字识别


    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.regression.RandomForestRegressor
    import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.optimization.L1Updater
    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
    import org.apache.spark.mllib.tree.configuration.Algo
    import org.apache.spark.mllib.tree.impurity.Entropy
    
    /**
      * Created by common on 17-5-17.
      */
    
    case class LabeledPic(
                           label: Int,
                           pic: List[Double] = List()
                         )
    
    object DigitRecognizer {
    
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
        val sc = new SparkContext(conf)
        // 去掉第一行,sed 1d train.csv > train_noheader.csv
        val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
        val trainRawData = sc.textFile(trainFile)
        // 通过逗号对数据进行分割,生成数组的rdd
        val trainRecords = trainRawData.map(line => line.split(","))
    
        val trainData = trainRecords.map { r =>
          val label = r(0).toInt
          val features = r.slice(1, r.size).map(d => d.toDouble)
          LabeledPoint(label, Vectors.dense(features))
        }
    
    
        //    // 使用贝叶斯模型
        //    val nbModel = NaiveBayes.train(trainData)
        //
        //    val nbTotalCorrect = trainData.map { point =>
        //      if (nbModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val nbAccuracy = nbTotalCorrect / trainData.count
        //
        //    println("贝叶斯模型正确率:" + nbAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = nbModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")
    
    
        //    // 使用线性回归模型
        //    val lrModel = new LogisticRegressionWithLBFGS()
        //      .setNumClasses(10)
        //      .run(trainData)
        //
        //    val lrTotalCorrect = trainData.map { point =>
        //      if (lrModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val lrAccuracy = lrTotalCorrect / trainData.count
        //
        //    println("线性回归模型正确率:" + lrAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = lrModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1")
    
    
        //    // 使用决策树模型
        //    val maxTreeDepth = 10
        //    val numClass = 10
        //    val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
        //
        //    val dtTotalCorrect = trainData.map { point =>
        //      if (dtModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val dtAccuracy = dtTotalCorrect / trainData.count
        //
        //    println("决策树模型正确率:" + dtAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = dtModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2")
    
    
    //    // 使用随机森林模型
    //    val numClasses = 30
    //    val categoricalFeaturesInfo = Map[Int, Int]()
    //    val numTrees = 50
    //    val featureSubsetStrategy = "auto"
    //    val impurity = "gini"
    //    val maxDepth = 10
    //    val maxBins = 32
    //    val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
    //
    //    val rtTotalCorrect = trainData.map { point =>
    //      if (rtModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val rtAccuracy = rtTotalCorrect / trainData.count
    //
    //    println("随机森林模型正确率:" + rtAccuracy)
    //
    //    // 对测试数据进行预测
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通过逗号对数据进行分割,生成数组的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = rtModel.predict(testData).map(p => p.toInt)
    //    // 保存预测结果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")
    
    
      }
    
    }
    
  • 相关阅读:
    LIST组件使用总结
    openlayers之interaction(地图交互功能)
    vbind:class绑定样式,决定样式的显示与否
    cesium之measure功能实现
    Cesium渲染效果差,锯齿明显,解决办法
    CSS让DIV层叠 两个DIV或多个DIV顺序重叠加
    ES6之import/export命令
    vantui:
    Openlayers简单要素的添加
    Vue中的this表示?
  • 原文地址:https://www.cnblogs.com/tonglin0325/p/6906524.html
Copyright © 2020-2023  润新知