• Machine Learning in Action(2) 决策树算法


            决策树也是有监督机器学习方法。 电影《无耻混蛋》里有一幕游戏,在德军小酒馆里有几个人在玩20问题游戏,游戏规则是一个设迷者在纸牌中抽出一个目标(可以是人,也可以是物),而猜谜者可以提问题,设迷者只能回答是或者不是,在几个问题(最多二十个问题)之后,猜谜者通过逐步缩小范围就准确的找到了答案。这就类似于决策树的工作原理。(图一)是一个判断邮件类别的工作方式,可以看出判别方法很简单,基本都是阈值判断,关键是如何构建决策树,也就是如何训练一个决策树。

     

    (图一)

    构建决策树的伪代码如下:

    Check if every item in the dataset is in the same class:

           If so return the class label

           Else

                 find the best feature to split the data

                 split the dataset

                 create a branch node

                 for each split

                        call create Branch and add the result to the branch node

                return branch node

             原则只有一个,尽量使得每个节点的样本标签尽可能少,注意上面伪代码中一句说:find the best feature to split the data,那么如何find thebest feature?一般有个准则就是尽量使得分支之后节点的类别纯一些,也就是分的准确一些。如(图二)中所示,从海洋中捞取的5个动物,我们要判断他们是否是鱼,先用哪个特征?

     

    (图二)

             为了提高识别精度,我们是先用“能否在陆地存活”还是“是否有蹼”来判断?我们必须要有一个衡量准则,常用的有信息论、基尼纯度等,这里使用前者。我们的目标就是选择使得分割后数据集的标签信息增益最大的那个特征,信息增益就是原始数据集标签基熵减去分割后的数据集标签熵,换句话说,信息增益大就是熵变小,使得数据集更有序。熵的计算如(公式一)所示:

     

    (公式一)

           有了指导原则,那就进入代码实战阶段,先来看看熵的计算代码:

     1 def calcShannonEnt(dataSet):
     2     numEntries = len(dataSet)
     3     labelCounts = {}
     4     for featVec in dataSet: #the the number of unique elements and their occurance
     5         currentLabel = featVec[-1]
     6         if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
     7         labelCounts[currentLabel] += 1  #收集所有类别的数目,创建字典
     8     shannonEnt = 0.0
     9     for key in labelCounts:
    10         prob = float(labelCounts[key])/numEntries
    11         shannonEnt -= prob * log(prob,2) #log base 2  计算熵
    12     return shannonEnt


     

             有了熵的计算代码,接下来看依照信息增益变大的原则选择特征的代码:

     1 def splitDataSet(dataSet, axis, value):
     2     retDataSet = []
     3     for featVec in dataSet:
     4         if featVec[axis] == value:
     5             reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
     6             reducedFeatVec.extend(featVec[axis+1:])
     7             retDataSet.append(reducedFeatVec)
     8     return retDataSet
     9     
    10 def chooseBestFeatureToSplit(dataSet):
    11     numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    12     baseEntropy = calcShannonEnt(dataSet)
    13     bestInfoGain = 0.0; bestFeature = -1
    14     for i in range(numFeatures):        #iterate over all the features
    15         featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
    16         uniqueVals = set(featList)       #get a set of unique values
    17         newEntropy = 0.0
    18         for value in uniqueVals:
    19             subDataSet = splitDataSet(dataSet, i, value)
    20             prob = len(subDataSet)/float(len(dataSet))
    21             newEntropy += prob * calcShannonEnt(subDataSet)     
    22         infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
    23         if (infoGain > bestInfoGain):       #compare this to the best gain so far    #选择信息增益最大的代码在此
    24             bestInfoGain = infoGain         #if better than current best, set to best
    25             bestFeature = i
    26     return bestFeature                      #returns an integer

            从最后一个if可以看出,选择使得信息增益最大的特征作为分割特征,现在有了特征分割准则,继续进入一下个环节,如何构建决策树,其实就是依照最上面的伪代码写下去,采用递归的思想依次分割下去,直到执行完成就构建了决策树。代码如下:

     1 def majorityCnt(classList):
     2     classCount={}
     3     for vote in classList:
     4         if vote not in classCount.keys(): classCount[vote] = 0
     5         classCount[vote] += 1
     6     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
     7     return sortedClassCount[0][0]
     8 
     9 def createTree(dataSet,labels):
    10     classList = [example[-1] for example in dataSet]
    11     if classList.count(classList[0]) == len(classList): 
    12         return classList[0]#stop splitting when all of the classes are equal
    13     if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
    14         return majorityCnt(classList)
    15     bestFeat = chooseBestFeatureToSplit(dataSet)
    16     bestFeatLabel = labels[bestFeat]
    17     myTree = {bestFeatLabel:{}}
    18     del(labels[bestFeat])
    19     featValues = [example[bestFeat] for example in dataSet]
    20     uniqueVals = set(featValues)
    21     for value in uniqueVals:
    22         subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
    23         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    24     return myTree                        


          用图二的样本构建的决策树如(图三)所示:

    (图三)

    有了决策树,就可以用它做分类咯,分类代码如下:

     

     1 def classify(inputTree,featLabels,testVec):
     2     firstStr = inputTree.keys()[0]
     3     secondDict = inputTree[firstStr]
     4     featIndex = featLabels.index(firstStr)
     5     key = testVec[featIndex]
     6     valueOfFeat = secondDict[key]
     7     if isinstance(valueOfFeat, dict): 
     8         classLabel = classify(valueOfFeat, featLabels, testVec)
     9     else: classLabel = valueOfFeat
    10     return classLabel


    最后给出序列化决策树(把决策树模型保存在硬盘上)的代码:

     1 def storeTree(inputTree,filename):
     2     import pickle
     3     fw = open(filename,'w')
     4     pickle.dump(inputTree,fw)
     5     fw.close()
     6     
     7 def grabTree(filename):
     8     import pickle
     9     fr = open(filename)
    10     return pickle.load(fr)
    11     

    优点:检测速度快

    缺点:容易过拟合,可以采用修剪的方式来尽量避免

    以上内容来至群友博客http://blog.csdn.net/marvin521/article/details/9255977

    Ps:决策树算法以其简单、清晰、高效的性能,在传统行业应用非常广泛,常见流失预警,客户的目标响应模型等,同时也是我最喜欢的一个算法,经典的决策树算法有CART和C5.0,前者是二叉树,适合并行,之前在学校的时候也在cuda架构上写过这个算法的并行程序,后者可以是多叉树。在此算法基础上的ensemble集成(bagging,adaboost、gbdt,bagging+random subspace = random forest、(pca+subspace)+tree = rotation forest)性能有很大的提升,至于决策树剪枝则可以选用不同的优化指标,采用前向还是后向剪。

  • 相关阅读:
    Java操作redis
    Ajax & Json
    【转载】K8s-Pod时区与宿主时区时区同步
    【转载】Python中如何将字符串作为变量名
    【转载】python实现dubbo接口的调用
    机器学习避坑指南:训练集/测试集分布一致性检查
    机器学习深度研究:特征选择中几个重要的统计学概念
    机器学习数学基础:学习线性代数,千万不要误入歧途!推荐一个正确学习路线
    被 Pandas read_csv 坑了
    print('Hello World!')的新玩法
  • 原文地址:https://www.cnblogs.com/kobedeshow/p/3337478.html
Copyright © 2020-2023  润新知