• k近邻算法(kNN)


    一、引入

    现有肿瘤大小与时间的关系散点图,其中红色为良性,蓝色为非良性

    现在获得了一条新数据(绿点),怎样根据现有的统计数据,分析其是否为良性?

    首先,我们取一个k值(k可以暂时理解为根据以往经验取得的一个最好值),如:k=3,那么k近邻算法所做的就是寻找离新数据点(绿点)最近的三个点

    并根据其所属类别进行投票(蓝:红=3:0),那么新数据点极有可能属于蓝色(非良性)

    又如,新数据点如下:

    此时数据点的k近邻结果为(蓝:红=1:2),那么新数据点极有可能属于红色(良性)

    二、kNN基础

    1.数据模拟

    现有测试数据如下:

    每个样本的特征集合:

    raw_data_X = [[3.393533211, 2.331273381],
                  [3.110073483, 1.781539638],
                  [1.343808831, 3.368360954],
                  [3.582294042, 4.679179110],
                  [2.280362439, 2.866990263],
                  [7.423436942, 4.696522875],
                  [5.745051997, 3.533989803],
                  [9.172168622, 2.511101045],
                  [7.792783481, 3.424088941],
                  [7.939820817, 0.791637231]
                 ]

    每个样本所属的类别:

    raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

    并将其存入numpy数组,命名为X_train和y_train

    X_train = np.array(raw_data_X)
    y_train = np.array(raw_data_y)

     现绘制该数据集的散点图

    plt.scatter(X_train[y_train==0, 0], X_train[y_train==0, 1], color='g')
    plt.scatter(X_train[y_train==1, 0], X_train[y_train==1, 1], color='r')

    现存在新数据x,需判断其属于哪一类(红或绿)?

    x = np.array([8.093607318, 3.365731514])

    在图中其位置如下(蓝点),按照kNN算法,该点应该属于红色一类的

    2.kNN的过程

    获取kNN临近点,最简单的方式就是求点之间的欧拉距离

    下面分别为2维、3维以及n维的欧拉距离公式

    欧拉距离通用公式:

    其代码实现如下:

    from math import sqrt
    distances = []
    for x_train in X_train:
        d = sqrt(np.sum((x_train - x)**2))
        distances.append(d)

    或使用生成表达式

    distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]

    求得所有点与新点的欧拉距离数组

    调用np.argsort方法获取排序后数组的元素在原数据集中的索引

    np.argsort(distances)

    先假设k=6,那么在原数据集中,距离x最近的六个点的索引分别为8、7、5、6、9、3

    nearest = np.argsort(distances)
    k = 6
    topK_y = [y_train[i] for i in nearest[:k]]

    可以获取它们的所属类别数组

    现对类别数组中的类别进行统计

    from collections import Counter
    votes = Counter(topK_y)

    并获取票数最多的类别作为结果

    votes.most_common(1)[0][0]

    由此可得,新数据x所属的类别很可能为1

    python实现代码:

    import numpy as np
    from math import sqrt
    from collections import Counter
    
    
    def kNN_classify(k, X_train, y_train, x):
    
        assert 1 <= k <= X_train.shape[0], "k must be valid"
        assert X_train.shape[0] == y_train.shape[0], \
            "the size of X_train must equal to the size of y_train"
        assert X_train.shape[1] == x.shape[0], \
            "the feature number of x must be equal to X_train"
    
        distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]
        nearest = np.argsort(distances)
    
        topK_y = [y_train[i] for i in nearest[:k]]
        votes = Counter(topK_y)
    
        return votes.most_common(1)[0][0]

  • 相关阅读:
    视音频开发测试文件下载
    H.264 中的Annex B格式和AVCC格式
    FFmpeg——命令笔记
    Gamma 矫正
    头文件 <string.h> <cstring> <string> 区别
    Serializable
    Oracle学习
    JDBC
    Servlet为主理解cookie,session,filter
    javaweb复习-环境篇
  • 原文地址:https://www.cnblogs.com/jizhiqiliao/p/15837554.html
Copyright © 2020-2023  润新知