• 03机器学习实战之决策树CART算法


    1.CART生成

    CART假设决策树是二叉树,内部结点特征的取值为“是”和“否”,左分支是取值为“是”的分支,右分支是取值为“否”的分支。这样的决策树等价于递归地二分每个特征,将输入空间即特征空间划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输入给定的条件下输出的条件概率分布。

    CART算法由以下两步组成:

    1. 决策树生成:基于训练数据集生成决策树,生成的决策树要尽量大;
    2. 决策树剪枝:用验证数据集对已生成的树进行剪枝并选择最优子树,这时损失函数最小作为剪枝的标准。

    CART决策树的生成就是递归地构建二叉决策树的过程。CART决策树既可以用于分类也可以用于回归。本文我们仅讨论用于分类的CART。对分类树而言,CART用Gini系数最小化准则来进行特征选择,生成二叉树。 CART生成算法如下:

    根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉决策树:

    算法停止计算的条件是结点中的样本个数小于预定阈值,或样本集的Gini系数小于预定阈值(样本基本属于同一类),或者没有更多特征。

    2.一个具体的例子

    首先对数据集非类标号属性{是否有房,婚姻状况,年收入}分别计算它们的Gini系数增益,取Gini系数增益值最大的属性作为决策树的根节点属性。根节点的Gini系数 

     

    当根据是否有房来进行划分时,Gini系数增益计算过程为 

    若按婚姻状况属性来划分,属性婚姻状况有三个可能的取值{married,single,divorced},分别计算划分后的

    • {married} | {single,divorced}
    • {single} | {married,divorced}
    • {divorced} | {single,married}

     的Gini系数增益。 
    当分组为{married} | {single,divorced}时,SlSl表示婚姻状况取值为married的分组,SrSr表示婚姻状况取值为single或者divorced的分组 

     

    当分组为{single} | {married,divorced}时, 

     

     

    当分组为{divorced} | {single,married}时, 

     
     

     对比计算结果,根据婚姻状况属性来划分根节点时取Gini系数增益最大的分组作为划分结果,也就是{married} | {single,divorced}。

     最后考虑年收入属性,我们发现它是一个连续的数值类型。我们在前面的文章里已经专门介绍过如何应对这种类型的数据划分了。对此还不是很清楚的朋友可以参考之前的文章,这里不再赘述。

     对于年收入属性为数值型属性,首先需要对数据按升序排序,然后从小到大依次用相邻值的中间值作为分隔将样本划分为两组。例如当面对年收入为60和70这两个值时,我们算得其中间值为65。倘若以中间值65作为分割点。

     

    其他值的计算同理可得,我们不再逐一给出计算过程,仅列出结果如下(最终我们取其中使得增益最大化的那个二分准则来作为构建二叉树的准则): 

     

    最大化增益等价于最小化子女结点的不纯性度量(Gini系数)的加权平均值,之前的表里我们列出的是Gini系数的加权平均值,现在的表里给出的是Gini系数增益。现在我们希望最大化Gini系数的增益。根据计算知道,三个属性划分根节点的增益最大的有两个:年收入属性和婚姻状况,他们的增益都为0.12。此时,选取首先出现的属性作为第一次划分

     接下来,采用同样的方法,分别计算剩下属性,其中根节点的Gini系数为(此时是否拖欠贷款的各有3个records) 

    与前面的计算过程类似,对于是否有房属性,可得 

     

    对于年收入属性则有:

    最后我们构建的CART如下图所示:

     

    最后我们总结一下,CART和C4.5的主要区别:

    • C4.5采用信息增益率来作为分支特征的选择标准,而CART则采用Gini系数;
    • C4.5不一定是二叉树,但CART一定是二叉树。

    3.关于过拟合以及剪枝

    决策树很容易发生过拟合,也就是由于对train数据集适应得太好,反而在test数据集上表现得不好。这个时候我们要么是

     通过阈值控制终止条件避免树形结构分支过细,要么就是通过对已经形成的决策树进行剪枝来避免过拟合。另外一个克服过拟合的手段就是基于Bootstrap的思想建立随机森林(Random Forest)。关于剪枝的内容可以参考文献【2】以了解更多。

    参考文献
    【1】Wu, X., Kumar, V., Quinlan, J.R., Ghosh, J., Yang, Q., Motoda, H., McLachlan, G.J., Ng, A., Liu, B., Philip, S.Y. and Zhou, Z.H., 2008. Top 10 algorithms in data mining. Knowledge and information systems, 14(1), pp.1-37. (http://www.cs.uvm.edu/~icdm/algorithms/10Algorithms-08.pdf
    【2】李航,统计学习方法,清华大学出版社


    4.代码实现

    import numpy as np
    
    
    # 定义树结构,采用的二叉树,左子树:条件为true,右子树:条件为false
    # leftBranch:左子树结点
    # rightBranch:右子树结点
    # col:信息增益最大时对应的列索引
    # value:最优列索引下,划分数据类型的值
    # results:分类结果
    # summary:信息增益最大时样本信息
    # data:信息增益最大时数据集
    class Tree:
        def __init__(self, leftBranch=None, rightBranch=None, col=-1, value=None, results=None, summary=None, data=None):
            self.leftBranch = leftBranch
            self.rightBranch = rightBranch
            self.col = col
            self.value = value
            self.results = results
            self.summary = summary
            self.data = data
    
        def __str__(self):
            print(u"列号:%d" % self.col)
            print(u"列划分值:%s" % self.value)
            print(u"样本信息:%s" % self.summary)
            return ""
    
    
    # 划分数据集
    def splitDataSet(dataSet, value, column):
        leftList = []
        rightList = []
        # 判断value是否是数值型
        if (isinstance(value, int) or isinstance(value, float)):
            # 遍历每一行数据
            for rowData in dataSet:
                # 如果某一行指定列值>=value,则将该行数据保存在leftList中,否则保存在rightList中
                if (rowData[column] >= value):
                    leftList.append(rowData)
                else:
                    rightList.append(rowData)
        # value为标称型
        else:
            # 遍历每一行数据
            for rowData in dataSet:
                # 如果某一行指定列值==value,则将该行数据保存在leftList中,否则保存在rightList中
                if (rowData[column] == value):
                    leftList.append(rowData)
                else:
                    rightList.append(rowData)
        return leftList, rightList
    
    
    # 统计标签类每个样本个数
    '''
    该函数是计算gini值的辅助函数,假设输入的dataSet为为['A', 'B', 'C', 'A', 'A', 'D'],
    则输出为['A':3,' B':1, 'C':1, 'D':1],这样分类统计dataSet中每个类别的数量
    '''
    
    
    def calculateDiffCount(dataSet):
        results = {}
        for data in dataSet:
            # data[-1] 是数据集最后一列,也就是标签类
            if data[-1] not in results:
                results.setdefault(data[-1], 1)
            else:
                results[data[-1]] += 1
        return results
    
    
    # 基尼指数公式实现
    def gini(dataSet):
        # 计算gini的值(Calculate GINI)
        # 数据所有行
        length = len(dataSet)
        # 标签列合并后的数据集
        results = calculateDiffCount(dataSet)
        imp = 0.0
        for i in results:
            imp += results[i] / length * results[i] / length
        return 1 - imp
    
    
    # 生成决策树
    '''算法步骤'''
    '''根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉决策树:
    1 设结点的训练数据集为D,计算现有特征对该数据集的信息增益。此时,对每一个特征A,对其可能取的
      每个值a,根据样本点对A >=a 的测试为“是”或“否”将D分割成D1和D2两部分,利用基尼指数计算信息增益。
    2 在所有可能的特征A以及它们所有可能的切分点a中,选择信息增益最大的特征及其对应的切分点作为最优特征
      与最优切分点,依据最优特征与最优切分点,从现结点生成两个子结点,将训练数据集依特征分配到两个子结点中去。
    3 对两个子结点递归地调用1,2,直至满足停止条件。
    4 生成CART决策树。
    '''''''''''''''''''''
    
    
    # evaluationFunc= gini :采用的是基尼指数来衡量信息关注度
    def buildDecisionTree(dataSet, evaluationFunc=gini):
        # 计算基础数据集的基尼指数
        baseGain = evaluationFunc(dataSet)
        # 计算每一行的长度(也就是列总数)
        columnLength = len(dataSet[0])
        # 计算数据项总数
        rowLength = len(dataSet)
        # 初始化
        bestGain = 0.0  # 信息增益最大值
        bestValue = None  # 信息增益最大时的列索引,以及划分数据集的样本值
        bestSet = None  # 信息增益最大,听过样本值划分数据集后的数据子集
        # 标签列除外(最后一列),遍历每一列数据
        for col in range(columnLength - 1):
            # 获取指定列数据
            colSet = [example[col] for example in dataSet]
            # 获取指定列样本唯一值
            uniqueColSet = set(colSet)
            # 遍历指定列样本集
            for value in uniqueColSet:
                # 分割数据集
                leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
                # 计算子数据集概率,python3 "/"除号结果为小数
                prop = len(leftDataSet) / rowLength
                # 计算信息增益
                infoGain = baseGain - prop * evaluationFunc(leftDataSet) - (1 - prop) * evaluationFunc(rightDataSet)
                # 找出信息增益最大时的列索引,value,数据子集
                if (infoGain > bestGain):
                    bestGain = infoGain
                    bestValue = (col, value)
                    bestSet = (leftDataSet, rightDataSet)
        # 结点信息
        #    nodeDescription = {'impurity:%.3f'%baseGain,'sample:%d'%rowLength}
        nodeDescription = {'impurity': '%.3f' % baseGain, 'sample': '%d' % rowLength}
        # 数据行标签类别不一致,可以继续分类
        # 递归必须有终止条件
        if bestGain > 0:
            # 递归,生成左子树结点,右子树结点
            leftBranch = buildDecisionTree(bestSet[0], evaluationFunc)
            rightBranch = buildDecisionTree(bestSet[1], evaluationFunc)
            # print(Tree(leftBranch=leftBranch, rightBranch=rightBranch, col=bestValue[0]
            #             , value=bestValue[1], summary=nodeDescription, data=bestSet))
            return Tree(leftBranch=leftBranch, rightBranch=rightBranch, col=bestValue[0]
                        , value=bestValue[1], summary=nodeDescription, data=bestSet)
        else:
            # 数据行标签类别都相同,分类终止
            return Tree(results=calculateDiffCount(dataSet), summary=nodeDescription, data=dataSet)
    
    
    # def createTree(dataSet, evaluationFunc=gini):
    #     # 递归建立决策树, 当gain=0,时停止回归
    #     # 计算基础数据集的基尼指数
    #     baseGain = evaluationFunc(dataSet)
    #     # 计算每一行的长度(也就是列总数)
    #     columnLength = len(dataSet[0])
    #     # 计算数据项总数
    #     rowLength = len(dataSet)
    #     # 初始化
    #     bestGain = 0.0  # 信息增益最大值
    #     bestValue = None  # 信息增益最大时的列索引,以及划分数据集的样本值
    #     bestSet = None  # 信息增益最大,听过样本值划分数据集后的数据子集
    #     # 标签列除外(最后一列),遍历每一列数据
    #     for col in range(columnLength - 1):
    #         # 获取指定列数据
    #         colSet = [example[col] for example in dataSet]
    #         # 获取指定列样本唯一值
    #         uniqueColSet = set(colSet)
    #         # 遍历指定列样本集
    #         for value in uniqueColSet:
    #             # 分割数据集
    #             leftDataSet, rightDataSet = splitDataSet(dataSet, value, col)
    #             # 计算子数据集概率,python3 "/"除号结果为小数
    #             prop = len(leftDataSet) / rowLength
    #             # 计算信息增益
    #             infoGain = baseGain - prop * evaluationFunc(leftDataSet) - (1 - prop) * evaluationFunc(rightDataSet)
    #             # 找出信息增益最大时的列索引,value,数据子集
    #             if (infoGain > bestGain):
    #                 bestGain = infoGain
    #                 bestValue = (col, value)
    #                 bestSet = (leftDataSet, rightDataSet)
    #
    #     impurity = u'%.3f' % baseGain
    #     sample = '%d' % rowLength
    #
    #     if bestGain > 0:
    #         bestFeatLabel = u'serial:%s
    impurity:%s
    sample:%s' % (bestValue[0], impurity, sample)
    #         myTree = {bestFeatLabel: {}}
    #         myTree[bestFeatLabel][bestValue[1]] = createTree(bestSet[0], evaluationFunc)
    #         myTree[bestFeatLabel]['no'] = createTree(bestSet[1], evaluationFunc)
    #         return myTree
    #     else:  # 递归需要返回值
    #         bestFeatValue = u'%s
    impurity:%s
    sample:%s' % (str(calculateDiffCount(dataSet)), impurity, sample)
    #         return bestFeatValue
    
    
    # 分类测试:
    '''根据给定测试数据遍历二叉树,找到符合条件的叶子结点'''
    '''例如测试数据为[5.9,3,4.2,1.75],按照训练数据生成的决策树分类的顺序为
       第2列对应测试数据4.2 =>与决策树根结点(2)的value(3)比较,>=3则遍历左子树,否则遍历右子树,
       叶子结点就是结果'''
    
    
    def classify(data, tree):
        # 判断是否是叶子结点,是就返回叶子结点相关信息,否就继续遍历
        if tree.results != None:
            return u"%s
    %s" % (tree.results, tree.summary)
        else:
            branch = None
            v = data[tree.col]
            # 数值型数据
            if isinstance(v, int) or isinstance(v, float):
                if v >= tree.value:
                    branch = tree.leftBranch
                else:
                    branch = tree.rightBranch
            else:  # 标称型数据
                if v == tree.value:
                    branch = tree.leftBranch
                else:
                    branch = tree.rightBranch
            return classify(data, branch)
    
    
    def loadCSV(fileName):
        def convertTypes(s):
            s = s.strip()
            try:
                return float(s) if '.' in s else int(s)
            except ValueError:
                return s
    
        data = np.loadtxt(fileName, dtype='str', delimiter=',')
        data = data[1:, :]
        dataSet = ([[convertTypes(item) for item in row] for row in data])
        return dataSet
    
    
    # 多数表决器
    # 列中相同值数量最多为结果
    # def majorityCnt(classList):
    #     import operator
    #     classCounts = {}
    #     for value in classList:
    #         if (value not in classCounts.keys()):
    #             classCounts[value] = 0
    #         classCounts[value] += 1
    #     sortedClassCount = sorted(classCounts.items(), key=operator.itemgetter(1), reverse=True)
    #     return sortedClassCount[0][0]
    
    
    # 剪枝算法(前序遍历方式:根=>左子树=>右子树)
    '''算法步骤
    1. 从二叉树的根结点出发,递归调用剪枝算法,直至左、右结点都是叶子结点
    2. 计算父节点(子结点为叶子结点)的信息增益infoGain
    3. 如果infoGain < miniGain,则选取样本多的叶子结点来取代父节点
    4. 循环1,2,3,直至遍历完整棵树
    '''''''''
    # def prune(tree, miniGain, evaluationFunc=gini):
    #     print(u"当前结点信息:")
    #     print(str(tree))
    #     # 如果当前结点的左子树不是叶子结点,遍历左子树
    #     if (tree.leftBranch.results == None):
    #         print(u"左子树结点信息:")
    #         print(str(tree.leftBranch))
    #         prune(tree.leftBranch, miniGain, evaluationFunc)
    #     # 如果当前结点的右子树不是叶子结点,遍历右子树
    #     if (tree.rightBranch.results == None):
    #         print(u"右子树结点信息:")
    #         print(str(tree.rightBranch))
    #         prune(tree.rightBranch, miniGain, evaluationFunc)
    #     # 左子树和右子树都是叶子结点
    #     if (tree.leftBranch.results != None and tree.rightBranch.results != None):
    #         # 计算左叶子结点数据长度
    #         leftLen = len(tree.leftBranch.data)
    #         # 计算右叶子结点数据长度
    #         rightLen = len(tree.rightBranch.data)
    #         # 计算左叶子结点概率
    #         leftProp = leftLen / (leftLen + rightLen)
    #         # 计算该结点的信息增益(子类是叶子结点)
    #         infoGain = (evaluationFunc(tree.leftBranch.data + tree.rightBranch.data) -
    #                     leftProp * evaluationFunc(tree.leftBranch.data) - (1 - leftProp) * evaluationFunc(
    #                     tree.rightBranch.data))
    #         # 信息增益 < 给定阈值,则说明叶子结点与其父结点特征差别不大,可以剪枝
    #         if (infoGain < miniGain):
    #             # 合并左右叶子结点数据
    #             dataSet = tree.leftBranch.data + tree.rightBranch.data
    #             # 获取标签列
    #             classLabels = [example[-1] for example in dataSet]
    #             # 找到样本最多的标签值
    #             keyLabel = majorityCnt(classLabels)
    #             # 判断标签值是左右叶子结点哪一个
    #             if keyLabel in tree.leftBranch.results:
    #                 # 左叶子结点取代父结点
    #                 tree.data = tree.leftBranch.data
    #                 tree.results = tree.leftBranch.results
    #                 tree.summary = tree.leftBranch.summary
    #             else:
    #                 # 右叶子结点取代父结点
    #                 tree.data = tree.rightBranch.data
    #                 tree.results = tree.rightBranch.results
    #                 tree.summary = tree.rightBranch.summary
    #             tree.leftBranch = None
    #             tree.rightBranch = None
    
    
    def printTree(myTree):
        print("当前结点信息:")
        print(myTree)
        # 如果当前结点的左子树不是叶子结点,遍历左子树
        if (myTree.leftBranch.results == None):
            print("左子树结点信息:")
            print(myTree.leftBranch)
            printTree(myTree.leftBranch)
        # 如果当前结点的右子树不是叶子结点,遍历右子树
        if (myTree.rightBranch.results == None):
            print("右子树结点信息:")
            print(myTree.rightBranch)
            printTree(myTree.rightBranch)
    
    
    if __name__ == '__main__':
        dataSet = loadCSV("D:\mlInAction\irisData.csv")
        # print(dataSet)
        myTree = buildDecisionTree(dataSet, evaluationFunc=gini)
        printTree(myTree)
        testData = [5.9, 3, 4.2, 1.75]
        result = classify(testData,myTree)
        print("预测结果为:")
        print(result)

    5.输出结果

    当前结点信息:
    列号:2
    列划分值:3
    样本信息:{'impurity': '0.667', 'sample': '150'}
    
    左子树结点信息:
    列号:3
    列划分值:1.8
    样本信息:{'impurity': '0.500', 'sample': '100'}
    
    当前结点信息:
    列号:3
    列划分值:1.8
    样本信息:{'impurity': '0.500', 'sample': '100'}
    
    左子树结点信息:
    列号:2
    列划分值:4.9
    样本信息:{'impurity': '0.043', 'sample': '46'}
    
    当前结点信息:
    列号:2
    列划分值:4.9
    样本信息:{'impurity': '0.043', 'sample': '46'}
    
    右子树结点信息:
    列号:0
    列划分值:6
    样本信息:{'impurity': '0.444', 'sample': '3'}
    
    当前结点信息:
    列号:0
    列划分值:6
    样本信息:{'impurity': '0.444', 'sample': '3'}
    
    右子树结点信息:
    列号:2
    列划分值:5
    样本信息:{'impurity': '0.168', 'sample': '54'}
    
    当前结点信息:
    列号:2
    列划分值:5
    样本信息:{'impurity': '0.168', 'sample': '54'}
    
    左子树结点信息:
    列号:3
    列划分值:1.6
    样本信息:{'impurity': '0.444', 'sample': '6'}
    
    当前结点信息:
    列号:3
    列划分值:1.6
    样本信息:{'impurity': '0.444', 'sample': '6'}
    
    左子树结点信息:
    列号:0
    列划分值:7.2
    样本信息:{'impurity': '0.444', 'sample': '3'}
    
    当前结点信息:
    列号:0
    列划分值:7.2
    样本信息:{'impurity': '0.444', 'sample': '3'}
    
    右子树结点信息:
    列号:3
    列划分值:1.7
    样本信息:{'impurity': '0.041', 'sample': '48'}
    
    当前结点信息:
    列号:3
    列划分值:1.7
    样本信息:{'impurity': '0.041', 'sample': '48'}
    
    预测结果为:
    {'virginica': 1}
    {'impurity': '0.000', 'sample': '1'}

    https://www.cnblogs.com/further-further-further/p/9482885.html

    https://baimafujinji.blog.csdn.net/article/details/53269040

  • 相关阅读:
    Selenium + WebDriver 各浏览器驱动下载地址
    selenium之 文件上传所有方法整理总结【转】
    FakeUserAgentError('Maximum amount of retries reached') 彻底解决办法
    git关联远程仓库
    通过chrome console 快速获取网页连接
    【转】Selenium
    【转】fiddler抓包HTTPS请求
    【转】Wireshark和Fiddler分析Android中的TLS协议包数据(附带案例样本)
    php 通过 create user 和grant 命令无法创建数据库用户和授权的解决办法
    差等生也是需要交卷的
  • 原文地址:https://www.cnblogs.com/xinmomoyan/p/10768611.html
Copyright © 2020-2023  润新知