• 《机器学习实战》第2章——k近邻算法(笔记)


    一、KNN(k-近邻算法)工作原理

    一句话:从训练集中找出k个最接近测试数据的训练样本,再从这k个样本中找出出现次数最多的分类,作为测试数据的分类。

    存在一个样本数据集合(训练样本集)且样本集中每个数据都存在标签(即我们知道样本集中每一数据与所属分类的对应关系);

    输入没有标签的数据后,将该数据的每个特征与样本集中数据对应的特征进行比较,并提取样本集中特征最相似数据(最近邻)的分类标签;

    选择样本集中前k个最相似的数据(这就是k-近邻算法中k的出处,通常k是不大于20的整数);

    选择k个最相似数据中出现次数最多的分类,作为新数据的分类。

    二、优缺点

    优点:

    1、理论成熟、思想简单、易理解和实现;

    2、可用于分类(包括非线性分类)、回归;

    3、计算时间和空间线性于训练集的规模(训练时间复杂度比支持向量机之类的算法低,O(n))

    4、适合单标签多分类和多标签分类问题;

    5、对于类域的交叉或重叠较多的待分类样本集较为适合;

    缺点:

    1、计算量大(尤其是特征数比较多的时候)

    2、对不平衡数据集(数据集中各个类别的样本量极不均衡)效果差(可采用加权投票法改进)

    3、k值的选择对分类效果有很大影响(较小的话对噪声敏感,需估计最佳k值)

    4、可解释性不强

    三、代码

    代码来自于《机器学习实践》,添加了一些小注释和一些测试代码

      1 #encoding:utf-8
      2 '''
      3 Created on Sep 16, 2010
      4 kNN: k Nearest Neighbors
      5 
      6 Input:      inX: vector to compare to existing dataset (1xN)
      7             dataSet: size m data set of known vectors (NxM)
      8             labels: data set labels (1xM vector)
      9             k: number of neighbors to use for comparison (should be an odd number)
     10             
     11 Output:     the most popular class label
     12 
     13 @author: pbharrin
     14 '''
     15 from numpy import *
     16 import operator
     17 from os import listdir
     18 
     19 #inX:用于分类的输入向量,即将对其进行分类
     20 #dataSet:训练样本集
     21 #labels:标签向量
     22 def classify0(inX, dataSet, labels, k):
     23     dataSetSize = dataSet.shape[0]  # 得到数组的行数。即知道有几个训练数据
     24     diffMat = tile(inX, (dataSetSize,1)) -  dataSet #tile将原来的一个数组,行数扩充dataSetSize个,列数不扩充。diffMat得到了目标与训练数值之间的差值
     25     sqDiffMat = diffMat**2  #各个元素分别平方
     26     sqDistances = sqDiffMat.sum(axis=1) # 对沿着轴1的方向进行数据相加处理,即得到一个每一个距离的平方
     27     distances = sqDistances**0.5 # 开方的距离
     28     sortedDistIndices = distances.argsort()    #升序排列
     29     classCount={} # 定义字典
     30     for i in range(k):
     31         voteIlabel = labels[sortedDistIndices[i]] # 从排序好的list中依次获取索引,并根据该索引,获得相应距离对应的标签值
     32         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 # 从字典中,获取该标签值对应的统计数,若还没有标签值,则取默认值0,并且+1
     33     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
     34     return sortedClassCount[0][0]
     35 
     36 def createDataSet():
     37     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
     38     labels = ['A','A','B','B']
     39     return group, labels
     40 
     41 def file2matrix(filename):
     42     fr = open(filename)
     43     numberOfLines = len(fr.readlines())         #get the number of lines in the file
     44     returnMat = zeros((numberOfLines,3))        #prepare matrix to return
     45     print(type(returnMat))
     46     classLabelVector = []                       #prepare labels return   
     47     fr = open(filename)
     48     index = 0
     49     for line in fr.readlines():
     50         line = line.strip()
     51         listFromLine = line.split('	')
     52         returnMat[index,:] = listFromLine[0:3] # numpy数组赋值方式
     53         classLabelVector.append(int(listFromLine[-1]))
     54         index += 1
     55     return returnMat,classLabelVector
     56     
     57 def autoNorm(dataSet):
     58     minVals = dataSet.min(0) # 0 表示纵轴,1表示横轴,与matrix刚好相反。注意这里保留了所有列的平均值
     59     maxVals = dataSet.max(0)
     60     ranges = maxVals - minVals
     61     normDataSet = zeros(shape(dataSet))
     62     m = dataSet.shape[0]
     63     normDataSet = dataSet - tile(minVals, (m,1))
     64     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
     65     return normDataSet, ranges, minVals
     66    
     67 def datingClassTest():
     68     hoRatio = 0.10      #hold out 10%
     69     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
     70     normMat, ranges, minVals = autoNorm(datingDataMat)
     71     m = normMat.shape[0]
     72     numTestVecs = int(m*hoRatio)
     73     errorCount = 0.0
     74     for i in range(numTestVecs):
     75         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
     76         print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
     77         if (classifierResult != datingLabels[i]): errorCount += 1.0
     78     print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
     79     print(errorCount)
     80     
     81 def img2vector(filename):
     82     returnVect = zeros((1,1024))
     83     fr = open(filename)
     84     for i in range(32):
     85         lineStr = fr.readline()
     86         for j in range(32):
     87             returnVect[0,32*i+j] = int(lineStr[j])
     88     return returnVect
     89 
     90 def handwritingClassTest():
     91     hwLabels = []
     92     trainingFileList = listdir('trainingDigits')           #load the training set
     93     m = len(trainingFileList)
     94     trainingMat = zeros((m,1024))
     95     for i in range(m):
     96         fileNameStr = trainingFileList[i]
     97         fileStr = fileNameStr.split('.')[0]     #take off .txt
     98         classNumStr = int(fileStr.split('_')[0])
     99         hwLabels.append(classNumStr)
    100         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    101     testFileList = listdir('testDigits')        #iterate through the test set
    102     errorCount = 0.0
    103     mTest = len(testFileList)
    104     for i in range(mTest):
    105         fileNameStr = testFileList[i]
    106         fileStr = fileNameStr.split('.')[0]     #take off .txt
    107         classNumStr = int(fileStr.split('_')[0])
    108         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
    109         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
    110         print("the classifier cam e back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
    111         if (classifierResult != classNumStr): errorCount += 1.0
    112     print ("
    the total number of errors is: %d" % errorCount)
    113     print ("
    the total error rate is: %f" % (errorCount/float(mTest)))
    114 
    115 
    116 if __name__ == '__main__':
    117     trainingMat = zeros((10,4))
    118     print(trainingMat)
    119     m = len(trainingMat)
    120     for i in range(m):
    121         trainingMat[i,:] = [1,2,3,4]
    122         print ('-----------------')
    123         print (trainingMat)
    124 
    125     a = array([[3,2,3,4],[3,4,5,6]])
    126     a = a**2
    127     print (a)
    128     print ('-----------')
    129     sqDistances = a.sum(axis=0) # 按y轴(纵轴)相加
    130     print (sqDistances)
    131     sqDistances2 = a.sum(axis=1) # 沿x轴(横轴)相加
    132     print (sqDistances2)
    133     distances = sqDistances**0.5
    134     print (distances)
    135     print (distances.argsort())
    136     print (distances[0])
    137 
    138     x = array([1,4,3,-1,6,9])
    139     print (x.argsort()[-1])
    140     y = x.argsort()
    141     print (x)
    142     print (y)
    143     print (x[y[0]])
    144 
    145     classCount = {0:3, 5:2, 4:6}
    146     print(classCount)
    147     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    148     print(sortedClassCount)
    149     print (sortedClassCount[0][0])
    150 
    151     handwritingClassTest() # 手写字体分类测试 
    152     datingClassTest() # 约会对象分类测试

    四、绘制里程数与玩视频游戏所占比例的数据散点图

    代码如下:

    '''
    Created on Oct 6, 2010
    
    @author: Peter
    '''
    from numpy import *
    import matplotlib
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    
    
    n = 1000 #number of points to create
    xcord1 = []; ycord1 = []
    xcord2 = []; ycord2 = []
    xcord3 = []; ycord3 = []
    markers =[]
    colors =[]
    fw = open('testSet.txt','w')
    for i in range(n):
        [r0,r1] = random.standard_normal(2) # 正态分布函数中区数据个数
        myClass = random.uniform(0,1) # 从0到1中取随机值,将0.66%(即66分以上的,标识为喜欢)
        if (myClass <= 0.16): # 
            fFlyer = random.uniform(22000, 60000) # 自定义里程数
            tats = 3 + 1.6*r1 # 自定义玩视频游戏时间的函数(这个蛮有意思的)
            markers.append(20) # 大小定义
            colors.append(2.1) # 颜色定义
            classLabel = 1 #'didntLike'
            xcord1.append(fFlyer); ycord1.append(tats)
        elif ((myClass > 0.16) and (myClass <= 0.33)):
            fFlyer = 6000*r0 + 70000
            tats = 10 + 3*r1 + 2*r0
            markers.append(20)
            colors.append(1.1)
            classLabel = 1 #'didntLike'
            if (tats < 0): tats =0 # 异常值处理
            if (fFlyer < 0): fFlyer =0 # 异常值处理
            xcord1.append(fFlyer); ycord1.append(tats)
        elif ((myClass > 0.33) and (myClass <= 0.66)):
            fFlyer = 5000*r0 + 10000
            tats = 3 + 2.8*r1
            markers.append(30)
            colors.append(1.1)
            classLabel = 2 #'smallDoses'
            if (tats < 0): tats =0
            if (fFlyer < 0): fFlyer =0
            xcord2.append(fFlyer); ycord2.append(tats)
        else:
            fFlyer = 10000*r0 + 35000 # 自定义:里程多又爱玩游戏的超级喜欢(里程多,所以会玩)
            tats = 10 + 2.0*r1 # 自定义:里程多又爱玩游戏的超级喜欢(玩游戏时间多,可以和自己一起玩)
            markers.append(50)
            colors.append(0.1)
            classLabel = 3 #'largeDoses'
            if (tats < 0): tats =0
            if (fFlyer < 0): fFlyer =0
            xcord3.append(fFlyer); ycord3.append(tats)    
    
    fw.close()
    fig = plt.figure()
    ax = fig.add_subplot(111)
    #ax.scatter(xcord,ycord, c=colors, s=markers)
    type1 = ax.scatter(xcord1, ycord1, s=20, c='red')
    type2 = ax.scatter(xcord2, ycord2, s=30, c='green')
    type3 = ax.scatter(xcord3, ycord3, s=50, c='blue')
    ax.legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
    ax.axis([-5000,100000,-2,25]) # 定义x轴和y轴的起始点
    plt.xlabel('Frequent Flyier Miles Earned Per Year') # 定义x轴标签描述
    plt.ylabel('Percentage of Time Spent Playing Video Games') # 定义x轴标签描述
    plt.show()

    参考:数据挖掘-各种分类算法的优缺点

  • 相关阅读:
    优化总结文章链接
    帧同步、状态同步
    ecs
    AStarPathFinding
    unity 热更方案对比
    C#数据类型
    JavaScript基础
    CSS中margin和padding的区别
    css选择器
    hadoop中使用shell判断HDFS文件是否存在
  • 原文地址:https://www.cnblogs.com/gwzz/p/13175601.html
Copyright © 2020-2023  润新知