• 机器学习分类算法之KNN算法


    KNN算法为按距离进行分类的,对于已知的分类,根据欧式距离,最靠近那个分类就被预测为那个分类。

    本文只是简单展示一下实现代码,具体的特征和分类,还得自己根据实际场景去调整。

    在开始之前注意看看导入的包是否都存在,如不存在的化,请先安装相应的包

    # -*- coding:utf-8 -*-
    import numpy as np
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.model_selection import cross_val_score
    import matplotlib.pyplot as plt
    import joblib
    import pandas as pd
    
    #加载iris数据集
    iris = datasets.load_iris()
    iris_X = iris.data
    iris_y = iris.target
    #print(iris_X[:4,:])
    
    #数据分割
    X_train,X_test,y_train,y_test = train_test_split(iris_X,iris_y,test_size=1/3,random_state=3)
    
    #==========交叉验证============================
    #cv_scores = []
    #k_range = range(1,31)
    
    '''
    #此处为交叉验证,看KNN的k取什么值的时候效果最好
    for n in k_range:
        knn = KNeighborsClassifier(n_neighbors=n)
        scores = cross_val_score(knn,X_train,y_train,cv=10,scoring='accuracy')
        cv_scores.append(scores.mean())
    plt.plot(k_range,cv_scores)
    plt.xlabel('K')
    plt.ylabel('Accuracy')
    plt.show()
    '''
    
    #模型训练
    '''
    best_knn = KNeighborsClassifier(n_neighbors=3)    # 选择最优的K=3传入模型
    best_knn.fit(X_train,y_train)            #训练模型
    print(best_knn.score(X_test,y_test))    #看看评分
    
    #模型本地保存
    joblib.dump(best_knn, 'D:/Users/wangkangren729/PycharmProjects/iris/model/best_knn.pkl',compress=3)
    #load model
    '''
    bknn = joblib.load('D:/Users/wangkangren729/PycharmProjects/iris/model/best_knn.pkl')
    
    #读取本地新数据
    data = pd.read_csv('predict.data')
    #print(data.head(5))
    
    attributes=data[['sl','sw','pl','pw']]  #前四列属性简化为sl,sw,pl,pw
    types=data['type'] #第5列属性为鸢尾花的类别
    
    #print(type(attributes))
    #data_frame = attributes.loc[0,:].to_frame()
    
    #print(attributes)
    #print(type(attributes[i]))    
    #预测新数据
    print(bknn.predict(attributes))
    #print(type([[4.1, 2.2, 2.3, 5.4]]))
    #print([[4.1, 2.2, 2.3, 5.4]])
    #print(bknn.predict([[4.1, 2.2, 2.3, 5.4]]))
    #print(types)
    #print(bknn.predict(attributes))
        
  • 相关阅读:
    jQuery Asynchronous
    Best Pratices——Make the Web Faster
    Asynchronous
    Deferred
    w3 protocol
    Android 设置wifi共享电脑服务器资源
    VC++ 6.0创建MFC工程时的初级备要点(二)
    【LeetCode】Pascal's Triangle II (杨辉三角)
    POJ 1564 Sum It Up(DFS)
    CSS写表格
  • 原文地址:https://www.cnblogs.com/xinyumuhe/p/12605255.html
Copyright © 2020-2023  润新知