• day-8 python自带库实现ID3决策树算法


      前一天,我们基于sklearn科学库实现了ID3的决策树程序,本文将基于python自带库实现ID3决策树算法。

     一、代码涉及基本知识

      1、 为了绘图方便,引入了一个第三方treePlotter模块进行图形绘制。该模块使用方法简单,调用模块createPlot接口,传入一个树型结构对象,即可绘制出相应图像。

      2、  在python中,如何定义一个树型结构对象

        可以使用了python自带的字典数据类型来定义一个树型对象。例如下面代码,我们定义一个根节点和两个左右子节点:

        rootNode = {'rootNode': {}}
        leftNode = {'leftNode': {'yes':'yes'}}
        rightNode = {'rightNode': {'no':'no'}}
        rootNode['rootNode']['left'] = leftNode
        rootNode['rootNode']['right'] = rightNode
        treePlotter.createPlot(rootNode)

        通过调用treePlotter模块,绘制出如下树的图像

        

      2、  递归调用

        为了求每个节点的各个子节点,要用到递归的方法来实现,基本思想和二叉树的遍历方法一致,后面我们还会用Python实现一个二叉树源码,此处不再进行介绍。

      3、  此外,还需要对python常用的数据类型及其操作比较了解,例如字典、列表、集合等

    二、程序主要流程

     

    三、测试数据集

    age

    income

    student

    credit_rating

    class_buys_computer

    youth

    high

    no

    fair

    no

    youth

    high

    no

    excellent

    no

    middle_aged

    high

    no

    fair

    yes

    senior

    medium

    no

    fair

    yes

    senior

    low

    yes

    fair

    yes

    senior

    low

    yes

    excellent

    no

    middle_aged

    low

    yes

    excellent

    yes

    youth

    medium

    no

    fair

    no

    youth

    low

    yes

    fair

    yes

    senior

    medium

    yes

    fair

    yes

    youth

    medium

    yes

    excellent

    yes

    middle_aged

    medium

    no

    excellent

    yes

    middle_aged

    high

    yes

    fair

    yes

    senior

    medium

    no

    excellent

    no

    四、程序代码

             1、计算测试集熵及信息增益        

    # 求最优的根节点
    def chooseBestFeatureToSplit(dataset,headerList):
        # 定义一个初始值
        bestInfoGainRate = 0.0
        bestFeature = 0
        # 求特征列项的数量
        numFeatures = len(dataset[0]) -1
        # 获取整个测试数据集的熵
        baseShnnonEnt = calcShannonEnt(dataset)
        print("total's shannonEnt = %f" % (baseShnnonEnt))
        # 遍历每一个特征列,求取信息增益
        for i in range(numFeatures):
            # 获取某一列所有特征值
            featureList = [example[i] for example in dataset]
            uniqueVals = set(featureList)
            newEntropy = 0.0
            # 求得某一列某一个特征值的概率和熵
            newShannonEnt = 0.0
            for value in uniqueVals:
                # 计算熵
                subDataset = splitDataSet(dataset,i,value)
                newEntropy = calcShannonEnt(subDataset)
                # 计算某一列某一个特征值的概率
                newProbability = len(subDataset) / float(len(dataset))
                newShannonEnt += newProbability*calcShannonEnt(subDataset)
            infoGainRate = baseShnnonEnt - newShannonEnt
            print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
            if infoGainRate > bestInfoGainRate:
                bestInfoGainRate = infoGainRate
                bestFeature = i
        return bestFeature

      该结果和前一天计算结果一致,age特征对应信息增益最大,因此设为根节点:

            

             2、程序源码

             treePlotter.py        

    import matplotlib.pyplot as plt
    
    # 定义决策树决策结果属性
    descisionNode = dict(boxstyle='sawtooth', fc='0.8')
    leafNode = dict(boxstyle='round4', fc='0.8')
    arrow_args = dict(arrowstyle='<-')
    
    def plotNode(nodeTxt, centerPt, parentPt, nodeType):
        # nodeTxt为要显示的文本,centerNode为文本中心点, nodeType为箭头所在的点, parentPt为指向文本的点
        createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                                 xytext=centerPt, textcoords='axes fraction',
                                  va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
    def getNumLeafs(myTree):
        numLeafs = 0
        firstStr = list(myTree.keys())[0]
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                numLeafs += getNumLeafs(secondDict[key])
            else:
                numLeafs += 1
        return numLeafs
    
    def getTreeDepth(myTree):
        maxDepth = 0
        firstStr = list(myTree.keys())[0]     # 这个是改的地方,原来myTree.keys()返回的是dict_keys类,不是列表,运行会报错。有好几个地方这样
        secondDict = myTree[firstStr]
        for key in secondDict.keys():
            if type(secondDict[key]).__name__ == 'dict':
                thisDepth = 1 + getTreeDepth(secondDict[key])
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
    
    def createPlot(inTree):
        fig = plt.figure(1, facecolor='white')
        fig.clf()
        axprops = {'xticks': None, 'yticks': None}
        createPlot.ax1 = plt.subplot(111, frameon=False)
        plotTree.totalW = float(getNumLeafs(inTree))     # 全局变量宽度 = 叶子数目
        plotTree.totalD = float(getTreeDepth(inTree))     # 全局变量高度 = 深度
        plotTree.xOff = -0.5/plotTree.totalW
        plotTree.yOff = 1.0
        plotTree(inTree, (0.5, 1.0), '')
        plt.show()
    
    def plotTree(myTree, parentPt, nodeTxt):
        numLeafs = getNumLeafs(myTree)
        depth = getTreeDepth(myTree)
        firstStr = list(myTree.keys())[0]
        # cntrPt文本中心点, parentPt指向文本中心的点
        cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
        plotMidText(cntrPt, parentPt, nodeTxt)
        plotNode(firstStr, cntrPt, parentPt, descisionNode)
        seconDict = myTree[firstStr]
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
        for key in seconDict.keys():
            if type(seconDict[key]).__name__ == 'dict':
                plotTree(seconDict[key], cntrPt, str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
                plotNode(seconDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    def plotMidText(cntrPt, parentPt, txtString):
        xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
        createPlot.ax1.text(xMid, yMid, txtString, va='center', ha='center', rotation=30)

             decision_tree_ID3.py

    # 导入库
    import csv
    import math
    import operator
    import treePlotter
    
    
    # 导入数据集
    def readDataset(file_path,file_mode):
        allElectronicsData = open(file_path, file_mode)
        reader = csv.reader(allElectronicsData)
        # 读取特征名称
        headers = next(reader)
        # 读取测试数据集
        dataset = []
        for row in reader:
            dataset.append(row)
        return headers,dataset
    
    # 求某个数据集的熵
    def calcShannonEnt(dataset):
        shannonEnt = 0.0
        labelList = {}
        for vec_now in dataset:
            labelValue = vec_now[-1]
            if vec_now[-1] not in labelList.keys():
                labelList[labelValue] = 0
            labelList[labelValue] += 1
        for labelKey in labelList:
            probability = float(labelList[labelKey] / len(dataset))
            shannonEnt -= probability*math.log(probability,2)
        return shannonEnt
    
    # 根据给定的列特征值,分理出给定的特征量
    def splitDataSet(dataset,feature_seq,value):
        new_dataset = []
        for vec_row in dataset:
            feature_Value = vec_row[feature_seq]
            if feature_Value == value:
                temp_vec = []
                temp_vec = vec_row[:feature_seq]
                temp_vec.extend(vec_row[feature_seq+1:])
                new_dataset.append(temp_vec)
        return new_dataset
    
    # 求最优的根节点
    def chooseBestFeatureToSplit(dataset,headerList):
        # 定义一个初始值
        bestInfoGainRate = 0.0
        bestFeature = 0
        # 求特征列项的数量
        numFeatures = len(dataset[0]) -1
        # 获取整个测试数据集的熵
        baseShnnonEnt = calcShannonEnt(dataset)
        #print("total's shannonEnt = %f" % (baseShnnonEnt))
        # 遍历每一个特征列,求取信息增益
        for i in range(numFeatures):
            # 获取某一列所有特征值
            featureList = [example[i] for example in dataset]
            uniqueVals = set(featureList)
            newEntropy = 0.0
            # 求得某一列某一个特征值的概率和熵
            newShannonEnt = 0.0
            for value in uniqueVals:
                # 计算熵
                subDataset = splitDataSet(dataset,i,value)
                newEntropy = calcShannonEnt(subDataset)
                # 计算某一列某一个特征值的概率
                newProbability = len(subDataset) / float(len(dataset))
                newShannonEnt += newProbability*calcShannonEnt(subDataset)
            infoGainRate = baseShnnonEnt - newShannonEnt
            #print("%s's infoGainRate = %f - %f = %f"%(headerList[i],baseShnnonEnt,newShannonEnt,infoGainRate))
            if infoGainRate > bestInfoGainRate:
                bestInfoGainRate = infoGainRate
                bestFeature = i
        return bestFeature
    
    # 标签判定,通过少数服从多数原则
    def majorityCnt(classList):
        classcount = {}
        for cl in classList:
            if cl not in classcount.keys():
                classcount[cl] = 0
            classcount[cl] += 1
        sortedClassCount = sorted(classcount.items(),key = operator.itemgetter(1),reverse= True)
        return sortedClassCount[0][0]
    
    # 创建一个决策树
    def createTree(dataSet, labels):
        classList = [example[-1] for example in dataSet]
        # 1 所有特征值都是相同的时候直接返回
        if classList.count(classList[0]) == len(classList):
            return classList[0]
        # 2 遍历完所有特征值,投票原则,返回出现次数最多的标签
        if len(dataSet[0])  == 1:
            return majorityCnt(classList)
        # 3 如果不满足上面两者,求最优特征
        bestFeature = chooseBestFeatureToSplit(dataSet,labels)
        bestFeatureLabel = labels[bestFeature]
        myTree = {bestFeatureLabel: {}}
        del (labels[bestFeature])
        featurValues = [example[bestFeature] for example in dataSet]
        uniqueVals = set(featurValues)
        # 使用递归的方法,获得整个树
        for value in uniqueVals:
            subLabels = labels[:]
            myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet, bestFeature, value), subLabels)
        return myTree
    
    def classify(inputTree, featLabels, testVec):
        firstStr = list(inputTree.keys())[0]
        secondDict = inputTree[firstStr]
        featIndex = featLabels.index(firstStr)
        for key in secondDict.keys():
            if testVec[featIndex] == key:
                if type(secondDict[key]).__name__ == 'dict':
                    classLabel = classify(secondDict[key], featLabels, testVec)
                else:
                    classLabel = secondDict[key]
        return classLabel
    
    def classifyAll(inputTree, featLabels, testDataSet):
        classLabelAll = []
        for testVec in testDataSet:
            classLabelAll.append(classify(inputTree, featLabels, testVec))
        return classLabelAll
    
    def storeTree(inputTree, filename):
        import pickle
        fw = open(filename, 'wb')
        pickle.dump(inputTree, fw)
        fw.close()
    
    def grabTree(filename):
        import pickle
        fr = open(filename, 'rb')
        return pickle.load(fr)
    
    def main():
        # 读取数据集
        labels, dataSet = readDataset(file_path=r'D:	est.csv', file_mode='r')
        labels_tmp = labels[:] # 拷贝,createTree会改变labels
        desicionTree = createTree(dataSet, labels_tmp)
        storeTree(desicionTree, 'classifierStorage.txt')
        desicionTree = grabTree('classifierStorage.txt')
        treePlotter.createPlot(desicionTree)
        testSet = [['youth', 'high', 'no', 'fair', 'no']]
        print('classifyResult:
    ', classifyAll(desicionTree, labels, testSet))
    
    if __name__ == '__main__':
        main()

    五、测试结果及结论

     

      我们从上面求解信息增益的公式中,其实可以看出,信息增益准则其实是对可取值数目较多的属性有所偏好!
      现在假如我们把数据集中的“编号”也作为一个候选划分属性。我们可以算出“编号”的信息增益是0.998
      因为每一个样本的编号都是不同的(由于编号独特唯一,条件熵为0了,每一个结点中只有一类,纯度非常高啊),也就是说,来了一个预测样本,你只要告诉我编号,其它特征就没有用了,这样生成的决策树显然不具有泛化能力。

      参考链接:

      http://www.cnblogs.com/wsine/p/5180310.html

      https://zhuanlan.zhihu.com/p/26760551

  • 相关阅读:
    面象对象设计原则之四:接口隔离原则(The Interface Segregation Principle,ISP)
    面象对象设计原则之三:里氏替换原则(The Liskov Substitution Principle,LSP)
    面象对象设计原则之二:开放封闭原则(Open-Closed Principle, OCP)
    面象对象设计原则之一:单一职责原则(Single Responsibility Principle, SRP)
    asp.net core 2.0中的配置(1)---Configuration
    UML类图与类的关系详解
    Javascript中的闭包和C#中的闭包
    Delphi考虑sql注入 QuotedStr
    delphi 窗体的创建和释放
    Delphi ADOQuery连接数据库的查询、插入、删除、修改
  • 原文地址:https://www.cnblogs.com/python-frog/p/8721065.html
Copyright © 2020-2023  润新知