• 基于OpenCV的KNN算法实现手写数字识别


    基于OpenCV的KNN算法实现手写数字识别

    一、数据预处理

    # 导入所需模块
    import cv2
    import numpy as np
    import matplotlib.pyplot as plt
    
    # 显示灰度图
    def plt_show(img):
        plt.imshow(img,cmap='gray')
        plt.show()
    
    # 加载数据集图片数据
    digits = cv2.imread('./image/digits.png',0)
    print(digits.shape)
    plt_show(digits)
    
    (1000, 2000)
    

    # 划分数据
    cells = [np.hsplit(row,100) for row in np.vsplit(digits,50)] 
    
    
    len(cells)
    
    50
    
    # 转换为numpy数组
    x = np.array(cells)
    
    x.shape
    
    (50, 100, 20, 20)
    
    plt_show(x[5][0])
    

    # 生成训练数据标签和测试数据标签
    k = np.arange(10)
    train_label = np.repeat(k,250)
    test_label = train_label.copy()
    
    # 图片数据转换为特征矩阵,划分训练数据集
    train = x[:,:50].reshape(-1,400).astype(np.float32)
    
    # 图片数据转换为特征矩阵,划分测试数据集
    test = x[:,50:100].reshape(-1,400).astype(np.float32)
    test.shape
    
    (2500, 400)
    

    二、knn算法预测

    # 生成模型
    knn = cv2.ml.KNearest_create()
    
    # 训练数据
    knn.train(train,cv2.ml.ROW_SAMPLE,train_label)
    
    True
    
    # 传入n值,和测试数据,返回结果
    ret,result,neighbours,dist = knn.findNearest(test, 3)
    
    # 统计正确的个数
    res = 0
    for i in range(2500):
        if result[i]==test_label[i]:
            res = res+1
    res
    
    2439
    
    # 计算模型准确率
    accuracy = res/result.size
    print('识别测试数据的准确率为:',accuracy)
    
    识别测试数据的准确率为: 0.9756
    

    三、导入图片预测

    # 在测试集中随便找一张图片
    test_image = test[2400].reshape(20,20)
    plt_show(test_image)
    test_label[2400]
    

    # 将图片转换为特征矩阵
    testImage = test[2400].reshape(-1,400).astype(np.float32)
    testImage.shape
    
    (1, 400)
    
    # 使用训练好的模型预测
    ret,result,neighbours,dist = knn.findNearest(testImage, 3)
    
    # 预测结果
    print('识别出的数字为:',result[0][0])
    
    识别出的数字为: 9.0
    
    # 传入一张自己找的图片进行识别尺寸(20*20)
    te = cv2.imread('test2.jpg',0)
    plt_show(te)
    te.shape
    

    (20, 20)

    testImage = te.reshape(-1,400).astype(np.float32)
    testImage.shape
    
    (1, 400)
    
    ret,result,neighbours,dist = knn.findNearest(testImage, 3)
    result
    
    array([[2.]], dtype=float32)
    
    print('识别出的数字为:',result[0][0])
    
    识别出的数字为: 2.0
    

    用自己写的一张图片预测

    # 用所有数据作为训练数据
    knn = cv2.ml.KNearest_create()
    k = np.arange(10)
    labels = np.repeat(k,500)
    knn.train(x.reshape(-1,400).astype(np.float32),cv2.ml.ROW_SAMPLE,labels)
    
    True
    
    te = cv2.imread('test1.jpg',0)
    plt_show(te)
    te.shape
    

    (20, 20)

    # 自适应阈值处理
    ret, image = cv2.threshold(te, 0, 255, cv2.THRESH_OTSU | cv2.THRESH_BINARY_INV)
    plt_show(image)
    

    # 将图片转换为特征矩阵
    testImage = image.reshape(-1,400).astype(np.float32)
    testImage.shape
    
    (1, 400)
    
    # 使用训练好的模型预测
    ret,result,neighbours,dist = knn.findNearest(testImage, 3)
    
    neighbours
    
    array([[5., 5., 5.]], dtype=float32)
    
    print('识别出的数字为:',result[0][0])
    
    识别出的数字为: 5.0
    

    资源地址:

    链接:https://pan.baidu.com/s/1sUgKBvex43-Yf-Ul2DQSIA
    提取码:t1sd

    视频地址:https://www.bilibili.com/video/BV14A411t7tk/

  • 相关阅读:
    最长上升子序列
    system call filters failed to install; check the logs and fix your configuration or disable system c
    linux centos 7 安装vnc远程服务
    Delphi XE 错误提示: [MySQL]-314. Cannot load vendor library [libmysql.dll orlibmysqlld.dll]
    MYSQL 修改密码的几种方式
    MySQL 常用操作和字段类型
    Java 获取GUID
    C# 获取GUID
    C++ 获取GUID
    Delphi GUID[2] 获取GUID值的方式
  • 原文地址:https://www.cnblogs.com/zq98/p/12844612.html
Copyright © 2020-2023  润新知