• python实现简单决策树(信息增益)——基于周志华的西瓜书数据


    数据集如下:

     1 色泽    根蒂    敲声    纹理    脐部    触感    好瓜
     2 青绿    蜷缩    浊响    清晰    凹陷    硬滑    是
     3 乌黑    蜷缩    沉闷    清晰    凹陷    硬滑    是
     4 乌黑    蜷缩    浊响    清晰    凹陷    硬滑    是
     5 青绿    蜷缩    沉闷    清晰    凹陷    硬滑    是
     6 浅白    蜷缩    浊响    清晰    凹陷    硬滑    是
     7 青绿    稍蜷    浊响    清晰    稍凹    软粘    是
     8 乌黑    稍蜷    浊响    稍糊    稍凹    软粘    是
     9 乌黑    稍蜷    浊响    清晰    稍凹    硬滑    是
    10 乌黑    稍蜷    沉闷    稍糊    稍凹    硬滑    否
    11 青绿    硬挺    清脆    清晰    平坦    软粘    否
    12 浅白    硬挺    清脆    模糊    平坦    硬滑    否
    13 浅白    蜷缩    浊响    模糊    平坦    软粘    否
    14 青绿    稍蜷    浊响    稍糊    凹陷    硬滑    否
    15 浅白    稍蜷    沉闷    稍糊    凹陷    硬滑    否
    16 乌黑    稍蜷    浊响    清晰    稍凹    软粘    否
    17 浅白    蜷缩    浊响    模糊    平坦    硬滑    否
    18 青绿    蜷缩    沉闷    稍糊    稍凹    硬滑    否

    基于信息增益的ID3决策树的原理这里不再赘述,读者如果不明白可参考西瓜书对这部分内容的讲解。

    python实现代码如下:

      1 from math import log2
      2 import pandas as pd
      3 import matplotlib.pyplot as plt
      4 from matplotlib.font_manager import FontProperties
      5 
      6 # 统计label出现次数
      7 def get_counts(data):
      8     total = len(data)
      9     results = {}
     10     for d in data:
     11         results[d[-1]] = results.get(d[-1], 0) + 1
     12     return results, total
     13 
     14 # 计算信息熵
     15 def calcu_entropy(data):
     16     results, total = get_counts(data)
     17     ent = sum([-1.0*v/total*log2(v/total) for v in results.values()])
     18     return ent
     19 
     20 # 计算每个feature的信息增益
     21 def calcu_each_gain(column, update_data):
     22     total = len(column)
     23     grouped = update_data.iloc[:, -1].groupby(by=column)
     24     temp = sum([len(g[1])/total*calcu_entropy(g[1]) for g in list(grouped)])
     25     return calcu_entropy(update_data.iloc[:, -1]) - temp
     26 
     27 # 获取最大的信息增益的feature
     28 def get_max_gain(temp_data):
     29     columns_entropy = [(col, calcu_each_gain(temp_data[col], temp_data)) for col in temp_data.iloc[:, :-1]]
     30     columns_entropy = sorted(columns_entropy, key=lambda f: f[1], reverse=True)
     31     return columns_entropy[0]
     32 
     33 # 去掉数据中已存在的列属性内容
     34 def drop_exist_feature(data, best_feature):
     35     attr = pd.unique(data[best_feature])
     36     new_data = [(nd, data[data[best_feature] == nd]) for nd in attr]
     37     new_data = [(n[0], n[1].drop([best_feature], axis=1)) for n in new_data]
     38     return new_data
     39 
     40 # 获得出现最多的label
     41 def get_most_label(label_list):
     42     label_dict = {}
     43     for l in label_list:
     44         label_dict[l] = label_dict.get(l, 0) + 1
     45     sorted_label = sorted(label_dict.items(), key=lambda ll: ll[1], reverse=True)
     46     return sorted_label[0][0]
     47 
     48 # 创建决策树
     49 def create_tree(data_set, column_count):
     50     label_list = data_set.iloc[:, -1]
     51     if len(pd.unique(label_list)) == 1:
     52         return label_list.values[0]
     53     if all([len(pd.unique(data_set[i])) ==1 for i in data_set.iloc[:, :-1].columns]):
     54         return get_most_label(label_list)
     55     best_attr = get_max_gain(data_set)[0]
     56     tree = {best_attr: {}}
     57     exist_attr = pd.unique(data_set[best_attr])
     58     if len(exist_attr) != len(column_count[best_attr]):
     59         no_exist_attr = set(column_count[best_attr]) - set(exist_attr)
     60         for nea in no_exist_attr:
     61             tree[best_attr][nea] = get_most_label(label_list)
     62     for item in drop_exist_feature(data_set, best_attr):
     63         tree[best_attr][item[0]] = create_tree(item[1], column_count)
     64     return tree
     65 
     66 # 决策树绘制基本参考《机器学习实战》书内的代码以及博客:http://blog.csdn.net/c406495762/article/details/76262487
     67 # 获取树的叶子节点数目
     68 def get_num_leafs(decision_tree):
     69     num_leafs = 0
     70     first_str = next(iter(decision_tree))
     71     second_dict = decision_tree[first_str]
     72     for k in second_dict.keys():
     73         if isinstance(second_dict[k], dict):
     74             num_leafs += get_num_leafs(second_dict[k])
     75         else:
     76             num_leafs += 1
     77     return num_leafs
     78 
     79 # 获取树的深度
     80 def get_tree_depth(decision_tree):
     81     max_depth = 0
     82     first_str = next(iter(decision_tree))
     83     second_dict = decision_tree[first_str]
     84     for k in second_dict.keys():
     85         if isinstance(second_dict[k], dict):
     86             this_depth = 1 + get_tree_depth(second_dict[k])
     87         else:
     88             this_depth = 1
     89         if this_depth > max_depth:
     90             max_depth = this_depth
     91     return max_depth
     92 
     93 # 绘制节点
     94 def plot_node(node_txt, center_pt, parent_pt, node_type):
     95     arrow_args = dict(arrowstyle='<-')
     96     font = FontProperties(fname=r'C:WindowsFontsSTXINGKA.TTF', size=15)
     97     create_plot.ax1.annotate(node_txt, xy=parent_pt,  xycoords='axes fraction', xytext=center_pt,
     98                             textcoords='axes fraction', va="center", ha="center", bbox=node_type,
     99                             arrowprops=arrow_args, FontProperties=font)
    100 
    101 # 标注划分属性
    102 def plot_mid_text(cntr_pt, parent_pt, txt_str):
    103     font = FontProperties(fname=r'C:WindowsFontsMSYH.TTC', size=10)
    104     x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
    105     y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
    106     create_plot.ax1.text(x_mid, y_mid, txt_str, va="center", ha="center", color='red', FontProperties=font)
    107 
    108 # 绘制决策树
    109 def plot_tree(decision_tree, parent_pt, node_txt):
    110     d_node = dict(boxstyle="sawtooth", fc="0.8")
    111     leaf_node = dict(boxstyle="round4", fc='0.8')
    112     num_leafs = get_num_leafs(decision_tree)
    113     first_str = next(iter(decision_tree))
    114     cntr_pt = (plot_tree.xoff + (1.0 +float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yoff)
    115     plot_mid_text(cntr_pt, parent_pt, node_txt)
    116     plot_node(first_str, cntr_pt, parent_pt, d_node)
    117     second_dict = decision_tree[first_str]
    118     plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totalD
    119     for k in second_dict.keys():
    120         if isinstance(second_dict[k], dict):
    121             plot_tree(second_dict[k], cntr_pt, k)
    122         else:
    123             plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalW
    124             plot_node(second_dict[k], (plot_tree.xoff, plot_tree.yoff), cntr_pt, leaf_node)
    125             plot_mid_text((plot_tree.xoff, plot_tree.yoff), cntr_pt, k)
    126     plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totalD
    127 
    128 def create_plot(dtree):
    129     fig = plt.figure(1, facecolor='white')
    130     fig.clf()
    131     axprops = dict(xticks=[], yticks=[])
    132     create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
    133     plot_tree.totalW = float(get_num_leafs(dtree))
    134     plot_tree.totalD = float(get_tree_depth(dtree))
    135     plot_tree.xoff = -0.5/plot_tree.totalW
    136     plot_tree.yoff = 1.0
    137     plot_tree(dtree, (0.5, 1.0), '')
    138     plt.show()
    139 
    140 if __name__ == '__main__':
    141     my_data = pd.read_csv('./watermelon2.0.csv', encoding='gbk')
    142     column_count = dict([(ds, list(pd.unique(my_data[ds]))) for ds in my_data.iloc[:, :-1].columns])
    143     d_tree = create_tree(my_data, column_count)
    144     create_plot(d_tree)

    绘制的决策树如下:

  • 相关阅读:
    如何将 Web 框架迁移到 Serverless
    Serverless + CVM 实战
    Serverless Registry 设计解读与实战
    使用 ServerLess 实现云原生
    腾讯云 Serverless 技术在「老司机汽车 app」的落地实践
    LeetCode 树 103. 二叉树的锯齿形层次遍历(双端队列 DFS的空间复杂度)
    LeetCode 树 116. 填充每个节点的下一个右侧节点指针(层序遍历 分层)
    LeetCode 树 230. 二叉搜索树中第K小的元素(二叉搜索树 中序遍历 剪枝)
    LeetCode 树 236. 二叉树的最近公共祖先(递归 深度优先搜索)
    LeetCode 树 102. 二叉树的层序遍历(广度优先搜索 深度优先搜索 队列)
  • 原文地址:https://www.cnblogs.com/laresh/p/7676170.html
Copyright © 2020-2023  润新知