1 ''' 2 Created on Nov 06, 2017 3 kNN: k Nearest Neighbors 4 5 Input: inX: vector to compare to existing dataset (1xN) 6 dataSet: size m data set of known vectors (NxM) 7 labels: data set labels (1xM vector) 8 k: number of neighbors to use for comparison (should be an odd number) 9 10 Output: the most popular class label 11 12 @author: Liu Chuanfeng 13 ''' 14 import operator 15 import numpy as np 16 import matplotlib.pyplot as plt 17 from os import listdir 18 19 def classify0(inX, dataSet, labels, k): 20 dataSetSize = dataSet.shape[0] 21 diffMat = np.tile(inX, (dataSetSize,1)) - dataSet 22 sqDiffMat = diffMat ** 2 23 sqDistances = sqDiffMat.sum(axis=1) 24 distances = sqDistances ** 0.5 25 sortedDistIndicies = distances.argsort() 26 classCount = {} 27 for i in range(k): 28 voteIlabel = labels[sortedDistIndicies[i]] 29 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 30 sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True) 31 return sortedClassCount[0][0] 32 33 #数据预处理,将文件中数据转换为矩阵类型 34 def file2matrix(filename): 35 fr = open(filename) 36 arrayLines = fr.readlines() 37 numberOfLines = len(arrayLines) 38 returnMat = np.zeros((numberOfLines, 3)) 39 classLabelVector = [] 40 index = 0 41 for line in arrayLines: 42 line = line.strip() 43 listFromLine = line.split(' ') 44 returnMat[index,:] = listFromLine[0:3] 45 classLabelVector.append(int(listFromLine[-1])) 46 index += 1 47 return returnMat, classLabelVector 48 49 #数据归一化处理:由于矩阵各列数据取值范围的巨大差异导致各列对计算结果的影响大小不一,需要归一化以保证相同的影响权重 50 def autoNorm(dataSet): 51 maxVals = dataSet.max(0) 52 minVals = dataSet.min(0) 53 ranges = maxVals - minVals 54 m = dataSet.shape[0] 55 normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1)) 56 return normDataSet, ranges, minVals 57 58 #约会网站测试代码 59 def datingClassTest(): 60 hoRatio = 0.10 61 datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') 62 normMat, ranges, minVals = autoNorm(datingDataMat) 63 m = normMat.shape[0] 64 numTestVecs = int(m * hoRatio) 65 errorCount = 0.0 66 for i in range(numTestVecs): 67 classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3) 68 print('theclassifier came back with: %d, the real answer is: %d' % (classifyResult, datingLabels[i])) 69 if ( classifyResult != datingLabels[i]): 70 errorCount += 1.0 71 print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100)) 72 73 #约会网站预测函数 74 def classifyPerson(): 75 resultList = ['not at all', 'in small doses', 'in large doses'] 76 percentTats = float(input("percentage of time spent playing video games?")) 77 ffMiles = float(input("frequent flier miles earned per year?")) 78 iceCream = float(input("liters of ice cream consumed per year?")) 79 datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') 80 normMat, ranges, minVals = autoNorm(datingDataMat) 81 inArr = np.array([ffMiles, percentTats, iceCream]) 82 classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3) 83 print ("You will probably like this persoon:", resultList[classifyResult - 1]) 84 85 86 #手写识别系统#============================================================================================================ 87 #数据预处理:输入图片为32*32的文本类型,将其形状转换为1*1024 88 def img2vector(filename): 89 returnVect = np.zeros((1, 1024)) 90 fr = open(filename) 91 for i in range(32): 92 lineStr = fr.readline() 93 for j in range(32): 94 returnVect[0, 32*i+j] = int(lineStr[j]) 95 return returnVect 96 97 #手写数字识别系统测试代码 98 def handwritingClassTest(): 99 hwLabels = [] 100 trainingFileList = listdir('C:\Private\PycharmProjects\Algorithm\kNNdigits\traingDigits') 101 m = len(trainingFileList) 102 trainingMat = np.zeros((m, 1024)) 103 for i in range(m): #| 104 fileNameStr = trainingFileList[i] #| 105 fileName = fileNameStr.split('.')[0] #| 获取训练集路径下每一个文件,分割文件名,将第一个数字作为标签存储在hwLabels中 106 classNumber = int(fileName.split('_')[0]) #| 107 hwLabels.append(classNumber) #| 108 trainingMat[i,:] = img2vector('C:\Private\PycharmProjects\Algorithm\kNNdigits\traingDigits\%s' % fileNameStr) #变换矩阵形状: from 32*32 to 1*1024 109 testFileList = listdir('C:\Private\PycharmProjects\Algorithm\kNNdigits\testDigits') 110 errorCount = 0.0 111 mTest = len(testFileList) 112 for i in range(mTest): #同训练集 113 fileNameStr = testFileList[i] 114 fileName = fileNameStr.split('.')[0] 115 classNumber = int(fileName.split('_')[0]) 116 vectorUnderTest = img2vector('C:\Private\PycharmProjects\Algorithm\kNNdigits\testDigits\%s' % fileNameStr) 117 classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #计算欧氏距离并分类,返回计算结果 118 print ('The classifier came back with: %d, the real answer is: %d' % (classifyResult, classNumber)) 119 if (classifyResult != classNumber): 120 errorCount += 1.0 121 print ('The total number of errors is: %d' % (errorCount)) 122 print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100)) 123 124 # Simple unit test of func: file2matrix() 125 #datingDataMat, datingLabels = file2matrix('datingTestSet2.txt') 126 #print (datingDataMat) 127 #print (datingLabels) 128 129 # Usage of figure construction of matplotlib 130 #fig=plt.figure() 131 #ax = fig.add_subplot(111) 132 #ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels)) 133 #plt.show() 134 135 # Simple unit test of func: autoNorm() 136 #normMat, ranges, minVals = autoNorm(datingDataMat) 137 #print (normMat) 138 #print (ranges) 139 #print (minVals) 140 141 # Simple unit test of func: img2vector 142 #testVect = img2vector('C:\Private\PycharmProjects\Algorithm\kNNdigits\testDigits\0_13.txt') 143 #print (testVect[0, 32:63] ) 144 145 #约会网站测试 146 datingClassTest() 147 148 #约会网站预测 149 classifyPerson() 150 151 #手写数字识别系统预测 152 handwritingClassTest()
Output:
theclassifier came back with: 3, the real answer is: 3
the total error rate is: 0.0%
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 0.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 0.0%
...
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 4.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 4.0%
theclassifier came back with: 3, the real answer is: 1
the total error rate is: 5.0%
percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice cream consumed per year?0.5
You will probably like this persoon: in small doses
...
The classifier came back with: 9, the real answer is: 9
The total number of errors is: 27
The total error rate is: 6.8%
Reference:
《机器学习实战》