有点乱,等我彻底想明白时再来整理清楚。
from math import log import operator def calcShannonEnt(dataSet): numEntries = len(dataSet) #print("样本总数:" + str(numEntries)) labelCounts = {} #记录每一类标签的数量 #定义特征向量featVec for featVec in dataSet: currentLabel = featVec[-1] #最后一列是类别标签 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0; labelCounts[currentLabel] += 1 #标签currentLabel出现的次数 #print("当前labelCounts状态:" + str(labelCounts)) shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries #每一个类别标签出现的概率 #print(str(key) + "类别的概率:" + str(prob)) #print(prob * log(prob, 2) ) shannonEnt -= prob * log(prob, 2) #print("熵值:" + str(shannonEnt)) return shannonEnt def createDataSet(): dataSet = [ [1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no'], #以下随意添加,用于测试熵的变化,越混乱越冲突,熵越大 # [1, 1, 'no'], # [1, 1, 'no'], # [1, 1, 'no'], # [1, 1, 'no'], #[1, 1, 'maybe'], # [1, 1, 'maybe1'] # 用下面的8个比较极端的例子看得会更清楚。如果按照这个规则继续增加下去,熵会继续增大。 # [1,1,'1'], # [1,1,'2'], # [1,1,'3'], # [1,1,'4'], # [1,1,'5'], # [1,1,'6'], # [1,1,'7'], # [1,1,'8'], # 这是另一个极端的例子,所有样本的类别是一样的,有序,不混乱,此时熵为0 # [1,1,'2'], # [1,1,'1'], # [1,1,'1'], # [1,1,'1'], # [1,1,'1'], # [1,1,'1'], # [1,1,'1'], # [1,1,'1'], ] #print("dataSet[0]:" + str(dataSet[0])) #print(dataSet) labels = ['no surfacing', 'flippers'] return dataSet, labels def testCalcShannonEnt(): myDat, labels = createDataSet() #print(calcShannonEnt(myDat)) def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: #print("featVec:" + str(featVec)) #print("featVec[axis]:" + str(featVec[axis])) if featVec[axis] == value: reduceFeatVec = featVec[:axis] #print(reduceFeatVec) reduceFeatVec.extend(featVec[axis + 1:]) #print('reduceFeatVec:' + str(reduceFeatVec)) retDataSet.append(reduceFeatVec) #print("retDataSet:" + str(retDataSet)) return retDataSet def testSplitDataSet(): myDat,labels = createDataSet() #print(myDat) a = splitDataSet(myDat, 0, 0) #print(a) def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #减掉类别列,剩2列 #print("特征数量:" + str(numFeatures)) baseEntropy = calcShannonEnt(dataSet) #print("基础熵:" + str(baseEntropy)) bestInfoGain = 0.0; bestFeature = -1 #numFeatures==2 for i in range(numFeatures): #print("i的值" + str(i)) featList = [example[i] for example in dataSet]; #print("featList:" + str(featList)) #在列表中创建集合是Python语言得到列表中唯一元素值的最快方法 #集合对象是一组无序排列的可哈希的值。集合化,收缩 #[1, 0, 1, 1, 1, 1]创建集合后,变为{0,1} uniqueVals = set(featList) #print("uniqueVals" + str(uniqueVals)) newEntropy = 0.0 #uniqueVals=={0,1} for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) #print("subDataSet:" + str(subDataSet)) prob = len(subDataSet) / float(len(dataSet)) #print("subDataSet:" + str(subDataSet)) #print("subDataSet的长度:" + str(len(subDataSet))) newEntropy += prob * calcShannonEnt(subDataSet) #print("newEntropy:" + str(newEntropy)) #信息增益,新序列熵越小,增益越大,最终目标是把最大的增益找出来 infoGain = baseEntropy - newEntropy #print("infoGain:" + str(infoGain)) #print("bestInfoGain:" + str(bestInfoGain)) if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i #print("bestFeature:" + str(bestFeature)) return bestFeature def testChooseBestFeatureToSplit(): myDat, labels = createDataSet() chooseBestFeatureToSplit(myDat) ''' 输入:类别列表 输出:类别列表中多数的类,即多数表决 这个函数的作用是返回字典中出现次数最多的value对应的key,也就是输入list中出现最多的那个值 ''' def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 #key=operator.itemgetter(0)或key=operator.itemgetter(1),决定以字典的键排序还是以字典的值排序 #0以键排序,1以值排序 #reverse(是否反转)默认是False,reverse == true 则反转由大到小排列 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) print(sortedClassCount) return sortedClassCount[0][0] def testMajorityCnt(): list1 = ['a','b','a','a','b','c','d','d','d','e','a','a','a','a','c','c','c','c','c','c','c','c'] print(majorityCnt(list1)) global n n=0 def createTree(dataSet, labels): global n print("=================createTree"+str(n)+" begin=============") n += 1 print(n) classList = [example[-1] for example in dataSet] print("第" + str(n) + "次classList:" + str(classList)) print("此时列表中的第1个元素为" + str(classList[0]) + ",数量为:" + str(classList.count(classList[0])) + ",列表总长度为:" + str(len(classList))) print("列表中"+str(classList[0])+"的数量:",classList.count(classList[0])) print("列表的长度:", len(classList)) if classList.count(classList[0])== len(classList): print("判断结果为:所有类别相同,停止本组划分") else: print("判断结果为:类别不相同") #列表中有n个元素,并且n个都一致,则停止递归 if classList.count(classList[0]) == len(classList): return classList[0] print("dataSet[0]:" + str(dataSet[0])) if len(dataSet[0]) == 1: print("启动多数表决") #书中的示例样本集合没有触发 return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] print("bestFeat:" +str(bestFeat)) print("bestFeatLabel:" + str(bestFeatLabel)) myTree = {bestFeatLabel:{}} print("当前树状态:" + str(myTree)) print("当前标签集合:" + str(labels)) print("准备删除" + labels[bestFeat]) del(labels[bestFeat]) print("已删除") print("删除元素后的标签集合:" + str(labels)) featValues = [example[bestFeat] for example in dataSet] print("featValues:",featValues) uniqueVals = set(featValues) print("uniqueVals:", uniqueVals) #{0,1} k = 0 print("********开始循环******") for value in uniqueVals: k += 1 print("第",k,"次循环") subLabels = labels[:] print("传入参数:") print(" --待划分的数据集:",dataSet) print(" --划分数据集的特征:", bestFeat) print(" --需要返回的符合特征值:", value) splited = splitDataSet(dataSet, bestFeat, value) print("splited:", str(splited)) myTree[bestFeatLabel][value] = createTree(splited, subLabels) #递归调用 print("*******结束循环*****") print("=================createTree"+str(n)+" end=============") return myTree def testCreateTree(): myDat,labels = createDataSet(); myTree = createTree(myDat, labels); print("============testCreateTree=============") print(myTree) if __name__ == '__main__': #测试输出信息熵 #testCalcShannonEnt() #测试拆分结果集 #testSplitDataSet() #选择最好的特征值 #testChooseBestFeatureToSplit() #testMajorityCnt() testCreateTree()