• ID3算法 决策树 C++实现


    人工智能课的实验。

    数据结构:多叉树

    这个实验我写了好久,开始的时候从数据的读入和表示入手,写到递归建树的部分时遇到了瓶颈,更新样例集和属性集的办法过于繁琐;

    于是参考网上的代码后重新写,建立决策树类,把属性集、样例集作为数据成员加入类中,并设立访问数组,这样每次更新属性集、样例集时只是标记访问数组的对应元素即可,不必实际拷贝。

    主函数:

     1 #include "Decision_tree.h"
     2 using namespace std;
     3 int main()
     4 {
     5     int num_attr,num_example;
     6     char filename[30];
     7     cout << "请输入训练集文件名:" << endl;
     8     cin >> filename;
     9     freopen(filename, "r", stdin);//从样例文件读入训练内容
    10     cin >> num_attr >> num_example;//读入属性个数、例子个数
    11     Decision_tree my_tree=Decision_tree(num_attr,num_example);
    12     fclose(stdin);
    13     freopen("CON", "r", stdin);//重定向标准输入到控制台
    14     my_tree.display_attr();
    15     cout << "决策树已建成,按深度优先遍历结果如下:" << endl;
    16     my_tree.traverse();
    17     do{
    18         cout << "请输入测试数据,格式:属性1值 属性2值..." << endl;
    19         Example test;
    20         for (int i = 0; i < num_attr; i++)
    21             cin >> test.values[i];
    22         int result = my_tree.judge(test);
    23         if (result == 1) cout << "分类结果为P" << endl;
    24         else if (result == -1) cout << "分类结果为N" << endl;
    25         else if (result == -2) cout << "无法根据已有样例集判断" << endl;
    26         cout << "继续吗?(y/n)";
    27         fflush(stdin);
    28     } while (getchar() == 'y');
    29 }

    属性结构体

    struct Attribute//属性
    {
        string name;
        int count;//属性值个数
        int number;//属性的秩
        string values[MAX_VAL];
    };

    样例结构体

    struct Example//样例
    {
        string values[MAX];
        int pn;
        Example(){ pn = 0; }//默认为未分类的
    };

    决策树的结点

    typedef struct Node//树的结点
    {
        Attribute attr;
        Node* children[MAX_VAL];
        int classification[MAX_VAL];
        Node(){}
    }Node;

    决策树类的实现

      1 class Decision_tree//决策树
      2 {
      3     Node *root;
      4     Example e[MAX];//样例全集
      5     Attribute a[MAX_ATTR];//属性全集
      6     int num_attr, num_example;
      7     int visited_exams[MAX];//样例集的访问情况
      8     int visited_attrs[MAX_ATTR];//属性集的访问情况
      9     Node* recursive_build_tree(int left_e[], int left_a[])//递归建树
     10     {
     11         double max = 0;
     12         int max_attr=-1;
     13         for (int i = 0;i<num_attr;i++)
     14         {//求信息增益最大的属性
     15             if (left_a[i]) continue;
     16             double temp = Gain(left_e, i);
     17             if (max<temp)
     18             {
     19                 max = temp;
     20                 max_attr = i;
     21             }
     22         }
     23         if (max_attr == -1) return NULL;//已没有可判的属性,返回空指针
     24         //cout << a[max_attr].name << endl;
     25         //以这个属性为结点,以各属性值为分支递归建树
     26         int p = 0, n = 0;
     27         Node *new_node=new Node();
     28         new_node->attr = a[max_attr];
     29         for (int i = 0; i<a[max_attr].count;i++)
     30         {//遍历这个属性的所有属性值
     31             for (int j = 0; j < num_example;j++)
     32             {//得到第i个属性值的正反例总数
     33                 if (left_e[j]) continue;
     34                 if (!e[j].values[max_attr].compare(a[max_attr].values[i]))
     35                 {//例子和属性都是循秩访问的,所以向量元素的顺序不能变
     36                     if (e[j].pn) p++;
     37                     else n++;
     38                 }
     39             }
     40             //cout << a[max_attr].values[i] << " ";
     41             //cout << p << " " << n << endl;
     42             if (p && !n)//全是正例,不再分
     43             {
     44                 //cout << "P" << endl;
     45                 new_node->classification[i] = 1;
     46                 new_node->children[i] = NULL;
     47             }
     48             else if (n && !p)//全是反例,不再分
     49             {
     50                 //cout << "N" << endl;
     51                 new_node->classification[i] = -1;
     52                 new_node->children[i] = NULL;
     53             }
     54             else if (!p && !n)//例子集已空
     55             {
     56                 //cout << "none" << endl;
     57                 new_node->classification[i] = -2;//表示未训练到这种分类,无法判断
     58                 new_node->children[i] = NULL;
     59             }
     60             else//例子集不空,且尚未能区分正反,更新访问情况,递归
     61             {
     62                 new_node->classification[i] = 0;
     63                 left_a[max_attr] = 1;//更新属性访问情况
     64                 int left_e_next[MAX];//下一轮的例子集(为便于回溯,不修改原例子集)
     65                 for (int k = 0; k < num_example; k++)
     66                     left_e_next[k] = left_e[k];
     67                 for (int j = 0; j < num_example; j++)
     68                 {
     69                     if (left_e[j]) continue;
     70                     if (!e[j].values[max_attr].compare(a[max_attr].values[i]))
     71                         left_e_next[j] = 0;//属性值匹配的例子,入选下一轮例子集
     72                     else left_e_next[j] = 1;//属性值不匹配,筛除
     73                 }
     74                 new_node->children[i] = recursive_build_tree(left_e_next, left_a);//递归
     75                 left_a[max_attr] = 0;//恢复属性访问情况
     76             }
     77             p = 0;
     78             n = 0;
     79         }
     80         return new_node;
     81     }
     82     double I(int p, int n)
     83     {
     84         double a = p / (p + (double)n);
     85         double b = n / (p + (double)n);
     86         if (a == 0 || b == 0) return 0;
     87         return -a*log(a) / log(2) - b*log(b) / log(2);
     88     }
     89     double Gain(int left_e[], int cur_attr)//计算信息增益
     90     {
     91         int sum_p=0, sum_n=0;
     92         int p[10] = { 0 }, n[10] = { 0 };
     93         for (int i = 0; i < num_example; i++)
     94         {//求样例集的p,n
     95             if (left_e[i]) continue;
     96             if (e[i].pn) sum_p++;
     97             else sum_n++;
     98         }
     99         if (!sum_p && !sum_n)
    100         {
    101             //cout << "no more examples!" << endl;
    102             return -1;//样例集是空集
    103         }
    104             
    105         double sum_Ipn = I(sum_p, sum_n);
    106         for (int i = 0; i < a[cur_attr].count; i++)
    107         {//求第i个属性值的p,n
    108             for (int j = 0; j < num_example; j++)
    109             {
    110                 if (left_e[j]) continue;
    111                 if (!e[j].values[cur_attr].compare(a[cur_attr].values[i]))
    112                     if (e[j].pn) p[i]++;
    113                     else n[i]++;
    114             }
    115         }
    116         double E = 0;
    117         for (int i = 0; i < a[cur_attr].count; i++)//计算属性的期望
    118             E += (p[i] + n[i])*I(p[i], n[i]);
    119         E /= (sum_p + sum_n);
    120         //cout << a[cur_attr].name <<sum_Ipn - E << endl;
    121         return sum_Ipn - E;
    122     }
    123     void recursive_traverse(Node *current)//DFS递归遍历
    124     {
    125         if (current == NULL) return;
    126         cout << current->attr.name << endl;
    127         for (int i = 0; i < current->attr.count; i++)
    128         {
    129             cout << current->attr.values[i] << " " << current->classification[i] << endl;
    130             recursive_traverse(current->children[i]);
    131         }
    132     }
    133     int recursive_judge(Example exa, Node *current)
    134     {
    135         for (int i = 0; i < current->attr.count; i++)
    136         {
    137             if (!exa.values[current->attr.number].compare(current->attr.values[i]))
    138             {
    139                 if (current->children[i]==NULL) return current->classification[i];
    140                 else return recursive_judge(exa, current->children[i]);        
    141             }        
    142         }
    143         return 0;
    144     }
    145 public:
    146     Decision_tree(int num1,int num2)
    147     {
    148         
    149         //通过读文件初始化
    150         num_attr = num1;
    151         num_example = num2;
    152 
    153         for (int i = 0; i<num_attr; i++)
    154         {
    155             a[i].number = i;//属性的秩
    156             cin>>a[i].name;//读入属性名
    157             cin>>a[i].count;//读入此属性的属性值个数
    158             for (int j = 0; j<a[i].count; j++)
    159             {
    160                 cin>>a[i].values[j];//读入各属性值
    161             }
    162         }
    163         
    164         for (int i = 0; i<num_example; i++)
    165         {
    166             string temp;
    167             for (int j = 0; j < num_attr; j++)
    168             {
    169                 cin>>e[i].values[j];
    170             }
    171             cin >> temp;
    172             if (!temp.compare("P")) e[i].pn = 1;
    173             else e[i].pn = 0;
    174         }
    175         //检查
    176         /*for (int i = 0; i<num_attr; i++)
    177         {
    178             cout << a[i].name << endl;//读入属性名
    179             for (int j = 0; j<a[i].count; j++)
    180             {
    181                 cout<<a[i].values[j]<<" ";//读入各属性值
    182             }
    183             cout << endl;
    184         }
    185         for (int i = 0; i<num_example; i++)
    186         {
    187             for (int j = 0; j < num_attr; j++)
    188                 cout<<e[i].values[j]<<" ";
    189             cout<<e[i].pn<<endl;
    190             
    191         }
    192         */
    193         memset(visited_exams, 0, sizeof(visited_exams));
    194         memset(visited_attrs, 0, sizeof(visited_attrs));
    195         root = recursive_build_tree(visited_exams,visited_attrs);
    196     }
    197     void traverse()
    198     {
    199         recursive_traverse(root);
    200     }
    201     int judge(Example exa)//判断
    202     {
    203         int result=recursive_judge(exa,root);
    204         return result;
    205     }
    206     void display_attr()//显示属性
    207     {
    208         cout << "There are " << num_attr << " attributes, they are" << endl;
    209         for (int i = 0; i < num_attr; i++)
    210         {
    211             cout << "[" << a[i].name << "]" << endl;
    212             for (int j = 0; j < a[i].count; j++)
    213                 cout << a[i].values[j] << " ";
    214             cout << endl;
    215         }
    216     }
    217 };
    Decision_tree

    现在这个版本的代码用了10小时完成,去检查时被研究生贬得一文不值。。。也的确,现在我们写的实验题目面向的都是规模非常小的问题,自然体会不到自己的代码在大数据面前的劣势。不过我现在确实学得太少了,很多数据结构都没有动手实现过,算法也是。对C++也只能算入了门。俗话说“磨刀不误砍柴工”,“工欲善其事,必先利其器”,先把基础知识学好,多做基本练习,学到的数据结构和算法都动手实现一遍,这样遇到实际问题也好对应到合适的数据结构和算法。

    另外,参照一本书学习好的代码风格和习惯也是很重要的,因为写代码的习惯是思维习惯的反映,而我现在还处于初学者阶段,按照一种典型的流派模仿,构建起自己的思维模式后再谈其他的。

    忽然觉得自己学了快两年编程还这么水实在是不能忍,都怪大一时年少不懂事没好好学基础。。。

    不过,“悟已往之不谏,知来者之可追”,有了方向,一步步走下去就好,不求优于别人,但一定要“优于过去的自己”。

  • 相关阅读:
    PAT (Basic Level) Practise 1013 数素数
    PAT (Basic Level) Practise 1014 福尔摩斯的约会
    codeforces 814B.An express train to reveries 解题报告
    KMP算法
    rsync工具
    codeforces 777C.Alyona and Spreadsheet 解题报告
    codeforces 798C.Mike and gcd problem 解题报告
    nginx + tomcat多实例
    MongoDB副本集
    指针的艺术(转载)
  • 原文地址:https://www.cnblogs.com/helenawang/p/4582081.html
Copyright © 2020-2023  润新知