• 决策树


    ID3算法构建决策树

      1 # Author Qian Chenglong
      2 #label 特征的名字        dataSet  n个特征+目标
      3 
      4 
      5 from math import log
      6 import operator
      7 
      8 
      9 
     10 '''计算香农熵'''
     11 def calcShannonEnt(dataSet):
     12     numEntries=len(dataSet)
     13     labelCounts={}
     14     for featVec in dataSet:#将数据放入字典中,并计算字典中label出现的次数
     15         currentLabel=featVec[-1]
     16         if currentLabel not in labelCounts.keys():
     17             labelCounts[currentLabel]=0
     18         labelCounts[currentLabel]+=1
     19     shannonEnt=0.0
     20     for key in labelCounts:
     21         porb=float(labelCounts[key])/numEntries #每一个label出现的概率
     22         shannonEnt-=porb*log(porb,2)
     23     return shannonEnt
     24 '''熵越高数据越混乱'''
     25 
     26 '''按照指定特征划分数据集'''
     27 def splitDataSet(dataSet,axis,value):#待划分数据集,划分数据集的特征的下标,特征的值
     28     retDataSet=[]
     29     for featVec in dataSet:
     30         if featVec[axis]==value:
     31             reducedFeatVec=featVec[:axis]           #取出除划分依据用的特征以外的值
     32             reducedFeatVec.extend(featVec[axis+1:])
     33             retDataSet.append(reducedFeatVec)
     34     return retDataSet
     35 '''把指定特征的数据取出来'''
     36 
     37 '''遍历所有特征,选择熵最小的划分方式'''
     38 def chooseBestFeatureToSplit(dataSet):
     39     numFeatures=len(dataSet[0])-1   #获取属性个数,最后一列为label所以-1
     40     baseEntropy=calcShannonEnt(dataSet)  #数据集的原始熵
     41     bestInfoGain=0.0;bestFeature=-1
     42     for i in range(numFeatures):
     43         featList=[example[i] for example in dataSet] #遍历当前特征的所有属性生成一个列表 i为特征下标
     44         uniqueVals=set(featList)                        #创建一个集合,集合会删除重复的内容
     45         newEntropy=0.0
     46         for value in uniqueVals:            #遍历当前特征的所有值
     47             subDataSet=splitDataSet(dataSet,i,value)
     48             prob=len(subDataSet)/float(len(dataSet))
     49             newEntropy+=prob*calcShannonEnt(subDataSet)  #计算新的熵
     50         infoGain=baseEntropy-newEntropy        #baseEntropy-newEntropy求熵减,即信息增益
     51         if(infoGain>bestInfoGain):
     52             bestInfoGain=infoGain
     53             bestFeature=i
     54     return bestFeature
     55 
     56 '''出现最多的目标及其次数'''
     57 def majorityCnt(classList):
     58     classCount={}
     59     for vote in classList:
     60         if vote not in classCount.keys():
     61             classCount[vote]=0
     62         classCount[vote]+=1
     63     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#reverse = True 降序 , reverse = False 升序(默认)
     64     return sortedClassCount[0][0]
     65 
     66 def createTree(dataSet,labels):
     67     classList = [example[-1] for example in dataSet]        #目标的列表
     68     if classList.count(classList[0]) == len(classList):      #所有类别都相同,即只有1个目标
     69         return classList[0]                                   #停止继续划分
     70     if len(dataSet[0]) == 1:                                 # 用完了所有特征,即只剩最后一个“目标”的时候,遍历完所有实例返回出现次数最多的类别
     71         return majorityCnt(classList)
     72     bestFeat = chooseBestFeatureToSplit(dataSet)
     73     bestFeatLabel = labels[bestFeat]
     74     myTree = {bestFeatLabel:{}}                             #以标签作为关键字创建树
     75     del(labels[bestFeat])                                   #删除已使用的标签
     76     featValues = [example[bestFeat] for example in dataSet]
     77     uniqueVals = set(featValues)
     78     for value in uniqueVals:
     79         subLabels = labels[:]                                 #copy all of labels, so trees don't mess up existing labels
     80         myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
     81     return myTree
     82 
     83 '''获取叶节点数目'''
     84 def getNumLeafs(myTree):
     85     numLeafs=0
     86     firstStr=myTree.keys()[0]
     87     secondDict=myTree[firstStr]
     88     for key in secondDict.keys():
     89         if type(secondDict[key]).__name__=='dict':
     90             numLeafs+=getNumLeafs(secondDict[key])
     91         else:   numLeafs+=1
     92     return numLeafs
     93 
     94 '''获取树的层数'''
     95 def getTreeDepth(myTree):
     96     maxDepth=0
     97     firstStr=myTree.key()[0]
     98     secondDict=myTree[firstStr]
     99     for key in secondDict.keys():
    100         if type(secondDict[key]).__name__=='dict':
    101             thisDepth=1+getTreeDepth(secondDict[key])
    102         else:   thisDepth=1
    103         if thisDepth>maxDepth:
    104             maxDepth=thisDepth
    105     return maxDepth
    106 
    107 '''使用决策树的分类函数'''
    108 def classify(inputTree,featLabels,testVec):
    109     firstStr = inputTree.keys()[0]    #字典中的第一个key
    110     secondDict = inputTree[firstStr]        #第二层字典
    111     featIndex = featLabels.index(firstStr)
    112     key = testVec[featIndex]
    113     valueOfFeat = secondDict[key]
    114     if isinstance(valueOfFeat, dict):
    115         classLabel = classify(valueOfFeat, featLabels, testVec)
    116     else: classLabel = valueOfFeat
    117     return classLabel
    118 
    119 '''存储树'''
    120 def storeTree(inputTree,filename):
    121     import pickle
    122     fw = open(filename,'w')
    123     pickle.dump(inputTree,fw)
    124     fw.close()
    125 
    126 '''加载树'''
    127 def grabTree(filename):
    128     import pickle
    129     fr = open(filename)
    130     return pickle.load(fr)
  • 相关阅读:
    JDBC
    Linux下的tar压缩解压缩命令详解
    Shell编程介绍
    SSH服务认证类型介绍
    SSH介绍及连接原理
    MySQL引擎之innodb介绍及特点
    myisam 存储引擎介绍及特点
    maven项目依赖中报错Plugin ‘org.apache.maven.plugins:maven-compiler-plugin:’ not found
    List调用add方法报错java.lang.UnsupportedOperationException
    C基础
  • 原文地址:https://www.cnblogs.com/long5683/p/9340083.html
Copyright © 2020-2023  润新知