• DecisionTree_python


    #coding:utf-8
    from numpy import *
    from math import *
    import operator
    def file2matrix(filename):
        fr=open(filename)
        lines=fr.readlines()
        lenth=len(lines)
        rematrix=zeros((lenth,7))
        label=["seze","gendi","qiaoshen","wenli","qibu","chugan"]#西瓜特征集
        index=0
        for line in lines:
            line=line.strip()
            lin=line.split(" ")
            rematrix[index:]=lin
            index=index+1
        return rematrix,label
    def singlesplit(data,axis,value):
        newlistt=[]
        for feat in data:
            if feat[axis]==value:
                newlist=list([feat[axis]])
                newlist.extend([feat[-1]])
                newlistt.append(newlist)
        return newlistt
    def allsplit(data):
        alldata=[]
        baseEntry=calcshannon(data)
        ordermax=0.0
        bestfuture=-1
        lenth=len(data[0])
        for i in range(lenth-1):
            b=[example[i] for example in data]#取得特征的所有取值
            newEntry=0.0
            uniq=set(b)#特征的可能取值
            for j in uniq:
                cooldata=singlesplit(data,i,j)
                prob=len(cooldata)/float(len(data))
                newEntry+=prob*calcshannon(cooldata)
            info=baseEntry-newEntry
            if(info>ordermax):
                ordermax=info
                bestfuture=i
        return bestfuture
    def calcshannon(data):
        simplenum=len(data)
        tempdict={}
        for line in data:
            tail=line[-1]
            if tail not in tempdict.keys():
                tempdict[tail]=0
            tempdict[tail]+=1
        shannonEntry=0.0
        for k in tempdict.keys():
            prob=tempdict[k]/float(simplenum)
            shannonEntry-=prob*log(prob,2)
        return shannonEntry
    def selectbigger(label):
        calcdict={}
        for line in label:
            if line not in calcdict.keys():
                calcdict[line]=0
            calcdict+=1
        Getsorted=sorted(calcdict.iteritems(),key=operator.itemgetter(1),reverse=True)
        return Getsorted[0][0]
    def createTree(data,label):
        labellist=[tt[-1] for tt in data]
        if labellist.count(labellist[0])==len(labellist):#所有样本均为同类
            return labellist[0]
        if len(data[0])==1:#特征集为空
            return selectbigger(labellist)
        bestfuture=allsplit(data)
        bestlabel=label[bestfuture]
        tree={bestlabel:{}}#用字典递归建立树
        del(label[bestfuture])
        bestval=[tt[bestfuture] for tt in data]
        uniq=set(bestval)
        for value in uniq:
            sublabel=label
            tree[bestlabel][value]=createTree(singlesplit(data,bestfuture,value),sublabel)
        return tree
    def classifier(inputree,featurelabel,clsdata):
        firststr=inputree.keys()[0]
        secondict=inputree[firststr]
        classlabel=''
        featindex=featurelabel.index(firststr)
        for key in secondict.keys():
            if clsdata[featindex]==key:
                if type(secondict[key]).__name__=='dict':#当节点为字典是,继续递归,否则返回当前的节点值
                    classlabel=classifier(secondict[key],featurelabel,clsdata)
                else:
                    classlabel=secondict[key]
        return classlabel
    dataset,label=file2matrix("out.txt")
    mytree=createTree(dataset,label)
    dataset,label=file2matrix("out.txt")#createTree中label元素已被全部删除,而classifier要用label
    print classifier(mytree,label,[3,1,1,3,3,1])
  • 相关阅读:
    Junit。。。
    TCP
    InetAddress
    URL
    【转】Hello SDL
    批量移动文件
    在阿里云Ubuntu 14.04.5 LTS下安装nethogs0.8.5
    十二银元分三次找一假
    SQL解析
    POI
  • 原文地址:https://www.cnblogs.com/semen/p/6959056.html
Copyright © 2020-2023  润新知