• sklearn实现kNN


    对鸢尾花数据集进行分类并交叉验证

    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.model_selection import GridSearchCV
    def kNN_iris_gscv():
        """
        用kNN对鸢尾花进行分类,添加网格搜索和交叉验证
        :return:
     """
        #1.获取数据
        iris=load_iris()
        #2.划分数据集
        x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=1)
        #3.特征工程:标准化
        transfer=StandardScaler()
        x_train=transfer.fit_transform(x_train)
        x_test=transfer.transform(x_test) #使用训练集的平均值和标准差
        #4.模型训练
        estimator=KNeighborsClassifier()
        #加入网格搜索和交叉验证
        #参数准备
        param_dict={"n_neighbors":[1,3,5,7,9,11]}
        estimator=GridSearchCV(estimator,param_grid=param_dict,cv=10) #对estimator预估器进行10折交叉验证
        estimator.fit(x_train,y_train) #模型拟合
        #5.模型评估
        #方法1:比对真实值和预测值
        y_predict=estimator.predict(x_test)
        print(y_predict)
        print("直接比对真实值和预测值:
    ",y_predict==y_test)
        #方法2:直接计算准确率
        score=estimator.score(x_test,y_test)
        print("准确率为:",score)
    
        #最佳参数:best_params
        print("最佳参数:
    ",estimator.best_params_)
        #最佳结果:best_score_
        print("最佳结果:
    ", estimator.best_score_)
        #最佳估计器:best_estimator_
        print("最佳估计器:
    ", estimator.best_estimator_)
        #交叉验证结果:cv_results_
        print("交叉验证结果:
    ", estimator.cv_results_)
        return None
    
    
    if __name__=="__main__":
        kNN_iris_gscv()
  • 相关阅读:
    centos通过yum安装mongodb
    js基于另一个数组排序数组
    centos 7 安装emule客户端
    typescript中interface和type的区别
    nodejs安装管理工具nvm的安装和使用
    PM2的参数配置
    centOS添加ipv6支持(仅限已分配ipv6地址和网关)
    linux执行计划任务at命令
    mysql中获取本月第一天、本月最后一天、上月第一天、上月最后一天等等
    win10子系统ubuntu内的nginx启动问题
  • 原文地址:https://www.cnblogs.com/sclu/p/11759730.html
Copyright © 2020-2023  润新知