基本思路:
通过香农熵来决定每一层使用哪一种标签做分类,分类后,通过多数表决法来决定该层两个节点的类别。每次消耗一个标签,所以一共需要递归“标签个数”层。
1 # -*- coding:utf-8 -*- 2 import math 3 import operator 4 from collections import Counter 5 6 def shannon_ent(dat): 7 siz = len(dat) 8 return 0.0 - reduce(lambda x, y: x + y, 9 map(lambda each: float(each)/siz * math.log(float(each)/siz, 2), 10 Counter(map(lambda each: each[-1], dat)).values())) 11 12 def split_dataset(dat, axis, val): 13 ret = filter(lambda each: each[axis] == val, dat) 14 return map(lambda each: each[:axis]+each[axis+1:], ret) 15 16 def choose_best_feature(dat): 17 feature_num = len(dat[0]) - 1 18 base_ent = shannon_ent(dat) 19 best_info_gain = 0.0 20 best_feature = -1 21 for i in range(feature_num): 22 feature_list = set([each[i] for each in dat]) 23 cur_ent = reduce(lambda x, y: x + y, 24 map(lambda val: len(split_dataset(dat, i, val))/float(len(dat))*shannon_ent(split_dataset(dat, i, val)), 25 feature_list)) 26 info_gain = base_ent - cur_ent 27 if info_gain > best_info_gain: 28 best_info_gain, best_feature = info_gain, i 29 return best_feature 30 31 def majority_count(class_list): 32 class_dict = sorted(dict(Counter(class_list)).iteritems(), key=operator.itemgetter(1)) 33 return class_dict[-1][0] 34 35 def create_tree(dat, label): 36 class_list = map(lambda each: each[-1], dat) 37 if class_list.count(class_list[0]) == len(class_list): 38 return class_list[0] 39 if len(dat[0]) == 1: 40 return majority_count(class_list) 41 best_feature = choose_best_feature(dat) 42 best_label = label[best_feature] 43 d_tree = {best_label:{}} 44 del(label[best_feature]) 45 feature_val = map(lambda each: each[best_feature], dat) 46 val_set = set(feature_val) 47 def _update_tree(val): 48 sub_label = label[:] 49 d_tree[best_label][val] = create_tree(split_dataset(dat, best_feature, val), sub_label) 50 map(_update_tree, val_set) 51 return d_tree 52 53 d = [[1,1,'y'], [1,1,'y'], [1,0,'n'], [0,1,'n'], [0,1,'n']] 54 l = ['no surfacing', 'flippers'] 55 56 print create_tree(d, l)