• ID3算法(Java实现)


    数据存储文件:buycomputer.properties

    #数据个数
    datanum=14
    #属性及属性值
    nodeAndAttribute=年龄:青/中/老,收入:高/中/低,学生:是/否,信誉:良/优,归类:买/不买
    #数据
    D1=青,高,否,良,不买
    D2=青,高,否,优,不买
    D3=中,高,否,良,买
    D4=老,中,否,良,买
    D5=老,低,是,良,买
    D6=老,低,是,优,不买
    D7=中,低,是,优,买
    D8=青,中,否,良,不买
    D9=青,低,是,良,买
    D10=老,中,是,良,买
    D11=青,中,是,优,买
    D12=中,中,否,优,买
    D13=中,高,是,良,买
    D14=老,中,否,优,不买
    D15=老,中,否,优,买
    View Code

    实体类:TreeNode.java

    package com.id3.node;
    
    import java.util.HashMap;
    import java.util.Map;
    
    public class TreeNode {
    
        private String nodeName;
        private Map<String,Attributes> attributes;
        private double gain;
        
        public double getGain() {
            return gain;
        }
        public void setGain(double gain) {
            this.gain = gain;
        }
        public String getNodeName() {
            
            return nodeName;
        }
        public void setNodeName(String nodeName) {
            this.nodeName = nodeName;
        }
        public Map<String, Attributes> getAttributes() {
            return attributes;
        }
        public void setAttributes(Map<String, Attributes> attributes) {
            
            this.attributes = attributes;
        }
        
        @Override
        public String toString() {
            return "TreeNode [nodeName=" + nodeName + ", attributes=" + attributes
                    + ", gain=" + gain + "]";
        }
        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result
                    + ((attributes == null) ? 0 : attributes.hashCode());
            long temp;
            temp = Double.doubleToLongBits(gain);
            result = prime * result + (int) (temp ^ (temp >>> 32));
            result = prime * result
                    + ((nodeName == null) ? 0 : nodeName.hashCode());
            return result;
        }
        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            TreeNode other = (TreeNode) obj;
            if (attributes == null) {
                if (other.attributes != null)
                    return false;
            } else if (!attributes.equals(other.attributes))
                return false;
            if (Double.doubleToLongBits(gain) != Double
                    .doubleToLongBits(other.gain))
                return false;
            if (nodeName == null) {
                if (other.nodeName != null)
                    return false;
            } else if (!nodeName.equals(other.nodeName))
                return false;
            return true;
        }
        
        
        
        
    }
    
    
    class Attributes{
        
        private String attrName;
        private TreeNode nextNode;
        private String leafName;
        private int attrNum;
        private double h;
        Map<String, Integer> resultNum = new HashMap<String, Integer>();
        
        public String getLeafName() {
            return leafName;
        }
        public void setLeafName(String leafName) {
            this.leafName = leafName;
        }
        public Map<String, Integer> getResultNum() {
            return resultNum;
        }
        public void setResultNum(Map<String, Integer> resultNum) {
            this.resultNum = resultNum;
        }
        public double getH() {
            return h;
        }
        public void setH(double h) {
            this.h = h;
        }
        public String getAttrName() {
            return attrName;
        }
        public void setAttrName(String attrName) {
            this.attrName = attrName;
        }
        public TreeNode getNextNode() {
            return nextNode;
        }
        public void setNextNode(TreeNode nextNode) {
            this.nextNode = nextNode;
        }
        public int getAttrNum() {
            return attrNum;
        }
        public void setAttrNum(int attrNum) {
            this.attrNum = attrNum;
        }
        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result
                    + ((attrName == null) ? 0 : attrName.hashCode());
            result = prime * result + attrNum;
            long temp;
            temp = Double.doubleToLongBits(h);
            result = prime * result + (int) (temp ^ (temp >>> 32));
            result = prime * result
                    + ((leafName == null) ? 0 : leafName.hashCode());
            result = prime * result
                    + ((nextNode == null) ? 0 : nextNode.hashCode());
            result = prime * result
                    + ((resultNum == null) ? 0 : resultNum.hashCode());
            return result;
        }
        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Attributes other = (Attributes) obj;
            if (attrName == null) {
                if (other.attrName != null)
                    return false;
            } else if (!attrName.equals(other.attrName))
                return false;
            if (attrNum != other.attrNum)
                return false;
            if (Double.doubleToLongBits(h) != Double.doubleToLongBits(other.h))
                return false;
            if (leafName == null) {
                if (other.leafName != null)
                    return false;
            } else if (!leafName.equals(other.leafName))
                return false;
            if (nextNode == null) {
                if (other.nextNode != null)
                    return false;
            } else if (!nextNode.equals(other.nextNode))
                return false;
            if (resultNum == null) {
                if (other.resultNum != null)
                    return false;
            } else if (!resultNum.equals(other.resultNum))
                return false;
            return true;
        }
        @Override
        public String toString() {
            return "Attributes [attrName=" + attrName + ", nextNode=" + nextNode
                    + ", leafName=" + leafName + ", attrNum=" + attrNum + ", h="
                    + h + ", resultNum=" + resultNum + "]";
        }
        
    }
    View Code

    ID3算法:ID3Alogo.java

    package com.id3.node;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileInputStream;
    import java.io.FileNotFoundException;
    import java.io.FileWriter;
    import java.io.IOException;
    import java.io.InputStream;
    import java.io.InputStreamReader;
    import java.text.DecimalFormat;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import java.util.Properties;
    
    /**
     * ID3算法
     * @author JoMint
     *
     */
    public class ID3Alogo {
    
        //存每个节点及其属性等相关变量
        private List<TreeNode> treeList;
        //存数据集
        private List<Map<String, String>> dataList;
        //遍历决策树时的开始节点
        private Attributes startNode;
        //决策结果变量的值
        private List<String> resultList;
        //结果属性节点
        private TreeNode resultNode;
        //决策树
        private String str;
    
        //构建决策树的开始调用方法
        public void ID3(String id3Name,String readPath,String printPath){
            
            //初始化成员变量
            initElement(id3Name);
            //读数据
            readData(readPath);
            //构建决策树
            cusTree(dataList, treeList, startNode);
            //System.out.println(startNode.getNextNode().getAttributes().get("Overcast").getLeafName());
            //遍历决策树,并把结果存入str中
            
            printTree(startNode,"");
            //打印决策树
            System.out.println(str);
            //输出决策树到文件
            printTreetoTxt(printPath);
            
        }
        
        /**
         * 初始化成员变量
         */
        private void initElement(String id3Name) {
            
            //存每个节点及其属性等相关变量
            treeList = new ArrayList<TreeNode>();
            //存数据集
            dataList = new ArrayList<Map<String,String>>();
            //遍历决策树时的开始节点
            startNode = new Attributes();
            //决策结果变量的值
            resultList = new ArrayList<String>();
            //结果属性节点
            TreeNode resultNode = null;
            //决策树
            str = id3Name+"决策树:
    ";
            
        }
    
    
        /**
         * 读数据
         */
        private void readData(String path) {
            
            Map<String, String> dataMap;
            Map<String,Attributes> attrMap;
            TreeNode treeNode;
            int num;
            
            //创建读取properties文件的对象
            Properties pro = new Properties();
            
            
            try {
                //为了读取中文字符,将读取文件的类型改为字符流读取
                InputStream inputStream = new FileInputStream(path);
                BufferedReader bf = new BufferedReader(new InputStreamReader(inputStream));
                //加载数据文件
                pro.load(bf);
                //读取数据总个数
                num = Integer.parseInt(pro.getProperty("datanum"));
                //读取属性及属性值
                String attribute = pro.getProperty("nodeAndAttribute");
                //将每个属性分开,用数组存,遍历每个属性,再把每个属性的属性值分开,存到treeList中
                String[] attArray = attribute.split(",");
                for (int i = 0; i < attArray.length; i++) {
                    
                    treeNode = new TreeNode();
                    String[] temp = attArray[i].split(":");
                    String nodeName = temp[0];
                    String[] attr = temp[1].split("/");
                    treeNode.setNodeName(nodeName);
                    attrMap = new HashMap<String, Attributes>();
                    Attributes attributes;
                    for (int j = 0; j < attr.length; j++) {
                        //Map<String, Integer> map = new HashMap<String, Integer>();
                        attributes = new Attributes();
                        //map.put(attr[j], 0);
                        attributes.setAttrName(attr[j]);
                        attrMap.put(attr[j], attributes);
                        
                        //存入结果变量的值,为最后的判断做铺垫
                        if(i == attArray.length-1){
                            
                            resultList.add(attr[j]);
                            
                        }
                        
                    }
                    treeNode.setAttributes(attrMap);
                    treeList.add(treeNode);
                }
                
                //遍历数据集,将数据按行存入dataList中
                for (int i = 1; i <= num; i++) {
                    
                    dataMap = new HashMap<String, String>();
                    String key = "D"+i;
                    String[] colline = pro.getProperty(key).split(",");
                    //System.out.println(key+"=="+colline.length);
                    for (int j = 0; j < treeList.size(); j++) {
                        //System.out.println(treeList.size());
                        dataMap.put(treeList.get(j).getNodeName(), colline[j]);
                    }
                    dataList.add(dataMap);
                }
                
                //得到结果属性的名字
                resultNode = treeList.get(treeList.size()-1);
                
                
    //            System.out.println("************************resultNode==" + resultNode + "***********************");
            } catch (FileNotFoundException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
    
        }
    
        
        /**
         * 数据处理
         * @param cdataList
         * @param ctreeList
         */
        private List<List> dealData(List<Map<String, String>> dataList, List<TreeNode> treeList){
            
            
            List<List> returnList= new ArrayList<List>();
            int num = dataList.size();
            
            /*
             * 统计数据集中每个属性的属性值个数
             */
            Map<String, Attributes> attrMap = new HashMap<String, Attributes>();
            Map<String, Integer> resultMap;
            for (int i = 0; i < treeList.size(); i++) {
                for (int j = 0; j < dataList.size(); j++) {
                    //获得当前数据集中当前列当前行的属性值
                    String key = dataList.get(j).get(treeList.get(i).getNodeName()); 
                    attrMap = treeList.get(i).getAttributes();
                    //System.out.println(attrMap.get(key)+"=="+key);
                    //计算样本中对应的属性变量的个数
                    attrMap.get(key).setAttrNum(attrMap.get(key).getAttrNum()+1); 
                    
                    //System.out.println("->"+attrMap.get(key));
                    
                    //获得结果变量值
                    String result = dataList.get(j).get(treeList.get(treeList.size()-1).getNodeName()); 
                    resultMap = attrMap.get(key).getResultNum();
                    //如果包含这个结果变量,则数量上加1; 如果不包含,赋初值为1
                    if (resultMap.containsKey(result)) {      
                        resultMap.put(result, resultMap.get(result)+1);
                    }else{
                        resultMap.put(result, 1);
                    }
                }
            }
            /*
             * 计算熵
             */
            DecimalFormat df = new DecimalFormat("#.###");   
            for (int i = 0; i < treeList.size(); i++) {
                //遍历 Attributes
                //计算属性熵: gain
                double gain = 0.0;
                for (Map.Entry<String, Attributes> element : treeList.get(i).getAttributes().entrySet()) {
                    Attributes attr = treeList.get(i).getAttributes().get(element.getKey());
                    Map<String, Integer> result = attr.getResultNum();
                    //遍历每个 Attributes 的 resultNum
                    //计算属性值的熵 :h
                    double h = 0.0;
                    for (Map.Entry<String, Integer> element2 : result.entrySet()) {
                        double resultNum = (double)result.get(element2.getKey());
                        double attrNum = (double)attr.getAttrNum();
                        resultNum = resultNum/attrNum;
                        h -= (resultNum*(Math.log(resultNum)/Math.log((double)2)));
                        h = Double.parseDouble(df.format(h));
                        attr.setH(h);
                        //System.out.println("resultNum=========="+resultNum);
                    }
                    //System.out.println(" attr==>"+attr);
                    gain += ((double)attr.getAttrNum()/num)*attr.getH();
                    gain = Double.parseDouble(df.format(gain));
                    
                    //System.out.println("gain=="+gain);
                }
                
                treeList.get(i).setGain(gain);
                //System.out.println(" gain-->"+treeList.get(i));
                
            }
            
            //将处理好的dataList和treeList放在returnList中返回
            returnList.add(dataList);
            returnList.add(treeList);
            
            return returnList;
            
    //        System.out.println("***************************************************+++++++↓");
    //        for (int i = 0; i < treeList.size(); i++) {
    //            System.out.println(treeList.get(i));
    //        }
    //        System.out.println();
    //        for (int i = 0; i < dataList.size(); i++) {
    //            System.out.println(dataList.get(i));
    //        }
    //        
    //        System.out.println("================================================="+num+"条数据=="+treeList.size()+"个属性");
    //        System.out.println("***************************************************+++++++↑");
            
        }
        
        /**
         * 构建决策树
         * @param dataList
         * @param treeList
         */
        @SuppressWarnings("unchecked")
        private void cusTree(List<Map<String, String>> dataList, List<TreeNode> treeList, Attributes cAttr){
            
            List<List> curryList= new ArrayList<List>();
            
            //处理数据
            
            curryList = dealData(dataList, treeList);
            
            
            //从 curryList 中得到 dataList 和 treeList
            dataList = (List<Map<String, String>>)curryList.get(0);
            treeList = (List<TreeNode>)curryList.get(1);
            
            
            //判断当前处理的数据集中的决策结果,若决策结果相同的个数等于总的当前处理的数据集的条数,则遍历结束
            //将当前的决策结果放入当前判断的属性值的后边
            //返回到调用这个函数的父函数
            for (TreeNode treeNode : treeList) {
                if (treeNode.getNodeName().equals(resultNode.getNodeName())) {
                    for (String attr : resultList) {
                        if (treeNode.getAttributes().get(attr).getAttrNum() == dataList.size()) {
                            cAttr.setLeafName(attr);
                            return;
                        }
                    }
                }
            }
            
    //        System.out.println("=_=_=_=_=_=_=datalist==="+dataList);
    //        System.out.println("=_=_=_=_=_=_=treelist==="+treeList);
            
            //寻找最优解
            
            //得到根节点
            TreeNode rootNode = treeList.get(0);
            
            for (TreeNode treeNode : treeList) {
                
                if(!treeNode.getNodeName().equals(treeList.get(treeList.size()-1).getNodeName())){
                    if(treeNode.getGain() < rootNode.getGain()){
                        rootNode = treeNode;
                    }
                }
                
            }
        //    System.out.println("*********↓↓↓↓↓↓↓↓***********当前根节点为:"+rootNode.getNodeName()+"***********↓↓↓↓↓↓↓↓*********");
            
            cAttr.setNextNode(rootNode);
            
            //对当前根节点的属性进行遍历,寻找下一个节点
            
            //节点名
            String nodeName = rootNode.getNodeName();
            //属性名
            String attrName = "";
            //属性节点
            Attributes attr = new Attributes();
            //当前节点的属性值集合
            Map<String, Attributes> attrMap = rootNode.getAttributes();
            
            
            //遍历节点的每个属性值
            for (Map.Entry<String, Attributes> entry : attrMap.entrySet()) {
                
                attr = attrMap.get(entry.getKey());
                attrName = attr.getAttrName();
                
                
    //            System.out.println("*****************attrName========"+attrName+"******************");
                
                //得到新的data集合对象
                
                List<Map<String, String>> newDataList = new ArrayList<Map<String,String>>();
                Map<String, String> newMap = new HashMap<String, String>();
                //String attrName = rootNode.getAttributes().get("Sunny").getAttrName();
                newMap.clear();
                
                //删除dataList中已处理过的节点数据
                //遍历dataList
                for (Map<String, String> map : dataList) {
                    
                    if(map.containsKey(nodeName)){
                        
                        if(map.get(nodeName).equals(attrName)){
                            newMap = new HashMap<String, String>();
                            for (Map.Entry<String, String> m : map.entrySet()) {
                                
                                //如果该节点不是已处理过的节点
                                if(!m.getKey().equals(nodeName)){
                                    //得到新的节点
                                    newMap.put(m.getKey(), map.get(m.getKey()));
                                }
                                
                            }
                            
                            //将新的节点存入newDataList中
                            newDataList.add(newMap);
                        }
                        
                    }
                    
                } 
    //            System.out.println("↓↓↓↓↓↓*******************新的data集合:*******************↓↓↓↓↓↓");
    //            for (Map<String, String> map : newDataList) {
    //                System.out.println(map);
    //            }
                
                //获得新的tree集合对象,而且值为初值
                
                List<TreeNode> newTreeList = new ArrayList<TreeNode>();
                
                //将treeList中的数据清空
                clearTree(treeList);
                
                //删除treeList中已处理过的节点
                for (TreeNode treeNode : treeList) {
                    if(!treeNode.getNodeName().equals(nodeName)){
                        newTreeList.add(treeNode);
                    }
                }
    //            System.out.println("↓↓↓↓↓↓*******************新的tree集合:*******************↓↓↓↓↓↓");
    //            for (TreeNode treeNode : newTreeList) {
    //                System.out.println(treeNode);
    //            }
                
                //递归调用当前函数,继续找节点
                cusTree(newDataList, newTreeList,attr);
            }
        }
        
        /**
         * 输出决策树
         * @param attr
         */
        private void printTree(Attributes attr, String ceil) {
            
            String nodeName = attr.getNextNode().getNodeName();
            Map<String, Attributes> attrMap = attr.getNextNode().getAttributes();
            
            str += ceil+"----"+nodeName+"
    ";
            for (Map.Entry<String, Attributes> nextAttr : attrMap.entrySet()) {
                
                //如果当前属性值没有下一个节点,则将当前属性值的名称及决策结果输出
                if(attrMap.get(nextAttr.getKey()).getNextNode() == null){
                    
                    str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"
    ";
                    str += ceil+"----------"+attrMap.get(nextAttr.getKey()).getLeafName()+"
    ";
                    
                }else{
                    
                    str += ceil+"-------"+attrMap.get(nextAttr.getKey()).getAttrName()+"
    ";
                    printTree(attrMap.get(nextAttr.getKey()),"------");
                }
            }
    
        }
        
        /**
         * 打印决策树到txt文本
         * @param path
         */
        private void printTreetoTxt(String path){
            
            if(path == null || path.equals("")) return;
            File file = new File(path);
            File folder = file.getParentFile();
            FileWriter fw;
            try {
                
                if(!folder.exists()){
                    folder.mkdirs();
                    file.createNewFile();
                }
                
                fw = new FileWriter(file);
                fw.write(str);
                
                fw.flush();
                fw.close();
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }
        
        
        /**
         * 还原初始数据
         * @param treeList
         */
        private void clearTree(List<TreeNode> treeList){
            
            for (TreeNode treeNode : treeList) {
                Map<String, Attributes> map = treeNode.getAttributes();
                
                for (Map.Entry<String, Attributes> entry : map.entrySet()) {
                    Attributes attr = map.get(entry.getKey());
                    attr.setAttrNum(0);
                    attr.setH(0);
                    Map<String, Integer> map2 = attr.getResultNum();
                    map2.clear();
                }
                treeNode.setGain(0);
            }
        }
        
    }
    View Code

    主函数:ID3Main.java

    package com.id3.node;
    
    
    public class ID3Main {
    
        public static void main(String[] args) {
            
            ID3Alogo id3Alogo = new ID3Alogo();
            id3Alogo.ID3("决策树名","数据文件地址", "输出文件地址");
            
        }
        
    }
    View Code
  • 相关阅读:
    CCF模拟题 窗口
    CSUOJ 1541 There is No Alternative
    MySQL数据库优化的八种方式(经典必看)
    PHP面向对象-----魔术方法
    PHP面向对象(OOP)----分页类
    2017最新PHP初级经典面试题目汇总(下篇)
    2017最新PHP经典面试题目汇总(上篇)
    原型模式
    适配器模式
    策略模式
  • 原文地址:https://www.cnblogs.com/mymint/p/4426147.html
Copyright © 2020-2023  润新知