• Spark机器学习(6):决策树算法


    1. 决策树基本知识

    决策树就是通过一系列规则对数据进行分类的一种算法,可以分为分类树和回归树两类,分类树处理离散变量的,回归树是处理连续变量。

    样本一般都有很多个特征,有的特征对分类起很大的作用,有的特征对分类作用很小,甚至没有作用。如决定是否对一个人贷款是,这个人的信用记录、收入等就是主要的判断依据,而性别、婚姻状况等等就是次要的判断依据。决策树构建的过程,就是根据特征的决定性程度,先使用决定性程度高的特征分类,再使用决定性程度低的特征分类,这样构建出一棵倒立的树,就是我们需要的决策树模型,可以用来对数据进行分类。

    决策树学习的过程可以分为三个步骤:1)特征选择,即从众多特征中选择出一个作为当前节点的分类标准;2)决策树生成,从上到下构建节点;3)剪枝,为了预防和消除过拟合,需要对决策树剪枝。

    2. 决策树算法

    主要的决策树算法包括ID3、C4.5和CART。

    ID3把信息增益作为选择特征的标准。由于取值较多的特征(如学号)的信息增益比较大,这种算法会偏向于取值较多的特征。而且该算法只能用于离散型的数据,优点是不需要剪枝。

    C4.5和ID3比较类似,区别在于使用信息增益比替代信息增益作为选择特征的标准,因此比ID3更加科学,并且可以用于连续型的数据,但是需要剪枝。

    CART(Classification And Regression Tree)采用的是Gini作为选择的标准。Gini越大,说明不纯度越大,这个特征就越不好。

    3. MLlib的决策树算法

    MLlib的决策树算法使用的随机森林RandomForest的方法,不过并不是真正的随机森林,因为实际上只有一棵决策树。

    直接上代码:

    import org.apache.log4j.{ Level, Logger }
    import org.apache.spark.{ SparkConf, SparkContext }
    import org.apache.spark.mllib.tree.DecisionTree
    import org.apache.spark.mllib.tree.model.DecisionTreeModel
    import org.apache.spark.mllib.util.MLUtils
    
    /**
      * Created by Administrator on 2017/7/6.
      */
    object DecisionTreeTest {
    
      def main(args: Array[String]): Unit = {
    
        // 设置运行环境
        val conf = new SparkConf().setAppName("Decision Tree")
          .setMaster("spark://master:7077").setJars(Seq("E:\Intellij\Projects\MachineLearning\MachineLearning.jar"))
        val sc = new SparkContext(conf)
        Logger.getRootLogger.setLevel(Level.WARN)
    
        // 读取样本数据并解析
        val dataRDD = MLUtils.loadLibSVMFile(sc, "hdfs://master:9000/ml/data/sample_dt_data.txt")
        // 样本数据划分,训练样本占0.8,测试样本占0.2
        val dataParts = dataRDD.randomSplit(Array(0.8, 0.2))
        val trainRDD = dataParts(0)
        val testRDD = dataParts(1)
    
        // 决策树参数
        val numClasses = 5
        val categoricalFeaturesInfo = Map[Int, Int]()
        val impurity = "gini"
        val maxDepth = 5
        val maxBins = 32
        // 建立决策树模型并训练
        val model = DecisionTree.trainClassifier(trainRDD, numClasses, categoricalFeaturesInfo,
          impurity, maxDepth, maxBins)
    
        // 对测试样本进行测试
        val predictionAndLabel = testRDD.map { point =>
          val score = model.predict(point.features)
          (score, point.label, point.features)
        }
        val showPredict = predictionAndLabel.take(50)
        println("Prediction" + "	" + "Label" + "	" + "Data")
        for (i <- 0 to showPredict.length - 1) {
          println(showPredict(i)._1 + "	" + showPredict(i)._2 + "	" + showPredict(i)._3)
        }
    
        // 误差计算
        val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / testRDD.count()
        println("Accuracy = " + accuracy)
    
        // 保存模型
        val ModelPath = "hdfs://master:9000/ml/model/Decision_Tree_Model"
        model.save(sc, ModelPath)
        val sameModel = DecisionTreeModel.load(sc, ModelPath)
    
      }

    运行结果:

  • 相关阅读:
    C#基础视频教程5.3 如何编写简单的超级热键
    spring boot中注入jpa时报could not autowire.No beans of 'PersonRepository' type found
    SpringBoot中常用注解@Controller/@RestController/@RequestMapping的区别
    idea如何搭建springboot框架
    Fiddler建好代理后,能连到手机,但手机不能上网了是什么原因
    如何用Fiddler对Android应用进行抓包
    【fiddler】抓取https数据失败,全部显示“Tunnel to......443”
    将excel的数据导入到数据库后都乱码了是怎么回事
    java保存繁体字到数据库时就报错Incorrect string value: 'xF0xA6x8Dx8BxE5xA4...' for column 'name' at row 1
    将爬取的网页数据保存到数据库时报错不能提交JPA,Caused by: java.sql.SQLException: Incorrect string value: 'xF0x9Fx98xB6 xE2...' for column 'content' at row 1
  • 原文地址:https://www.cnblogs.com/mstk/p/7128540.html
Copyright © 2020-2023  润新知