• Python实现决策树


    Python实现无剪枝的决策树

    import math
    import numpy as np
    import pydot
    
    """
    自己实现决策树:无剪枝
    """
    
    
    class Node:
        def __init__(self, data):
            self.attr = -1  # 当前节点上的划分属性
            self.sons = {}  # 该结点的儿子结点
            self.ans = None  # 该结点的答案,只有叶子节点才有,通过此属性判断是否是叶子节点
            self.data = data  # 数据的下标列表
    
        def is_leaf(self):
            return self.ans is not None
    
        def __str__(self):
            return "node_cnt:{} {}".format(len(self.data), "ans:%d" % self.ans if self.ans else "")
    
    
    class DicisionTree:
        def __init__(self, x, y, creteria="id3"):
            self.x = np.array(x)
            self.y = np.array(y)
            if self.x.dtype not in (np.int, np.int64) or self.y.dtype not in (np.int, np.int64):
                raise Exception("这里的决策树只能处理整数类型")
            self.num_class = len(set(y))
            self.num_attr = len(x[0])
            self.creteria = creteria
            self.root = self._build(list(range(len(self.y))), list(range(self.num_attr)))
    
        def _split(self, x, attr):
            # 将数据集x按照属性attr的取值分开
            xset = {}
            for i in x:
                v = self.x[i][attr]
                if v not in xset:
                    xset[v] = []
                xset[v].append(i)
            return xset
    
        def _buildTable(self, x, attrs):
            # 按照属性、属性取值、类别三个维度统计元素个数
            table = [{} for _ in range(self.num_attr)]
            for i in x:
                for attr in range(self.num_attr):
                    v, c = self.x[i][attr], self.y[i]
                    if v not in table[attr]:
                        table[attr][v] = {}
                    if c not in table[attr][v]:
                        table[attr][v][c] = 0
                    table[attr][v][c] += 1
            return table
    
        def _id3(self, table):
            # 根据表求id3的值
            aloga = 0
            rlogr = 0
            for v in table:
                r = 0
                for c in table[v]:
                    aloga += table[v][c] * math.log(table[v][c])
                    r += table[v][c]
                rlogr += r * math.log(r)
            return aloga - rlogr
    
        def _c45(self, table, tlogt, slogs):
            aloga = 0
            rlogr = 0
            for v in table:
                r = 0
                for c in table[v]:
                    aloga += table[v][c] * math.log(table[v][c])
                    r += table[v][c]
                rlogr += r * math.log(r)
            return (tlogt - aloga) / (1 if len(table) == 1 else  rlogr - slogs)
    
        def _gini(self, table):
            gain = 0
            for v in table:
                a2 = 0
                r = 0
                for c in table[v]:
                    a2 += table[v][c] ** 2
                    r += table[v][c]
                gain += a2 / r
            return gain
    
        # 只有C45用到了slogs和tlogt,id3和gini都没有用到
        def _c45_tlogt(self, data):
            slogs = len(data) * math.log(len(data))
            cnt = {}
            for i in data:
                y = self.y[i]
                if y not in cnt:
                    cnt[y] = 0
                cnt[y] += 1
            tlogt = 0
            for i in cnt:
                tlogt += cnt[i] * math.log(cnt[i])
            return tlogt, slogs
    
        def _selectAttr(self, x, attrs):
            # 选择属性
            t = self._buildTable(x, attrs)
            ans_attr, ans_gain = None, -0xfffff
            for attr in attrs:
                if self.creteria == "id3":
                    gain = self._id3(t[attr])
                elif self.creteria == "gini":
                    gain = self._gini(t[attr])
                elif self.creteria == "c45":
                    tlogt, slogs = self._c45_tlogt(x)
                    gain = self._c45(t[attr], tlogt=tlogt, slogs=slogs)
                else:
                    raise Exception("unkown creterial{},the 3 suported creteria are id3,c45,gini".format(self.creteria))
                if ans_gain is None or gain > ans_gain:
                    ans_gain = gain
                    ans_attr = attr
            return ans_attr
    
        def _allsame(self, array):
            x = array[0]
            for i in array:
                if x != i: return False
            return True
    
        def _build(self, data, attrs):
            node = Node(data)
            if self._allsame(self.y[data]) or not attrs:
                node.ans = self.y[data[0]]
                return node
            node.attr = self._selectAttr(data, attrs)
            # print(node.attr, "selected attr")
            attrs.remove(node.attr)
            xset = self._split(data, node.attr)
            for v in xset.keys():
                node.sons[v] = self._build(xset[v], attrs)
            attrs.append(node.attr)  # 将属性复原还给父结点
            return node
    
        def predict(self, data_x):
            def _predict_one(x):
                node = self.root
                while not node.is_leaf():
                    value = x[node.attr]
                    if value in node.sons:
                        node = node.sons[value]
                    else:
                        break
                if node.is_leaf():
                    return node.ans
                return None  # 无答案
    
            return np.array(list(map(_predict_one, data_x)))
    
        def get_node_count(self):
            def dfs(node):
                cnt = 1
                for i in node.sons:
                    cnt += dfs(node.sons[i])
                return cnt
    
            return dfs(self.root)
    
        def export_graphviz(self):
            g = pydot.Dot(graph_type="digraph")
    
            def dfs(node, parent, label):
                if hasattr(dfs, "nodeid"):
                    dfs.nodeid += 1
                else:
                    dfs.nodeid = 0
                me = pydot.Node(str(dfs.nodeid), label=str(node))
                g.add_node(me)
                if parent is not None:
                    g.add_edge(pydot.Edge(parent, me, label=label))
                for k, v in node.sons.items():
                    dfs(v, me, "attr{}={}".format(node.attr, k))
    
            dfs(self.root, None, "")
            g.write("haha.jpg", prog='dot', format="jpg")
    
    
    if __name__ == '__main__':
        # 增益函数的选取:id3,gini,c45
        gain_f = "id3"
        x = np.array([[0, 3, 0], [0, 2, 1], [1, 1, 2], [1, 2, 2], [2, 3, 0], [2, 1, 1]])
        y = np.array([0, 0, 1, 2, 0, 1])
        tree = DicisionTree(x, y, gain_f)
        ans = tree.predict(x)
        print(ans)
        cnt = np.count_nonzero(y == ans)
        print('正确的个数,正确率', cnt, cnt / len(x))
        print('不确定的个数', len([1 for i in range(len(ans)) if ans[i] == 'not found']))
        print('结点总数', tree.get_node_count())
    
    
  • 相关阅读:
    11月2号
    2020.9.22 (Java测试)
    2020.8.30 + 每周周报(8)
    2020.8.31
    2020.8.29
    2020.9.23
    伪分布式hbase从0.94.11版本升级stable的1.4.9版本
    python 将列表(也可以是file.readlines())输出多个文件
    hbase0.94.11版本和hbase1.4.9版本的benchamark区别
    idea 配置
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/8488124.html
Copyright © 2020-2023  润新知