最近开始啃《机器学习实战》,把其中第二章的K近邻算法的代码笔记整理如下。
kNN分类算法:
对未知类别属性的数据集中的每个点依次执行以下操作:
1.计算已知类别数据集中的点与当前点的距离;
2.按照距离递增次序排序;
3.选择与当前点距离最小的k个点;
4.确定前k个点所在类别的出现频率;
5.返回前k个点出现频率最高的类别作为当前点的预测分类;
具体实现代码如下:
import numpy as np import operator import matplotlib.pyplot as plt def creatDataSet(): group = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) labels = ['A','A','B','B'] return group, labels def classify0(inX, dataSet, labels, k): ''' :param inX: 用于分类的输入向量 :param dataSet: 输入的训练样本集 :param labels: 标签向量 :param k: 选择最近邻的数目 :return: 发生频率最高的元素标签 ''' # 训练样本的第一维度数量 dataSetSize = dataSet.shape[0] # Construct an array by repeating A the number of times given by reps # tile函数根据第二参数reps确定复制几次第一个参数(数组)的元素并添加到 # 第一个参数(数组)中,返回新的数组 ''' 例如 In [10]: arr = np.array([1,2,3,4]) In [11]: max_num = 2 In [12]: arr_new = np.tile(arr,(1,max_num)) In [13]: arr_new Out[13]: array([[1, 2, 3, 4, 1, 2, 3, 4]]) In [14]: arr_new = np.tile(arr,(max_num,3)) In [15]: arr_new Out[15]: array([[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]]) ''' # k近邻法的算法核心就是将输入的的向量与训练样本的向量的 # 差的平方和开根得到两个向量点之间的距离 # d = √ ̄(xA0 - xB0)^2 + (xA1 - xB1)^2 # 得到新的数组后减去dataSet,相当于对将输入的数据与样本数据进行差值计算 diffMat = np.tile(inX, (dataSetSize,1)) - dataSet # 距离差的平方 sqDiffMat = diffMat**2 # 差的平方求和 sqDistances = sqDiffMat.sum(axis=1) # 开根得到距离值 distances = sqDistances**0.5 # 对结果进行排序,返回排序后的索引组成的数组 ''' argsort方法实例: In [32]: arr = np.array([9,3,12,7,4]) In [33]: arr.argsort() Out[33]: array([1, 4, 3, 0, 2]) ''' sortedDistIndicies = distances.argsort() classCount={} # 通过循环将标签 for i in range(k): # 循环取出已排序的索引数组的索引值,通过索引值取出labels的标签 voteIlabel = labels[sortedDistIndicies[i]] # 以标签值作为classCount的键,如果claaCount已经存在该标签键的值,就取出加1 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 使用sorted返回一个经过排序的列表,其中key=operator.itemgetter: # key这个参数为指定一个接收一个参数的函数,这个函数用于从每个元素中提取一个用于比较的关键字, # operator.itemgetter返回一个函数,用于获取对象的哪些维的数据,参数为一些序号 ''' operator.itemgette用法: Out[47]: ce = operator.itemgetter(2) In [48]: d = [4,5,6] In [49]: ce(d) Out[49]: 6 ''' # 生成一个新的排序列表 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 返回数据 return sortedClassCount[0][0] if __name__ == '__main__': group, labels = creatDataSet() # plt.figure() # for i in range(len(group)): # plt.scatter(group[i][0],group[i][1]) # plt.show() x = classify0([0.3,0.3],group,labels,3) print(x)
输出结果:B
因为 [0.3,0.3] 在与样本数据中B类数据更为接近。
如果输入数据换为[1.2,1.2],输出结果为A。
总结起来就是:距离哪个类近,就输出哪个类。