• 决策树


       决策树

       在这里决策树就是,通过事物的特征,判断一个事物是否属于一个分类的算法。

       比如,通过下面的一组数据通过条件 “不浮出水面是否可以生存” 和 “是否有脚蹼” 来判读是否是鱼的例子

               不浮出水面是否可以生存	     是否有脚蹼	       属于鱼类
    1	            是	                  是	         是
    2	            是	                  是	         是
    3	            是	                  否	         否
    4	            否	                  是	         否
    5	            否	                  否	         否
    

        直接观察可以得出,“当浮出水面是否可以生存” 和 “是否有脚蹼” 两个条件都是 “是” 的时候,这时候就可以确定是鱼类

           下面通过算法实现这个过程, 思路是

            1、首先,根据顺序应该先使用那个特征进行分类,也就是先用 “当浮出水面是否可以生存”  还是 “是否有脚蹼” 进行分类;

            2、其次,根据分类特征,对数据进行分类。也就比如,上一步是决定好要先使用 “当浮出水面是否可以生存” 做为特征,然后根据这个特征来对数据分类,分完类,再使用 “是否有脚蹼” 这个特征分类;

            3、完成上面那两步之后,就可以用递归算法,得出一个决策方式;

            在说明算法前先把上面的例子用下面的代码形式的数据表示(代码1):

    def createDataSet():
        dataSet = [[1, 1, 'yes'],
                   [1, 1, 'yes'],
                   [1, 0, 'no'],
                   [0, 1, 'no'],
                   [0, 1, 'no']]
        labels = ['no surfacing','flippers']
        #change to discrete values
        return dataSet, labels 

     一、     

          首先使用哪个特征分类,是通过“香农”定义 “熵” 公式决定的,这个公式是 l(x) = -log2P(x)    (ps:手打不好看)

            这段就是计算熵的代码(代码2)

    def calcShannonEnt(dataSet):
        numEntries = len(dataSet)
        labelCounts = {}
        for featVec in dataSet: #the the number of unique elements and their occurance
            currentLabel = featVec[-1]
            if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
        shannonEnt = 0.0
        for key in labelCounts:
            prob = float(labelCounts[key])/numEntries
            shannonEnt -= prob * log(prob,2)   #log base 2
        return shannonEnt  

            在这里,先选择的特征值 l(x),和对应特征划分的数据计算熵,哪个熵越小,就先使用哪个作为特征来判断 ( ps:只是看算法理解的,不知道是不是 )

            首先说明,根据特征划分数据集的意思就是,如果用 “不浮出水面是否可以生存”来划 分数据就是,只取数据中 “不浮出水面是否可以生存” 是的条件为前提,选择数据。也就是再代码1,中的数据集就变为,这个数据就只剩下 “是否有脚蹼” 和 ”是否为鱼“ 这两个维度了 。(ps:可能写的不太好理解)

    [1,yes], [1,yes], [0,no]
    

        根据这个特征划分,数据集的代码(代码3),”axis“这个参数是 第几列也就是哪个特征;value 是 这个特征值是否为存在,在代码1中的数据集里就是 1 就是有,0就是没有;

    def splitDataSet(dataSet, axis, value):
        retDataSet = []
        for featVec in dataSet:
            if featVec[axis] == value:
                reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
                reducedFeatVec.extend(featVec[axis+1:])
                retDataSet.append(reducedFeatVec)
        return retDataSet
    

         然后返回了根据这个特征作为判断的数据集后,通过计算这个数据集的熵来决定先用哪个特征(代码4)

    def chooseBestFeatureToSplit(dataSet):
        numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
        baseEntropy = calcShannonEnt(dataSet)
        bestInfoGain = 0.0; bestFeature = -1
        for i in range(numFeatures):        #iterate over all the features
            featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
            uniqueVals = set(featList)       #get a set of unique values
            newEntropy = 0.0
            for value in uniqueVals:
                subDataSet = splitDataSet(dataSet, i, value)
                prob = len(subDataSet)/float(len(dataSet))
                newEntropy += prob * calcShannonEnt(subDataSet)     
            infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
            if (infoGain > bestInfoGain):       #compare this to the best gain so far
                bestInfoGain = infoGain         #if better than current best, set to best
                bestFeature = i
        return bestFeature                      #returns an integer
    

      

     二、总的算法

              递归构建决策树

    def createTree(dataSet,labels):
        classList = [example[-1] for example in dataSet]
        if classList.count(classList[0]) == len(classList): 
            return classList[0]#stop splitting when all of the classes are equal
        if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
            return majorityCnt(classList)
        bestFeat = chooseBestFeatureToSplit(dataSet)
        bestFeatLabel = labels[bestFeat]
        myTree = {bestFeatLabel:{}}
        del(labels[bestFeat])
        featValues = [example[bestFeat] for example in dataSet]
        uniqueVals = set(featValues)
        for value in uniqueVals:
            subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
            myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
        return myTree   
    

      

    -------------------------------------------------------------------

    参考:Mechine learning in Action

             

           

  • 相关阅读:
    Qt 数据库篇
    js字符串函数(转)
    如何解决IE无法识别html5中的新标签(article、abbr、header等)
    web多页打印问题
    诡异的Spinner级联样式
    discuz x2用户删除了,帖子不能用了,恢复帖子的办法
    创业公司如何招聘优秀工程师
    清除目录下的SVN信息
    .NET 项目SVN 全局排除设置
    编程技术面试的五大要点
  • 原文地址:https://www.cnblogs.com/Jomini/p/11298888.html
Copyright © 2020-2023  润新知