决策树对实例进行分类的树形结构,由节点和有向边组成。其实很像平时画的流程图。
学习决策树之前要搞懂几个概念:
熵:表示随机变量不确定性的度量,定义:H(p)=-
信息增益:集合D的经验熵与特征A条件下D的经验条件熵H(D/A)之差(公式省略,自行查找)
信息增益比:信息增益g(D,A)与训练数据集D关于特征A的值得熵HA(D)之比(公式省略)
基尼系数:(公式省略)
以上几个公式要牢记并学会推到。
具体计算过程:
ID3算法:寻找信息增益最大的特征
C4.5 寻找信息增益比最大的特征
另外有CART树算法,使用基尼系数来确定特征选择部分。
树的剪枝分为前剪枝和后剪枝。目的为防止过拟合。
前剪枝即为在树的构建过程中,若新增加的分支未使得准确率增加,则不进行该分支操作。
后剪枝即构建完决策树之后,从最底部的分支开始,若去掉该分支,分类准确性增加,则去掉该分支,否则保留。
def chooseBestFeatureToSplit(dataSet):#关键部分代码为寻找最优特征,寻找使得信息增益最大的特征 numFeatures = len(dataSet[0]) - 1 # the last column is used for the labels baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): # iterate over all the features featList = [example[i] for example in dataSet] # create a list of all the examples of this feature uniqueVals = set(featList) # get a set of unique values newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy # calculate the info gain; ie reduction in entropy if (infoGain > bestInfoGain): # compare this to the best gain so far bestInfoGain = infoGain # if better than current best, set to best bestFeature = i return bestFeature # returns an integer
def createTree(dataSet, labels): #构建决策树,寻找最优特征,做为根节点,之后根据特征分为几类,进行递归,最终形成完整决策树 classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] # stop splitting when all of the classes are equal if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel: {}} del (labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] # copy all of labels, so trees don't mess up existing labels myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree
对应SKlearn中的API接口:
DecisionTreeClassifier 分类,注意此方法只有前剪枝选项。