• 【十大算法实现之KNN】KNN算法实例(含测试数据和源码)


    KNN算法基本的思路是比较好理解的,今天根据它的特点写了一个实例,我会把所有的数据和代码都写在下面供大家参考,不足之处,请指正。谢谢!

    update:工程代码全部在本页面中,测试数据已丢失,建议去UCI Dataset中找一个自行测试一下。

    几点说明:

    1.KNN中的K=5;

    2.在计算权重时,采用的是减去函数{1,0.8,0.6,0.4,0.2},当然你也可以采用反函数或高斯函数;

    3.5%作为测试集(decision.txt),95%作为训练集(training.txt);

    4.在计算costfun之前,对所有的属性进行了归一化,由于这里不知道数据集每个属性代表的含义,所以就一视同仁,实际情况下,应该具体问题具体分析;

    image

    XBWKNN.java

    package XBWKNN;
    
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.Comparator;
    import java.util.List;
    
    /**
     * KNN算法
     * @author XBW
     * @date 2014年8月16日
     */
    
    
    public class XBWKNN{
        public final static int KofKNN=5;
        public final static double weight[]={1,0.9,0.7,0.4,0.1};                //减法函数y=1-0.2*x
        
    
        /**
         * knn
         * @param data
         * @param ds
         * @return ans
         */
        public static int knn(Data data,DataSet ds){
            int ans = 0;
            List<Data> dis=calcDis(data,ds);
            ans=calcKDis(data,dis);
            return ans;
        }
        
        /**
         * 计算训练集中所有向量的距离,排序之后取前K个
         * @param data
         * @param ds
         * @return
         */
        @SuppressWarnings("null")
        public static List<Data>calcDis(Data data,DataSet ds){
            List<Data> anslist =new ArrayList<Data>();
            double dx1=data.x1;
            double dx2=data.x2;
            double dx3=data.x3;
            for(int i=0;i<ds.ds.size();i++){
                double x1=ds.ds.get(i).x1;
                double x2=ds.ds.get(i).x2;
                double x3=ds.ds.get(i).x3;
                ds.ds.get(i).costfun=Math.sqrt((dx1-x1)*(dx1-x1)+(dx2-x2)*(dx2-x2)+(dx3-x3)*(dx3-x3));
                anslist.add(ds.ds.get(i));
            }
            Collections.sort(anslist,new Comparator<Data>(){
                   public int compare(Data o1, Data o2) {
                       Double s=o1.costfun-o2.costfun;
                       if(s<0)
                           return -1;
                       else
                           return 1; 
                    }
            });
            return anslist;
        }
        
        
        /**
         * 按一定的权重计算出前K个
         * @param data
         * @param ds
         * @return
         */
        public static int calcKDis(Data data,List<Data> anslist){
            Double[] anstype={0.0,0.0,0.0,0.0};
            for(int i=0;i<KofKNN;i++){
                if(anslist.get(i).type==1){
                    anstype[1]+=weight[i];
                }
                else if(anslist.get(i).type==2){
                    anstype[2]+=weight[i];
                }
                if(anslist.get(i).type==3){
                    anstype[3]+=weight[i];
                }
            }
            Double maxt=-1.0;
            int tag=1;
            for(int i=1;i<=3;i++){
                if(maxt<anstype[i]){
                    tag=i;
                    maxt=anstype[i];
                }
            }
            return tag;
        }
        
        public static void main(String[] args) throws IOException{
            DataSet ds=new DataSet();
            DataTest dt=new DataTest();
            
            int correct=0;
            for(int i=0;i<dt.dt.size();i++){
                Data data=dt.dt.get(i);
                int result=knn(data,ds);
                if(result==data.type){
                    correct++;
                }
            }
            System.out.println("total test num :"+dt.dt.size());
            System.out.println("correct test num :"+correct);
            System.out.println("ratio :"+correct/(double)dt.dt.size());
        }
    }

    Datatest.java

    package XBWKNN;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;
    
    
    
    
    /**
     * 测试数据
     * @author XBW
     * @date 2014年8月16日
     */
    
    public class DataTest{
        String defaultpath="D:\MachineLearning\十大算法\KNN\knncode\decision.txt";
        List<Data> dt;
        
        @SuppressWarnings("null")
        public DataTest() throws IOException{
            List<Data> dset = new ArrayList<Data>();
            File ds=new File(defaultpath);
            @SuppressWarnings("resource")
            BufferedReader br = new BufferedReader(new FileReader(ds));
            String tsing;
            double max1=-1;
            double max2=-1;
            double max3=-1;
            while((tsing=br.readLine())!=null){
                String[] dlist=tsing.split("    ");
                Data data=new Data();
                data.x1=Double.parseDouble(dlist[0]);
                data.x2=Double.parseDouble(dlist[1]);
                data.x3=Double.parseDouble(dlist[2]);
                data.type=Integer.parseInt(dlist[3]);
                dset.add(data);
                
                if(data.x1>max1){
                    max1=data.x1;
                }
                if(data.x2>max2){
                    max2=data.x2;
                }
                if(data.x3>max3){
                    max3=data.x3;
                }
            }
            dset=normalization(dset,max1,max2,max3);
            this.dt=dset;
        }
        
        public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){
            for(int i=0;i<dset.size();i++){
                dset.get(i).x1/=m1;
                dset.get(i).x2/=m2;
                dset.get(i).x3/=m3;
            }
            return dset;
        }
    }

    DataSet.java

    package XBWKNN;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;
    
    
    
    
    /**
     * 训练数据
     * @author XBW
     * @date 2014年8月16日
     */
    
    public class DataSet{
        String defaultpath="D:\MachineLearning\十大算法\KNN\knncode\training.txt";
        List<Data> ds;
        
        @SuppressWarnings("null")
        public DataSet() throws IOException{
            List<Data> dset =new ArrayList<Data>();
            File ds=new File(defaultpath);
            @SuppressWarnings("resource")
            BufferedReader br = new BufferedReader(new FileReader(ds));
            String tsing;
            double max1=-1;
            double max2=-1;
            double max3=-1;
            while((tsing=br.readLine())!=null){
                String[] dlist=tsing.split("    ");
                Data data=new Data();
                data.x1=Double.parseDouble(dlist[0]);
                data.x2=Double.parseDouble(dlist[1]);
                data.x3=Double.parseDouble(dlist[2]);
                data.type=Integer.parseInt(dlist[3]);
                dset.add(data);
                
                if(data.x1>max1){
                    max1=data.x1;
                }
                if(data.x2>max2){
                    max2=data.x2;
                }
                if(data.x3>max3){
                    max3=data.x3;
                }
            }
            dset=normalization(dset,max1,max2,max3);
            this.ds=dset;
        }
        
        public List<Data> normalization(List<Data> dset,double m1,double m2,double m3){
            for(int i=0;i<dset.size();i++){
                dset.get(i).x1/=m1;
                dset.get(i).x2/=m2;
                dset.get(i).x3/=m3;
            }
            return dset;
        }
    }

    Data.java

    package XBWKNN;
    
    /**
     * 一条数据
     * @author XBW
     * @date 2014年8月16日
     */
    
    public class Data{
        Double x1;
        Double x2;
        Double x3;
        Double costfun;
        int type;
    }

    output:

    image







                If you have any questions about this article, welcome to leave a message on the message board.



    Brad(Bowen) Xu
    E-Mail : maxxbw1992@gmail.com


  • 相关阅读:
    JVM(六)——垃圾回收算法
    JVM(五)——执行引擎、String
    JVM(四)——方法区
    JVM(三)——堆
    JVM(二)——虚拟机栈
    JVM(一)——概述和类加载子系统
    Java EE入门(二十二)——Linux和Nginx
    操作系统(六)——磁盘和IO
    【03】RNN
    【02】时间复杂度
  • 原文地址:https://www.cnblogs.com/XBWer/p/3916884.html
Copyright © 2020-2023  润新知