本文代码均来自《机器学习实战》
这里讲了两个例子,datingclass 和 figureclass,用到的都是KNN,要调用这两个例子的话就在代码末尾加datingClassTest()
和handwritingClassTest()
至于第二个例子中用到的图片,是指那种字符点阵的图片,但是对于同样的原理,灰度图片应该也是可以的,虽然准确率就不一定了吧
图片长这个样子:
0_0.txt
00000000000001111000000000000000
00000000000011111110000000000000
00000000001111111111000000000000
00000001111111111111100000000000
00000001111111011111100000000000
00000011111110000011110000000000
00000011111110000000111000000000
00000011111110000000111100000000
00000011111110000000011100000000
00000011111110000000011100000000
00000011111100000000011110000000
00000011111100000000001110000000
00000011111100000000001110000000
00000001111110000000000111000000
00000001111110000000000111000000
00000001111110000000000111000000
00000001111110000000000111000000
00000011111110000000001111000000
00000011110110000000001111000000
00000011110000000000011110000000
00000001111000000000001111000000
00000001111000000000011111000000
00000001111000000000111110000000
00000001111000000001111100000000
00000000111000000111111000000000
00000000111100011111110000000000
00000000111111111111110000000000
00000000011111111111110000000000
00000000011111111111100000000000
00000000001111111110000000000000
00000000000111110000000000000000
00000000000011000000000000000000
'''
Created on Sep 16, 2010
kNN: k Nearest Neighbors
Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number)
Output: the most popular class label
@author: pbharrin
'''
from numpy import *
#NumPy是Python语言的一个扩展程序库。支持高端大量的维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库。
import pdb
pdb.set_trace()#用于调试
import operator#operator 模块是 Python 中内置的操作符函数接口,它定义了算术,比较和与标准对象 API 相对应的其他操作的内置函数。
#operator 模块是用 C 实现的,所以执行速度比 Python 代码快。
from os import listdir#os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。这个列表以字母顺序。 它不包括 '.' 和'..' 即使它在文件夹中。
def classify0(inX, dataSet, labels, k):
#这个方法每次只能处理一个样本
#这里的dataSet是一个数组,inX是待分类的样本,K是neighbor的数量
#inX是以行向量的方式储存的,dataSet也是一行表示一个样本
#KNN算法几乎不需要“训练”,属于即开即用那种的
dataSetSize = dataSet.shape[0]#这是样本个数
diffMat = tile(inX, (dataSetSize,1)) - dataSet#ile()函数内括号中的参数代表扩展后的维度,而扩展是通过复制A来运作的,最终得到一个与括号内的参数(reps)维度一致的数组(矩阵)
#将inX复制为和样本一样多的行数
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)#sum对array求和,如果参数是0,就按列求和,返回一个行向量;如果参数是1,就按行求和,但是也返回一个行向量(从计算的角度来看,是列向量转置之后的)
distances = sqDistances**0.5#**是python中的幂运算,用在矩阵上的效果的对应位置相乘而不是矩阵乘法中的A*A
##现在distances中的每一个元素代表了待求目标点和每一个样本点之间的距离
sortedDistIndicies = distances.argsort() #argsort是numpy的方法,从小到大排序(不加参数的话),返回的是index而不是排序后的元素本身
classCount={}#这是个字典类型
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1#给这个类型加一
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)#选出k中数量最大的label
return sortedClassCount[0][0]##输出最大的label
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
#[[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]是list类型的二维向量,转成array可以方便进行向量化计算(array是numpy封装的)
labels = ['A','A','B','B']
return group, labels
def file2matrix(filename):#将文件数据(data,label)转换为矩阵
fr = open(filename)
numberOfLines = len(fr.readlines()) #get the number of lines (行数)in the file
returnMat = zeros((numberOfLines,3)) #prepare matrix to return,这里将矩阵的列数硬编码为3了,需要的时候可以改
classLabelVector = [] #prepare labels return
index = 0
fr = open(filename)#为啥这里要再读一次呢?因为上面的fr.readlines()为了获取数据行数已经把全文读完了
for line in fr.readlines():
line = line.strip()#移除字符串头尾指定的字符(默认为空格或换行符)或字符序列
listFromLine = line.split(' ')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector
def autoNorm(dataSet):#归一化,使用公式为 newValue=(oldValue-min)/(max-min)
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide
return normDataSet, ranges, minVals
#第一个KNN例子,classify date
def datingClassTest():
hoRatio = 0.50 #hold out 10%
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
if (classifierResult != datingLabels[i]): errorCount += 1.0
print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
print(errorCount)
def img2vector(filename):
#将路径中文件转换为行向量进行存储,说到底干的就是一个char转int的活,
returnVect = zeros((1,1024))#行向量,这里不好的一点就是特征数也是写死的,要实现泛用需要修改
fr = open(filename)
for i in range(32):
lineStr = fr.readline()#读一行
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
#第二个例子,归类由字符串组成的数字
def handwritingClassTest():
hwLabels = []#存储所有样本的label
trainingFileList = listdir('trainingDigits') #load the training set,返回的是一个字符串数组,里面是该文件夹中所有文件的名称
print(trainingFileList)
m = len(trainingFileList)#m代表训练集样本个数
trainingMat = zeros((m,1024))#1024是特征个数
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
#为什么要分这个?因为这里的样本比较特殊,文件名的第一个数组就代表了label
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#将路径输入,返回转换好的矩阵
testFileList = listdir('testDigits') #iterate through the test set
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):#对验证集一个一个进行运算,虽然这种for比较慢吧`````
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0] #take off .txt
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print("
the total number of errors is: %d" % errorCount)
print("
the total error rate is: %f" % (errorCount/float(mTest)))