• 手写数字识别


    四、 任务分析

      sklearn.neighbors模块实现了k-近邻算法。如图1所示。

    图1 sklearn.neighbors模块

      KNeighborsClassifier函数有8个参数。如图2所示。

    图2 参数

      n_neighbors: int,可选,默认值为5,最近的k个点

      weights: 权重,默认是uniform

      algorithm: 快速k近邻搜索算法,默认参数为auto

      leaf_size: int,可选,默认值为30,构造kd树和ball树的大小

      p: 整数,可选,默认值为2,距离度量公式

      metric: 用于距离度量

      metric_params: 距离公式的其他关键参数

      n_jobs: 并行处理设置

     ♥ 知识链接
    sklearn
    sklearn实现KNN的方法是sklearn.neighbors.KNeighborsClassifier。

    五、 任务实施

    步骤1、环境准备

      右击Ubuntu操作系统桌面,从弹出菜单中选择【Open in Terminal】命令 打开终端。

      通过【cd /home】切换到home目录下。【ls】查看该目录下的所有内容。如图3所示。

    图3 切换目录

      【mkdir KNN】在home目录下创建KNN文件夹。如图4所示。

    图4 创建文件夹

    步骤2、数据集

      【cd KNN】切换至KNN目录下,【cp -R /home/soft/digits/ /home/KNN/】将数据从/home/soft目录下拷贝至/home/KNN目录下,【cd digits】切换至数据目录下查看,该目录下分别放置训练数据集和测试数据集的文件夹。如图5所示。

    图5 拷贝数据集

      【cd testDigits】切换至测试数据集的文件夹中查看,所有的文本格式存储的数字文件命名格式为:数字的值_该数字的样本序号。训练数据集和测试数据集是一样的格式。可自行查看。如图6所示。

    图6 文本格式

      每一个文本中都包含32像素x32像素的数字。可通过【cat】命令查看任意文件。如图7所示。

    图7 文本内容

    步骤3、K-近邻算法

      【cd ../..】切换至KNN目录下,【vim kNN_digits.py】回车后创建并编辑名为kNN_digits的Python文件。如图8所示。

    图8 创建Python文件

      回车后进入编辑框内,按键盘【i】进入编辑状态,编译如下程序。如图9所示。

      【import numpy as np】导入numpy矩阵库

      【import operator】导入运算符模块

      【from os import listdir】导入os模块,操作文件夹

      【from sklearn.neighbors import KNeighborsClassifier as kNN】导入sklearn库

    图9 模块

      将32x32的二进制图像转换为1x1024向量。如图10所示。

    图10 向量转换

      创建手写数字分类测试函数hangwritingClassTest,分别得到训练集训练的矩阵及类别,通过sklear的KNN近邻算法训练模型。如图11所示。

    图11 训练模型

      在hangwritingClassTest内继续得到测试数据集的矩阵,通过neigh.predict对测试数据进行预测。最后通过出现的错误数除以总的个数得到错误率。如图12所示。

    图12 预测

      在main函数内调用hangwritingClassTest函数。如图13所示。

    图13 main方法

      编辑完毕后,按【esc】退出编辑状态,【:wq】保存并退出编辑框,【python kNN_digits.py】执行kNN_digits的Python文件,结果以实际预测为准。如图14所示。

    图14 运行Python文件

    步骤4、源码

     1 #coding:utf-8
     2 import numpy as np
     3 import operator
     4 from os import listdir
     5 from sklearn.neighbors import KNeighborsClassifier as kNN
     6 """
     7 将32x32的二进制图像转换为1x1024向量
     8 """
     9 def img2vector(filename):
    10     returnVect = np.zeros((1,1024)) #生成1x1024零向量
    11     fr = open(filename) #打开文件
    12     for i in range(32): #文本格式是32x32的,读取所有行
    13         lineStr = fr.readline() #读一行数据
    14         for j in range(32): #读取行中所有元素
    15             returnVect[0, 32*i+j] = int(lineStr[j]) #将所有的元素添加到returnVect中
    16     return returnVect #返回转换后的1x1024向量
    17 """
    18 手写数字分类测试
    19 """
    20 def handwritingClassTest():
    21     hwLabels = [] #训练集的Labels
    22     trainingFileList = listdir('digits/trainingDigits') #返回trainingDigits目录下的文件名
    23     m = len(trainingFileList) #返回文件夹下文件的个数
    24     trainingMat = np.zeros(((m,1024))) #初始化训练的矩阵
    25     for i in range(m): 
    26         fileNameStr = trainingFileList[i] #获得文件的名字
    27         classNumber = int(fileNameStr.split('_')[0]) #获得分类的数字
    28         hwLabels.append(classNumber) #将获得的类别添加到hwLabels中
    29         #将每一个文件的1x1024数据存储到trainingMat矩阵中
    30         trainingMat[i,:] = img2vector("digits/trainingDigits/%s" % (fileNameStr))
    31     neigh = kNN(n_neighbors=3,algorithm="auto") #构建kNN分类器
    32     neigh.fit(trainingMat,hwLabels) #训练模型
    33     testFileList = listdir("digits/testDigits") #返回TestDigits目录下的文件列表
    34     errorCount = 0.0 #错误检测计数
    35     mTest = len(testFileList) #测试数据的数量
    36     for i in range(mTest): #从文件中解析出测试集的类别并进行分类测试
    37         fileNameStr = testFileList[i] #获得文件名字
    38         classNumber = int(fileNameStr.split("_")[0]) #获得分类的数字
    39         vectorUnderTest = img2vector("digits/testDigits/%s" % (fileNameStr)) #获得测试集的1x1024向量,用于训练
    40         classifierResult = neigh.predict(vectorUnderTest) #获取预测结果
    41         print "分类返回结果为%d	真实结果为%d" % (classifierResult,classNumber)
    42         if (classifierResult != classNumber):
    43             errorCount += 1.0
    44     print "总共错了%d个数据
    错误率为%f%%" % (errorCount,errorCount/mTest * 100)
    45 if __name__ == '__main__':
    46     handwritingClassTest()
  • 相关阅读:
    Ubuntu 18.04.4 系统优化
    Ubuntu 18.04.4 LTS 更换国内系统源
    django 数据库迁移
    django2.0解决跨域问题
    python requests get请求传参
    python 常用排序方法
    python 电脑说话
    centos6.x配置虚拟主机名及域名hosts
    php 合并,拆分,追加,查找,删除数组教程
    PHP统计在线用户数量
  • 原文地址:https://www.cnblogs.com/yu-1104/p/9050564.html
Copyright © 2020-2023  润新知