• 机器学习——k近邻算法(KNN)


    import math
    import numpy as np
    from collections import Counter
    class KNNClassfiy(object):
        def __init__(self,k):
        #判断k有效
            assert k>=1,'k must be valid'
            self.k=k
            self._xTrain=None
            self._yTrain=None
    
    
        def fit(self,xTrain,yTrain):
        #判断输入的训练集有效
            assert xTrain.shape[0]==yTrain.shape[0],
                'The size of xTrain must be equals to the size of yTrain'
        #判断K有效   
            assert self.k<=xTrain.shape[0],
                'The size of xTrain must be least at k'
            self._xTrain=xTrain
            self._yTrain=yTrain
            return self
    
        def predict(self,X_predict):
            # X_predict是预测数据数组,判断预测数据合法性,必须是二维数组
            assert X_predict.shape[1]==self._xTrain.shape[1],
                'The feature of x must be equal to xTrain'
            assert self._xTrain is not None and self._yTrain is not None,
                'must fit before predict'
            y_predict=[self._predict(x) for  x in X_predict]
            return np.array(y_predict)
    
        def _predict(self,x):
            distances=[math.sqrt(np.sum((xTrain-x)**2)) for xTrain in self._xTrain]
            nearest=np.argsort(distances)
            top_y=[self._yTrain[i] for i in nearest[:self.k]]
            votes=Counter(top_y)
            print(votes.most_common(1))
            return votes.most_common(1)[0][0]
        def __repr__(self):
            return self.k
    
    KNN_clf=KNNClassfiy(k=6);
    #先训练后预测
    xTrain=np.array([[4.5,3.2],
                     [5.8,4.1],
                     [6.7,5.3],
                     [8.6,7.1],
                     [3.8,2.5],
                     [5.3,4.4],
                     [9.4,8.6],
                     [11.8,9.4],
                     [3.8,3.2],
                     [12.8,10.1]])
    yTrain=np.array([0,0,1,1,0,0,1,1,0,1])
    KNN_clf.fit(xTrain=xTrain,yTrain=yTrain)
    x_predict=np.array([[6.9,5.7],[3.4,2.8]])
    a=KNN_clf.predict(x_predict)
    print(a[0],a[1])
    

    代码比较简单,主要逻辑在于预测部分。

    调用matplotlib绘制图形分布图

    在这里插入图片描述

    步骤可简化如下:

    • 确定k值
    • 训练数据集
    • 预测函数

    K近邻算法主要解决分类问题,是机器学习中最简单的最基础的一种算法。

  • 相关阅读:
    第一行DOCTYPE 的作用
    es6 proxy、handler.get()
    vue router-link 默认a标签去除下划线
    打开记事本
    JS数组遍历的方法
    vue项目中使用proxy解决跨域
    封装axios
    postMessage vue iframe传值
    input限制只能输入数字,且保留小数后两位
    axios封装
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13309462.html
Copyright © 2020-2023  润新知