• 自定义封装KNN算法


    一:首先定义一个类,定义构造函数,对训练集进行赋值,fit和predict过程。

    #自己手写KNN算法进行封装
    
    import numpy as np
    from math import sqrt
    from collections import Counter
    
    class KNN_classifier:
        #构造函数,指定key
        def __init__(self,k):
            self.k = k
            self._X_train = None
            self._Y_train = None
        #赋值
        def fit(self,x_train,y_train):
            self._X_train = x_train
            self._Y_train = y_train
            return self
        #predict
        def predict(self,x):
            assert self._X_train is not  None
            assert self._Y_train is not None
            assert self._X_train.shape[0]==self._Y_train.shape[0] ,'列数不一样'
            assert self._X_train.shape[1]==x.shape[0]
    
            arr = [i for i in x]
            pre = np.array(arr)
    
            prediction = self._prediction(pre)
    
            return prediction
    
        def _prediction(self,x):
            distice = [sqrt(np.sum(i-x)**2) for i in self._X_train]
    
            nearest = np.argsort(distice)
    
            topK = [self._Y_train[i] for i in nearest[:self.k]]
    
            final = Counter(topK)
    
            pre = final.most_common(1)[0][0]
    
            return pre
    
    #测试
    if __name__=='__main__':
        raw_data_X = [[1.232422,1.22324],
                      [2.324232,1.3224],
                     [2.3435353,2.3232342],
                     [3.434353,3.434353],
                     [4.54546,3.54544],
                     [7.42422,6.764353],
                     [6.42224534,7.533232],
                     [8.435353,8.5433],
                     [9.423534,9.422224],
                     [8.544444,9.4564454]]
    
        raw_data_y=[0,0,0,0,0,1,1,1,1,1]
    
        x_train = np.array(raw_data_X)
        y_train = np.array(raw_data_y)
    
        x = np.array([7.5353343,8.53324232])
    
        knf = KNN_classifier(5)
    
        knf.fit(x_train,y_train)
    
        print(knf.predict(x))
    
  • 相关阅读:
    树莓派安装realvnc_server
    python中#!含义
    树莓派无显示屏连接wifi
    转载_fread函数详解
    树莓派3b+更改静态IP
    linux命令语法格式
    python-Arduino串口传输数据到电脑并保存至excel表格
    mysql的sql_mode合理设置
    Mysql 配置参数性能调优
    Kubernetes 部署 gitlab
  • 原文地址:https://www.cnblogs.com/lyr999736/p/10654202.html
Copyright © 2020-2023  润新知