• 机器学习之KNN算法思想及其实现


    从一个例子来直观感受KNN思想

    如下图 , 绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

                        

    从这个例子中,我们再来看KNN思想:

    1, 计算已知类别数据集合中的点与当前点之间的距离(使用欧式距离公司: d =sqrt(pow(x-x1),2)+pow(y-y1),2)
    
    2, 按照距离递增次序排序(由近到远)
    
    3, 选取与当前点距离最小的的K个点(如上题中的 k=3,k=5)
    
    4, 确定前K个点所在类别的出现频率
    
    5, 将频率最高的那组,作为该点的预测分类

    实现代码:

     1 package com.data.knn;
     2 
     3 /**
     4  * *********************************************************
     5  * <p/>
     6  * Author:     XiJun.Gong
     7  * Date:       2016-09-06 12:02
     8  * Version:    default 1.0.0
     9  * Class description:
    10  * <p/>
    11  * *********************************************************
    12  */
    13 public class Point {
    14 
    15     private double x;  //x坐标
    16     private double y;  //y坐标
    17     private double dist; //距离另一个点的距离
    18 
    19 
    20 
    21     private String label; //所属类别
    22 
    23     public Point() {
    24         this(0d, 0d, "");
    25     }
    26 
    27     public Point(double x, double y, String label) {
    28         this.x = x;
    29         this.y = y;
    30         this.label = label;
    31     }
    32 
    33     /*计算两点之间的距离*/
    34     public double distance(final Point a) {
    35         return Math.sqrt((a.x - x) * (a.x - x) + (a.y - y) * (a.y - y));
    36     }
    37 
    38     public double getX() {
    39         return x;
    40     }
    41 
    42     public void setX(double x) {
    43         this.x = x;
    44     }
    45 
    46     public double getY() {
    47         return y;
    48     }
    49 
    50     public void setY(double y) {
    51         this.y = y;
    52     }
    53 
    54     public String getLabel() {
    55         return label;
    56     }
    57 
    58     public void setLabel(String label) {
    59         this.label = label;
    60     }
    61 
    62 
    63     public double getDist() {
    64         return dist;
    65     }
    66 
    67     public void setDist(double dist) {
    68         this.dist = dist;
    69     }
    70 }

    KNN实现

     1 package com.data.knn;
     2 
     3 import com.google.common.base.Preconditions;
     4 import com.google.common.collect.Maps;
     5 
     6 import java.util.Collections;
     7 import java.util.Comparator;
     8 import java.util.List;
     9 import java.util.Map;
    10 
    11 /**
    12  * *********************************************************
    13  * <p/>
    14  * Author:     XiJun.Gong
    15  * Date:       2016-09-06 11:59
    16  * Version:    default 1.0.0
    17  * Class description:
    18  * <p/>
    19  * *********************************************************
    20  */
    21 public class knn {
    22 
    23     private List<Point> dataSet;    //统计频率
    24     private Point newPoint;         //当前点
    25 
    26 
    27     //进行KNN分类
    28     public String classify(List<Point> dataSet, final Point newPoint, Integer K) {
    29 
    30         Preconditions.checkArgument(K < dataSet.size(), "K的值超过了dataSet的元素");
    31         //求解每一个点到新的点的距离
    32         for (Point point : dataSet) {
    33             point.setDist(newPoint.distance(point));
    34         }
    35         //进行排序
    36         Collections.sort(dataSet, new Comparator<Point>() {
    37             @Override
    38             public int compare(Point o1, Point o2) {
    39                 //return o1.distance(newPoint) < o2.distance(newPoint) ? 1 : -1;
    40                 return o1.getDist() < o2.getDist() ? 1 : -1;
    41             }
    42         });
    43 
    44         //统计前K个标签的频率
    45         Map<String, Integer> map = Maps.newHashMap();
    46         Integer maxCnt = -9999; //最高频率
    47         String label = "";  //最高频率标签
    48         Integer currentCnt = 0; //当前标签的频率
    49         Integer times = 0;
    50         for (Point point : dataSet) {
    51             currentCnt = 1;
    52             if (map.containsKey(point.getLabel())) {
    53                 currentCnt += map.get(point);
    54             }
    55             if (maxCnt < currentCnt) {
    56                 maxCnt = currentCnt;
    57                 label = point.getLabel();
    58             }
    59             map.put(point.getLabel(), currentCnt);
    60             times++;
    61             if (times > K) break;
    62         }
    63         return label;
    64     }
    65 
    66 
    67 }
     1 package com.data.knn;
     2 
     3 import com.google.common.collect.Lists;
     4 
     5 import java.util.List;
     6 
     7 /**
     8  * *********************************************************
     9  * <p/>
    10  * Author:     XiJun.Gong
    11  * Date:       2016-09-06 14:45
    12  * Version:    default 1.0.0
    13  * Class description:
    14  * <p/>
    15  * *********************************************************
    16  */
    17 public class Main {
    18 
    19     public static void main(String args[]) {
    20         List<Point> list = Lists.newArrayList();
    21         list.add(new Point(1., 1.1, "A"));
    22         list.add(new Point(1., 1., "A"));
    23         list.add(new Point(0., 0., "B"));
    24         list.add(new Point(0., 0.1, "B"));
    25         Point point = new Point(0.5, 0.5, null);
    26         KNN knn = new KNN();
    27         System.out.println(knn.classify(list, point, 3));
    28     }
    29 }

    结果:

    A
    

      

  • 相关阅读:
    http协议(二、报文格式)
    http协议(一、基础部分)
    echarts双轴轴线不对齐的解决办法
    svn 强制解锁的解决办法
    分析器错误
    JQgrid for asp.net
    养生宝典,值得一读(健康养生)
    ORM框架是什么
    WebSite和WebApplication的区别
    MVC3和MVC4相关问题
  • 原文地址:https://www.cnblogs.com/gongxijun/p/5845764.html
Copyright © 2020-2023  润新知