• 决策树ID3算法实现


    决策树的ID3算法基于信息增益来选择最优特征,于是自己实现了一把,直接上代码。

      1 """
      2 CreateTime    : 2019/3/3 22:19
      3 Author        : X
      4 Filename      : decision_tree.py
      5 """
      6 
      7 import pandas as pd
      8 from math import log2
      9 
     10 
     11 def create_data_set():
     12     """Create 8 * 3 data set. two feature."""
     13     data_set = [['long', 'thick', 'man'],
     14                 ['short', 'thick', 'man'],
     15                 ['short', 'thick', 'man'],
     16                 ['long', 'thin', 'woman'],
     17                 ['short', 'thin', 'woman'],
     18                 ['short', 'thick', 'woman'],
     19                 ['long', 'thick', 'woman'],
     20                 ['long', 'thick', 'woman']]
     21     labels = ['hair', 'sound']
     22     return data_set, labels
     23 
     24 
     25 def calculate_entropy(data_set):
     26     """Calculate entropy by data set label.
     27        formula: H(X) = -3/8*log(3/8, 2) - -5/8*log(5/8, 2)"""
     28     data_len = data_set.shape[0]
     29     entropy = 0
     30     for size in data_set.groupby(data_set.iloc[:, -1]).size():
     31         p_label = size/data_len
     32         entropy -= p_label * log2(p_label)
     33     return entropy
     34 
     35 
     36 def get_best_feature(data_set):
     37     """Get the best feature by infoGain.
     38        formula: InfoGain(X, Y) = H(X) - H(X|Y)
     39                 H(X|Y) = sum(P(X) * H(Yx))"""
     40     best_feature = -1
     41     base_entropy = calculate_entropy(data_set)
     42     best_info_gain = 0
     43     len_data = data_set.shape[0]
     44     for i in range(data_set.shape[1] - 1):
     45         new_entropy = 0
     46         for _, group in data_set.groupby(data_set.iloc[:, i]):
     47             p_label = group.shape[0]/len_data
     48             new_entropy += p_label * calculate_entropy(group)
     49         info_gain = base_entropy - new_entropy
     50         if info_gain > best_info_gain:
     51             best_feature = i
     52             best_info_gain = info_gain
     53     return best_feature
     54 
     55 
     56 def majority_cnt(class_list):
     57     """When only class label, return the max label."""
     58     majority_class = class_list.groupby(
     59         class_list.iloc[:, -1]).size().sort_values().index[-1]
     60     return majority_class
     61 
     62 
     63 def create_tree(data_set, labels):
     64     """data_set: DataFrame"""
     65     class_list = data_set.values[:, -1]
     66     class_list_set = set(class_list)
     67     if len(class_list_set) == 1:
     68         return list(class_list)[0]
     69     if len(data_set.values[0]) == 1:
     70         return majority_cnt(data_set)
     71     best_feature = get_best_feature(data_set)
     72     best_feature_label = labels[best_feature]
     73     del labels[best_feature]
     74     my_tree = {best_feature_label: {}}
     75     for name, group in data_set.groupby(data_set.iloc[:, best_feature]):
     76         group.drop(columns=[best_feature], axis=1, inplace=True)
     77         my_tree[best_feature_label][name] = create_tree(group, labels)
     78     return my_tree
     79 
     80 
     81 def classify(test_data, my_tree):
     82     if not test_data:
     83         return 'Not found class.'
     84     for key, tree in my_tree.items():
     85         if key != test_data[0]:
     86             return classify(test_data, tree)
     87         else:
     88             if isinstance(tree, dict):
     89                 del test_data[0]
     90                 return classify(test_data, tree)
     91             else:
     92                 return tree
     93 
     94 
     95 if __name__ == '__main__':
     96     DATA_SET, LABELS = create_data_set()
     97     TREE = create_tree(pd.DataFrame(DATA_SET), LABELS)
     98     import json
     99     print(json.dumps(TREE, indent=4))
    100     print(classify(["thick", "long"], TREE))

    C4.5算法是基于信息增益率来选择最优特征的,即在ID3算法基础上再求出信息增益率即可,将信息增益除以基于label的特征X的熵。

    此处就不再给出实现代码,自己实现一遍意在加深理解。

  • 相关阅读:
    hdu 5723 Abandoned country 最小生成树 期望
    OpenJ_POJ C16G Challenge Your Template 迪杰斯特拉
    OpenJ_POJ C16D Extracurricular Sports 打表找规律
    OpenJ_POJ C16B Robot Game 打表找规律
    CCCC 成都信息工程大学游记
    UVALive 6893 The Big Painting hash
    UVALive 6889 City Park 并查集
    UVALive 6888 Ricochet Robots bfs
    UVALive 6886 Golf Bot FFT
    UVALive 6885 Flowery Trails 最短路
  • 原文地址:https://www.cnblogs.com/xu-xiaofeng/p/10473087.html
Copyright © 2020-2023  润新知