1 from math import log 2 import numpy as np 3 import matplotlib.pyplot as plt 4 import operator 5 6 #计算给定数据集的香农熵 7 def calcShannonEnt(dataSet): 8 numEntries = len(dataSet) 9 labelCounts = {} 10 for featVec in dataSet: #| 11 currentLabel = featVec[-1] #| 12 if currentLabel not in labelCounts.keys(): #|获取标签类别取值空间(key)及出现的次数(value) 13 labelCounts[currentLabel] = 0 #| 14 labelCounts[currentLabel] += 1 #| 15 shannonEnt = 0.0 16 for key in labelCounts: #| 17 prob = float(labelCounts[key])/numEntries #|计算香农熵 18 shannonEnt -= prob * log(prob, 2) #| 19 return shannonEnt 20 21 #创建数据集 22 def createDataSet(): 23 dataSet = [[1,1,'yes'], 24 [1,1,'yes'], 25 [1,0,'no'], 26 [0,1,'no'], 27 [0,1,'no']] 28 labels = ['no surfacing', 'flippers'] 29 return dataSet, labels 30 31 #按照给定特征划分数据集 32 def splitDataSet(dataSet, axis, value): 33 retDataSet = [] 34 for featVec in dataSet: #| 35 if featVec[axis] == value: #| 36 reducedFeatVec = featVec[:axis] #|抽取出符合特征的数据 37 reducedFeatVec.extend(featVec[axis+1:]) #| 38 retDataSet.append(reducedFeatVec) #| 39 return retDataSet 40 41 #选择最好的数据集划分方式 42 def chooseBestFeatureToSplit(dataSet): 43 numFeatures = len(dataSet[0]) - 1 44 basicEntropy = calcShannonEnt(dataSet) 45 bestInfoGain = 0.0; bestFeature = -1 46 for i in range(numFeatures): #计算每一个特征的熵增益 47 featlist = [example[i] for example in dataSet] 48 uniqueVals = set(featlist) 49 newEntropy = 0.0 50 for value in uniqueVals: #计算每一个特征的不同取值的熵增益 51 subDataSet = splitDataSet(dataSet, i, value) 52 prob = len(subDataSet)/float(len(dataSet)) 53 newEntropy += prob * calcShannonEnt(subDataSet) #不同取值的熵增加起来就是整个特征的熵增益 54 infoGain = basicEntropy - newEntropy 55 if (infoGain > bestInfoGain): #选择最高的熵增益作为划分方式 56 bestInfoGain = infoGain 57 bestFeature = i 58 return bestFeature 59 #挑选出现次数最多的类别 60 def majorityCnt(classList): 61 classCount={} 62 for vote in classList: 63 if vote not in classCount.keys(): 64 classCount[vote] = 0 65 classCount[vote] += 1 66 sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True) 67 return sortedClassCount[0][0] 68 69 def createTree(dataSet, labels): 70 classList = [example[-1] for example in dataSet] 71 if classList.count(classList[0]) == len(classList): #停止条件一:判断所有类别标签是否相同,完全相同则停止继续划分 72 return classList[0] 73 if len(dataSet[0]) == 1: #停止条件二:遍历完所有特征时返回出现次数最多的 74 return majorityCnt(classList) 75 bestFeat = chooseBestFeatureToSplit(dataSet) #得到列表包含的所有属性值 76 bestFeatLabel = labels[bestFeat] 77 myTree = {bestFeatLabel:{}} 78 del(labels[bestFeat]) 79 featValues = [example[bestFeat] for example in dataSet] 80 uniqueVals = set(featValues) 81 for value in uniqueVals: 82 subLabels = labels[:] 83 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) 84 return myTree 85 86 # Simple unit test of func: createDataSet() 87 myDat, labels = createDataSet() 88 print (myDat) 89 #print (labels) 90 # Simple unit test of func: splitDataSet() 91 splitData = splitDataSet(myDat,0,1) 92 print (splitData) 93 # Simple unit test of func: chooseBestFeatureToSplit() 94 chooseResult = chooseBestFeatureToSplit(myDat) 95 print (chooseResult) 96 # Simple unit test of func: createTree( 97 myDat, labels = createDataSet() 98 myTree = createTree(myDat, labels) 99 print(myTree)
Output:
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
[[1, 'yes'], [1, 'yes'], [0, 'no']]
0
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
Reference:
《机器学习实战》