• K-近邻算法的Python实现 : 源代码分析


    网上介绍K-近邻算法的样例非常多。其Python实现版本号基本都是来自于机器学习的入门书籍《机器学习实战》,尽管K-近邻算法本身非常easy,但非常多刚開始学习的人对其Python版本号的源码理解不够,所以本文将对其源码进行分析。


    什么是K-近邻算法?

    简单的说,K-近邻算法採用不同特征值之间的距离方法进行分类。所以它是一个分类算法。

    长处:无数据输入假定,对异常值不敏感

    缺点:复杂度高


    好了,直接先上代码,等会在分析:(这份代码来自《机器学习实战》)

    def classify0(inx, dataset, lables, k):
        dataSetSize = dataset.shape[0]
        diffMat = tile(inx, (dataSetSize, 1)) - dataset
        sqDiffMat = diffMat**2
        sqDistance = sqDiffMat.sum(axis=1)
        distances = sqDistance**0.5
        sortedDistances = distances.argsort()
        classCount={}
        for i in range(k):
            label = lables[sortedDistances[i]]
            classCount[label] = classCount.get(label, 0) + 1
        sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]
    

     

    该函数的原理是:

    存在一个样本数据集合,也称为训练集,在样本集中每一个数据都存在标签。在我们输入没有标签的新数据后,将新数据的每一个特征与样本集中相应的特征进行比較,然后提取最相似(近期邻)的分类标签。

    一般我们仅仅选样本数据集中前K 个最相似的数据。最后。出现次数最多的分类就是新数据的分类。


    classify0函数的參数意义例如以下:

    inx : 是输入没有标签的新数据,表示为一个向量。

    dataset: 是样本集。

    表示为向量数组。

    labels:相应样本集的标签。

    k:即所选的前K。


    用于产生数据样本的简单函数:


    def create_dataset():
        group = array([[1.0, 1.1], [1.0, 1.1], [0, 0], [0, 0.1]])
        labels = ['A', 'A', 'B', 'B']
        return group, labels


    注意,array是numpy里面的。

    我们须要实现import进来。

    from numpy import *
    import operator


    我们在调用时。

    group,labels = create_dataset()
    result = classify0([0,0], group, labels, 3)
    print result

    显然,[0,0]特征向量肯定是属于B 的,上面也将打印B。


    知道了这些。刚開始学习的人应该对实际代码还是非常陌生。不急,正文開始了!


    源代码分析


    dataSetSize = dataset.shape[0]

    shape是array的属性,它描写叙述了一个数组的“形状”,也就是它的维度。比方,

    In [2]: dataset = array([[1.0, 1.1], [1.0, 1.1], [0, 0], [0, 0.1]])
    
    In [3]: print dataset.shape
    (4, 2)
    

    所以,dataset.shape[0] 就是样本集的个数。


    diffMat = tile(inx, (dataSetSize, 1)) - dataset

    tile(A,rep)函数是基于数组A来构造数组的,详细怎么构造就看第二个參数了。其API介绍有点绕,但简单的使用方法相信几个样例就能明确。

    我们看看tile(inx, (4, 1))的结果,

    In [5]: tile(x, (4, 1))
    Out[5]: 
    array([[0, 0],
           [0, 0],
           [0, 0],
           [0, 0]])
    

    你看。4扩展的是数组的个数(本来1个。如今4个),1扩展的是每一个数组元素的个数(原来是2个,如今还是两个)。

    为证实上面的结论,

    In [6]: tile(x,(4,2))
    Out[6]: 
    array([[0, 0, 0, 0],
           [0, 0, 0, 0],
           [0, 0, 0, 0],
           [0, 0, 0, 0]])
    

    和。

    In [7]: tile(x,(2,2))
    Out[7]: 
    array([[0, 0, 0, 0],
           [0, 0, 0, 0]])
    

    关于,tile的详细使用方法。请自行查阅API DOC。


    得到tile后,减去dataset。

    这类似一个矩阵的减法。结果仍是一个 4 * 2的数组。

    In [8]: tile(x, (4, 1)) - dataset
    Out[8]: 
    array([[-1. , -1.1],
           [-1. , -1.1],
           [ 0. ,  0. ],
           [ 0. , -0.1]])
    

    结合欧式距离的求法,后面的代码就清晰些,对上面结果平方运算,求和。开方。

    我们看看求和的方法,

    sqDiffMat.sum(axis=1)

    当中。

    In [14]: sqDiffmat
    Out[14]: 
    array([[ 1.  ,  1.21],
           [ 1.  ,  1.21],
           [ 0.  ,  0.  ],
           [ 0.  ,  0.01]])
    


    求和的结果是对行求和,是一个N*1的数组。

    假设要对列求和,

    sqlDiffMat.sum(axis=0)

    argsort()是对数组升序排序的。


    classCount是一个字典,key是标签。value是该标签出现的次数。


    这样。算法的一些详细代码细节就清楚了。




  • 相关阅读:
    shell数组
    Apache HTTP Server 与 Tomcat 的三种连接方式介绍
    实现Java动态类载入机制
    Tomcat 阀
    MYSQL 常用命令
    MYSQL字符数字转换
    主题:MySQL数据库操作实战
    日本手机三大代理商的UA
    Java解析XML文档——dom解析xml (转载)
    MS sql server和mysql中update多条数据的例子
  • 原文地址:https://www.cnblogs.com/clnchanpin/p/6802617.html
Copyright © 2020-2023  润新知