• 【机器学习实战】第二章k-近邻算法的完整代码


    《机器学习实战》第二章--k近邻算法的完整代码如下:

      1 from numpy import *
      2 import operator
      3 import matplotlib
      4 import matplotlib.pyplot as plt
      5 from os import listdir
      6 
      7 
      8 def createDataSet():
      9     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
     10     labels = ['A', 'A', 'B', 'B']
     11     return group, labels
     12 
     13 
     14 def classify0(inX, dataSet, labels, k):
     15     '''
     16     k-近邻算法
     17     :param inX:用于分类的输入向量
     18     :param dataSet: 输入的训练样本集
     19     :param labels: 标签向量
     20     :param k: 用于选择最近邻居的数目
     21     :return: 返回k个邻居中距离最近且数量最多的类别作为预测类别
     22     '''
     23     dataSetSize = dataSet.shape[0]
     24     diffMat = tile(inX, (dataSetSize, 1)) - dataSet
     25     sqDiffMat = diffMat ** 2
     26     sqDistances = sqDiffMat.sum(axis=1)
     27     distances = sqDistances ** 0.5
     28     # 以上为计算输入向量与已有标签样本的欧式距离
     29     sortedDistIndicies = distances.argsort()  # argsort函数返回的是数组值从小到大的索引值,距离需要从小到大排序
     30     classCount = {}
     31     for i in range(k):
     32         voteIlabel = labels[sortedDistIndicies[i]]
     33         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     34         # Python 字典(Dictionary) get() 函数返回指定键的值,如果值不在字典中返回默认值。
     35         # get(voteIlabel,0)表示当能查询到相匹配的字典时,就会显示相应key对应的value,如果不能的话,就会显示后面的这个参数。
     36     sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
     37     # 按照元祖中第2个值的大小降序排序
     38     # python2中的iteritems()方法需改为items()
     39     return sortedClassCount[0][0]
     40 
     41 
     42 def file2matrix(filename):
     43     # 将文本记录转换为NumPy的解析程序
     44     fr = open(filename)
     45     arrayOLines = fr.readlines()
     46     numberOfLines = len(arrayOLines)
     47     print(numberOfLines)
     48     returnMat = zeros((numberOfLines,3))  # 存放3种特征
     49     classLabelVector = []  # 存放标签
     50     index = 0
     51     for line in arrayOLines:
     52         line = line.strip()  # strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。
     53         listFromLine = line.split('	')  # split() 通过指定分隔符对字符串进行切片,组成列表
     54         returnMat[index, :] = listFromLine[0:3]  # 将当前列表的前3个值赋予returnMat的当前行
     55         classLabelVector.append(int(listFromLine[-1]))  # 将标签添加到classLabelVector中
     56         index += 1
     57     return returnMat, classLabelVector
     58 
     59 
     60 def autoNum(dataSet):
     61     minVals = dataSet.min(0)  # A.min(0) : 返回A每一列最小值组成的一维数组;
     62     maxVals = dataSet.max(0)  # A.max(0):返回A每一列最大值组成的一维数组;
     63     # https://blog.csdn.net/qq_41800366/article/details/86313052
     64     ranges = maxVals - minVals
     65     normDataSet = zeros(shape(dataSet))
     66     m = dataSet.shape[0]
     67     normDataSet = dataSet - tile(minVals, (m,1))
     68     # tile将minVals的行数乘以m次重复,列数乘以1次重复,每一行都减掉minVals
     69     normDataSet = normDataSet/tile(ranges,(m,1))
     70     # 每一行都除以ranges以是实现数据归一化
     71     return normDataSet,ranges, minVals
     72 
     73 
     74 def datingClassTest():
     75     hoRatio = 0.10  # 测试集比重
     76     m = normMat.shape[0]
     77     numTestVecs = int(m*hoRatio)  # 测试集数量
     78     errorCount = 0.0
     79     for i in range(numTestVecs):
     80         classfierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
     81         print("the classifierResult came back with: %d,the real answer is: %d"%(classfierResult,datingLabels[i]))
     82         if(classfierResult != datingLabels[i]):errorCount += 1.0
     83         print("the total error rate is: %f"%(errorCount/float(numTestVecs)))
     84 
     85 def classifyPerson():
     86     resultList = ['not at all', 'in small doses', 'in large doses']
     87     percentTats = float(input("percentage of time spent playing video games?"))
     88     # 在 Python3.x 中 raw_input( ) 和 input( ) 进行了整合,去除了 raw_input( ),仅保留了 input( ) 函数,
     89     # 其接收任意任性输入,将所有输入默认为字符串处理,并返回字符串类型。
     90     ffMiles = float(input("frequent flier miles earned per year?"))
     91     iceCream = float(input("liters of ice cream consumed per year?"))
     92     inArr = array([ffMiles, percentTats, iceCream])  # 输入测试向量
     93     classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)  # 得到分类结果
     94     print("You will probably like this person:",resultList[classifierResult-1])
     95 
     96 
     97 def img2vector(filename):
     98     returnVect = zeros((1, 1024))  # 创建1行1024列的零数组
     99     fr = open(filename)
    100     for i in range(32):
    101         lineStr = fr.readline()
    102         for j in range(32):
    103             returnVect[0, 32*i+j] = int(lineStr[j])
    104     return returnVect
    105 
    106 
    107 def handwritingClassTest():
    108     hwLabels = []  # 存放手写数字的类别
    109     trainingFileList = listdir('../machinelearninginaction/Ch02/digits/trainingDigits')
    110     m = len(trainingFileList)
    111     trainingMat = zeros((m, 1024))  # 创建m行1024列的零数组
    112     for i in range(m):
    113         fileNameStr = trainingFileList[i]
    114         fileStr = fileNameStr.split('.')[0]  # 文件名字
    115         classNumStr = int(fileStr.split('_')[0])  # 数字类别
    116         hwLabels.append(classNumStr)
    117         trainingMat[i,:]= img2vector('../machinelearninginaction/Ch02/digits/trainingDigits/%s'%fileNameStr)
    118     testFileList = listdir('../machinelearninginaction/Ch02/digits/testDigits')
    119     errorCount = 0.0
    120     mTest = len(testFileList)
    121     for i in range(mTest):
    122         fileNameStr = testFileList[i]
    123         fileStr = fileNameStr.split('.')[0]
    124         classNumStr = int(fileStr.split('_')[0])
    125         vectorUnderTest = img2vector('../machinelearninginaction/Ch02/digits/testDigits/%s'%fileNameStr)
    126         classifierResult = classify0(vectorUnderTest, trainingMat,hwLabels,3)  # 对测试数据集进行knn计算
    127         print('the classifier came back with: %d, the real answer is :%d'%(classifierResult,classNumStr))
    128         if(classifierResult != classNumStr):errorCount += 1.0
    129     print('
    the total number of error is: %d'%errorCount)
    130     print('
    the total error rate is: %f'%(errorCount/float(mTest)))
    131 
    132 
    133 if __name__ == "__main__":
    134     '''
    135     group,labels = createDataSet()
    136     result = classify0([0,0],group,labels,3)
    137     print(result)
    138     '''
    139     datingDataMat, datingLabels = file2matrix("./datingTestSet2.txt")  # 数据转换
    140     fig = plt.figure()
    141     ax = fig.add_subplot()
    142     ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
    143     plt.show()
    144     normMat, ranges, minVals = autoNum(datingDataMat)  # 输入数据归一化
    145     # datingClassTest()
    146     # classifyPerson()  # 分类
    147     # testVector = img2vector("../machinelearninginaction/Ch02/digits/trainingDigits/0_1.txt")
    148     # print(testVector[0,0:31])
    149     handwritingClassTest()

    手写数字识别的运行结果如下:

    the total number of error is: 12

    the total error rate is: 0.012685
  • 相关阅读:
    二:dot语言语法及使用
    一:安装graphviz
    一个程序的前世今生(四)——延迟绑定和GOT与PLT
    一个程序的前世今生(三)——动态链接库和静态链接库
    一个程序的前世今生(二)——可执行文件如何加载进内存
    更新mysql驱动5.1-47 Generated keys not requested. You need to specify Statement.RETURN_GENERATED_KEY
    The superclass javax servlet http HttpServlet was not found on the Java Build Path
    vue-router地址栏URL全局参数拼接
    Canvas签字画图板
    Vue 表单拖拽排序
  • 原文地址:https://www.cnblogs.com/DJames23/p/13062429.html
Copyright © 2020-2023  润新知