• 使用KNN对MNIST数据集进行实验


    由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢。这里只是想测试观察一下KNN的效果而已,不调参。

    K选择之前看过貌似最好不要超过20,因此,此处选择了K=10,距离为欧式距离。如果需要改进,可以再调整K来选择最好的成绩。

    先跑了一遍不经过scale的,也就是直接使用像素灰度值来计算欧式距离进行比较。发现开始基本稳定在95%的正确率上,吓了一跳。因为本来觉得KNN算是没有怎么“学习”的机器学习算法了,猜测它的特点可能会是在任何情况下都可以用,但都表现的不是最好。所以估计在60%~80%都可以接受。没想到能基本稳定在95%上,确定算法和代码没什么问题后,突然觉得是不是这个数据集比较没挑战性。。。

    去MNIST官网(http://yann.lecun.com/exdb/mnist/),上面挂了以该数据集为数据的算法的结果比较。查看了一下KNN,发现有好多,而且错误率基本都在5%以内,甚至能做到1%以内。唔。

    跑的结果是,正确率:96.687%。也就是说,错误率error rate为3.31%左右。

    再跑一下经过scale的数据,即对灰度数据归一化到[0,1]范围内。看看效果是否有所提升。

    经过scale,最终跑的结果是,正确率:竟然也是96.687%! 也就是说,对于该数据集下,对KNN的数据是否进行归一化并无效果!

    在跑scale之前,个人猜测:由于一般对数据进行处理之前都进行归一化,防止高维诅咒(在784维空间中很容易受到高维诅咒)。因此,预测scale后会比前者要好一些的。但是,现在看来二者结果相同。也就是说,对于K=10的KNN算法中,对MNIST的预测一样的。

    对scale前后的正确率相同的猜测:由于在训练集合中有60000个数据点,因此0-9每个分类平均都有6000个数据点,在这样的情况下,对于测试数据集中的数据点,相临近的10个点中大部分都是其他分类而导致分类错误的概率会比较地(毕竟10相对与6000来说很小),所以,此时,KNN不仅可以取得较好的分类效果,而且对于是否scale并不敏感,效果相同。

    代码如下:

    1. #KNN for MNIST  
    2. from numpy import *  
    3. import operator  
    4.   
    5. def line2Mat(line):  
    6.     line = line.strip().split(' ')  
    7.     label = line[0]  
    8.     mat = []  
    9.     for pixel in line[1:]:  
    10.         pixel = pixel.split(':')[1]  
    11.         mat.append(float(pixel))  
    12.     return mat, label  
    13.   
    14. #matrix should be type: array. Or classify() will get error.  
    15. def file2Mat(fileName):  
    16.     f = open(fileName)  
    17.     lines = f.readlines()  
    18.     matrix = []  
    19.     labels = []  
    20.     for line in lines:  
    21.         mat, label = line2Mat(line)  
    22.         matrix.append(mat)  
    23.         labels.append(label)  
    24.     print 'Read file '+str(fileName) + ' to matrix done!'  
    25.     return array(matrix), labels  
    26.   
    27. #classify mat with trained data: matrix and labels. With KNN's K set.  
    28. def classify(mat, matrix, labels, k):  
    29.     diffMat = tile(mat, (shape(matrix)[0], 1)) - matrix  
    30.     #diffMat = array(diffMat)  
    31.     sqDiffMat = diffMat ** 2  
    32.     sqDistances = sqDiffMat.sum(axis=1)  
    33.     distances = sqDistances ** 0.5  
    34.     sortedDistanceIndex = distances.argsort()  
    35.     classCount = {}  
    36.     for i in range(k):  
    37.         voteLabel = labels[sortedDistanceIndex[i]]  
    38.         classCount[voteLabel] = classCount.get(voteLabel,0) + 1  
    39.     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True)  
    40.     return sortedClassCount[0][0]  
    41.       
    42. def classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K):  
    43.     rightCnt = 0  
    44.     for i in range(len(testMatrix)):  
    45.         if i % 100 == 0:  
    46.             print 'num '+str(i)+'. ratio: '+ str(float(rightCnt)/(i+1))  
    47.         label = testLabels[i]  
    48.         predictLabel = classify(testMatrix[i], trainMatrix, trainLabels, K)  
    49.         if label == predictLabel:  
    50.             rightCnt += 1  
    51.     return float(rightCnt)/len(testMatrix)  
    52.   
    53. trainFile = 'train_60k.txt'  
    54. testFile = 'test_10k.txt'  
    55. trainMatrix, trainLabels = file2Mat(trainFile)  
    56. testMatrix, testLabels = file2Mat(testFile)  
    57. K = 10  
    58. rightRatio = classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K)  
    59. print 'classify right ratio:' +str(right)  
  • 相关阅读:
    shell变量/环境变量和set/env/export用法_转
    常用英语短语累积
    可执行文件格式elf和bin
    spring boot 配置文件application
    (转)Linux命令grep
    plsql 数据迁移——导出表结构,表数据,表序号
    (转)logback 打印Mybitis中的sql执行过程
    (转)PLSQL Developer导入Excel数据
    Linux时间设置
    (转)Oracle中的rownum,ROWID的 用法
  • 原文地址:https://www.cnblogs.com/wt869054461/p/5030862.html
Copyright © 2020-2023  润新知