对鸢尾花数据集进行分类并交叉验证
from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import GridSearchCV def kNN_iris_gscv(): """ 用kNN对鸢尾花进行分类,添加网格搜索和交叉验证 :return: """ #1.获取数据 iris=load_iris() #2.划分数据集 x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=1) #3.特征工程:标准化 transfer=StandardScaler() x_train=transfer.fit_transform(x_train) x_test=transfer.transform(x_test) #使用训练集的平均值和标准差 #4.模型训练 estimator=KNeighborsClassifier() #加入网格搜索和交叉验证 #参数准备 param_dict={"n_neighbors":[1,3,5,7,9,11]} estimator=GridSearchCV(estimator,param_grid=param_dict,cv=10) #对estimator预估器进行10折交叉验证 estimator.fit(x_train,y_train) #模型拟合 #5.模型评估 #方法1:比对真实值和预测值 y_predict=estimator.predict(x_test) print(y_predict) print("直接比对真实值和预测值: ",y_predict==y_test) #方法2:直接计算准确率 score=estimator.score(x_test,y_test) print("准确率为:",score) #最佳参数:best_params print("最佳参数: ",estimator.best_params_) #最佳结果:best_score_ print("最佳结果: ", estimator.best_score_) #最佳估计器:best_estimator_ print("最佳估计器: ", estimator.best_estimator_) #交叉验证结果:cv_results_ print("交叉验证结果: ", estimator.cv_results_) return None if __name__=="__main__": kNN_iris_gscv()