• k邻近算法(KNN)实例


    一 k近邻算法原理

    k近邻算法是一种基本分类和回归方法.

    原理:K近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的K个实例,这K个实例的
    多数属于某个类,就把该输入实例分类到这个类中。

    如上图所示,有两类不同的样本数据,分别用蓝色的小正方形和红色的小三角形表示,而图正中间的那个绿色的圆所标示的数据则是待分类的数据。这也就是我们的目的,来了一个新的数据点,我要得到它的类别是什么?好的,下面我们根据k近邻的思想来给绿色圆点进行分类。

    • 如果K=3,绿色圆点的最邻近的3个点是2个红色小三角形和1个蓝色小正方形,少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于红色的三角形一类。
    • 如果K=5,绿色圆点的最邻近的5个邻居是2个红色三角形和3个蓝色的正方形,还是少数从属于多数,基于统计的方法,判定绿色的这个待分类点属于蓝色的正方形一类。

    参考一文搞懂k近邻(k-NN)算法(一) https://zhuanlan.zhihu.com/p/25994179

    二  特点

    优点:精度高(计算距离)、对异常值不敏感(单纯根据距离进行分类,会忽略特殊情况)、无数据输入假定
      (不会对数据预先进行判定)。 缺点:时间复杂度高、空间复杂度高。 适用数据范围:数值型和标称型。

    三 欧氏距离(Euclidean Distance)

    欧氏距离是最常见的距离度量,衡量的是多维空间中各个点之间的绝对距离。公式如下:

    四  sklearn库中使用k邻近算法

    • 分类问题:from sklearn.neighbors import KNeighborsClassifier
    • 回归问题:from sklearn.neighbors import KNeighborsRegressor

    五 使用sklearn的K邻近简单实例

    1 数据蓝蝴蝶

    #导包
    import numpy as np
    import pandas  as pd
    from pandas import DataFrame,Series
    from sklearn.neighbors import KNeighborsClassifier #k邻近算法模型
    
    #使用datasets创建数据
    import sklearn.datasets as datasets
    iris = datasets.load_iris()
    
    feature = iris['data']
    target = iris['target']
    
    #将样本打乱,符合真实情况
    
    np.random.seed(1)
    np.random.shuffle(feature)
    np.random.seed(1)
    np.random.shuffle(target)
    
    #训练数据
    x_train = feature[:140]
    y_train = target[:140]
    #测试数据
    x_test = feature[-10:]
    y_test =target[-10:]
    
    #实例化模型对象&训练模型
    knn = KNeighborsClassifier(n_neighbors=10)
    knn.fit(x_train,y_train)
    knn.score(x_train,y_train)
    
    print('预测分类:',knn.predict(x_test))
    print('真实分类:',y_test)

    2 根据身高、体重、鞋子尺码,预测性别

    #导包
    import numpy as np
    import pandas  as pd
    from pandas import DataFrame,Series
    
    #手动创建训练数据集
    feature = np.array([[170,65,41],[166,55,38],[177,80,39],[179,80,43],[170,60,40],[170,60,38]])
    target = np.array(['','','','','',''])
    
    from sklearn.neighbors import KNeighborsClassifier #k邻近算法模型
    
    #实例k邻近模型,指定k值=3
    knn = KNeighborsClassifier(n_neighbors=3)
    
    #训练数据
    knn.fit(feature,target)
    
    #模型评分
    knn.score(feature,target)
    
    #预测
    knn.predict(np.array([[176,71,38]]))

     3 手写数字识别

    • 导包
    import numpy as np 
    import pandas as pd
    from pandas import DataFrame,Series
    import matplotlib.pyplot as plt
    
    from sklearn.neighbors import KNeighborsClassifier
    • 查看单一图片特征
    img=plt.imread('data/0/0_2.bmp')
    plt.imshow(img)

    • 提炼样本数据
    feature=[]
    target=[]
    for i in range(10):
        for j in range(500):
            img_arr=plt.imread(f'data/{i}/{i}_{j+1}.bmp')
            feature.append(img_arr)
            target.append(i)
    
    #构建特征数据格式
    feature=np.array(feature)
    target=np.array(target)
    
    feature.shape #(5000, 28, 28)
    
    #输入数据必须是二维数组,必须对feature降维
    #(1)降维方式一:mean() (2)降维方式二:reshape()
    feature=feature.reshape(5000,28*28)
    
    #将样本打乱 (必须使用多个seed)
    np.random.seed(5)
    np.random.shuffle(feature)
    np.random.seed(5)
    np.random.shuffle(target)
    
    #数据分割为训练数据和测试数据
    x_train=feature[:4950]
    y_train=target[:4950]
    x_test=feature[-50:]
    y_test=target[-50:]
    • KNN模型建立和评分
    #训练模型
    knn.fit(x_train,y_train)
    
    #评分
    knn.score(x_train,y_train)
    
    #预测
    # knn.predict(x_test)
    • 真实预测手写数字图片的一般流程
    # 读取图片数据
    num_img_arr=plt.imread('../../数字.jpg')
    plt.imshow(num_img_arr)

    #图片截取数字5
    five_arr=num_img_arr[90:158,80:132]
    plt.imshow(five_arr)

    #降维操作(five数组是三维的,需要进行降维,舍弃第三个表示颜色的维度)
    print(five_arr.shape) #(65, 56, 3)
    five=five_arr.mean(axis=2)
    print(five.shape) #(65, 56)
    plt.imshow(five)

    # 图片压缩为像素28*28
    import scipy.ndimage as ndimage
    five = ndimage.zoom(five,zoom = (28/68,28/52))
    five.shape #(28, 28)
    
    # 压缩后的5的显示
    plt.imshow(five)

    # 把数据降维为feature 数据格式
    five.reshape(1,28*28)
    #预测
    knn.predict(five.reshape(1,28*28))

    下载源数据和代码:https://github.com/angleboygo/data_ansys

     

  • 相关阅读:
    APP高级抓包
    Linux使用日志
    ffmpeg使用记录
    win7 远程连接ubuntu
    v-2-r-a-y使用
    adb
    golang mysql 模糊查询
    交互式批量删除指定目录下指定类型文件
    golang打包和部署到centos7
    Nginx unknown directive ""
  • 原文地址:https://www.cnblogs.com/angle6-liu/p/10416736.html
Copyright © 2020-2023  润新知