CART模型 :即Classification And Regression Trees。它和一般回归分析类似,是用来对变量进行解释和预测的工具,也是数据挖掘中的一种常用算法。如果因变量是连续数据,相对应的分析称为回归树,如果因变量是分类数据,则相应的分析称为分类树。决策树是一种倒立的树结构,它由内部节点、叶子节点和边组成。其中最上面的一个节点叫根节点。 构造一棵决策树需要一个训练集,一些例子组成,每个例子用一些属性(或特征)和一个类别标记来描述。构造决策树的目的是找出属性和类别间的关系,一旦这种关系找出,就能用它来预测将来未知类别的记录的类别。这种具有预测功能的系统叫决策树分类器。
CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每个非叶子结点都有两个分支,因此CART算法生成的决策树是结构简洁的二叉树。由于CART算法构成的是一个二叉树,它在每一步的决策时只能 是“是”或者“否”,即使一个feature有多个取值,也是把数据分为两部分。在CART算法中主要分为两个步骤
- 将样本递归划分进行建树过程
- 用验证数据进行剪枝
在R包中,有如下的算法包可完成CART 分类计算,如下,分别以鸢尾花数据集为例进行验证
- rpart::rpart
- tree::tree
rpart::rpart
- rpart包中有针对CART决策树算法提供的函数,比如rpart函数,以及用于剪枝的prune函数
- rpart函数的基本形式:rpart(formula,data,subset,na.action=na.rpart,method.parms,control,...)
- 安装所需R包
install.packages("mboost") install.packages("rpart") install.packages("maptree")
- 数据集划分训练集和测试,比例是2:1
set.seed(1234) index <-sample(1:nrow(iris),100) iris.train <-iris[index,] iris.test <-iris[-index,]
-
构建CART模型,查看模型结构,在结构中能看到很多有意思的内容
library(rpart) model.CART <-rpart(Species~.,data=iris.train) str(model.CART)
-
-
control:对树进行一些设置
- minsplit是最小分支节点数,这里指大于等于20,那么该节点会继续分划下去,否则停止
- minbucket:树中叶节点包含的最小样本数
- maxdepth:决策树最大深度
- xval:交叉验证的次数
- cp (complexity parameter),指某个点的复杂度,对每一步拆分,模型的拟合优度必须提高的程度。(即是每次分割对应的复杂度系数)
- variable.importance:变量的重要性
> model.CART$variable.importance Petal.Width Petal.Length Sepal.Length Sepal.Width 60.58917 56.38914 39.79006 26.00328
- 预测数据: vector: 预测数值 class: 预测类别 prob: 预测类别的概率
> p.rpart <- predict(model.CART, iris.test,type="class") > table(p.rpart,iris.test$Species) p.rpart setosa versicolor virginica 1 12 0 0 2 0 21 3 3 0 0 14
- 可视化,需要rpart.plot包
#可视化决策树
#install.packages("rpart.plot")
library(rpart.plot)
rpart.plot(model.CART)
- 效果如下图:
- CART剪枝:
- prune函数可以实现最小代价复杂度剪枝法,对于CART的结果,每个节点均输出一个对应的cp
- prune函数通过设置cp参数来对决策树进行修剪,cp为复杂度系数
- 可以用下面的办法选择具有最小xerror的cp的办法:
model.CART.pru<-prune(model.CART, cp= model.CART$cptable[which.min(model.CART$cptable[,"xerror"]),"CP"]) model.CART.pru$cp
-
CART剪枝后的模型进行预测
p.rpart1<-predict(model.CART.pru,iris.test,type="class") table(p.rpart1,iris.test$Species)
-
tree::tree
- 数据集划分训练集和测试见上节
- 构建模型,查看生成模型结构,如下图,错误率为:0.02667
> #install.packages("tree") > library(tree) > ir.tr <- tree(Species~., iris) > summary(ir.tr) Classification tree: tree(formula = Species ~ ., data = iris) Variables actually used in tree construction: [1] "Petal.Length" "Petal.Width" "Sepal.Length" Number of terminal nodes: 6 Residual mean deviance: 0.1253 = 18.05 / 144 Misclassification error rate: 0.02667 = 4 / 150
- 查看生成决策树及图例
plot(ir.tr) text(ir.tr,pretty = 0)
- 结果验证
> tree_predict <- predict(ir.tr,iris.test,type="class") > table(tree_predict,iris.test$Species) tree_predict setosa versicolor virginica setosa 12 0 0 versicolor 0 20 1 virginica 0 1 16
- 用误分类率来剪枝,做交叉验证,代码如下:
> cv.carseats=cv.tree(ir.tr, FUN=prune.misclass) > str(cv.carseats) List of 4 $ size : int [1:5] 6 4 3 2 1 $ dev : num [1:5] 11 11 10 96 121 $ k : num [1:5] -Inf 0 2 44 50 $ method: chr "misclass" - attr(*, "class")= chr [1:2] "prune" "tree.sequence"
-
可视化模型
par(mfrow=c(1, 2)) plot(cv.carseats$size, cv.carseats$dev, type="b") plot(cv.carseats$k, cv.carseats$dev, type="b")
-
图表示例
- 随着树的节点越来越多(树越来越复杂),deviance逐渐减小,然后又开始增大
- 随着对模型复杂程度的惩罚越来越重(k越来越大),deviance逐渐减小,然后又开始增大 (此图暂看不起来)
- 从左边的图可以看出,当树的节点个数为 3 时,deviance达到最小,画出3个叶子节点的树
#画出3个叶子节点的树 par(new = TRUE) prune.carseats <- prune.misclass(ir.tr, best=3) plot(prune.carseats) text(prune.carseats, pretty=0)
- 图示例
- 测试及结果
> tree.pred <- predict(prune.carseats, iris.test, type="class") > summary(tree.pred) setosa versicolor virginica 12 24 14 > table(tree.pred,iris.test$Species) tree.pred setosa versicolor virginica setosa 12 0 0 versicolor 0 21 3 virginica 0 0 14
-