• java编写ID3决策树


    说明:每个样本都会装入Data样本对象,决策树生成算法接收的是一个Array<Data>样本列表,所以构建测试数据时也要符合格式,最后生成的决策树是树的根节点,通过里面提供的showTree()方法可查看整个树结构,下面奉上源码。

    Data.java

    package ai.tree.data;
    
    import java.util.HashMap;
    
    /**
     * 样本类
     * @author ChenLuyang
     * @date 2019/2/21
     */
    public class Data implements Cloneable{
        /**
         * K是特征描述,V是特征值
         */
        private HashMap<String,String> feature = new HashMap<String, String>();
    
        /**
         * 该样本结论
         */
        private String result;
    
        public Data(HashMap<String,String> feature,String result){
            this.feature = feature;
            this.result = result;
        }
    
        public HashMap<String, String> getFeature() {
            return feature;
        }
    
        public String getResult() {
            return result;
        }
    
        private void setFeature(HashMap<String, String> feature) {
            this.feature = feature;
        }
    
        @Override
        public Data clone()
        {
            Data object=null;
            try {
                object = (Data) super.clone();
                object.setFeature((HashMap<String, String>) this.feature.clone());
            } catch (CloneNotSupportedException e) {
                e.printStackTrace();
            }
    
            return object;
        }
    }
    

      

    DecisionTree.java

    package ai.tree.algorithm;
    
    import ai.tree.data.Data;
    
    import java.math.BigDecimal;
    import java.util.*;
    
    /**
     * @author ChenLuyang
     * @date 2019/2/21
     */
    public class DecisionTree {
        /**
         * 递归构建决策树
         *
         * @param dataList 样本集合
         * @return ai.tree.algorithm.DecisionTree.TreeNode 使用传入样本构建的决策节点
         * @author ChenLuyang
         * @date 2019/2/21 16:05
         */
        public TreeNode createTree(List<Data> dataList) {
            //创建当前节点
            TreeNode<String, String, String> nowTreeNode = new TreeNode<String, String, String>();
            //当前节点的各个分支节点
            Map<String, TreeNode> featureDecisionMap = new HashMap<String, TreeNode>();
    
            //统计当前样本集中所有的分类结果
            Set<String> resultSet = new HashSet<String>();
            for (Data data :
                    dataList) {
                resultSet.add(data.getResult());
            }
    
            //如果当前样本集只有一种类别,则表示不用分类了,返回当前节点
            if (resultSet.size() == 1) {
                String resultClassify = resultSet.iterator().next();
    
                nowTreeNode.setResultNode(resultClassify);
    
                return nowTreeNode;
            }
    
            //如果数据集中特征为空,则选择整个集合中出现次数最多的分类,作为分类结果
            if (dataList.get(0).getFeature().size() == 0) {
                Map<String, Integer> countMap = new HashMap<String, Integer>();
                for (Data data :
                        dataList) {
                    Integer num = countMap.get(data.getResult());
                    if (num == null) {
                        countMap.put(data.getResult(), 1);
                    } else {
                        countMap.put(data.getResult(), num + 1);
                    }
                }
    
                String tmpResult = "";
                Integer tmpNum = 0;
                for (String res :
                        countMap.keySet()) {
                    if (countMap.get(res) > tmpNum) {
                        tmpNum = countMap.get(res);
                        tmpResult = res;
                    }
                }
    
                nowTreeNode.setResultNode(tmpResult);
    
                return nowTreeNode;
            }
    
            //寻找当前最优分类
            String bestLabel = chooseBestFeatureToSplit(dataList);
    
            //提取最优特征的所有可能值
            Set<String> bestLabelInfoSet = new HashSet<String>();
            for (Data data :
                    dataList) {
                bestLabelInfoSet.add(data.getFeature().get(bestLabel));
            }
    
            //使用最优特征的各个特征值进行分类
            for (String labelInfo :
                    bestLabelInfoSet) {
                for (Data data :
                        dataList) {
                }
                List<Data> branchDataList = splitDataList(dataList, bestLabel, labelInfo);
    
                //最优特征下该特征值的节点
                TreeNode branchTreeNode = createTree(branchDataList);
                featureDecisionMap.put(labelInfo, branchTreeNode);
            }
    
            nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap);
    
            return nowTreeNode;
        }
    
        /**
         * 计算传入数据集中的最优分类特征
         *
         * @param dataList
         * @return int 最优分类特征的描述
         * @author ChenLuyang
         * @date 2019/2/21 14:12
         */
        public String chooseBestFeatureToSplit(List<Data> dataList) {
            //目前数据集中的特征集合
            Set<String> futureSet = dataList.get(0).getFeature().keySet();
    
            //未分类时的熵
            BigDecimal baseEntropy = calcShannonEnt(dataList);
    
            //熵差
            BigDecimal bestInfoGain = new BigDecimal("0");
            //最优特征
            String bestFeature = "";
    
            //按照各特征分类
            for (String future :
                    futureSet) {
                //该特征分类后的熵
                BigDecimal futureEntropy = new BigDecimal("0");
    
                //该特征的所有特征值去重集合
                Set<String> futureInfoSet = new HashSet<String>();
                for (Data data :
                        dataList) {
                    futureInfoSet.add(data.getFeature().get(future));
                }
    
                //按照该特征的特征值一一分类
                for (String futureInfo :
                        futureInfoSet) {
                    List<Data> splitResultDataList = splitDataList(dataList, future, futureInfo);
    
                    //分类后样本数占总样本数的比例
                    BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN);
    
                    //所占比例乘以分类后的样本熵,然后再进行熵的累加
                    futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList)));
                }
    
                BigDecimal subEntropy = baseEntropy.subtract(futureEntropy);
    
                if (subEntropy.compareTo(bestInfoGain) >= 0) {
                    bestInfoGain = subEntropy;
                    bestFeature = future;
                }
            }
    
            return bestFeature;
        }
    
        /**
         * 计算传入样本集的熵值
         *
         * @param dataList 样本集
         * @return java.math.BigDecimal 熵
         * @author ChenLuyang
         * @date 2019/2/22 9:41
         */
        public BigDecimal calcShannonEnt(List<Data> dataList) {
            //样本总数
            BigDecimal sumEntries = new BigDecimal(dataList.size() + "");
            //香农熵
            BigDecimal shannonEnt = new BigDecimal("0");
            //统计各个分类结果的样本数量
            Map<String, Integer> resultCountMap = new HashMap<String, Integer>();
            for (Data data :
                    dataList) {
                Integer dataResultCount = resultCountMap.get(data.getResult());
                if (dataResultCount == null) {
                    resultCountMap.put(data.getResult(), 1);
                } else {
                    resultCountMap.put(data.getResult(), dataResultCount + 1);
                }
            }
    
            for (String resultCountKey :
                    resultCountMap.keySet()) {
                BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString());
    
                BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN);
                shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + "")));
            }
    
            return shannonEnt;
        }
    
        /**
         * 根据某个特征的特征值,进行样本数据的划分,将划分后的样本数据集返回
         *
         * @param dataList 待划分的样本数据集
         * @param future   筛选的特征依据
         * @param info     筛选的特征值依据
         * @return java.util.List<ai.tree.data.Data> 按照指定特征值分类后的数据集
         * @author ChenLuyang
         * @date 2019/2/21 18:26
         */
        public List<Data> splitDataList(List<Data> dataList, String future, String info) {
            List<Data> resultDataList = new ArrayList<Data>();
            for (Data data :
                    dataList) {
                if (data.getFeature().get(future).equals(info)) {
                    Data newData = (Data) data.clone();
                    newData.getFeature().remove(future);
                    resultDataList.add(newData);
                }
            }
    
            return resultDataList;
        }
    
        /**
         * L:每一个特征的描述信息的类型
         * F:特征的类型
         * R:最终分类结果的类型
         */
        public class TreeNode<L, F, R> {
            /**
             * 该节点的最优特征的描述信息
             */
            private L label;
    
            /**
             * 根据不同的特征作出响应的决定。
             * K为特征值,V为该特征值作出的决策节点
             */
            private Map<F, TreeNode> featureDecisionMap;
    
            /**
             * 是否为最终分类节点
             */
            private boolean isFinal;
    
            /**
             * 最终分类结果信息
             */
            private R resultClassify;
    
            /**
             * 设置叶子节点
             *
             * @param resultClassify 最终分类结果
             * @return void
             * @author ChenLuyang
             * @date 2019/2/22 18:31
             */
            public void setResultNode(R resultClassify) {
                this.isFinal = true;
                this.resultClassify = resultClassify;
            }
    
            /**
             * 设置分支节点
             *
             * @param label              当前分支节点的描述信息(特征)
             * @param featureDecisionMap 当前分支节点的各个特征值,与其对应的子节点
             * @return void
             * @author ChenLuyang
             * @date 2019/2/22 18:31
             */
            public void setDecisionNode(L label, Map<F, TreeNode> featureDecisionMap) {
                this.isFinal = false;
                this.label = label;
                this.featureDecisionMap = featureDecisionMap;
            }
    
            /**
             * 展示当前节点的树结构
             *
             * @return void
             * @author ChenLuyang
             * @date 2019/2/22 16:54
             */
            public String showTree() {
                HashMap<String, String> treeMap = new HashMap<String, String>();
                if (isFinal) {
                    String key = "result";
                    R value = resultClassify;
                    treeMap.put(key, value.toString());
                } else {
                    String key = label.toString();
                    HashMap<F, String> showFutureMap = new HashMap<F, String>();
                    for (F f :
                            featureDecisionMap.keySet()) {
                        showFutureMap.put(f, featureDecisionMap.get(f).showTree());
                    }
                    String value = showFutureMap.toString();
    
                    treeMap.put(key, value);
                }
    
                return treeMap.toString();
            }
    
            public L getLabel() {
                return label;
            }
    
            public Map<F, TreeNode> getFeatureDecisionMap() {
                return featureDecisionMap;
            }
    
            public R getResultClassify() {
                return resultClassify;
            }
    
            public boolean getFinal() {
                return isFinal;
            }
        }
    }
    

      

    Start.java

    package ai.tree.algorithm;
    
    import ai.tree.data.Data;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    
    /**
     * @author ChenLuyang
     * @date 2019/2/22
     */
    public class Start {
        /**
         * 构建测试样本集,测试样本如下:
         样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男
         样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女
         样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:女
         样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=没眼镜} 分类:男
         样本特征:{头发长短=短发, 身材=瘦, 是否戴眼镜=没眼镜} 分类:男
         样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女
         样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男
         * @author ChenLuyang
         * @date 2019/2/21 15:34
         * @return java.util.List<ai.tree.data.DecisionTreeTestData.Data> 样本集
         */
        public static List<Data> createDataList(){
            /**
             * 样本特征描述
             * @author ChenLuyang
             * @date 2019/2/22 18:55
             * @return java.util.List<ai.tree.data.Data>
             */
            String[] labels = new String[]{"是否戴眼镜", "头发长短", "身材"};
    
            List<Data> dataList = new ArrayList<Data>();
    
            HashMap<String,String> feature1 = new HashMap<String, String>();
            feature1.put(labels[0],"有眼镜");
            feature1.put(labels[1].toString(),"短发");
            feature1.put(labels[2].toString(),"胖");
            dataList.add(new Data(feature1,"男"));
    
            HashMap<String,String> feature2 = new HashMap<String, String>();
            feature2.put(labels[0],"有眼镜");
            feature2.put(labels[1],"长发");
            feature2.put(labels[2],"瘦");
            dataList.add(new Data(feature2,"女"));
    
            HashMap<String,String> feature3 = new HashMap<String, String>();
            feature3.put(labels[0],"有眼镜");
            feature3.put(labels[1],"短发");
            feature3.put(labels[2],"胖");
            dataList.add(new Data(feature3,"女"));
    
            HashMap<String,String> feature4 = new HashMap<String, String>();
            feature4.put(labels[0],"没眼镜");
            feature4.put(labels[1],"长发");
            feature4.put(labels[2],"胖");
            dataList.add(new Data(feature4,"男"));
    
            HashMap<String,String> feature5 = new HashMap<String, String>();
            feature5.put(labels[0],"没眼镜");
            feature5.put(labels[1],"短发");
            feature5.put(labels[2],"瘦");
            dataList.add(new Data(feature5,"男"));
    
            HashMap<String,String> feature6 = new HashMap<String, String>();
            feature6.put(labels[0],"有眼镜");
            feature6.put(labels[1],"长发");
            feature6.put(labels[2],"瘦");
            dataList.add(new Data(feature6,"女"));
    
            HashMap<String,String> feature7 = new HashMap<String, String>();
            feature7.put(labels[0],"有眼镜");
            feature7.put(labels[1],"长发");
            feature7.put(labels[2],"胖");
            dataList.add(new Data(feature7,"男"));
    
            return dataList;
        }
    
        public static void main(String[] args) {
            DecisionTree decisionTree = new DecisionTree();
    
            //使用测试样本生成决策树
            DecisionTree.TreeNode tree = decisionTree.createTree(createDataList());
    
            //展示决策树
            System.out.println(tree.showTree());
        }
    }
    

      

    生成树结构:{是否戴眼镜={没眼镜={result=男}, 有眼镜={身材={胖={头发长短={长发={result=男}, 短发={result=女}}}, 瘦={result=女}}}}}

  • 相关阅读:
    MySQL binlog中 format_desc event格式解析
    位bit和字节Byte
    MySQL利用mysqlbinlog模拟增量恢复
    mysqldump参数 --master-data详解
    开启MySQL二进制日志
    设置花里胡哨的Xshell字体与背景颜色(超全)
    Python操作MySQL数据库
    给定一个由括号([{)]}其中之一或多个组成的字符串判断是否符合左右括号成对标准,不同括号可任意嵌套
    给定一个字符串str,将str中连续两个字符为a的字符替换为b(一个或连续超过多个字符a则不替换)
    不使用局部变量和for循环或其它循环打印出如m=19,n=2結果为2 4 8 16 16 8 4 2形式的串
  • 原文地址:https://www.cnblogs.com/red-code/p/10420107.html
Copyright © 2020-2023  润新知