• 机器学习:决策树算法(简单尝试)


    这里只写一下用C++简单实现的ID3算法决策树

    ID3算法是基于信息熵和信息获取量

    每次建立新节点时,选取一个信息获取量最大(以信息熵为衡量)的属性进行分割

    决策树还有很多其他算法,不过都只是衡量标准不同

    实质都是按照贪心自上而下地建树

    如果深度过深,还要采取剪枝的手段

    #include <iostream>
    #include <cstdio>
    #include <cstring>
    #include <vector>
    #include <cmath>
    using namespace std;
    typedef unsigned int ui;
    typedef vector< vector<int>> dv;
    const int maxm = 100, maxn = 1000;
    const double eps = 1e-7;
    struct Node
    {
        bool flag[maxm];
        int st, yes, no;
    }node[maxn];   //结点,flag表示已采用的属性,st为此次划分的标准 
    
    double cal_entropy(double p) //计算信息熵
    {
        if(abs(p) <= eps || abs(p-1) <= eps) return 0;
        return -(p*log(p)/log(2)+(1-p)*log(1-p)/log(2));
    }
    
    double split(dv v, int k) //算出如果以第k个属性分割得到的信息获取量
    {
        int v1, v2, n1, n2;
        v1 = v2 = n1 = n2 = 0;
        for(ui i = 0; i < v.size(); i++)
        {
            if(v[i][k])
            {
                n1++;
                if(v[i][v[i].size()-1]) v1++;
            }
            else
            {
                n2++;
                if(v[i][v[i].size()-1]) v2++;
            }
        }
        int n = n1+n2;
        double ans = (double)n1/n*cal_entropy((double)v1/n1) + (double)n2/n*cal_entropy((double)v2/n2);
        return cal_entropy((double)(v1+v2)/n) - ans;
    }
    
    
    
    void build(int x, dv vnode) //按照贪心算法建树
    {
        double ans = -1;
        int k = -1;
        for(ui i = 0; i < vnode.size(); i++)
            if(vnode[i][vnode[i].size()-1]) node[x].yes++;
        node[x].no = vnode.size() - node[x].yes;
        for(ui i = 0; i < vnode[0].size()-1; i++)
            if(!node[x].flag[i] && (split(vnode, i) - ans > eps))
            {
                ans = split(vnode, i);
                k = i;
            }
        node[x].st = k;
        printf("%d %d %d %d
    ", x, node[x].yes, node[x].no, node[x].st);  //先序遍历输出树的结构
        if(k == -1) return;
        dv v1, v2;
        for(ui i = 0; i < vnode.size(); i++)
            if(vnode[i][k]) v1.push_back(vnode[i]);
            else v2.push_back(vnode[i]);
        for(ui i = 0; i < v1[0].size(); i++)
        {
            node[x*2].flag[i] = node[x].flag[i];
            node[x*2+1].flag[i] = node[x].flag[i];
        }
        node[x*2].flag[k] = node[x*2+1].flag[k] = 1;
        build(x*2, v1); build(x*2+1, v2);
    }
    
    int n, m, x;
    dv v;
    
    int dfs(int x, vector<int> vv)  //用于测试集
    {
        if(node[x].st == -1) return node[x].yes > node[x].no;
        if(vv[node[x].st]) return dfs(2*x, vv);
        else return dfs(2*x+1, vv);
    }
    
    vector <int> vv;
    int main()
    {
        freopen("a.txt", "r", stdin);
        cin>>n>>m;
        v.resize(n);
        for(int i = 0; i < n; i++)
            for(int j = 0; j < m; j++)
            {
                cin>>x;
                v[i].push_back(x);
            }
        build(1, v);
        for(int i = 0; i < m; i++) cin>>x, vv.push_back(x);
        cout<<dfs(1, vv)<<endl;
    }
  • 相关阅读:
    form组件进阶_django
    form组件_django
    django的数据库ORM进阶操作
    内网安装python模块_python
    Redhat7.4安装oracle11.2.0.4版本数据库遇见的问题_oracle
    Redis基础数据类型与对象
    SpringIOC容器——ApplicationContext和BeanFactory
    AQS源码解析
    Java内存模型(一)
    面试准备笔记
  • 原文地址:https://www.cnblogs.com/Saurus/p/6308601.html
Copyright © 2020-2023  润新知