• k-近邻算法2(kNN)手写识别系统


    这里构造的系统只能识别数字0-9

    目录trainingDigits中包含了1934个文件

    目录testDigits中包含了946个文件

    文件形式

    (1)准备数据:将图像转换为测试向量

     

    # 将图像格式化处理为一个向量
    def img2vector(filename):
        returnVect = zeros((1, 1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32 * i + j] = int(lineStr[j])
        return returnVect

     

    (2)测试算法:使用k-近邻算法识别手写数字

    def handwritingClassTest():
        hwLabels = []
        # 获取目录内容
        trainingFileList = listdir('digits/trainingDigits')
        m = len(trainingFileList)
        trainingMat = zeros((m, 1024))
        for i in range(m):
            # 从文件名解析分类数字
            # 1、获取文件名
            fileNameStr = trainingFileList[i]
            # 2、去掉文件后缀
            fileStr = fileNameStr.split('.')[0]
            # 3、这个文件内的图像所表示的数字,即分类
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
        testFileList = listdir('digits/testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
            classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came back with: %d, the real answer is:%d" % (classifierResult, classNumStr))
            if (classifierResult != classNumStr):
                errorCount += 1.0
    
        print("
    the total number of errors is:%d" % errorCount)
        print("
    the total error rate is:%f" % (errorCount / float(mTest)))
    def classify0(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]
        # 1距离计算
        diffMat = tile(inX, (dataSetSize, 1)) - dataSet
        sqDiffMat = diffMat ** 2
        sqDistances = sqDiffMat.sum(axis=1)
        distances = sqDistances ** 0.5
        sortedDistIndicies = distances.argsort()
        # 2选择距离最小的k个点
        classCount = {}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
        # 3排序
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]

    上篇+这篇的全部代码

    #!usr/bin/env python3
    # -*-coding:utf-8 -*-
    
    from numpy import *
    import operator
    import matplotlib
    import matplotlib.pyplot as plt
    
    
    def createDataSet():
        group = array([[1.0, 1.1], [1.0, 1.], [0, 0.], [0, 0.1]])
        labels = ['A', 'A', 'B', 'B']
        return group, labels
    
    
    def classify0(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]
        # 1距离计算
        diffMat = tile(inX, (dataSetSize, 1)) - dataSet
        sqDiffMat = diffMat ** 2
        sqDistances = sqDiffMat.sum(axis=1)
        distances = sqDistances ** 0.5
        sortedDistIndicies = distances.argsort()
        # 2选择距离最小的k个点
        classCount = {}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
        # 3排序
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]
    
    
    def file2matrix(filename):
        fr = open(filename)
        arrayOLines = fr.readlines()
        # 得到文件行数
        numberOfLines = len(arrayOLines)
        # 创建以0填充的矩阵
        returnMat = zeros((numberOfLines, 3))
        classLabelVector = []
        index = 0
        for line in arrayOLines:
            # 截取掉所有回车字符
            line = line.strip()
            # 将整行数据分割成一个元素列表
            listFromLine = line.split('	')
            returnMat[index, :] = listFromLine[0:3]
            classLabelVector.append(int(listFromLine[-1]))
            index += 1
        return returnMat, classLabelVector
    
    
    # newValue=(oldValue-min)/(max-min)
    def autoNorm(dataSet):
        # 参数0使得函数可以从列中选取最小值
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        ranges = maxVals - minVals
        normDataSet = zeros(shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - tile(minVals, (m, 1))
        normDataSet = normDataSet / tile(ranges, (m, 1))
        return normDataSet, ranges, minVals
    
    
    def datingClassTest():
        # 测试数据所占的比例
        hoRatio = 0.1
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDataMat)
        # 矩阵第一维大小
        m = normMat.shape[0]
        # 测试向量的数量
        numTestVecs = int(m * hoRatio)
        errorCount = 0.0
        for i in range(numTestVecs):
            classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
            print("the classifier came back with: %d,the real answer is: %d" % (classifierResult, datingLabels[i]))
            if classifierResult != datingLabels[i]:
                errorCount += 1.0
        print("the total error rate is: %f" % (errorCount / float(numTestVecs)))
    
    
    def classifyPerson():
        resultList = ['not at all', 'in small doses', 'in large doses']
        percentTats = float(input("percentage of time spent playing video games?"))
        ffMiles = float(input("frequent flier miles earned per year?"))
        iceCream = float(input("liters of ice cream consumed per year?"))
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDataMat)
        inArr = array([ffMiles, percentTats, iceCream])
        classifierResult = classify0(inArr, datingDataMat, datingLabels, 3)
        print("You will probably like this person: ", resultList[classifierResult - 1])
    
    
    def draw():
        fig = plt.figure()  # figure创建一个绘图对象
        ax = fig.add_subplot(111)  # 若参数为349,意思是:将画布分割成3行4列,图像画在从左到右从上到下的第9块,
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    
        '''
        matplotlib.pyplot.scatter(x, y, s=20, c='b', marker='o', cmap=None,
        norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, hold=None,**kwargs)
        其中,xy是点的坐标,s是点的大小
        maker是形状可以maker=(5,1)5表示形状是5边型,1表示是星型(0表示多边形,2放射型,3圆形)
        alpha表示透明度;facecolor=‘none’表示不填充。
        '''
    
        type1_x = []
        type1_y = []
        type2_x = []
        type2_y = []
        type3_x = []
        type3_y = []
        for i in range(len(datingLabels)):
            if datingLabels[i] == 1:  # 不喜欢
                type1_x.append(datingDataMat[i][0])
                type1_y.append(datingDataMat[i][1])
    
            if datingLabels[i] == 2:  # 魅力一般
                type2_x.append(datingDataMat[i][0])
                type2_y.append(datingDataMat[i][1])
    
            if datingLabels[i] == 3:  # 极具魅力
                type3_x.append(datingDataMat[i][0])
                type3_y.append(datingDataMat[i][1])
    
        type1 = ax.scatter(type1_x, type1_y, s=20, c='red')
        type2 = ax.scatter(type2_x, type2_y, s=30, c='green')
        type3 = ax.scatter(type3_x, type3_y, s=40, c='blue')
    
        # ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2])
        # 设置字体防止中文乱码
        zhfont = matplotlib.font_manager.FontProperties(fname='C:WindowsFontsSTXINGKA.TTF')
        plt.xlabel('每年获取的飞行常客里程数', fontproperties=zhfont)
        plt.ylabel('玩视频游戏所耗时间百分比', fontproperties=zhfont)
        # ax.scatter(datingDataMat[:, 0], datingDataMat[:, 1],
        # 15.0 * array(datingLabels), 15.0 * array(datingLabels))
        ax.legend((type1, type2, type3), (u'不喜欢', u'魅力一般', u'极具魅力'), loc=2, prop=zhfont)
        plt.show()
    
    
    from os import listdir
    
    
    # 将图像格式化处理为一个向量
    def img2vector(filename):
        returnVect = zeros((1, 1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32 * i + j] = int(lineStr[j])
        return returnVect
    
    
    def handwritingClassTest():
        hwLabels = []
        # 获取目录内容
        trainingFileList = listdir('digits/trainingDigits')
        m = len(trainingFileList)
        trainingMat = zeros((m, 1024))
        for i in range(m):
            # 从文件名解析分类数字
            # 1、获取文件名
            fileNameStr = trainingFileList[i]
            # 2、去掉文件后缀
            fileStr = fileNameStr.split('.')[0]
            # 3、这个文件内的图像所表示的数字,即分类
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
        testFileList = listdir('digits/testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
            classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came back with: %d, the real answer is:%d" % (classifierResult, classNumStr))
            if (classifierResult != classNumStr):
                errorCount += 1.0
    
        print("
    the total number of errors is:%d" % errorCount)
        print("
    the total error rate is:%f" % (errorCount / float(mTest)))
    
    
    if __name__ == "__main__":
        handwritingClassTest()
        # classifyPerson()
    kNN.py
    handwritingClassTest()运行结果,只截取了最后的部分

    实际使用这个算法时,算法的执行效率不高。因为算法需要为每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计执行900次。

    此外,还要为测试向量准备2MB的存储空间。(k决策树是k-近邻算法的优化版,可以节省大量的计算开销)

    小结

    k-近邻算法时分类数据最简单最有效的算法。是基于实例的学习,使用算法时我们必须有接近实际数据的训练样本数据。

    此算法必须保存全部数据集,如果训练集很大,必须使用大量的存储空间。此外,由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。

    它无法给出任何数据的基础结构信息,也无法知晓平均实例样本和典型实例样本具有什么特征。

  • 相关阅读:
    Jpa 一对多级联查询 排序设置
    Spring Data Jpa Specification 调用Oracle 函数/方法
    Spring boot 集成 阿里 Mqtt
    将Jquery序列化后的表单值转换成Json
    Linux安装和卸载MySQL5.7
    NoNodeAvailableException[None of the configured nodes are available: [{#transport#-1}{3bFuKD5MTOWOCfJ1ZFrfdw}{192.168.0.105}{192.168.0.105:9301}]]
    Docker下安装RabbitMQ
    JAVA数据结构与算法-稀疏数组
    第一篇博客
    测试用例编写方法:边界值分析方法
  • 原文地址:https://www.cnblogs.com/wangkaipeng/p/7879783.html
Copyright © 2020-2023  润新知