• 使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码


    # -*- coding: utf-8 -*-
    import numpy as np
    from sklearn.feature_extraction import FeatureHasher
    from sklearn import datasets
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.neighbors import KNeighborsClassifier
    import xgboost as xgb
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import train_test_split
    from sklearn import metrics
    from matplotlib import pyplot as plt
    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.model_selection import GridSearchCV
    
    def report(test_Y, pred_Y):
        print("accuracy_score:")
        print(metrics.accuracy_score(test_Y, pred_Y))
        print("f1_score:")
        print(metrics.f1_score(test_Y, pred_Y))
        print("recall_score:")
        print(metrics.recall_score(test_Y, pred_Y))
        print("precision_score:")
        print(metrics.precision_score(test_Y, pred_Y))
        print("confusion_matrix:")
        print(metrics.confusion_matrix(test_Y, pred_Y))
        print("AUC:")
        print(metrics.roc_auc_score(test_Y, pred_Y))
    
        f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y)
        auc_area = metrics.auc(f_pos, t_pos)
        plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area)
        plt.legend(loc='lower right')
        plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
        plt.title('ROC')
        plt.ylabel('True Pos Rate')
        plt.xlabel('False Pos Rate')
        plt.show()
    
    
    
    if __name__== '__main__':
        x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1)
        train_X, test_X, train_Y, test_Y = train_test_split(x,
                                                            y,
                                                            test_size=0.2,
                                                            random_state=66)
        #clf = GradientBoostingClassifier(n_estimators=100)
        #clf.fit(train_X, train_Y)
        #pred_Y = clf.predict(test_X)
        #report(test_Y, pred_Y)
        scoring= "f1"
        parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)}
        gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5) 
        gsearch.fit(x, y)
        print("gsearch.best_params_") 
        print(gsearch.best_params_) 
        print("gsearch.best_score_") 
        print(gsearch.best_score_)
    

     效果:

    gsearch.best_params_
    {'max_depth': 4, 'n_estimators': 100}
    gsearch.best_score_
    0.868142228555714

  • 相关阅读:
    蓝桥杯---打印回型嵌套(简单递归)
    蓝桥杯---分酒
    蓝桥杯---简单试题集锦
    蓝桥杯---黑洞数
    2013蓝桥杯B组 预赛试题
    2012蓝桥杯预赛--取球博弈
    2012第三届蓝桥杯预赛题
    C中的动态开辟(malloc)
    文件的输入输出
    hdoj 1233 还是畅通工程
  • 原文地址:https://www.cnblogs.com/bonelee/p/9154171.html
Copyright © 2020-2023  润新知