• 决策树ID3算法python实现 -- 《机器学习实战》


     1 from math import log
     2 import numpy as np
     3 import matplotlib.pyplot as plt
     4 import operator
     5 
     6 #计算给定数据集的香农熵
     7 def calcShannonEnt(dataSet):
     8     numEntries = len(dataSet)
     9     labelCounts = {}
    10     for featVec in dataSet:                         #|
    11         currentLabel = featVec[-1]                  #|
    12         if currentLabel not in labelCounts.keys():  #|获取标签类别取值空间(key)及出现的次数(value)
    13             labelCounts[currentLabel] = 0           #|
    14         labelCounts[currentLabel] += 1              #|
    15     shannonEnt = 0.0
    16     for key in labelCounts:                         #|
    17         prob = float(labelCounts[key])/numEntries   #|计算香农熵
    18         shannonEnt -= prob * log(prob, 2)           #|
    19     return shannonEnt
    20 
    21 #创建数据集
    22 def createDataSet():
    23     dataSet = [[1,1,'yes'],
    24                [1,1,'yes'],
    25                [1,0,'no'],
    26                [0,1,'no'],
    27                [0,1,'no']]
    28     labels = ['no surfacing', 'flippers']
    29     return dataSet, labels
    30 
    31 #按照给定特征划分数据集
    32 def splitDataSet(dataSet, axis, value):
    33     retDataSet = []
    34     for featVec in dataSet:                         #|
    35         if featVec[axis] == value:                  #|
    36             reducedFeatVec = featVec[:axis]         #|抽取出符合特征的数据
    37             reducedFeatVec.extend(featVec[axis+1:]) #|
    38             retDataSet.append(reducedFeatVec)       #|
    39     return retDataSet
    40 
    41 #选择最好的数据集划分方式
    42 def chooseBestFeatureToSplit(dataSet):
    43     numFeatures = len(dataSet[0]) - 1
    44     basicEntropy = calcShannonEnt(dataSet)
    45     bestInfoGain = 0.0; bestFeature = -1
    46     for i in range(numFeatures):        #计算每一个特征的熵增益
    47         featlist = [example[i] for example in dataSet]
    48         uniqueVals = set(featlist)
    49         newEntropy = 0.0
    50         for value in uniqueVals:        #计算每一个特征的不同取值的熵增益
    51             subDataSet = splitDataSet(dataSet, i, value)
    52             prob = len(subDataSet)/float(len(dataSet))
    53             newEntropy += prob * calcShannonEnt(subDataSet) #不同取值的熵增加起来就是整个特征的熵增益
    54         infoGain = basicEntropy - newEntropy
    55         if (infoGain > bestInfoGain):   #选择最高的熵增益作为划分方式
    56             bestInfoGain = infoGain
    57             bestFeature = i
    58     return bestFeature
    59 #挑选出现次数最多的类别
    60 def majorityCnt(classList):
    61     classCount={}
    62     for vote in classList:
    63         if vote not in classCount.keys():
    64             classCount[vote] = 0
    65         classCount[vote] += 1
    66     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
    67     return sortedClassCount[0][0]
    68 
    69 def createTree(dataSet, labels):
    70     classList = [example[-1] for example in dataSet]
    71     if classList.count(classList[0]) == len(classList): #停止条件一:判断所有类别标签是否相同,完全相同则停止继续划分
    72         return classList[0]
    73     if len(dataSet[0]) == 1:    #停止条件二:遍历完所有特征时返回出现次数最多的
    74         return majorityCnt(classList)
    75     bestFeat = chooseBestFeatureToSplit(dataSet)    #得到列表包含的所有属性值
    76     bestFeatLabel = labels[bestFeat]
    77     myTree = {bestFeatLabel:{}}
    78     del(labels[bestFeat])
    79     featValues = [example[bestFeat] for example in dataSet]
    80     uniqueVals = set(featValues)
    81     for value in uniqueVals:
    82         subLabels = labels[:]
    83         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    84     return myTree
    85 
    86 # Simple unit test of func: createDataSet()
    87 myDat, labels = createDataSet()
    88 print (myDat)
    89 #print (labels)
    90 # Simple unit test of func: splitDataSet()
    91 splitData = splitDataSet(myDat,0,1)
    92 print (splitData)
    93 # Simple unit test of func: chooseBestFeatureToSplit()
    94 chooseResult = chooseBestFeatureToSplit(myDat)
    95 print (chooseResult)
    96 # Simple unit test of func: createTree(
    97 myDat, labels = createDataSet()
    98 myTree = createTree(myDat, labels)
    99 print(myTree)

    Output:

    [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    [[1, 'yes'], [1, 'yes'], [0, 'no']]
    0
    {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

    Reference:

    《机器学习实战》

  • 相关阅读:
    详解Oracle临时表的几种用法及意义
    Testing and Debugging Procedures using SQL Developer 3.1
    ORACLE 流复制
    ORA01017 invalid username/password; logon denied
    oracle数据类型
    使用Pls_Integer的好处
    js取得上传图片大小
    高效整洁CSS代码原则
    在线压缩js和css
    图片等比例缩放后裁切
  • 原文地址:https://www.cnblogs.com/knownx/p/7825068.html
Copyright © 2020-2023  润新知