• 决策树(2)


    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.List;
    
    public class testClass {
        public static void main(String[] args) {
            double [][]exercise = {{1,1,0,0},{1,3,1,1},{3,2,0,0},{3,2,1,10},{3,2,1,10},{3,2,1,10},{2,2,1,1},{3,2,1,9},{2,3,0,1},{2,1,0,0},{3,2,0,1},{2,1,0,1},{1,1,0,1}};
            String []Attribute = {"weather","thin","cloth","target"};
            int []index = {1,0,2,3};
            double [][]exerciseData = new double[exercise.length][];
            for(int i = 0;i<exerciseData.length;i++){
                exerciseData[i] = new double[exercise[i].length];
                for(int j = 0;j<exerciseData[i].length;j++){
                    exerciseData[i][j] = exercise[i][index[j]];
                }
            }
            
            
            for(int i = 0;i<exerciseData.length;i++){
                for(int j = 0;j<exerciseData[i].length;j++){
                    System.out.print("  "+exerciseData[i][j]);
                }
                System.out.println();
            }
            
            DecisionTree dt = new DecisionTree();
            List<ArrayList<String>> data = new ArrayList<ArrayList<String>>();
            for(int i=0;i<exerciseData.length;i++){
                ArrayList<String> t = new ArrayList<String>();
                for(int j=0;j<exerciseData[i].length;j++){
                    t.add(exerciseData[i][j]+"");
                }
                data.add(t);
            }
            
            List<String>attribute = new ArrayList<String>();
            for(int k=0;k<Attribute.length;k++){
                attribute.add(Attribute[k]);
            }
            TreeNode n =null;
            TreeNode node = dt.createDT(data,attribute,n);
            double[]dataExercise = {2,3};
            List list = new ArrayList();
            for(int i = 0;i<dataExercise.length;i++){
                list.add(dataExercise[i]);
            }
            
            node.traverse(list);
            
            System.out.println();
        }
        
    }
    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    
    public class DecisionTree {
        
        public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList,TreeNode node){
            
            System.out.println("当前的DATA为");
             for(int i=0;i<data.size();i++){
                    ArrayList<String> temp = data.get(i);
                    for(int j=0;j<temp.size();j++){
                        System.out.print(temp.get(j)+ " ");
                    }
                    System.out.println();
                }
                System.out.println("---------------------------------");
                System.out.println("当前的ATTR为");
                for(int i=0;i<attributeList.size();i++){
                    System.out.print(attributeList.get(i)+ " ");
                }
                System.out.println();
                System.out.println("---------------------------------");
                //String result = InfoGain.IsPure(InfoGain.getTarget(data));
                //System.out.println("***************"+result);
                
                if(node==null){
                    node = new TreeNode();
                    node.setAttributeValue("start");
                    node.setNodeName("start");
                    
                }
                
                if(attributeList.size() == 1){
                    
                    int num = data.size();
                    for(int i = 0;i<num;i++){
                    TreeNode leafNode = new TreeNode();
                    leafNode.setAttributeValue(data.get(i).get(0));
                    leafNode.setNodeName("target");
                    node.getChildTreeNode().add(leafNode);
                    }
                    return node;
                    
                }else{
                    
                    System.out.println("选择出的最大增益率属性为: " + attributeList.get(0));
                    //node.setAttributeValue(attributeList.get(0));
                    List<ArrayList<String>> resultData = null;
                    InfoGain gain = new InfoGain(data,attributeList);
                    
                    Map<String,Long> attrvalueMap = gain.getAttributeValue(0);
                    
                    for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
                        resultData = gain.getData4Value(entry.getKey(), 0);
                        TreeNode leafNode = new TreeNode();
                        leafNode.setAttributeValue(entry.getKey());
                        leafNode.setNodeName(attributeList.get(0));
                        
                        node.getChildTreeNode().add(leafNode);
                        
                        System.out.println("当前为"+attributeList.get(0)+"的"+entry.getKey()+"分支。");
                        for (int j = 0; j < resultData.size(); j++) {
                            resultData.get(j).remove(0);
                        }
                        ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
                        resultAttr.remove(0);
                        createDT(resultData,resultAttr,leafNode);            
                        }
                }
               
                return node;
                }
            }
                
                
                
                
                
                
                
                
                
        
        
    
        
    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    
    public class InfoGain {
        private List<ArrayList<String>> data;
        private List<String> attribute;
        
        
    public InfoGain(List<ArrayList<String>> data,List<String> attribute){
            
            this.data = new ArrayList<ArrayList<String>>();
            for(int i=0;i<data.size();i++){
                List<String> temp = data.get(i);
                ArrayList<String> t = new ArrayList<String>();
                for(int j=0;j<temp.size();j++){
                    t.add(temp.get(j));
                }
                this.data.add(t);
            }
            
            this.attribute = new ArrayList<String>();
            for(int k=0;k<attribute.size();k++){
                this.attribute.add(attribute.get(k));
            }
            /*this.data = data;
            this.attribute = attribute;*/
        }
    public  Map<String,Long> getAttributeValue(int attributeIndex){
            
            Map<String,Long> attributeValueMap = new HashMap<String,Long>();
            for(ArrayList<String> note : data){
                String key = note.get(attributeIndex);
                Long value = attributeValueMap.get(key);
                attributeValueMap.put(key, value != null ? ++value :1L);
            }
            return attributeValueMap;
            
        }
        
        public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
            
            List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
            Iterator<ArrayList<String>> iterator = data.iterator();
            for(;iterator.hasNext();){
                ArrayList<String> templist = iterator.next();
                if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
                    ArrayList<String> temp = (ArrayList<String>) templist.clone();
                    resultData.add(temp);
                }
            }
            return resultData;
        }
    public static List<String> getTarget(List<ArrayList<String>> data){
            
            List<String> list = new ArrayList<String>();
            for(ArrayList<String> temp : data){
                int index = temp.size()-1 ;
                if(index == -1){
                    break;
                }
                String value = temp.get(index);
                list.add(value);
            }
            return list;
        }
        
        //判断当前纯度是否100%
        public static String IsPure(List<String> list){
            
           
            
            return list.get(0);
        }
        
    }
    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    
     class TreeNode{
    
    private String attributeValue;
            private List<TreeNode> childTreeNode;
            private List<String> pathName;
            private String targetFunValue;
            private String nodeName;
            
            public TreeNode(String nodeName){
                
                this.nodeName = nodeName;
                this.childTreeNode = new ArrayList<TreeNode>();
                this.pathName = new ArrayList<String>();
            }
            
            public TreeNode(){
                this.childTreeNode = new ArrayList<TreeNode>();
                this.pathName = new ArrayList<String>();
            }
    
            public String getAttributeValue() {
                return attributeValue;
            }
    
            public void setAttributeValue(String attributeValue) {
                this.attributeValue = attributeValue;
            }
    
            public List<TreeNode> getChildTreeNode() {
                return childTreeNode;
            }
    
            public void setChildTreeNode(List<TreeNode> childTreeNode) {
                this.childTreeNode = childTreeNode;
            }
    
            public String getTargetFunValue() {
                return targetFunValue;
            }
    
            public void setTargetFunValue(String targetFunValue) {
                this.targetFunValue = targetFunValue;
            }
    
            public String getNodeName() {
                return nodeName;
            }
    
            public void setNodeName(String nodeName) {
                this.nodeName = nodeName;
            }
    
            public List<String> getPathName() {
                return pathName;
            }
    
            public void setPathName(List<String> pathName) {
                this.pathName = pathName;
            }
            
            public void traverse() {  
                System.out.println(this.getNodeName()+":   "+this.getAttributeValue());
                int childNumber = this.childTreeNode.size(); 
                System.out.println(childNumber);
                for (int i = 0; i < childNumber; i++) {  
                    TreeNode child = this.childTreeNode.get(i);  
                    child.traverse();  
                }  
            }  
            
            
            public List getTarget(TreeNode node){
                List a = new ArrayList();;
                int childNum = node.getChildTreeNode().size();
                if(node.childTreeNode.get(0).childTreeNode.size()==0){//表示node孩子的孩子为空,即node下一层为目标层
                    for(int i = 0;i<childNum;i++){
                        a.add(node.getChildTreeNode().get(i).getAttributeValue());
                        
                    }
                    
                }else{
                    for(int i = 0;i<childNum;i++){
                        a.addAll(getTarget(node.getChildTreeNode().get(i)));
                    }
                }
                return a;
            }
            public void traverse(List list) {
                if(list.size()==0){
                    List target = getTarget(this);
    //                int childlistNumber = this.childTreeNode.size(); 
    //                List a = new ArrayList();
    //                for(int i = 0;i<childlistNumber;i++){
    //                TreeNode child = this.childTreeNode.get(i);
    //                a.add(child.getAttributeValue());
    //                }
                    List b = new ArrayList();
    //                Map result = new HashMap();
                    for(int i = 0;i<target.size();i++){
                        if(!b.contains(target.get(i))){
                        b.add(target.get(i));
                        }
                    }
                    int []count = new int [b.size()];
                    for(int i = 0;i<b.size();i++){
                        
                        for(int j = 0;j<target.size();j++){
                            if(b.get(i).equals(target.get(j))){
                                count[i] = count[i]+1;
                            }
                        }
                        System.out.println(b.get(i)+"的数量是:   "+count[i]);
                    }
                    int maxIndex = 0;
                    for(int i = 1;i<count.length;i++){
                        if(count[maxIndex]<count[i]){
                            maxIndex = i;
                        }
                    }
                    System.out.println("选择"+b.get(maxIndex)+"为最终决策");
                    
                    
                    
                    
                }else{
                List a = new ArrayList();
                double temp = (Double)list.get(0);
                int childlistNumber = this.childTreeNode.size(); 
                System.out.println(childlistNumber);
                for(int i = 0;i<childlistNumber;i++){
                    TreeNode child = this.childTreeNode.get(i);  
                    double tempchild = Double.valueOf(child.getAttributeValue());
                    if(temp==tempchild){
                        System.out.println(child.getNodeName()+":   "+child.getAttributeValue());
                        list.remove(0);
                        child.traverse(list);
                    }
                }
                }
            }
     }
            
        
     
  • 相关阅读:
    python 判断返回结果 in用法
    关于requests的session方法保持不了cookie的问题。(seesion的意思是保持一个会话,比如 登陆后继续操作(记录身份信息) 而requests是单次请求的请求,身份信息不会被记录)
    python-selenium并发执行测试用例(方法一 各模块每一条并发执行)
    python 正则表达提取方法 (提取不来的信息print不出来 加个输出type 再print信息即可)
    unittest框架 assertEqual 报错 让其出现中文的方法(这个问题出现时 我找了老半天) 还追加了 报错信息自定义的方法
    python 指定文件编码的方法
    解决python中路径中包含中文无法找到文件的问题
    python 字符转换记录
    python-selenium 并发执行用例的问题
    深度影响价值
  • 原文地址:https://www.cnblogs.com/yunerlalala/p/6119833.html
Copyright © 2020-2023  润新知