K-近邻算法API
- sklearn.neighbors.KNeighborsClassifier(n_neighbors=5,algorithm='auto')
- n_neighbors:int,可选(默认= 5),k_neighbors查询默认使用的邻居数
- algorithm:{‘auto’,‘ball_tree’,‘kd_tree’,‘brute’},可选用于计算最近邻居的算法:‘ball_tree’将会使用 BallTree,‘kd_tree’将使用 KDTree。‘auto’将尝试根据传递给fit方法的值来决定最合适的算法。 (不同实现方式影响效率)
#KNN算法
def knn_iris():
#1.获取数据
iris=load_iris();
#2.划分数据集
x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=6)
#3.特征工程:标准化
transfer=StandardScaler()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
#4.KNN算法预估器
estimator=KNeighborsClassifier(n_neighbors=3)
estimator.fit(x_train,y_train)
#5.模型评估
y_predict=estimator.predict(x_test)
print(y_predict)
print(y_test==y_predict)
print(estimator.score(x_test,y_test))
pass
#KNN算法网格搜索与交叉验证
def knn_iris_gscv():
#1.获取数据
iris=load_iris();
#2.划分数据集
x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=6)
#3.特征工程:标准化
transfer=StandardScaler()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
#4.KNN算法预估器
estimator=KNeighborsClassifier()
#加如网格搜索与交叉验证
param_grid={"n_neighbors":[1,3,5,7,9,11]}
estimator=GridSearchCV(estimator,param_grid=param_grid,cv=10)
estimator.fit(x_train,y_train)
#5.模型评估
y_predict=estimator.predict(x_test)
print(y_predict)
print(y_test==y_predict)
print(estimator.score(x_test,y_test))
print("最佳参数:",estimator.best_params_)
print("最佳结果:",estimator.best_score_)
print("最佳估计器:",estimator.best_estimator_)
print("交叉验证结果:",estimator.cv_results_)
pass