• ML(4.2): R CART


         CART模型 :即Classification And Regression Trees。它和一般回归分析类似,是用来对变量进行解释和预测的工具,也是数据挖掘中的一种常用算法。如果因变量是连续数据,相对应的分析称为回归树,如果因变量是分类数据,则相应的分析称为分类树。决策树是一种倒立的树结构,它由内部节点、叶子节点和边组成。其中最上面的一个节点叫根节点。 构造一棵决策树需要一个训练集,一些例子组成,每个例子用一些属性(或特征)和一个类别标记来描述。构造决策树的目的是找出属性和类别间的关系,一旦这种关系找出,就能用它来预测将来未知类别的记录的类别。这种具有预测功能的系统叫决策树分类器。

        CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,因此CART算法生成的决策树是结构简洁的二叉树。由于CART算法构成的是一个二叉树,它在每一步的决策时只能 是“是”或者“否”,即使一个feature有多个取值,也是把数据分为两部分。在CART算法中主要分为两个步骤

    • 将样本递归划分进行建树过程
    • 用验证数据进行剪枝

      在R包中,有如下的算法包可完成CART 分类计算,如下,分别以鸢尾花数据集为例进行验证

    •  rpart::rpart
    •  tree::tree

    rpart::rpart


    • rpart包中有针对CART决策树算法提供的函数,比如rpart函数,以及用于剪枝的prune函数
    • rpart函数的基本形式:rpart(formula,data,subset,na.action=na.rpart,method.parms,control,...)
    • 安装所需R包
      install.packages("mboost")
      install.packages("rpart")
      install.packages("maptree")
    • 数据集划分训练集和测试,比例是2:1
      set.seed(1234)
      index <-sample(1:nrow(iris),100)
      iris.train <-iris[index,]
      iris.test <-iris[-index,]
    •  构建CART模型,查看模型结构,在结构中能看到很多有意思的内容

      library(rpart)
      model.CART <-rpart(Species~.,data=iris.train)
      str(model.CART)
    •  

    •  control:对树进行一些设置 

      1. minsplit是最小分支节点数,这里指大于等于20,那么该节点会继续分划下去,否则停止
      2. minbucket:树中叶节点包含的最小样本数 
      3. maxdepth:决策树最大深度
      4. xval:交叉验证的次数
      5. cp (complexity parameter),指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度。(即是每次分割对应的复杂度系数)
    • variable.importance:变量的重要性
      > model.CART$variable.importance
       Petal.Width Petal.Length Sepal.Length  Sepal.Width 
          60.58917     56.38914     39.79006     26.00328 
    • 预测数据: vector: 预测数值   class: 预测类别  prob: 预测类别的概率
      > p.rpart <- predict(model.CART, iris.test,type="class") 
      > table(p.rpart,iris.test$Species)
             
      p.rpart setosa versicolor virginica
            1     12          0         0
            2      0         21         3
            3      0          0        14
    •  可视化,需要rpart.plot包
    #可视化决策树
    #install.packages("rpart.plot")
    library(rpart.plot)
    rpart.plot(model.CART) 
    • 效果如下图:
    • CART剪枝:
      1. prune函数可以实现最小代价复杂度剪枝法,对于CART的结果,每个节点均输出一个对应的cp
      2. prune函数通过设置cp参数来对决策树进行修剪,cp为复杂度系数
      3. 可以用下面的办法选择具有最小xerror的cp的办法:
        model.CART.pru<-prune(model.CART, cp= model.CART$cptable[which.min(model.CART$cptable[,"xerror"]),"CP"]) 
        model.CART.pru$cp
    • CART剪枝后的模型进行预测 

      p.rpart1<-predict(model.CART.pru,iris.test,type="class")
      table(p.rpart1,iris.test$Species)
    •  

    tree::tree


    • 数据集划分训练集和测试见上节
    • 构建模型,查看生成模型结构,如下图,错误率为:0.02667
      > #install.packages("tree")
      > library(tree)  
      > ir.tr <- tree(Species~., iris)  
      > summary(ir.tr)
      
      Classification tree:
      tree(formula = Species ~ ., data = iris)
      Variables actually used in tree construction:
      [1] "Petal.Length" "Petal.Width"  "Sepal.Length"
      Number of terminal nodes:  6 
      Residual mean deviance:  0.1253 = 18.05 / 144 
      Misclassification error rate: 0.02667 = 4 / 150
    • 查看生成决策树及图例
      plot(ir.tr)
      text(ir.tr,pretty = 0) 
    • 结果验证
      > tree_predict <- predict(ir.tr,iris.test,type="class")
      > table(tree_predict,iris.test$Species)
                  
      tree_predict setosa versicolor virginica
        setosa         12          0         0
        versicolor      0         20         1
        virginica       0          1        16
    • 用误分类率来剪枝,做交叉验证,代码如下:
      > cv.carseats=cv.tree(ir.tr, FUN=prune.misclass)
      > str(cv.carseats)
      List of 4
       $ size  : int [1:5] 6 4 3 2 1
       $ dev   : num [1:5] 11 11 10 96 121
       $ k     : num [1:5] -Inf 0 2 44 50
       $ method: chr "misclass"
       - attr(*, "class")= chr [1:2] "prune" "tree.sequence"
    •  可视化模型

      par(mfrow=c(1, 2))
      plot(cv.carseats$size, cv.carseats$dev, type="b")
      plot(cv.carseats$k, cv.carseats$dev, type="b")
    •  图表示例

    • 随着树的节点越来越多(树越来越复杂),deviance逐渐减小,然后又开始增大
    • 随着对模型复杂程度的惩罚越来越重(k越来越大),deviance逐渐减小,然后又开始增大 (此图暂看不起来)
    • 从左边的图可以看出,当树的节点个数为 3 时,deviance达到最小,画出3个叶子节点的树
      #画出3个叶子节点的树
      par(new = TRUE) 
      prune.carseats <- prune.misclass(ir.tr, best=3)
      plot(prune.carseats)
      text(prune.carseats, pretty=0)
    • 图示例
    • 测试及结果
      > tree.pred  <- predict(prune.carseats, iris.test, type="class")
      > summary(tree.pred)
          setosa versicolor  virginica 
              12         24         14 
      > table(tree.pred,iris.test$Species)
                  
      tree.pred    setosa versicolor virginica
        setosa         12          0         0
        versicolor      0         21         3
        virginica       0          0        14
    •  

  • 相关阅读:
    create_project.py报错问题,建议用回python2.7
    windows下执行build_native.sh报权限问题
    编辑器CocoStudio和CocosBuilder的对比
    双击判断
    Web文件的ContentType类型大全
    Java四类八种数据类型
    自己写的通过ADO操作mysql数据库
    使用Cout输出String和CString对象
    CString和string头文件
    C++连接mysql数据库的两种方法
  • 原文地址:https://www.cnblogs.com/tgzhu/p/6697564.html
Copyright © 2020-2023  润新知