刚开始研读《machine learning in action》这本书,介绍的第一个算法就是k-近邻算法。
机器学习算法可分为监督学习和无监督学习,其中监督学习解决的是问题是分类和回归这两类问题,而无监督学习由于没有目标值和类别信息,将数据集合进行聚类。无监督学习尚不了解,以前课题用到了神经网络也并未对无监督学习方法涉及。
分类问题即给定一个实例数据,将其划分至合适的类别中;回归问题解决的是预测值,最简单的回归问题应该是物理实验课上用的一元二次回归了。
对于分类问题,k-近邻算法是一种简单有效的算法,其思路特别简单。假定存在一个已知样本集S,S中每个样本si对应有一个类别cj,其中类别集合C是有限的。那么给定一个待分类数据d,可由如下方法给出:
- 计算d与S中每个si之间的欧氏距离;
- 对所有的距离进行升序排列;
- 取距离最近的k个样本集s1~sk,其对应的类别为c1~ck;
- c1~ck中出现频率最高的类别就是d的类别。
算法实现起来也很简单,python版本如下:
#-*-coding: utf-8 -*- from numpy import * import operator import matplotlib import matplotlib.pyplot as pyplot from dircache import listdir def create_data_set(): group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]]) labels = ['A', 'A', 'B', 'B'] return group, labels def classify0(inX, data_set, labels, k): data_set_size = data_set.shape[0] # number of samples diff_mat = tile(inX, (data_set_size, 1)) - data_set # tile: expand to 1 cols and data_set_size rows diff_mat2 = diff_mat ** 2 distances2 = diff_mat2.sum(axis = 1) distances = distances2 ** 0.5 sorted_dist_index = distances.argsort() class_count = {} for i in xrange(k): vote_ilabel = labels[sorted_dist_index[i]] class_count[vote_ilabel] = class_count.get(vote_ilabel, 0) + 1 sorted_class_count = sorted(class_count.iteritems(), key = operator.itemgetter(1), reverse = True) return sorted_class_count[0][0] def test_classify0(): group, labels = create_data_set() res1 = classify0([0, 0], group, labels, 3) print '分类结果', res1 if __name__ == '__main__': test_classify0()
其中classify0就是k-近邻算法的实现。这里用到了numpy包。
测试结果
对于《machine learning in action》中给的几个例子我也重新做了一遍,其实大同小异,大部分工作都是如何将外部的数据导入:)
约会网站示例
#-*-coding: utf-8 –*-
from numpy import * import operator import matplotlib import matplotlib.pyplot as pyplot from dircache import listdir def file2matrix(filename): fr = open(filename) read_lines = fr.readlines() sample_count = len(read_lines) print '%d lines in "%s"' % (sample_count, filename) sample_matrix = zeros((sample_count, 3)) # 3个特征 label_vector = [] isample = 0 for line in read_lines: line = line.strip() one_sample_list = line.split(' ') sample_matrix[isample, :] = [double(item) for item in one_sample_list[0 : 3]] label_vector.append(int(one_sample_list[-1])) # 每行最后一个值为类别 isample += 1 return sample_matrix, label_vector
def auto_normalize(data_set):
min_val = data_set.min(0)
max_val = data_set.max(0)
ranges = max_val - min_val
m = data_set.shape[0]
norm_set = data_set - tile(min_val, (m, 1))
norm_set = norm_set / tile(ranges, (m, 1))
return norm_set, ranges, min_val
def test_dating_classify(): dating_matrix, dating_label = file2matrix('datingTestSet2.txt') # fig = pyplot.figure() # ax = fig.add_subplot(111) # ax.scatter(dating_matrix[:, 0], dating_matrix[:, 1], 15.0 * array(dating_label), 15 * array(dating_label)) # pyplot.show() norm_matrix, _, _ = auto_normalize(dating_matrix) verify_ratio = 0.1 samples_count = norm_matrix.shape[0] verify_count = int(verify_ratio * samples_count) error_count = 0.0 for i in xrange(verify_count): classify_result = classify0(norm_matrix[i, :], norm_matrix[verify_count : samples_count, :], dating_label[verify_count : samples_count], 9) print '分类器识别为%d,真实类别为%d' % (classify_result, dating_label[i]) if (classify_result != dating_label[i]): error_count += 1 print '分类错误率为:%.2f' % (error_count / float(verify_count)) if __name__ == '__main__': test_dating_classify()
测试结果
手写识别实例
#-*-coding: utf-8 -*- from numpy import * import operator import matplotlib import matplotlib.pyplot as pyplot from dircache import listdir def img2vector(filename): img_vector = zeros((1, 1024)) fr = open(filename) for i in xrange(32): line_str = fr.readline() line_str = line_str.strip() for j in xrange(32): img_vector[0, 32 * i + j] = int(line_str[j]) return img_vector def test_handwritting_classify(): handwritting_labels = [] training_files = listdir('trainingDigits') samples_count = len(training_files) training_matrix = zeros((samples_count, 1024)) # construct training matrix for i in xrange(samples_count): file_name_str = training_files[i] file_str = file_name_str.split('.')[0] label_str = int(file_str.split('_')[0]) handwritting_labels.append(label_str) training_matrix[i, :] = img2vector('trainingDigits/%s' % file_name_str) # test test_files = listdir('testDigits') error_count = 0 tests_count = len(test_files) for i in xrange(tests_count): file_name_str = test_files[i] file_str = file_name_str.split('.')[0] label_str = int(file_str.split('_')[0]) vector_under_test = img2vector('testDigits/%s' % file_name_str) classify_result = classify0(vector_under_test, training_matrix, handwritting_labels, 3) print '手写识别为%d, 实际为%d' % (classify_result, label_str) if (classify_result != label_str): error_count += 1 print '手写识别错误共计%d, 错误率%.2f' % (error_count, error_count / float(tests_count)) if __name__ == '__main__': test_handwritting_classify()
测试结果
总结
从原理上讲,k-近邻算法是精确有效的,也符合人的分类习惯,说白了,离谁最近就是谁。
但从使用的情况看,k-近邻算法运行速度非常慢(计算复杂度高),存储空间要求有很大(空间复杂度高)。