• Linux下用matplotlib画决策树


    1、trees = {'no surfacing': { 0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

    2、从我的文件trees.txt里读的决策树,也是一个递归字典表示

    #coding=utf-8
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt  # 载入 pyplot API
    import os, sys
    import time
    
    decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 注(a)
    leafNode = dict(boxstyle="round4", fc="0.8")
    arrow_args = dict(arrowstyle="<-")  # 箭头样式
    
    def plotNode(Nodename, centerPt, parentPt, nodeType):  #  centerPt节点中心坐标  parentPt 起点坐标
        creatPlot.ax1.annotate(Nodename, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) # 注(b)
    
    def getNumleafs(mytree): # 获得叶节点数目,输入为我们前面得到的树(字典)
        Numleafs = 0 # 初始化
        firstStr = list(mytree.keys())[0] # 注(a) 获得第一个key值(根节点) 'no surfacing'
        secondDict = mytree[firstStr]  # 获得value值 {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
        for key in secondDict.keys(): #  键值:0 和 1
            if type(secondDict[key]).__name__=='dict': # 判断如果里面的一个value是否还是dict
                Numleafs += getNumleafs(secondDict[key]) # 递归调用
            else:
                Numleafs += 1
        return Numleafs
    
    def getTreeDepth(mytree):
        maxDepth = 0
        firstStr = list(mytree.keys())[0]
        secondDict = mytree[firstStr]
        for key in secondDict.keys(): #  键值:0 和 1
            thisDepth = 0
            if type(secondDict[key]).__name__=='dict': # 判断如果里面的一个value是否还是dict
                thisDepth = 1 + getTreeDepth(secondDict[key]) # 递归调用
            else:
                thisDepth = 1
            if thisDepth > maxDepth:
                maxDepth = thisDepth
        return maxDepth
    
    def plotMidText(cntrPt, parentPt, txtString):   #  在两个节点之间的线上写上字
        xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
        yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
        creatPlot.ax1.text(xMid, yMid, txtString)  # text() 的使用
    
    def plotTree(myTree, parentPt, nodeName):  # 画树
        numleafs = getNumleafs(myTree)
        depth = getTreeDepth(myTree)
        firstStr = myTree.keys()[0]
        cntrPt = (plotTree.xOff+(0.5/plotTree.totalw+float(numleafs)/2.0/plotTree.totalw), plotTree.yOff)
        plotMidText(cntrPt, parentPt, nodeName) 
        plotNode(firstStr, cntrPt, parentPt, decisionNode)
        secondDict = myTree[firstStr]
        plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 减少y的值,将树的总深度平分,每次减少移动一点(向下,因为树是自顶向下画的)
        for key in secondDict.keys():
            if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key], cntrPt, str(key))
            else:
                plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalw
                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
        plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
    def creatPlot(inTree):  # 使用的主函数
        fig = plt.figure(figsize=(200,200), facecolor='white')
        fig.clf()  # 清空绘图区
        axprops = dict(xticks=[], yticks=[]) # 创建字典 存储=====有疑问???=====
        creatPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #  ===参数的意义?===
        plotTree.totalw = float(getNumleafs(inTree))
        plotTree.totalD = float(getTreeDepth(inTree))  # 创建两个全局变量存储树的宽度和深度
        print 'tree width =', plotTree.totalw 
        print 'tree height =', plotTree.totalD 
        plotTree.xOff = -0.5/plotTree.totalw # 追踪已经绘制的节点位置 初始值为 将总宽度平分 在取第一个的一半 
        plotTree.yOff = 1.0
        plotTree(inTree, (0.5,1.0), '')  # 调用函数,并指出根节点源坐标 
        plt.savefig('images/tree2.png', format='png',  dpi=100)
    
    trees = []
    try:
            fin = open(sys.argv[1])
            line = fin.readline()
            trees = eval(line)
            #print trees
    except:
            print 'load tree error'
            raise
    if(len(sys.argv) == 1):
        trees = {'no surfacing': { 0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
    t1 = time.clock()
    creatPlot(trees)
    t2 = time.clock()
    print t2 - t1

    ps:参考博客[http://blog.csdn.net/ifruoxi/article/details/53150129]

  • 相关阅读:
    51Nod 1352 集合计数(扩展欧几里德)
    莫比乌斯函数
    Codefroces 919D Substring(拓扑排序+DP)
    Codeforces 918C The Monster(括号匹配+思维)
    平面分割类问题总结
    01字典树(待更新)
    进程同步和互斥??
    进程间的八种通信方式----共享内存是最快的 IPC 方式??
    super() 函数??
    HTTP协议详解??
  • 原文地址:https://www.cnblogs.com/vincent-vg/p/6730773.html
Copyright © 2020-2023  润新知