• 决策树ID3算法实现


    决策树的ID3算法基于信息增益来选择最优特征,于是自己实现了一把,直接上代码。

      1 """
      2 CreateTime    : 2019/3/3 22:19
      3 Author        : X
      4 Filename      : decision_tree.py
      5 """
      6 
      7 import pandas as pd
      8 from math import log2
      9 
     10 
     11 def create_data_set():
     12     """Create 8 * 3 data set. two feature."""
     13     data_set = [['long', 'thick', 'man'],
     14                 ['short', 'thick', 'man'],
     15                 ['short', 'thick', 'man'],
     16                 ['long', 'thin', 'woman'],
     17                 ['short', 'thin', 'woman'],
     18                 ['short', 'thick', 'woman'],
     19                 ['long', 'thick', 'woman'],
     20                 ['long', 'thick', 'woman']]
     21     labels = ['hair', 'sound']
     22     return data_set, labels
     23 
     24 
     25 def calculate_entropy(data_set):
     26     """Calculate entropy by data set label.
     27        formula: H(X) = -3/8*log(3/8, 2) - -5/8*log(5/8, 2)"""
     28     data_len = data_set.shape[0]
     29     entropy = 0
     30     for size in data_set.groupby(data_set.iloc[:, -1]).size():
     31         p_label = size/data_len
     32         entropy -= p_label * log2(p_label)
     33     return entropy
     34 
     35 
     36 def get_best_feature(data_set):
     37     """Get the best feature by infoGain.
     38        formula: InfoGain(X, Y) = H(X) - H(X|Y)
     39                 H(X|Y) = sum(P(X) * H(Yx))"""
     40     best_feature = -1
     41     base_entropy = calculate_entropy(data_set)
     42     best_info_gain = 0
     43     len_data = data_set.shape[0]
     44     for i in range(data_set.shape[1] - 1):
     45         new_entropy = 0
     46         for _, group in data_set.groupby(data_set.iloc[:, i]):
     47             p_label = group.shape[0]/len_data
     48             new_entropy += p_label * calculate_entropy(group)
     49         info_gain = base_entropy - new_entropy
     50         if info_gain > best_info_gain:
     51             best_feature = i
     52             best_info_gain = info_gain
     53     return best_feature
     54 
     55 
     56 def majority_cnt(class_list):
     57     """When only class label, return the max label."""
     58     majority_class = class_list.groupby(
     59         class_list.iloc[:, -1]).size().sort_values().index[-1]
     60     return majority_class
     61 
     62 
     63 def create_tree(data_set, labels):
     64     """data_set: DataFrame"""
     65     class_list = data_set.values[:, -1]
     66     class_list_set = set(class_list)
     67     if len(class_list_set) == 1:
     68         return list(class_list)[0]
     69     if len(data_set.values[0]) == 1:
     70         return majority_cnt(data_set)
     71     best_feature = get_best_feature(data_set)
     72     best_feature_label = labels[best_feature]
     73     del labels[best_feature]
     74     my_tree = {best_feature_label: {}}
     75     for name, group in data_set.groupby(data_set.iloc[:, best_feature]):
     76         group.drop(columns=[best_feature], axis=1, inplace=True)
     77         my_tree[best_feature_label][name] = create_tree(group, labels)
     78     return my_tree
     79 
     80 
     81 def classify(test_data, my_tree):
     82     if not test_data:
     83         return 'Not found class.'
     84     for key, tree in my_tree.items():
     85         if key != test_data[0]:
     86             return classify(test_data, tree)
     87         else:
     88             if isinstance(tree, dict):
     89                 del test_data[0]
     90                 return classify(test_data, tree)
     91             else:
     92                 return tree
     93 
     94 
     95 if __name__ == '__main__':
     96     DATA_SET, LABELS = create_data_set()
     97     TREE = create_tree(pd.DataFrame(DATA_SET), LABELS)
     98     import json
     99     print(json.dumps(TREE, indent=4))
    100     print(classify(["thick", "long"], TREE))

    C4.5算法是基于信息增益率来选择最优特征的,即在ID3算法基础上再求出信息增益率即可,将信息增益除以基于label的特征X的熵。

    此处就不再给出实现代码,自己实现一遍意在加深理解。

  • 相关阅读:
    取指定长度的字符串(包括中英文),以"..."的方式显示
    js 常用函数
    js 规范
    js高级编程笔记2
    js高级编程笔记
    WinJS开发div中元素的水平和垂直居中metro
    WinJS开发iframe中Javascript执行错误metro
    MySql乱码
    正则表达式符号系统
    Java替换字符串中的回车换行
  • 原文地址:https://www.cnblogs.com/xu-xiaofeng/p/10473087.html
Copyright © 2020-2023  润新知