• KNN分类算法--python实现


    一、kNN算法分析

           K最近邻(k-Nearest Neighbor,KNN)分类算法可以说是最简单的机器学习算法了。它采用测量不同特征值之间的距离方法进行分类。它的思想很简单:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。

          KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。

           该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。因此可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分[参考机器学习十大算法]。

          总的来说就是我们已经存在了一个带标签的数据库,然后输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似(最近邻)的分类标签。一般来说,只选择样本数据库中前k个最相似的数据。最后,选择k个最相似数据中出现次数最多的分类。其算法描述如下:

    1)计算已知类别数据集中的点与当前点之间的距离;

    2)按照距离递增次序排序;

    3)选取与当前点距离最小的k个点;

    4)确定前k个点所在类别的出现频率;

    5)返回前k个点出现频率最高的类别作为当前点的预测分类。

    代码:

    #########################################
    # kNN: k Nearest Neighbors
    
    # Input:      inX: vector to compare to existing dataset (1xN)
    #             dataSet: size m data set of known vectors (NxM)
    #             labels: data set labels (1xM vector)
    #             k: number of neighbors to use for comparison 
                
    # Output:     the most popular class label
    #########################################
    
    from numpy import *
    import operator
    import os
    from Canvas import Line
    
    
    # classify using kNN
    def kNNClassify(newInput, dataSet, labels, k):
        numSamples = dataSet.shape[0] # shape[0] stands for the num of row
    
        ## step 1: calculate Euclidean distance
        # tile(A, reps): Construct an array by repeating A reps times
        # the following copy numSamples rows for dataSet
        diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise
        squaredDiff = diff ** 2 # squared for the subtract
        squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row
        distance = squaredDist ** 0.5
    
        ## step 2: sort the distance
        # argsort() returns the indices that would sort an array in a ascending order
        sortedDistIndices = argsort(distance)
    
        classCount = {} # define a dictionary (can be append element)
        for i in xrange(k):
            ## step 3: choose the min k distance
            voteLabel = labels[sortedDistIndices[i]]
    
            ## step 4: count the times labels occur
            # when the key voteLabel is not in dictionary classCount, get()
            # will return 0
            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
    
        ## step 5: the max voted class will return
        maxCount = 0
        for key, value in classCount.items():
            if value > maxCount:
                maxCount = value
                maxIndex = key
    
        return maxIndex    
    
    # convert image to vector
    def  img2vector(filename):
         rows = 32
         cols = 32
         imgVector = zeros((1, rows * cols))
         
         fileIn = open(filename)
         for row in xrange(rows):
             lineStr = fileIn.readline()
             for col in xrange(cols):
                 imgVector[0, row * 32 + col] = int(lineStr[col])
    
         return imgVector
    
    # load dataSet
    def loadDataSet():
        ## step 1: Getting training set
        print "---Getting training set..."
        dataSetDir = 'F:/eclipse/workspace/KnnTest/'
        trainingFileList = os.listdir(dataSetDir + 'trainingDigits') # load the training set
        numSamples = len(trainingFileList)
    
        train_x = zeros((numSamples, 1024))
        train_y = []
        for i in xrange(numSamples):
            filename = trainingFileList[i]
    
            # get train_x
            train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename) 
    
            # get label from file name such as "1_18.txt"
            label = int(filename.split('_')[0]) # return 1
            train_y.append(label)
    
        ## step 2: Getting testing set
        print "---Getting testing set..."
        testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set
        numSamples = len(testingFileList)
        test_x = zeros((numSamples, 1024))
        test_y = []
        for i in xrange(numSamples):
            filename = testingFileList[i]
    
            # get train_x
            test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename) 
    
            # get label from file name such as "1_18.txt"
            label = int(filename.split('_')[0]) # return 1
            test_y.append(label)
    
        return train_x, train_y, test_x, test_y
    
    # test hand writing class
    def testHandWritingClass():
        ## step 1: load data
        print "step 1: load data..."
        train_x, train_y, test_x, test_y = loadDataSet()
    
        ## step 2: training...
        print "step 2: training..."
        pass
    
        ## step 3: testing
        print "step 3: testing..."
        numTestSamples = test_x.shape[0]
        matchCount = 0
        for i in xrange(numTestSamples):
            predict = kNNClassify(test_x[i], train_x, train_y, 3)
            if predict == test_y[i]:
                matchCount += 1
        accuracy = float(matchCount) / numTestSamples
    
        ## step 4: show the result
        print "step 4: show the result..."
        print 'The classify accuracy is: %.2f%%' % (accuracy * 100)

    另外创建一个脚本knnTest.py

    import KNN
    KNN.testHandWritingClass()

    其中数据集下载链接为:http://download.csdn.net/detail/zouxy09/6610571

  • 相关阅读:
    jNotify:操作结果信息提示条
    jqurey datatable tableTools 自定义button元素 以及按钮定义事件
    jqurey datatable mRender FnRender 不起作用问题
    VS2013 ViewData ViewBag Ajax等关键词报错(当前上下文不存在名称)而且不提示也点不出来,但是可以正常运行,
    关于 update别名 与update select
    EF 保证线程内唯一 上下文的创建
    文件接收上传
    lucence.net+盘古分词
    log4net 入门教程
    MVC+EF OA观看视频记录
  • 原文地址:https://www.cnblogs.com/graceting/p/4166016.html
Copyright © 2020-2023  润新知